robomotic/causality-two-rooms

LeWM checkpoints trained on the confounded Glitched Hue TwoRoom dataset for causal world-model experiments. The goal is to test whether the model learns the true teleport mechanism or the spurious background-hue correlation.

Model description

  • Architecture: LeWM / JEPA-style world model with an autoregressive predictor
  • Domain: swm/GlitchedHueTwoRoom-v1
  • Framework: PyTorch + Lightning
  • Repository: https://github.com/epokhcs/stable-worldmodel
  • Upload generated (UTC): 2026-04-10 08:16:57

Training data

  • Dataset path: /home/robomotic/.stable_worldmodel/glitched_hue_tworoom.h5
  • Dataset size: 9.09 GB
  • Episodes: 10,000
  • Frames: 609,539
  • Pixel tensor shape: (609539, 224, 224, 3)
  • Teleport events: 1,593

The dataset was collected with:

python scripts/data/collect_glitched_hue.py     num_traj=10000     seed=3072     world.num_envs=10

Training procedure

The checkpoints in this repo come from the 5-epoch LeWM training run used in the causality experiment. The run completed successfully after resuming from the last full trainer checkpoint.

Command family:

python scripts/train/lewm.py     data=glitched_hue_tworoom     trainer.max_epochs=5     num_workers=1     loader.num_workers=1     loader.persistent_workers=False

Key hyperparameters

Parameter Value
trainer.max_epochs 5
trainer.accelerator gpu
trainer.precision bf16
loader.batch_size 128
loader.num_workers 1 for the resumed run
optimizer.lr 5e-5
wm.history_size 3
wm.num_preds 1
wm.embed_dim 192
loss.sigreg.weight 0.09
data.dataset.frameskip 5
config.yaml included in the repo root

Epoch metrics (logged to W&B / Lightning)

The table below summarizes the epoch-end losses extracted from the local training logs. The raw values are also included as metrics/epoch_metrics.json and metrics/epoch_metrics.csv.

Epoch Global step fit/loss fit/pred_loss fit/sigreg_loss validate/loss validate/pred_loss validate/sigreg_loss
1 3046 2.676974 0.005099 29.750000 2.766938 0.009671 30.635967
2 6092 1.143889 0.436858 7.843750 1.173404 0.444629 8.096438
3 9138 0.673702 0.310420 4.031250 0.746696 0.313994 4.807232
4 12184 0.526915 0.163634 4.031250 0.601880 0.166302 4.839801
5 15230 0.496008 0.156165 3.765625 0.520261 0.127601 4.362802

Files in this repo

File Purpose Size
checkpoints/lewm_epoch_1_object.ckpt Serialized model object checkpoint 69.00 MB
checkpoints/lewm_epoch_2_object.ckpt Serialized model object checkpoint 69.00 MB
checkpoints/lewm_epoch_3_object.ckpt Serialized model object checkpoint 69.00 MB
checkpoints/lewm_epoch_4_object.ckpt Serialized model object checkpoint 69.00 MB
checkpoints/lewm_epoch_5_object.ckpt Serialized model object checkpoint 69.00 MB
checkpoints/lewm_weights.ckpt Full Lightning trainer checkpoint 206.75 MB
config.yaml Hydra config used for the run 1.20 KB
metrics/epoch_metrics.json Raw epoch metrics extracted from local logs small
metrics/epoch_metrics.csv Tabular epoch metrics for spreadsheets / plotting small

How to use

Load a serialized model-object checkpoint:

import torch

model = torch.load('checkpoints/lewm_epoch_5_object.ckpt', map_location='cpu')
model.eval()

Load the full Lightning trainer checkpoint:

import torch

checkpoint = torch.load('checkpoints/lewm_weights.ckpt', map_location='cpu')
print(checkpoint.keys())

Intended uses

  • Reproducing the causal disentanglement experiment in research/runme.md
  • Running the Step 3 causal AAP analysis with research/glitched_hue_experiment.py
  • Comparing epoch-wise world-model checkpoints during training

Limitations

  • These checkpoints are research artifacts, not production control policies.
  • Performance is specific to the Glitched Hue TwoRoom environment and the confounded blue/green data collection procedure.
  • The object checkpoints are convenient for inspection, while the full trainer checkpoint is the correct file for resuming optimization.
Downloads last month
14
Video Preview
loading