causality-two-rooms / README.md
robomotic's picture
Update README.md
c5b4752 verified
---
license: mit
library_name: pytorch
pipeline_tag: reinforcement-learning
tags:
- robotics
- reinforcement-learning
- world-model
- causal-representation-learning
- stable-worldmodel
---
# robomotic/causality-two-rooms
LeWM checkpoints trained on the confounded **Glitched Hue TwoRoom** dataset for
causal world-model experiments. The goal is to test whether the model learns the
true teleport mechanism or the spurious background-hue correlation.
## Model description
- **Architecture:** LeWM / JEPA-style world model with an autoregressive predictor
- **Domain:** `swm/GlitchedHueTwoRoom-v1`
- **Framework:** PyTorch + Lightning
- **Repository:** https://github.com/epokhcs/stable-worldmodel
- **Upload generated (UTC):** 2026-04-10 08:16:57
## Training data
- Dataset path: `/home/robomotic/.stable_worldmodel/glitched_hue_tworoom.h5`
- Dataset size: 9.09 GB
- Episodes: 10,000
- Frames: 609,539
- Pixel tensor shape: `(609539, 224, 224, 3)`
- Teleport events: 1,593
The dataset was collected with:
```bash
python scripts/data/collect_glitched_hue.py num_traj=10000 seed=3072 world.num_envs=10
```
## Training procedure
The checkpoints in this repo come from the 5-epoch LeWM training run used in the
causality experiment. The run completed successfully after resuming from the last
full trainer checkpoint.
Command family:
```bash
python scripts/train/lewm.py data=glitched_hue_tworoom trainer.max_epochs=5 num_workers=1 loader.num_workers=1 loader.persistent_workers=False
```
### Key hyperparameters
| Parameter | Value |
|---|---|
| `trainer.max_epochs` | `5` |
| `trainer.accelerator` | `gpu` |
| `trainer.precision` | `bf16` |
| `loader.batch_size` | `128` |
| `loader.num_workers` | `1` for the resumed run |
| `optimizer.lr` | `5e-5` |
| `wm.history_size` | `3` |
| `wm.num_preds` | `1` |
| `wm.embed_dim` | `192` |
| `loss.sigreg.weight` | `0.09` |
| `data.dataset.frameskip` | `5` |
| `config.yaml` | included in the repo root |
## Epoch metrics (logged to W&B / Lightning)
The table below summarizes the epoch-end losses extracted from the local training
logs. The raw values are also included as `metrics/epoch_metrics.json` and
`metrics/epoch_metrics.csv`.
| Epoch | Global step | fit/loss | fit/pred_loss | fit/sigreg_loss | validate/loss | validate/pred_loss | validate/sigreg_loss |
|---:|---:|---:|---:|---:|---:|---:|---:|
| 1 | 3046 | 2.676974 | 0.005099 | 29.750000 | 2.766938 | 0.009671 | 30.635967 |
| 2 | 6092 | 1.143889 | 0.436858 | 7.843750 | 1.173404 | 0.444629 | 8.096438 |
| 3 | 9138 | 0.673702 | 0.310420 | 4.031250 | 0.746696 | 0.313994 | 4.807232 |
| 4 | 12184 | 0.526915 | 0.163634 | 4.031250 | 0.601880 | 0.166302 | 4.839801 |
| 5 | 15230 | 0.496008 | 0.156165 | 3.765625 | 0.520261 | 0.127601 | 4.362802 |
## Files in this repo
| File | Purpose | Size |
|---|---|---:|
| `checkpoints/lewm_epoch_1_object.ckpt` | Serialized model object checkpoint | 69.00 MB |
| `checkpoints/lewm_epoch_2_object.ckpt` | Serialized model object checkpoint | 69.00 MB |
| `checkpoints/lewm_epoch_3_object.ckpt` | Serialized model object checkpoint | 69.00 MB |
| `checkpoints/lewm_epoch_4_object.ckpt` | Serialized model object checkpoint | 69.00 MB |
| `checkpoints/lewm_epoch_5_object.ckpt` | Serialized model object checkpoint | 69.00 MB |
| `checkpoints/lewm_weights.ckpt` | Full Lightning trainer checkpoint | 206.75 MB |
| `config.yaml` | Hydra config used for the run | 1.20 KB |
| `metrics/epoch_metrics.json` | Raw epoch metrics extracted from local logs | small |
| `metrics/epoch_metrics.csv` | Tabular epoch metrics for spreadsheets / plotting | small |
## How to use
Load a serialized model-object checkpoint:
```python
import torch
model = torch.load('checkpoints/lewm_epoch_5_object.ckpt', map_location='cpu')
model.eval()
```
Load the full Lightning trainer checkpoint:
```python
import torch
checkpoint = torch.load('checkpoints/lewm_weights.ckpt', map_location='cpu')
print(checkpoint.keys())
```
## Intended uses
- Reproducing the causal disentanglement experiment in `research/runme.md`
- Running the Step 3 causal AAP analysis with `research/glitched_hue_experiment.py`
- Comparing epoch-wise world-model checkpoints during training
## Limitations
- These checkpoints are research artifacts, not production control policies.
- Performance is specific to the Glitched Hue TwoRoom environment and the
confounded blue/green data collection procedure.
- The object checkpoints are convenient for inspection, while the full trainer
checkpoint is the correct file for resuming optimization.