| --- |
| license: mit |
| library_name: pytorch |
| pipeline_tag: reinforcement-learning |
| tags: |
| - robotics |
| - reinforcement-learning |
| - world-model |
| - causal-representation-learning |
| - stable-worldmodel |
| --- |
| |
| # robomotic/causality-two-rooms |
|
|
| LeWM checkpoints trained on the confounded **Glitched Hue TwoRoom** dataset for |
| causal world-model experiments. The goal is to test whether the model learns the |
| true teleport mechanism or the spurious background-hue correlation. |
|
|
| ## Model description |
|
|
| - **Architecture:** LeWM / JEPA-style world model with an autoregressive predictor |
| - **Domain:** `swm/GlitchedHueTwoRoom-v1` |
| - **Framework:** PyTorch + Lightning |
| - **Repository:** https://github.com/epokhcs/stable-worldmodel |
| - **Upload generated (UTC):** 2026-04-10 08:16:57 |
|
|
| ## Training data |
|
|
| - Dataset path: `/home/robomotic/.stable_worldmodel/glitched_hue_tworoom.h5` |
| - Dataset size: 9.09 GB |
| - Episodes: 10,000 |
| - Frames: 609,539 |
| - Pixel tensor shape: `(609539, 224, 224, 3)` |
| - Teleport events: 1,593 |
|
|
| The dataset was collected with: |
|
|
| ```bash |
| python scripts/data/collect_glitched_hue.py num_traj=10000 seed=3072 world.num_envs=10 |
| ``` |
|
|
| ## Training procedure |
|
|
| The checkpoints in this repo come from the 5-epoch LeWM training run used in the |
| causality experiment. The run completed successfully after resuming from the last |
| full trainer checkpoint. |
|
|
| Command family: |
|
|
| ```bash |
| python scripts/train/lewm.py data=glitched_hue_tworoom trainer.max_epochs=5 num_workers=1 loader.num_workers=1 loader.persistent_workers=False |
| ``` |
|
|
| ### Key hyperparameters |
|
|
| | Parameter | Value | |
| |---|---| |
| | `trainer.max_epochs` | `5` | |
| | `trainer.accelerator` | `gpu` | |
| | `trainer.precision` | `bf16` | |
| | `loader.batch_size` | `128` | |
| | `loader.num_workers` | `1` for the resumed run | |
| | `optimizer.lr` | `5e-5` | |
| | `wm.history_size` | `3` | |
| | `wm.num_preds` | `1` | |
| | `wm.embed_dim` | `192` | |
| | `loss.sigreg.weight` | `0.09` | |
| | `data.dataset.frameskip` | `5` | |
| | `config.yaml` | included in the repo root | |
|
|
| ## Epoch metrics (logged to W&B / Lightning) |
|
|
| The table below summarizes the epoch-end losses extracted from the local training |
| logs. The raw values are also included as `metrics/epoch_metrics.json` and |
| `metrics/epoch_metrics.csv`. |
|
|
| | Epoch | Global step | fit/loss | fit/pred_loss | fit/sigreg_loss | validate/loss | validate/pred_loss | validate/sigreg_loss | |
| |---:|---:|---:|---:|---:|---:|---:|---:| |
| | 1 | 3046 | 2.676974 | 0.005099 | 29.750000 | 2.766938 | 0.009671 | 30.635967 | |
| | 2 | 6092 | 1.143889 | 0.436858 | 7.843750 | 1.173404 | 0.444629 | 8.096438 | |
| | 3 | 9138 | 0.673702 | 0.310420 | 4.031250 | 0.746696 | 0.313994 | 4.807232 | |
| | 4 | 12184 | 0.526915 | 0.163634 | 4.031250 | 0.601880 | 0.166302 | 4.839801 | |
| | 5 | 15230 | 0.496008 | 0.156165 | 3.765625 | 0.520261 | 0.127601 | 4.362802 | |
|
|
| ## Files in this repo |
|
|
| | File | Purpose | Size | |
| |---|---|---:| |
| | `checkpoints/lewm_epoch_1_object.ckpt` | Serialized model object checkpoint | 69.00 MB | |
| | `checkpoints/lewm_epoch_2_object.ckpt` | Serialized model object checkpoint | 69.00 MB | |
| | `checkpoints/lewm_epoch_3_object.ckpt` | Serialized model object checkpoint | 69.00 MB | |
| | `checkpoints/lewm_epoch_4_object.ckpt` | Serialized model object checkpoint | 69.00 MB | |
| | `checkpoints/lewm_epoch_5_object.ckpt` | Serialized model object checkpoint | 69.00 MB | |
| | `checkpoints/lewm_weights.ckpt` | Full Lightning trainer checkpoint | 206.75 MB | |
| | `config.yaml` | Hydra config used for the run | 1.20 KB | |
| | `metrics/epoch_metrics.json` | Raw epoch metrics extracted from local logs | small | |
| | `metrics/epoch_metrics.csv` | Tabular epoch metrics for spreadsheets / plotting | small | |
|
|
| ## How to use |
|
|
| Load a serialized model-object checkpoint: |
|
|
| ```python |
| import torch |
| |
| model = torch.load('checkpoints/lewm_epoch_5_object.ckpt', map_location='cpu') |
| model.eval() |
| ``` |
|
|
| Load the full Lightning trainer checkpoint: |
|
|
| ```python |
| import torch |
| |
| checkpoint = torch.load('checkpoints/lewm_weights.ckpt', map_location='cpu') |
| print(checkpoint.keys()) |
| ``` |
|
|
| ## Intended uses |
|
|
| - Reproducing the causal disentanglement experiment in `research/runme.md` |
| - Running the Step 3 causal AAP analysis with `research/glitched_hue_experiment.py` |
| - Comparing epoch-wise world-model checkpoints during training |
|
|
| ## Limitations |
|
|
| - These checkpoints are research artifacts, not production control policies. |
| - Performance is specific to the Glitched Hue TwoRoom environment and the |
| confounded blue/green data collection procedure. |
| - The object checkpoints are convenient for inspection, while the full trainer |
| checkpoint is the correct file for resuming optimization. |
|
|