robomotic/causality-two-rooms
LeWM checkpoints trained on the confounded Glitched Hue TwoRoom dataset for causal world-model experiments. The goal is to test whether the model learns the true teleport mechanism or the spurious background-hue correlation.
Model description
- Architecture: LeWM / JEPA-style world model with an autoregressive predictor
- Domain:
swm/GlitchedHueTwoRoom-v1 - Framework: PyTorch + Lightning
- Repository: https://github.com/epokhcs/stable-worldmodel
- Upload generated (UTC): 2026-04-10 08:16:57
Training data
- Dataset path:
/home/robomotic/.stable_worldmodel/glitched_hue_tworoom.h5 - Dataset size: 9.09 GB
- Episodes: 10,000
- Frames: 609,539
- Pixel tensor shape:
(609539, 224, 224, 3) - Teleport events: 1,593
The dataset was collected with:
python scripts/data/collect_glitched_hue.py num_traj=10000 seed=3072 world.num_envs=10
Training procedure
The checkpoints in this repo come from the 5-epoch LeWM training run used in the causality experiment. The run completed successfully after resuming from the last full trainer checkpoint.
Command family:
python scripts/train/lewm.py data=glitched_hue_tworoom trainer.max_epochs=5 num_workers=1 loader.num_workers=1 loader.persistent_workers=False
Key hyperparameters
| Parameter | Value |
|---|---|
trainer.max_epochs |
5 |
trainer.accelerator |
gpu |
trainer.precision |
bf16 |
loader.batch_size |
128 |
loader.num_workers |
1 for the resumed run |
optimizer.lr |
5e-5 |
wm.history_size |
3 |
wm.num_preds |
1 |
wm.embed_dim |
192 |
loss.sigreg.weight |
0.09 |
data.dataset.frameskip |
5 |
config.yaml |
included in the repo root |
Epoch metrics (logged to W&B / Lightning)
The table below summarizes the epoch-end losses extracted from the local training
logs. The raw values are also included as metrics/epoch_metrics.json and
metrics/epoch_metrics.csv.
| Epoch | Global step | fit/loss | fit/pred_loss | fit/sigreg_loss | validate/loss | validate/pred_loss | validate/sigreg_loss |
|---|---|---|---|---|---|---|---|
| 1 | 3046 | 2.676974 | 0.005099 | 29.750000 | 2.766938 | 0.009671 | 30.635967 |
| 2 | 6092 | 1.143889 | 0.436858 | 7.843750 | 1.173404 | 0.444629 | 8.096438 |
| 3 | 9138 | 0.673702 | 0.310420 | 4.031250 | 0.746696 | 0.313994 | 4.807232 |
| 4 | 12184 | 0.526915 | 0.163634 | 4.031250 | 0.601880 | 0.166302 | 4.839801 |
| 5 | 15230 | 0.496008 | 0.156165 | 3.765625 | 0.520261 | 0.127601 | 4.362802 |
Files in this repo
| File | Purpose | Size |
|---|---|---|
checkpoints/lewm_epoch_1_object.ckpt |
Serialized model object checkpoint | 69.00 MB |
checkpoints/lewm_epoch_2_object.ckpt |
Serialized model object checkpoint | 69.00 MB |
checkpoints/lewm_epoch_3_object.ckpt |
Serialized model object checkpoint | 69.00 MB |
checkpoints/lewm_epoch_4_object.ckpt |
Serialized model object checkpoint | 69.00 MB |
checkpoints/lewm_epoch_5_object.ckpt |
Serialized model object checkpoint | 69.00 MB |
checkpoints/lewm_weights.ckpt |
Full Lightning trainer checkpoint | 206.75 MB |
config.yaml |
Hydra config used for the run | 1.20 KB |
metrics/epoch_metrics.json |
Raw epoch metrics extracted from local logs | small |
metrics/epoch_metrics.csv |
Tabular epoch metrics for spreadsheets / plotting | small |
How to use
Load a serialized model-object checkpoint:
import torch
model = torch.load('checkpoints/lewm_epoch_5_object.ckpt', map_location='cpu')
model.eval()
Load the full Lightning trainer checkpoint:
import torch
checkpoint = torch.load('checkpoints/lewm_weights.ckpt', map_location='cpu')
print(checkpoint.keys())
Intended uses
- Reproducing the causal disentanglement experiment in
research/runme.md - Running the Step 3 causal AAP analysis with
research/glitched_hue_experiment.py - Comparing epoch-wise world-model checkpoints during training
Limitations
- These checkpoints are research artifacts, not production control policies.
- Performance is specific to the Glitched Hue TwoRoom environment and the confounded blue/green data collection procedure.
- The object checkpoints are convenient for inspection, while the full trainer checkpoint is the correct file for resuming optimization.
- Downloads last month
- 14