Files changed (45) hide show
  1. .gitignore +16 -0
  2. LICENSE +21 -0
  3. README.md +87 -3
  4. configs/fourrooms_ppo.yaml +31 -0
  5. configs/fourrooms_world_model.yaml +68 -0
  6. docs/spec_clarifications.md +26 -0
  7. minidreamer_project_spec.md +911 -0
  8. notebooks/results_analysis.ipynb +25 -0
  9. notebooks/rollout_debug.ipynb +26 -0
  10. plots/.gitkeep +1 -0
  11. plots/learning_curves.png +0 -0
  12. plots/model_error_vs_rollout_horizon.png +0 -0
  13. plots/success_rate_vs_env_steps.png +0 -0
  14. pyproject.toml +46 -0
  15. results.md +127 -0
  16. scripts/collect_random.sh +9 -0
  17. scripts/eval_planner.sh +26 -0
  18. scripts/generate_results_plots.py +159 -0
  19. scripts/train_ppo.sh +9 -0
  20. scripts/train_world_model.sh +10 -0
  21. src/evaluate.py +69 -0
  22. src/minidreamer/__init__.py +6 -0
  23. src/minidreamer/baselines/__init__.py +2 -0
  24. src/minidreamer/baselines/train_ppo.py +127 -0
  25. src/minidreamer/config.py +59 -0
  26. src/minidreamer/envs/__init__.py +2 -0
  27. src/minidreamer/envs/make_env.py +125 -0
  28. src/minidreamer/evaluation.py +147 -0
  29. src/minidreamer/models/__init__.py +2 -0
  30. src/minidreamer/models/decoder.py +27 -0
  31. src/minidreamer/models/encoder.py +29 -0
  32. src/minidreamer/models/heads.py +29 -0
  33. src/minidreamer/models/rssm.py +147 -0
  34. src/minidreamer/models/world_model.py +267 -0
  35. src/minidreamer/planning/__init__.py +2 -0
  36. src/minidreamer/planning/cem.py +103 -0
  37. src/minidreamer/planning/evaluate_planner.py +99 -0
  38. src/minidreamer/serialization.py +40 -0
  39. src/minidreamer/utils/__init__.py +2 -0
  40. src/minidreamer/utils/common.py +58 -0
  41. src/train_world_model.py +334 -0
  42. tests/test_cem_planner.py +31 -0
  43. tests/test_env.py +19 -0
  44. tests/test_replay_buffer.py +50 -0
  45. tests/test_rssm_shapes.py +43 -0
.gitignore ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .pytest_cache/
2
+ .ruff_cache/
3
+ .venv/
4
+ __pycache__/
5
+ *.pyc
6
+ *.pyo
7
+ *.pyd
8
+ *.so
9
+ *.egg-info/
10
+ .DS_Store
11
+
12
+ artifacts/
13
+ checkpoints/
14
+ metrics/
15
+ data/
16
+ logs/
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2026 alpatrykos
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,3 +1,87 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MiniDreamer
2
+
3
+ MiniDreamer is a PlaNet-style world model project for `MiniGrid-FourRooms-v0`. It learns a recurrent latent dynamics model from partial RGB observations, predicts reward and episode termination, and uses discrete CEM planning in latent space.
4
+
5
+ The repository contains:
6
+
7
+ - MiniGrid RGB environment wrappers and bootstrap trajectory collection
8
+ - Episode-aware replay buffer with reproducible train/val/test splits
9
+ - CNN encoder, Gaussian RSSM, reward/done heads, optional decoder
10
+ - Discrete CEM planner with termination-aware return scoring
11
+ - PPO baseline entrypoint with a MiniGrid-compatible CNN feature extractor
12
+ - Evaluation code, configs, scripts, tests, and project documentation
13
+
14
+ A complete baseline training run has been executed. A summary is recorded in [results.md](/Users/patryktargosinski/minidreamer/results.md), while the frozen baseline artifacts remain gitignored under `artifacts/world_model/`.
15
+
16
+ ## Layout
17
+
18
+ ```text
19
+ configs/
20
+ docs/
21
+ notebooks/
22
+ scripts/
23
+ src/
24
+ tests/
25
+ ```
26
+
27
+ Core code lives under `src/minidreamer/`, with CLI entrypoints at `src/train_world_model.py` and `src/evaluate.py`.
28
+
29
+ ## Setup
30
+
31
+ Use Python 3.11 or 3.12. The project metadata is defined in [pyproject.toml](/Users/patryktargosinski/minidreamer/pyproject.toml).
32
+
33
+ ```bash
34
+ python3.11 -m venv .venv
35
+ source .venv/bin/activate
36
+ pip install -e ".[dev]"
37
+ ```
38
+
39
+ ## Main Commands
40
+
41
+ Bootstrap replay collection:
42
+
43
+ ```bash
44
+ ./scripts/collect_random.sh
45
+ ```
46
+
47
+ World-model pipeline:
48
+
49
+ ```bash
50
+ ./scripts/train_world_model.sh
51
+ ```
52
+
53
+ By default, the script writes new experiments to `artifacts/world_model_experiment/`. To choose a different experiment directory without touching the frozen baseline, set `MINIDREAMER_OUTPUT_DIR`:
54
+
55
+ ```bash
56
+ MINIDREAMER_OUTPUT_DIR=artifacts/world_model_restricted_actions ./scripts/train_world_model.sh
57
+ ```
58
+
59
+ Resume an interrupted world-model run from a checkpoint:
60
+
61
+ ```bash
62
+ python3.11 src/train_world_model.py \
63
+ --config configs/fourrooms_world_model.yaml \
64
+ --output-dir artifacts/world_model \
65
+ --replay-dir artifacts/world_model/replay \
66
+ --resume-checkpoint artifacts/world_model/checkpoints/world_model_env_steps_90021.pt
67
+ ```
68
+
69
+ Planner evaluation from a checkpoint:
70
+
71
+ ```bash
72
+ ./scripts/eval_planner.sh /path/to/checkpoint.pt /path/to/replay
73
+ ```
74
+
75
+ PPO baseline:
76
+
77
+ ```bash
78
+ ./scripts/train_ppo.sh
79
+ ```
80
+
81
+ ## Notes
82
+
83
+ - The latest completed run summary is in [results.md](/Users/patryktargosinski/minidreamer/results.md).
84
+ - The baseline run in `artifacts/world_model/` is intentionally frozen as the reference artifact.
85
+ - New world-model experiments should write to separate directories under `artifacts/`.
86
+ - The trainer refuses to overwrite an existing run directory unless you resume with `--resume-checkpoint` or explicitly pass `--allow-overwrite-existing-output`.
87
+ - Metrics, replay snapshots, and checkpoints are intentionally gitignored.
configs/fourrooms_ppo.yaml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ project:
2
+ name: minidreamer-fourrooms-ppo
3
+ seed: 0
4
+
5
+ env:
6
+ id: MiniGrid-FourRooms-v0
7
+ rgb_partial_obs: true
8
+ image_only: true
9
+ resize: [64, 64]
10
+ normalize_obs: true
11
+
12
+ ppo:
13
+ total_timesteps: 100000
14
+ num_envs: 4
15
+ learning_rate: 0.0003
16
+ n_steps: 256
17
+ batch_size: 256
18
+ n_epochs: 4
19
+ gamma: 0.99
20
+ gae_lambda: 0.95
21
+ clip_range: 0.2
22
+ ent_coef: 0.01
23
+ vf_coef: 0.5
24
+ features_dim: 256
25
+ device: auto
26
+
27
+ evaluation:
28
+ episodes: 100
29
+ seeds: [0, 1, 2]
30
+ eval_every_env_steps: 10000
31
+
configs/fourrooms_world_model.yaml ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ project:
2
+ name: minidreamer-fourrooms
3
+ seed: 0
4
+
5
+ env:
6
+ id: MiniGrid-FourRooms-v0
7
+ rgb_partial_obs: true
8
+ image_only: true
9
+ resize: [64, 64]
10
+ normalize_obs: true
11
+ action_space: full
12
+
13
+ replay:
14
+ capacity_episodes: 5000
15
+ sequence_length: 32
16
+ batch_size: 32
17
+ train_fraction: 0.8
18
+ val_fraction: 0.1
19
+ test_fraction: 0.1
20
+ split_key: episode_id
21
+
22
+ collection:
23
+ bootstrap_env_steps: 5000
24
+ bootstrap_success_threshold: 20
25
+ bootstrap_fallback_policy: restricted_random_3_actions
26
+ bootstrap_env_step_cap: 20000
27
+ collect_steps_per_iteration: 1000
28
+ gradient_updates_per_iteration: 1000
29
+ train_collect_ratio: 1.0
30
+ random_action_fraction_after_planner: 0.3
31
+
32
+ model:
33
+ embedding_dim: 256
34
+ deter_dim: 256
35
+ stoch_dim: 32
36
+ hidden_dim: 256
37
+ use_decoder: true
38
+ min_std: 0.1
39
+
40
+ training:
41
+ optimizer: adam
42
+ lr: 0.0003
43
+ grad_clip_norm: 100.0
44
+ train_steps: 100000
45
+ beta_kl: 1.0
46
+ free_nats: 1.0
47
+ beta_recon: 1.0
48
+ beta_reward: 1.0
49
+ beta_done: 1.0
50
+ device: null
51
+
52
+ planner:
53
+ type: discrete_cem
54
+ horizon: 8
55
+ candidates: 256
56
+ elites: 32
57
+ iterations: 4
58
+ discount: 0.99
59
+ use_done_mask: true
60
+
61
+ evaluation:
62
+ episodes: 100
63
+ seeds: [0, 1, 2]
64
+ eval_every_env_steps: 10000
65
+
66
+ comparison:
67
+ env_steps: [10000, 25000, 50000, 100000]
68
+
docs/spec_clarifications.md ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Spec Clarifications
2
+
3
+ This document records the implementation choices the spec left open enough that code needed an explicit default.
4
+
5
+ ## Implemented choices
6
+
7
+ 1. The Python package lives under `src/minidreamer/`, with thin CLI entrypoints at `src/train_world_model.py` and `src/evaluate.py`. That keeps imports stable while still matching the spec's requested top-level scripts.
8
+ 2. The encoder and decoder use `padding=1` in all `4x4, stride=2` convolutions so `64x64` inputs shrink cleanly to `4x4` and decode symmetrically back to `64x64`.
9
+ 3. CEM planning uses the prior mean during latent imagination instead of sampling the stochastic latent. That reduces planner variance and makes candidate ranking deterministic given the current model parameters.
10
+ 4. Replay sampling pads only at the tail of an in-episode chunk and applies a transition mask so padded steps do not contribute to loss terms.
11
+ 5. Bootstrap and online training both default to `train_collect_ratio = 1.0`, so the initial bootstrap replay produces one gradient update per collected environment step unless an explicit `gradient_updates_per_iteration` override is set.
12
+ 6. Evaluation computes one-step held-out metrics over full episodes and open-loop rollout metrics for horizons `1`, `5`, and `10` using actual held-out action sequences.
13
+
14
+ ## Remaining optional extensions
15
+
16
+ These are intentionally not implemented in v1 because the spec marked them as later improvements or ablations:
17
+
18
+ - KL balancing beyond free nats
19
+ - uncertainty penalties or ensembles in the planner
20
+ - actor-critic imagination learning
21
+ - richer sparse-reward heads beyond scalar reward regression
22
+
23
+ ## Budget semantics
24
+
25
+ Collection currently finishes complete episodes rather than cutting trajectories mid-episode. That keeps replay episodes semantically clean for recurrent training, but it means a run can land slightly above a requested step target when the last episode crosses the threshold. If exact step-matched checkpoints become mandatory for reporting, the next refinement should snapshot metrics at the first checkpoint at or above each budget and report the realized environment step count alongside the nominal target.
26
+
minidreamer_project_spec.md ADDED
@@ -0,0 +1,911 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MiniDreamer: A PlaNet-style World Model for Pixel-based MiniGrid Planning
2
+
3
+ ## 1. Project name
4
+
5
+ **MiniDreamer: A PlaNet-style world model for pixel-based MiniGrid planning**
6
+
7
+ ## 2. Goal
8
+
9
+ Build a small research-grade world model agent that learns a latent dynamics model from pixel observations in `MiniGrid-FourRooms-v0`, then uses the learned model for short-horizon planning.
10
+
11
+ The project should answer:
12
+
13
+ > Can a compact latent dynamics model trained from partial RGB observations support useful planning in a sparse-reward gridworld?
14
+
15
+ MiniGrid is a good first target because it is designed as a simple, fast, customizable RL benchmark suite, with discrete actions and goal-oriented environments. FourRooms specifically asks an agent to navigate a four-room maze to reach a green goal; the registered environment is `MiniGrid-FourRooms-v0`.
16
+
17
+ Reference: [MiniGrid documentation](https://minigrid.farama.org/index.html)
18
+
19
+ ---
20
+
21
+ ## 3. Scope
22
+
23
+ ### In scope
24
+
25
+ - Pixel-based observations using:
26
+ - `RGBImgPartialObsWrapper`
27
+ - `ImgObsWrapper`
28
+ - Environment:
29
+ - MVP: `MiniGrid-FourRooms-v0`
30
+ - Extension: `MiniGrid-DoorKey-*`, `MiniGrid-LockedRoom-v0`, or `MiniGrid-Dynamic-Obstacles-*`
31
+ - Learned latent dynamics model
32
+ - Reward and termination prediction
33
+ - MPC/CEM-style planning in latent space
34
+ - PPO baseline for comparison
35
+ - Ablations:
36
+ - with image reconstruction vs. without
37
+ - latent size
38
+ - planning horizon
39
+ - stochastic vs. deterministic latent state
40
+
41
+ ### Out of scope for v1
42
+
43
+ - Full DreamerV3 reproduction
44
+ - Actor-critic learning inside imagination
45
+ - Minecraft/Crafter-scale environments
46
+ - Large video-generation-style world models
47
+ - Language-conditioned BabyAI tasks
48
+
49
+ ---
50
+
51
+ ## 4. Technical background
52
+
53
+ This should be a **PlaNet-lite** implementation. PlaNet learns a latent dynamics model from images and chooses actions through online planning in latent space, using both deterministic and stochastic transition components.
54
+
55
+ Reference: [PlaNet: Learning Latent Dynamics for Planning from Pixels](https://arxiv.org/abs/1811.04551)
56
+
57
+ DreamerV3 is the more mature descendant: it learns a world model and improves behavior by imagining future scenarios, but its full implementation is too much for a first project. Use DreamerV3 as conceptual inspiration, not as the implementation target.
58
+
59
+ Reference: [DreamerV3](https://arxiv.org/abs/2301.04104)
60
+
61
+ MiniGrid’s default observation is a compact symbolic encoding, not raw pixels. For this project, use `RGBImgPartialObsWrapper` to obtain RGB pixel observations, then `ImgObsWrapper` to keep only the image tensor.
62
+
63
+ Reference: [MiniGrid wrappers](https://minigrid.farama.org/api/wrapper/)
64
+
65
+ ---
66
+
67
+ ## 5. Core hypothesis
68
+
69
+ ### Main hypothesis
70
+
71
+ A learned recurrent latent dynamics model can support better sample efficiency than a model-free baseline in early training, even if final performance is lower or less stable.
72
+
73
+ ### Secondary hypotheses
74
+
75
+ 1. **Reconstruction loss may help representation learning early**, but may hurt planning if it forces the latent state to model irrelevant pixels.
76
+ 2. **Short planning horizons will work better than long ones**, because model error compounds over imagined rollouts.
77
+ 3. **Stochastic latent states should outperform deterministic-only states** under partial observability.
78
+ 4. **Reward prediction quality matters more than pixel reconstruction quality** for planning performance.
79
+
80
+ ---
81
+
82
+ ## 6. Environment specification
83
+
84
+ ### MVP environment
85
+
86
+ ```python
87
+ env_id = "MiniGrid-FourRooms-v0"
88
+ ```
89
+
90
+ FourRooms has a discrete action space of size 7. The meaningful actions for this task are mostly `left`, `right`, and `forward`; `pickup`, `drop`, `toggle`, and `done` are listed but unused in FourRooms.
91
+
92
+ Reference: [MiniGrid FourRooms environment](https://minigrid.farama.org/environments/minigrid/FourRoomsEnv/)
93
+
94
+ ### Wrappers
95
+
96
+ ```python
97
+ import gymnasium as gym
98
+ from minigrid.wrappers import RGBImgPartialObsWrapper, ImgObsWrapper
99
+
100
+ env = gym.make("MiniGrid-FourRooms-v0")
101
+ env = RGBImgPartialObsWrapper(env)
102
+ env = ImgObsWrapper(env)
103
+ ```
104
+
105
+ ### Observation
106
+
107
+ Use RGB partial observation.
108
+
109
+ Normalize to:
110
+
111
+ ```python
112
+ obs = obs.astype("float32") / 255.0
113
+ ```
114
+
115
+ Recommended internal size:
116
+
117
+ ```text
118
+ 64 x 64 x 3
119
+ ```
120
+
121
+ Resize if needed.
122
+
123
+ ### Action space
124
+
125
+ Use the full 7-action space for baseline compatibility.
126
+
127
+ Optional ablation:
128
+
129
+ ```text
130
+ restricted_action_space = {left, right, forward}
131
+ ```
132
+
133
+ If bootstrap data collection yields too few successful episodes, it is acceptable to use the restricted 3-action random policy during bootstrap only. If that fallback is used, record it explicitly in the report and still count all bootstrap environment steps toward the world-model sample-efficiency budget.
134
+
135
+ ### Reward
136
+
137
+ FourRooms gives a success reward based on step count and zero reward for failure; the documented reward is:
138
+
139
+ ```text
140
+ 1 - 0.9 * step_count / max_steps
141
+ ```
142
+
143
+ For model training, predict:
144
+
145
+ ```text
146
+ reward_t ∈ R
147
+ done_t ∈ {0, 1}
148
+ ```
149
+
150
+ ### Episode end semantics
151
+
152
+ Use the Gymnasium step API explicitly:
153
+
154
+ ```python
155
+ obs_next, reward, terminated, truncated, info = env.step(action)
156
+ ```
157
+
158
+ Store all three episode-end signals:
159
+
160
+ ```text
161
+ terminated_t = goal reached / environment terminal condition
162
+ truncated_t = episode stopped by time limit
163
+ done_t = terminated_t or truncated_t
164
+ ```
165
+
166
+ Training and planning semantics:
167
+
168
+ - `reward_t` is the scalar emitted after taking `action_t`
169
+ - `done_t` is the episode-end target used by the world model and planner
170
+ - `success_t = 1` iff `terminated_t == 1` and `reward_t > 0`
171
+
172
+ Do not collapse `terminated` and `truncated` in logged metrics; report them separately when debugging failures.
173
+
174
+ ---
175
+
176
+ ## 7. Model architecture
177
+
178
+ ### 7.1 Overview
179
+
180
+ ```text
181
+ obs_t ──► CNN encoder ──► embedding e_t
182
+
183
+ action_{t-1} ────────────┤
184
+
185
+ RSSM latent model
186
+
187
+ ┌──────────┼──────────┐
188
+ ▼ ▼ ▼
189
+ reward head done head decoder? optional
190
+ ```
191
+
192
+ ### 7.2 Latent state
193
+
194
+ Use a simplified **RSSM-style** state:
195
+
196
+ ```text
197
+ h_t = deterministic recurrent state
198
+ z_t = stochastic latent state
199
+ s_t = concat(h_t, z_t)
200
+ ```
201
+
202
+ Recommended dimensions:
203
+
204
+ ```yaml
205
+ embedding_dim: 256
206
+ deterministic_state_dim: 256
207
+ stochastic_state_dim: 32
208
+ num_stochastic_classes: null # use Gaussian latent for v1
209
+ ```
210
+
211
+ ### 7.3 Encoder
212
+
213
+ CNN encoder:
214
+
215
+ ```text
216
+ Input: 64x64x3 RGB
217
+ Conv 4x4 stride 2, channels 32
218
+ Conv 4x4 stride 2, channels 64
219
+ Conv 4x4 stride 2, channels 128
220
+ Conv 4x4 stride 2, channels 256
221
+ Flatten
222
+ Linear -> embedding_dim
223
+ ```
224
+
225
+ Use `LayerNorm` or `BatchNorm` only if training is unstable. Start simple.
226
+
227
+ ### 7.4 Recurrent dynamics
228
+
229
+ Prior:
230
+
231
+ ```text
232
+ p(z_t | h_t)
233
+ ```
234
+
235
+ Posterior:
236
+
237
+ ```text
238
+ q(z_t | h_t, e_t)
239
+ ```
240
+
241
+ Transition:
242
+
243
+ ```text
244
+ h_t = GRU(h_{t-1}, concat(z_{t-1}, one_hot(action_{t-1})))
245
+ ```
246
+
247
+ Gaussian latent:
248
+
249
+ ```text
250
+ prior_mean, prior_std = prior_net(h_t)
251
+ post_mean, post_std = posterior_net(h_t, e_t)
252
+ z_t ~ Normal(post_mean, post_std)
253
+ ```
254
+
255
+ Use reparameterization.
256
+
257
+ Clamp std:
258
+
259
+ ```python
260
+ std = softplus(raw_std) + 0.1
261
+ ```
262
+
263
+ ### 7.4.1 Sequence indexing and initialization
264
+
265
+ Use transition tuples with explicit alignment:
266
+
267
+ ```text
268
+ (obs_t, action_t, reward_t, terminated_t, truncated_t, done_t, obs_{t+1})
269
+ ```
270
+
271
+ Training-time convention:
272
+
273
+ 1. Infer posterior `s_t` from `obs_t` and the previous recurrent state.
274
+ 2. Apply `action_t` in the transition model to form the prior for step `t+1`.
275
+ 3. Condition on `obs_{t+1}` to infer posterior `s_{t+1}`.
276
+ 4. Predict `reward_t`, `done_t`, and optional reconstruction of `obs_{t+1}` from `s_{t+1}`.
277
+
278
+ Sequence-start convention:
279
+
280
+ - initialize `h_0` and `z_0` to zeros
281
+ - use a zero action embedding in place of `action_{-1}`
282
+ - never sample a training chunk that crosses an episode boundary
283
+ - if padding is needed for batching, pad only within an episode and apply a loss mask so padded steps do not contribute to any loss
284
+
285
+ ### 7.5 Prediction heads
286
+
287
+ Reward head:
288
+
289
+ ```text
290
+ MLP(s_{t+1}) -> scalar reward_t
291
+ ```
292
+
293
+ Done head:
294
+
295
+ ```text
296
+ MLP(s_{t+1}) -> Bernoulli logit for done_t
297
+ ```
298
+
299
+ Optional decoder:
300
+
301
+ ```text
302
+ MLP/ConvTranspose(s_{t+1}) -> reconstructed RGB observation obs_{t+1}
303
+ ```
304
+
305
+ For v1, include the decoder behind a config flag:
306
+
307
+ ```yaml
308
+ use_decoder: true | false
309
+ ```
310
+
311
+ ---
312
+
313
+ ## 8. Loss function
314
+
315
+ Total loss:
316
+
317
+ ```text
318
+ L = β_reward * L_reward + β_done * L_done + β_kl * max(L_kl, free_nats) + β_recon * L_recon
319
+ ```
320
+
321
+ ### Reward loss
322
+
323
+ ```text
324
+ L_reward = MSE(pred_reward_t, reward_t)
325
+ ```
326
+
327
+ Alternative for sparse rewards:
328
+
329
+ ```text
330
+ two_head_reward:
331
+ reward_occurrence: BCE
332
+ reward_value_given_success: MSE
333
+ ```
334
+
335
+ Do not start there unless MSE fails.
336
+
337
+ ### Done loss
338
+
339
+ ```text
340
+ L_done = BCEWithLogits(done_logit_t, done_t)
341
+ ```
342
+
343
+ ### KL loss
344
+
345
+ ```text
346
+ L_kl = KL(q(z_t | h_t, e_t) || p(z_t | h_t))
347
+ ```
348
+
349
+ Use free-nats from the start in v1. KL balancing is a later stabilization improvement.
350
+
351
+ Initial config:
352
+
353
+ ```yaml
354
+ beta_reward: 1.0
355
+ beta_done: 1.0
356
+ beta_kl: 1.0
357
+ free_nats: 1.0
358
+ ```
359
+
360
+ ### Reconstruction loss
361
+
362
+ If decoder is enabled:
363
+
364
+ ```text
365
+ L_recon = MSE(reconstructed_obs_{t+1}, obs_{t+1})
366
+ ```
367
+
368
+ If decoder is disabled, set:
369
+
370
+ ```yaml
371
+ beta_recon: 0.0
372
+ ```
373
+
374
+ Initial config:
375
+
376
+ ```yaml
377
+ beta_recon: 1.0
378
+ ```
379
+
380
+ Ablate with:
381
+
382
+ ```yaml
383
+ beta_recon: 0.0
384
+ ```
385
+
386
+ ---
387
+
388
+ ## 9. Data collection
389
+
390
+ ### Phase 1: bootstrap dataset
391
+
392
+ Collect:
393
+
394
+ ```yaml
395
+ bootstrap_env_steps: 5000
396
+ policy: random_full_action_space
397
+ fallback_if_successes_too_low: random_restricted_action_space
398
+ min_success_episodes_before_planning: 20
399
+ bootstrap_env_step_cap: 20000
400
+ ```
401
+
402
+ Use environment-step budgets, not fixed episode counts, so the early comparison points against PPO remain meaningful.
403
+
404
+ Bootstrap rule:
405
+
406
+ - start with full-action random data collection
407
+ - if the replay buffer contains fewer than `20` successful episodes after `5000` steps, continue collecting with restricted-action random `{left, right, forward}`
408
+ - do not make sample-efficiency claims for planner performance before the replay buffer contains at least `20` successful episodes
409
+
410
+ Store transitions:
411
+
412
+ ```text
413
+ obs_t
414
+ action_t
415
+ reward_t
416
+ terminated_t
417
+ truncated_t
418
+ done_t
419
+ obs_{t+1}
420
+ episode_id
421
+ step_id
422
+ ```
423
+
424
+ ### Phase 2: mixed dataset
425
+
426
+ After bootstrap, alternate fixed collection/training rounds:
427
+
428
+ ```yaml
429
+ collect_steps_per_iteration: 1000
430
+ gradient_updates_per_iteration: 1000
431
+ ```
432
+
433
+ Default collection mix:
434
+
435
+ ```text
436
+ 70% planner policy
437
+ 30% random exploration
438
+ ```
439
+
440
+ This prevents early model errors from collapsing exploration. `train_collect_ratio = 1.0` means one gradient update per newly collected environment step.
441
+
442
+ ### Replay format and splits
443
+
444
+ Use chunked trajectory sampling:
445
+
446
+ ```yaml
447
+ sequence_length: 32
448
+ batch_size: 32
449
+ ```
450
+
451
+ Sample contiguous sequences, not independent transitions, because the recurrent model needs temporal structure.
452
+
453
+ Maintain episode-level dataset splits:
454
+
455
+ ```yaml
456
+ train_fraction: 0.8
457
+ val_fraction: 0.1
458
+ test_fraction: 0.1
459
+ split_key: episode_id
460
+ ```
461
+
462
+ Assign splits by episode id, not by transition, and keep the split fixed as more episodes are collected. Use train episodes for optimization, val episodes for model selection and debugging, and held-out test episodes only for final world-model reporting.
463
+
464
+ ---
465
+
466
+ ## 10. Planner
467
+
468
+ ### MVP planner: CEM over discrete actions
469
+
470
+ At each real environment step:
471
+
472
+ 1. Encode current observation.
473
+ 2. Update posterior latent state.
474
+ 3. Sample candidate action sequences.
475
+ 4. Roll them forward using the learned prior model.
476
+ 5. Score imagined returns.
477
+ 6. Execute the first action from the best sequence.
478
+ 7. Replan at the next step.
479
+
480
+ ### Planning config
481
+
482
+ ```yaml
483
+ planning_horizon: 8
484
+ num_candidates: 256
485
+ num_elites: 32
486
+ cem_iterations: 4
487
+ discount: 0.99
488
+ ```
489
+
490
+ Because actions are discrete, maintain categorical probabilities over actions per timestep:
491
+
492
+ ```text
493
+ π_h ∈ R[horizon, num_actions]
494
+ ```
495
+
496
+ CEM loop:
497
+
498
+ ```text
499
+ sample action sequences from π
500
+ roll out latent dynamics
501
+ score by termination-aware discounted reward
502
+ select top-k elites
503
+ update π toward elite action frequencies
504
+ ```
505
+
506
+ ### Score function
507
+
508
+ ```text
509
+ done_prob_t = sigmoid(pred_done_logit_t)
510
+ alive_0 = 1
511
+ alive_{t+1} = alive_t * (1 - done_prob_t)
512
+ score = Σ_t γ^t * alive_t * predicted_reward_t
513
+ ```
514
+
515
+ This makes the planner termination-aware and prevents CEM from exploiting trajectories that unrealistically continue after predicted episode end.
516
+
517
+ Optional later refinement:
518
+
519
+ ```text
520
+ score += λ * predicted_success_probability
521
+ score -= μ * predicted_done_probability_before_reward
522
+ ```
523
+
524
+ Keep v1 simple: use only discounted reward with the alive-mask above.
525
+
526
+ ---
527
+
528
+ ## 11. Baselines
529
+
530
+ ### Baseline 1: random policy
531
+
532
+ Measure:
533
+
534
+ ```text
535
+ success_rate
536
+ mean_return
537
+ mean_episode_length
538
+ ```
539
+
540
+ ### Baseline 2: PPO
541
+
542
+ Use Stable-Baselines3 PPO with a custom CNN feature extractor. MiniGrid’s own training tutorial shows PPO training with Stable-Baselines3 and notes that a custom feature extractor is needed because the default CNN architecture does not directly support MiniGrid’s observation space.
543
+
544
+ Reference: [MiniGrid training tutorial](https://minigrid.farama.org/content/training/)
545
+
546
+ Compare PPO and world model under equal environment steps:
547
+
548
+ ```yaml
549
+ env_steps:
550
+ - 10_000
551
+ - 25_000
552
+ - 50_000
553
+ - 100_000
554
+ ```
555
+
556
+ ### Comparison protocol
557
+
558
+ Use the same wrapped observation setup for PPO and the world-model agent unless an experiment explicitly states otherwise.
559
+
560
+ Budget accounting:
561
+
562
+ - count every environment interaction used for world-model data collection toward the world-model budget, including bootstrap random data and later planner-driven data
563
+ - do not count offline gradient updates toward environment-step budgets
564
+ - do not count evaluation episodes toward training budgets
565
+
566
+ Checkpointing protocol at budget `B`:
567
+
568
+ 1. PPO is trained for exactly `B` environment steps, then evaluated.
569
+ 2. The world-model pipeline is allowed to collect exactly `B` environment steps in total, updating the model online according to Section 9, then evaluated.
570
+ 3. Report the latest checkpoint at budget `B`, not the best checkpoint seen so far.
571
+ 4. Use the same seeds and the same number of evaluation episodes for both methods.
572
+
573
+ ---
574
+
575
+ ## 12. Evaluation metrics
576
+
577
+ ### Agent metrics
578
+
579
+ ```yaml
580
+ success_rate: fraction of episodes reaching goal
581
+ mean_return: average episodic return
582
+ median_return: robust episodic return
583
+ mean_episode_length: average steps per episode
584
+ env_steps_to_80_percent_success: sample efficiency metric
585
+ ```
586
+
587
+ ### World model metrics
588
+
589
+ ```yaml
590
+ reward_mse: reward prediction error
591
+ done_bce: termination prediction error
592
+ kl_loss: posterior-prior divergence
593
+ reconstruction_mse: if decoder enabled
594
+ open_loop_reward_error_h1_h5_h10: reward rollout error over horizons
595
+ open_loop_done_accuracy_h1_h5_h10: done prediction accuracy over horizons
596
+ ```
597
+
598
+ World-model metrics must be reported on held-out validation or test episodes, never on the training replay used for optimization. For final tables and plots, prefer the held-out test split.
599
+
600
+ ### Planning metrics
601
+
602
+ ```yaml
603
+ imagined_return_vs_real_return_correlation
604
+ planner_action_entropy
605
+ model_rollout_horizon_sensitivity
606
+ ```
607
+
608
+ ---
609
+
610
+ ## 13. Experiment matrix
611
+
612
+ ### MVP experiments
613
+
614
+ | ID | Model | Decoder | Latent | Horizon | Actions | Purpose |
615
+ |---:|---|---:|---:|---:|---|---|
616
+ | E0 | Random | n/a | n/a | n/a | 7 | lower bound |
617
+ | E1 | PPO | n/a | n/a | n/a | 7 | model-free baseline |
618
+ | E2 | World model | yes | 32 | 8 | 7 | default model |
619
+ | E3 | World model | no | 32 | 8 | 7 | test reconstruction value |
620
+ | E4 | World model | yes | 16 | 8 | 7 | latent bottleneck |
621
+ | E5 | World model | yes | 64 | 8 | 7 | larger latent |
622
+ | E6 | World model | yes | 32 | 4 | 7 | short planning |
623
+ | E7 | World model | yes | 32 | 16 | 7 | long planning |
624
+
625
+ ### Extension experiments
626
+
627
+ | ID | Environment | Purpose |
628
+ |---:|---|---|
629
+ | X1 | DoorKey | test object interaction |
630
+ | X2 | LockedRoom | test longer-horizon partial observability |
631
+ | X3 | DynamicObstacles | test non-stationary local hazards |
632
+ | X4 | FourRooms larger/custom | test layout generalization |
633
+
634
+ ---
635
+
636
+ ## 14. Expected deliverables
637
+
638
+ ### Code deliverables
639
+
640
+ ```text
641
+ minidreamer/
642
+ configs/
643
+ fourrooms_world_model.yaml
644
+ fourrooms_ppo.yaml
645
+ src/
646
+ envs/
647
+ make_env.py
648
+ data/
649
+ replay_buffer.py
650
+ collect_random.py
651
+ models/
652
+ encoder.py
653
+ rssm.py
654
+ heads.py
655
+ decoder.py
656
+ world_model.py
657
+ planning/
658
+ cem.py
659
+ evaluate_planner.py
660
+ baselines/
661
+ train_ppo.py
662
+ train_world_model.py
663
+ evaluate.py
664
+ scripts/
665
+ collect_random.sh
666
+ train_world_model.sh
667
+ eval_planner.sh
668
+ train_ppo.sh
669
+ notebooks/
670
+ rollout_debug.ipynb
671
+ results_analysis.ipynb
672
+ tests/
673
+ test_env.py
674
+ test_replay_buffer.py
675
+ test_rssm_shapes.py
676
+ test_cem_planner.py
677
+ ```
678
+
679
+ ### Research deliverables
680
+
681
+ ```text
682
+ README.md
683
+ results.md
684
+ plots/
685
+ learning_curves.png
686
+ success_rate_vs_env_steps.png
687
+ model_error_vs_rollout_horizon.png
688
+ reconstruction_examples.png
689
+ imagined_vs_real_rollouts.png
690
+ ```
691
+
692
+ ### Minimum publishable artifact
693
+
694
+ A short report with:
695
+
696
+ 1. Problem statement
697
+ 2. Method
698
+ 3. Environment setup
699
+ 4. Baselines
700
+ 5. Main learning curves
701
+ 6. Ablations
702
+ 7. Failure cases
703
+ 8. Next steps
704
+
705
+ ---
706
+
707
+ ## 15. Milestones
708
+
709
+ ### Milestone 1: environment and data pipeline
710
+
711
+ Acceptance criteria:
712
+
713
+ - Can create `MiniGrid-FourRooms-v0`
714
+ - Can wrap it into RGB-only observation mode
715
+ - Can collect and save bootstrap trajectories with explicit `terminated`, `truncated`, and `done` flags
716
+ - Can reload trajectory chunks for sequence training
717
+ - Train/val/test episode splits are reproducible
718
+ - Shape tests pass
719
+
720
+ ### Milestone 2: world model training
721
+
722
+ Acceptance criteria:
723
+
724
+ - RSSM forward pass works on sequence batches
725
+ - Reward, done, KL, and optional reconstruction losses train without NaNs
726
+ - One-step reward/done prediction beats trivial constant predictor on held-out validation episodes
727
+ - Model checkpointing works
728
+
729
+ ### Milestone 3: open-loop model evaluation
730
+
731
+ Acceptance criteria:
732
+
733
+ - Can visualize real vs. reconstructed observations
734
+ - Can roll model forward for 1, 5, and 10 imagined steps
735
+ - Can report reward and done prediction error by horizon on held-out episodes
736
+
737
+ ### Milestone 4: CEM planner
738
+
739
+ Acceptance criteria:
740
+
741
+ - Planner can choose actions from latent state
742
+ - Planner runs online in the environment
743
+ - Planner scoring uses termination-aware reward masking
744
+ - Success rate beats random policy over 100 evaluation episodes
745
+
746
+ ### Milestone 5: baseline comparison
747
+
748
+ Acceptance criteria:
749
+
750
+ - PPO baseline runs on same observation setup
751
+ - World model and PPO are compared at fixed environment-step budgets, with bootstrap data counted toward the world-model budget
752
+ - At least 3 seeds per method
753
+
754
+ ### Milestone 6: ablations and report
755
+
756
+ Acceptance criteria:
757
+
758
+ - Run decoder/no-decoder ablation
759
+ - Run at least 2 planning horizons
760
+ - Run at least 2 latent sizes
761
+ - Produce final plots and written analysis
762
+
763
+ ---
764
+
765
+ ## 16. Suggested default config
766
+
767
+ ```yaml
768
+ project:
769
+ name: minidreamer-fourrooms
770
+ seed: 0
771
+
772
+ env:
773
+ id: MiniGrid-FourRooms-v0
774
+ rgb_partial_obs: true
775
+ image_only: true
776
+ resize: [64, 64]
777
+ normalize_obs: true
778
+ action_space: full
779
+
780
+ replay:
781
+ capacity_episodes: 5000
782
+ sequence_length: 32
783
+ batch_size: 32
784
+ train_fraction: 0.8
785
+ val_fraction: 0.1
786
+ test_fraction: 0.1
787
+ split_key: episode_id
788
+
789
+ collection:
790
+ bootstrap_env_steps: 5000
791
+ bootstrap_success_threshold: 20
792
+ bootstrap_fallback_policy: restricted_random_3_actions
793
+ bootstrap_env_step_cap: 20000
794
+ collect_steps_per_iteration: 1000
795
+ train_collect_ratio: 1.0
796
+ random_action_fraction_after_planner: 0.3
797
+
798
+ model:
799
+ embedding_dim: 256
800
+ deter_dim: 256
801
+ stoch_dim: 32
802
+ hidden_dim: 256
803
+ use_decoder: true
804
+ min_std: 0.1
805
+
806
+ training:
807
+ optimizer: adam
808
+ lr: 0.0003
809
+ grad_clip_norm: 100.0
810
+ train_steps: 100000
811
+ beta_kl: 1.0
812
+ free_nats: 1.0
813
+ beta_recon: 1.0
814
+ beta_reward: 1.0
815
+ beta_done: 1.0
816
+
817
+ planner:
818
+ type: discrete_cem
819
+ horizon: 8
820
+ candidates: 256
821
+ elites: 32
822
+ iterations: 4
823
+ discount: 0.99
824
+ use_done_mask: true
825
+
826
+ evaluation:
827
+ episodes: 100
828
+ seeds: [0, 1, 2]
829
+ eval_every_env_steps: 10000
830
+ ```
831
+
832
+ ---
833
+
834
+ ## 17. Acceptance criteria for v1
835
+
836
+ The project is successful if:
837
+
838
+ 1. The world model trains stably for at least 3 seeds.
839
+ 2. The CEM planner beats random policy on `MiniGrid-FourRooms-v0`.
840
+ 3. The report shows learning curves for world-model planner, PPO, and random.
841
+ 4. The report includes at least one ablation showing whether reconstruction helps.
842
+ 5. The repo can reproduce the main result from a clean config file and fixed environment-step comparison protocol.
843
+
844
+ A strong v1 result would be:
845
+
846
+ ```text
847
+ world_model_success_rate > random_success_rate
848
+ ```
849
+
850
+ An excellent v1 result would be:
851
+
852
+ ```text
853
+ world_model reaches useful success rate with fewer env steps than PPO
854
+ ```
855
+
856
+ Do not require the world model to beat PPO asymptotically. For a first project, sample efficiency and interpretable failure analysis are more important.
857
+
858
+ ---
859
+
860
+ ## 18. Main risks
861
+
862
+ ### Risk 1: sparse reward makes reward model hard to train
863
+
864
+ Mitigations:
865
+
866
+ - start with random + heuristic exploratory data
867
+ - add goal-reaching trajectories if random rarely succeeds
868
+ - optionally train first on `MiniGrid-Empty-*`
869
+ - use success classification in addition to reward regression
870
+
871
+ ### Risk 2: model learns visual reconstruction but not controllable dynamics
872
+
873
+ Mitigations:
874
+
875
+ - ablate decoder
876
+ - track reward/done prediction separately from reconstruction
877
+ - evaluate imagined-vs-real return correlation
878
+
879
+ ### Risk 3: CEM planner exploits model errors
880
+
881
+ Mitigations:
882
+
883
+ - short planning horizon
884
+ - replan every step
885
+ - penalize high predicted uncertainty if using ensemble later
886
+ - mix random exploration into data collection
887
+
888
+ ### Risk 4: partial observability hurts Markov assumptions
889
+
890
+ Mitigations:
891
+
892
+ - keep recurrent state
893
+ - compare deterministic-only vs. stochastic recurrent latent
894
+ - optionally add frame stacking as a diagnostic, not as the main solution
895
+
896
+ ---
897
+
898
+ ## 19. Recommended implementation order
899
+
900
+ 1. Implement environment wrapper and trajectory collection.
901
+ 2. Implement replay buffer with sequence sampling.
902
+ 3. Implement encoder + RSSM + reward/done heads.
903
+ 4. Train model on random trajectories.
904
+ 5. Add decoder and visualization.
905
+ 6. Add open-loop rollout diagnostics.
906
+ 7. Implement discrete CEM planner.
907
+ 8. Evaluate against random.
908
+ 9. Add PPO baseline.
909
+ 10. Run ablations.
910
+
911
+ The key design constraint: **do not optimize for final score first**. Optimize for observability: diagnostics, rollout plots, prediction errors, and failure cases. That will make this a research project instead of just another RL training script.
notebooks/results_analysis.ipynb ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Results Analysis\n",
8
+ "\n",
9
+ "Use this notebook to load `metrics/*.jsonl`, aggregate seed-level runs, and generate the plots referenced in `results.md`."
10
+ ]
11
+ }
12
+ ],
13
+ "metadata": {
14
+ "kernelspec": {
15
+ "display_name": "Python 3",
16
+ "language": "python",
17
+ "name": "python3"
18
+ },
19
+ "language_info": {
20
+ "name": "python"
21
+ }
22
+ },
23
+ "nbformat": 4,
24
+ "nbformat_minor": 5
25
+ }
notebooks/rollout_debug.ipynb ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Rollout Debug\n",
8
+ "\n",
9
+ "Use this notebook to inspect replay sequences, reconstructions, and imagined-vs-real reward rollouts after training has been run."
10
+ ]
11
+ }
12
+ ],
13
+ "metadata": {
14
+ "kernelspec": {
15
+ "display_name": "Python 3",
16
+ "language": "python",
17
+ "name": "python3"
18
+ },
19
+ "language_info": {
20
+ "name": "python"
21
+ }
22
+ },
23
+ "nbformat": 4,
24
+ "nbformat_minor": 5
25
+ }
26
+
plots/.gitkeep ADDED
@@ -0,0 +1 @@
 
 
1
+
plots/learning_curves.png ADDED
plots/model_error_vs_rollout_horizon.png ADDED
plots/success_rate_vs_env_steps.png ADDED
pyproject.toml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=69", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "minidreamer"
7
+ version = "0.1.0"
8
+ description = "PlaNet-style world model planning for pixel-based MiniGrid"
9
+ readme = "README.md"
10
+ requires-python = ">=3.11,<3.13"
11
+ license = { file = "LICENSE" }
12
+ authors = [{ name = "OpenAI Codex" }]
13
+ dependencies = [
14
+ "gymnasium>=0.29,<1.1",
15
+ "minigrid>=2.3.1",
16
+ "numpy>=1.26,<3",
17
+ "pillow>=10.0",
18
+ "pyyaml>=6.0",
19
+ "torch>=2.3,<3",
20
+ "tqdm>=4.66",
21
+ "matplotlib>=3.8",
22
+ "stable-baselines3>=2.3.0",
23
+ ]
24
+
25
+ [project.optional-dependencies]
26
+ dev = [
27
+ "pytest>=8.2",
28
+ ]
29
+
30
+ [project.scripts]
31
+ minidreamer-collect-random = "minidreamer.data.collect_random:main"
32
+ minidreamer-train-world-model = "train_world_model:main"
33
+ minidreamer-evaluate = "evaluate:main"
34
+ minidreamer-train-ppo = "minidreamer.baselines.train_ppo:main"
35
+
36
+ [tool.setuptools]
37
+ package-dir = { "" = "src" }
38
+ py-modules = ["train_world_model", "evaluate"]
39
+
40
+ [tool.setuptools.packages.find]
41
+ where = ["src"]
42
+ include = ["minidreamer*"]
43
+
44
+ [tool.pytest.ini_options]
45
+ testpaths = ["tests"]
46
+ pythonpath = ["src"]
results.md ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Results
2
+
3
+ ## Status
4
+
5
+ A baseline world-model training run completed on `2026-04-21` on Apple Silicon using the `mps` backend. This run is now frozen as the reference artifact under `artifacts/world_model/`. Final world-model summaries now come from `artifacts/world_model/metrics/run_summary.json`, `artifacts/world_model/metrics/final_eval_latest.json`, and `artifacts/world_model/metrics/planner_eval_latest_clean.json`, the PPO baseline summary comes from `artifacts/ppo/metrics/run_summary.json`, and generated figures live under `plots/`.
6
+
7
+ ## Problem Statement
8
+
9
+ Train a PlaNet-style world model for `MiniGrid-FourRooms-v0` from partial RGB observations, then evaluate a latent-space discrete CEM planner against a random baseline.
10
+
11
+ ## Method
12
+
13
+ - CNN encoder + Gaussian RSSM world model with reward, done, and reconstruction heads.
14
+ - Discrete CEM planning in latent space.
15
+ - Replay-buffer training with episode-aware train/val/test splits.
16
+ - Config: `configs/fourrooms_world_model.yaml`.
17
+
18
+ ## Environment Setup
19
+
20
+ - Device: `mps`
21
+ - Target environment steps: `100000`
22
+ - Final realized environment steps: `100004`
23
+ - Total gradient updates completed: `10000`
24
+ - Final replay summary:
25
+ - Episodes: `1022`
26
+ - Success episodes: `42`
27
+ - Train/val/test episodes: `819 / 116 / 87`
28
+
29
+ ## Baselines
30
+
31
+ Persisted comparison metrics were recorded at `97124` env steps:
32
+
33
+ - Random baseline success rate: `0.0`
34
+ - Random baseline mean return: `0.0`
35
+ - Random baseline mean episode length: `100.0`
36
+
37
+ Completed PPO baseline after `100000` training env steps:
38
+
39
+ - PPO success rate: `0.14`
40
+ - PPO mean return: `0.1336`
41
+ - PPO median return: `0.0`
42
+ - PPO return std: `0.3313`
43
+ - PPO mean episode length: `86.71`
44
+
45
+ ## Main Metrics
46
+
47
+ Persisted evaluation metrics at `97124` env steps:
48
+
49
+ - Planner success rate: `0.10`
50
+ - Planner mean return: `0.0568`
51
+ - Planner median return: `0.0`
52
+ - Planner mean episode length: `94.8`
53
+ - Planner action entropy: `0.6569`
54
+
55
+ Held-out world-model metrics at the same checkpoint:
56
+
57
+ - Reward MSE: `1.98e-6`
58
+ - Done BCE: `0.1808`
59
+ - KL loss: `0.8996`
60
+ - Reconstruction MSE: `0.0109`
61
+
62
+ Open-loop rollout quality:
63
+
64
+ - Done accuracy @1/@5/@10: `0.9880 / 0.9960 / 0.9973`
65
+ - Reward error @1/@5/@10: `1.98e-6 / 1.83e-6 / 1.69e-6`
66
+
67
+ Canonical clean final planner evaluation at `world_model_latest.pt`:
68
+
69
+ - Evaluation budget: `100` planner episodes and `100` random episodes
70
+ - Planner success rate: `0.04`
71
+ - Planner mean return: `0.03415`
72
+ - Planner median return: `0.0`
73
+ - Planner mean episode length: `96.65`
74
+ - Planner action entropy: `0.9142`
75
+ - Final random baseline success rate: `0.03`
76
+ - Final random baseline mean return: `0.02073`
77
+ - Final random baseline mean episode length: `98.03`
78
+
79
+ Small-sample final world-model diagnostics at `world_model_latest.pt`:
80
+
81
+ - Evaluation budget: `5` held-out world-model episodes
82
+ - Final reward MSE: `4.95e-6`
83
+ - Final done BCE: `0.1513`
84
+ - Final KL loss: `0.8833`
85
+ - Final reconstruction MSE: `0.0096`
86
+ - Final done accuracy @1/@5/@10: `0.9880 / 0.9971 / 0.9978`
87
+ - Final reward error @1/@5/@10: `4.98e-6 / 5.09e-6 / 5.12e-6`
88
+
89
+ ## Comparison
90
+
91
+ At roughly matched data budgets, the PPO baseline produced stronger direct control performance than the world-model planner:
92
+
93
+ - PPO at `100000` env steps: `0.14` success rate, `0.1336` mean return, `86.71` mean episode length over `100` evaluation episodes.
94
+ - Final clean planner eval at `100004` env steps: `0.04` success rate, `0.03415` mean return, `96.65` mean episode length over `100` evaluation episodes.
95
+ - Final clean random eval at `100004` env steps: `0.03` success rate, `0.02073` mean return, `98.03` mean episode length over `100` evaluation episodes.
96
+
97
+ The predictive model stayed numerically strong through the end of training, but that did not translate into a robust planner win on FourRooms at this training budget. After removing evaluation-time action noise, the final planner is slightly better than random, but still clearly below PPO.
98
+
99
+ ## Ablations
100
+
101
+ No ablation runs have been recorded yet.
102
+
103
+ ## Failure Cases / Operational Notes
104
+
105
+ - The initial long run stopped near `90021` env steps when the local machine hit severe disk pressure.
106
+ - Training was resumed successfully from `artifacts/world_model/checkpoints/world_model_env_steps_90021.pt` after adding checkpoint-resume support to the trainer.
107
+ - Only one scheduled planner evaluation row is persisted in `eval_metrics.jsonl`. The resumed segment completed without crossing another configured evaluation boundary before the final save, so the end-of-run planner metrics are stored separately in `final_eval_latest.json` and `planner_eval_latest_clean.json`.
108
+ - The scheduled planner row in `eval_metrics.jsonl` and the planner section inside `final_eval_latest.json` were recorded before the evaluation-noise fix. They are kept for provenance, but `planner_eval_latest_clean.json` is the canonical final planner result.
109
+
110
+ ## Visualizations
111
+
112
+ - Generated from `artifacts/world_model/metrics/train_metrics.jsonl`, `artifacts/world_model/metrics/eval_metrics.jsonl`, `artifacts/world_model/metrics/final_eval_latest.json`, and `artifacts/world_model/metrics/planner_eval_latest_clean.json` using `scripts/generate_results_plots.py`.
113
+ - The success-rate chart uses the clean final planner evaluation for the last point.
114
+ - `plots/learning_curves.png`
115
+ - `plots/success_rate_vs_env_steps.png`
116
+ - `plots/model_error_vs_rollout_horizon.png`
117
+
118
+ ## Artifact Locations
119
+
120
+ - Metrics: `artifacts/world_model/metrics/`
121
+ - Replay: `artifacts/world_model/replay/`
122
+ - Checkpoints: `artifacts/world_model/checkpoints/`
123
+ - Final checkpoint: `artifacts/world_model/checkpoints/world_model_latest.pt`
124
+ - Canonical clean final planner eval: `artifacts/world_model/metrics/planner_eval_latest_clean.json`
125
+ - PPO metrics: `artifacts/ppo/metrics/`
126
+ - PPO checkpoint: `artifacts/ppo/checkpoints/ppo_latest.zip`
127
+ - Generated plots: `plots/`
scripts/collect_random.sh ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ export PYTHONPATH="${PYTHONPATH:-}:$(pwd)/src"
5
+ python3.11 src/minidreamer/data/collect_random.py \
6
+ --config configs/fourrooms_world_model.yaml \
7
+ --output-dir artifacts/bootstrap_replay \
8
+ "$@"
9
+
scripts/eval_planner.sh ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ if [ "$#" -lt 2 ]; then
5
+ echo "Usage: $0 CHECKPOINT REPLAY_DIR [extra evaluate args]" >&2
6
+ exit 1
7
+ fi
8
+
9
+ checkpoint="$1"
10
+ replay_dir="$2"
11
+ shift 2
12
+
13
+ export PYTHONPATH="${PYTHONPATH:-}:$(pwd)/src"
14
+ python3.11 src/evaluate.py \
15
+ planner \
16
+ --config configs/fourrooms_world_model.yaml \
17
+ --checkpoint "${checkpoint}" \
18
+ "$@"
19
+
20
+ python3.11 src/evaluate.py \
21
+ world-model \
22
+ --config configs/fourrooms_world_model.yaml \
23
+ --checkpoint "${checkpoint}" \
24
+ --replay-dir "${replay_dir}" \
25
+ "$@"
26
+
scripts/generate_results_plots.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from pathlib import Path
5
+
6
+ import matplotlib.pyplot as plt
7
+
8
+
9
+ ROOT = Path(__file__).resolve().parents[1]
10
+ METRICS_DIR = ROOT / "artifacts" / "world_model" / "metrics"
11
+ PLOTS_DIR = ROOT / "plots"
12
+
13
+
14
+ def load_jsonl(path: Path) -> list[dict]:
15
+ rows = []
16
+ if not path.exists():
17
+ return rows
18
+ for line in path.read_text(encoding="utf-8").splitlines():
19
+ line = line.strip()
20
+ if line:
21
+ rows.append(json.loads(line))
22
+ return rows
23
+
24
+
25
+ def load_json(path: Path) -> dict | None:
26
+ if not path.exists():
27
+ return None
28
+ return json.loads(path.read_text(encoding="utf-8"))
29
+
30
+
31
+ def rolling_mean(values: list[float], window: int) -> list[float]:
32
+ if not values:
33
+ return []
34
+ out: list[float] = []
35
+ running_sum = 0.0
36
+ for idx, value in enumerate(values):
37
+ running_sum += value
38
+ if idx >= window:
39
+ running_sum -= values[idx - window]
40
+ out.append(running_sum / min(idx + 1, window))
41
+ return out
42
+
43
+
44
+ def generate_learning_curves(train_rows: list[dict]) -> None:
45
+ if not train_rows:
46
+ return
47
+ steps = list(range(1, len(train_rows) + 1))
48
+ window = min(250, len(train_rows))
49
+ loss = rolling_mean([row["loss"] for row in train_rows], window)
50
+ kl = rolling_mean([row["kl_loss"] for row in train_rows], window)
51
+ done = rolling_mean([row["done_loss"] for row in train_rows], window)
52
+ recon = rolling_mean([row["recon_loss"] for row in train_rows], window)
53
+
54
+ fig, ax = plt.subplots(figsize=(10, 6))
55
+ ax.plot(steps, loss, label="total loss", linewidth=2.0)
56
+ ax.plot(steps, kl, label="kl loss", linewidth=1.5)
57
+ ax.plot(steps, done, label="done loss", linewidth=1.5)
58
+ ax.plot(steps, recon, label="recon loss", linewidth=1.5)
59
+ ax.set_title("World Model Training Curves")
60
+ ax.set_xlabel("Gradient update")
61
+ ax.set_ylabel("Smoothed loss")
62
+ ax.legend()
63
+ ax.grid(alpha=0.25)
64
+ fig.tight_layout()
65
+ fig.savefig(PLOTS_DIR / "learning_curves.png", dpi=200)
66
+ plt.close(fig)
67
+
68
+
69
+ def generate_success_plot(
70
+ eval_rows: list[dict],
71
+ final_eval: dict | None,
72
+ clean_planner_eval: dict | None,
73
+ ) -> None:
74
+ steps: list[int] = []
75
+ planner_success: list[float] = []
76
+ random_success: list[float] = []
77
+
78
+ for row in eval_rows:
79
+ steps.append(int(row["env_steps"]))
80
+ planner_success.append(float(row["planner/success_rate"]))
81
+ random_success.append(float(row["random/success_rate"]))
82
+
83
+ if clean_planner_eval is not None:
84
+ steps.append(int(clean_planner_eval["metadata"]["env_steps"]))
85
+ planner_success.append(float(clean_planner_eval["planner_clean"]["success_rate"]))
86
+ random_success.append(float(clean_planner_eval["random"]["success_rate"]))
87
+ elif final_eval is not None:
88
+ steps.append(int(final_eval["metadata"]["env_steps"]))
89
+ planner_success.append(float(final_eval["planner"]["success_rate"]))
90
+ random_success.append(float(final_eval["random"]["success_rate"]))
91
+
92
+ if not steps:
93
+ return
94
+
95
+ paired = sorted(zip(steps, planner_success, random_success), key=lambda item: item[0])
96
+ steps = [item[0] for item in paired]
97
+ planner_success = [item[1] for item in paired]
98
+ random_success = [item[2] for item in paired]
99
+
100
+ fig, ax = plt.subplots(figsize=(9, 5))
101
+ ax.plot(steps, planner_success, marker="o", linewidth=2.0, label="planner success rate")
102
+ ax.plot(steps, random_success, marker="o", linewidth=2.0, label="random success rate")
103
+ ax.set_title("Success Rate vs Environment Steps")
104
+ ax.set_xlabel("Environment steps")
105
+ ax.set_ylabel("Success rate")
106
+ ax.set_ylim(-0.02, 1.02)
107
+ ax.legend()
108
+ ax.grid(alpha=0.25)
109
+ fig.tight_layout()
110
+ fig.savefig(PLOTS_DIR / "success_rate_vs_env_steps.png", dpi=200)
111
+ plt.close(fig)
112
+
113
+
114
+ def generate_rollout_error_plot(final_eval: dict | None) -> None:
115
+ if final_eval is None:
116
+ return
117
+ horizons = [1, 5, 10]
118
+ reward_errors = [
119
+ float(final_eval["world_model"][f"open_loop_reward_error_h{h}"])
120
+ for h in horizons
121
+ ]
122
+ done_accuracy = [
123
+ float(final_eval["world_model"][f"open_loop_done_accuracy_h{h}"])
124
+ for h in horizons
125
+ ]
126
+
127
+ fig, ax1 = plt.subplots(figsize=(9, 5))
128
+ ax1.bar([str(h) for h in horizons], reward_errors, color="#2b6cb0", alpha=0.8)
129
+ ax1.set_xlabel("Open-loop horizon")
130
+ ax1.set_ylabel("Reward MSE", color="#2b6cb0")
131
+ ax1.tick_params(axis="y", labelcolor="#2b6cb0")
132
+ ax1.set_title("Model Error vs Rollout Horizon")
133
+ ax1.grid(alpha=0.2, axis="y")
134
+
135
+ ax2 = ax1.twinx()
136
+ ax2.plot([str(h) for h in horizons], done_accuracy, color="#c05621", marker="o", linewidth=2.0)
137
+ ax2.set_ylabel("Done accuracy", color="#c05621")
138
+ ax2.tick_params(axis="y", labelcolor="#c05621")
139
+ ax2.set_ylim(0.95, 1.001)
140
+
141
+ fig.tight_layout()
142
+ fig.savefig(PLOTS_DIR / "model_error_vs_rollout_horizon.png", dpi=200)
143
+ plt.close(fig)
144
+
145
+
146
+ def main() -> None:
147
+ PLOTS_DIR.mkdir(parents=True, exist_ok=True)
148
+ train_rows = load_jsonl(METRICS_DIR / "train_metrics.jsonl")
149
+ eval_rows = load_jsonl(METRICS_DIR / "eval_metrics.jsonl")
150
+ final_eval = load_json(METRICS_DIR / "final_eval_latest.json")
151
+ clean_planner_eval = load_json(METRICS_DIR / "planner_eval_latest_clean.json")
152
+
153
+ generate_learning_curves(train_rows)
154
+ generate_success_plot(eval_rows, final_eval, clean_planner_eval)
155
+ generate_rollout_error_plot(final_eval)
156
+
157
+
158
+ if __name__ == "__main__":
159
+ main()
scripts/train_ppo.sh ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ export PYTHONPATH="${PYTHONPATH:-}:$(pwd)/src"
5
+ python3.11 src/minidreamer/baselines/train_ppo.py \
6
+ --config configs/fourrooms_ppo.yaml \
7
+ --output-dir artifacts/ppo \
8
+ "$@"
9
+
scripts/train_world_model.sh ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ output_dir="${MINIDREAMER_OUTPUT_DIR:-artifacts/world_model_experiment}"
5
+
6
+ export PYTHONPATH="${PYTHONPATH:-}:$(pwd)/src"
7
+ python3.11 src/train_world_model.py \
8
+ --config configs/fourrooms_world_model.yaml \
9
+ --output-dir "${output_dir}" \
10
+ "$@"
src/evaluate.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ from pathlib import Path
5
+
6
+ from minidreamer.config import load_config
7
+ from minidreamer.data.replay_buffer import ReplayBuffer
8
+ from minidreamer.evaluation import evaluate_random_policy, evaluate_world_model
9
+ from minidreamer.envs.make_env import make_env_from_config
10
+ from minidreamer.planning.evaluate_planner import evaluate_planner
11
+ from minidreamer.serialization import load_world_model_checkpoint
12
+
13
+
14
+ def build_arg_parser() -> argparse.ArgumentParser:
15
+ parser = argparse.ArgumentParser(description="Evaluate MiniDreamer components.")
16
+ subparsers = parser.add_subparsers(dest="command", required=True)
17
+
18
+ random_parser = subparsers.add_parser("random", help="Evaluate a random policy.")
19
+ random_parser.add_argument("--config", type=Path, required=True)
20
+
21
+ planner_parser = subparsers.add_parser("planner", help="Evaluate a trained planner.")
22
+ planner_parser.add_argument("--config", type=Path, required=True)
23
+ planner_parser.add_argument("--checkpoint", type=Path, required=True)
24
+ planner_parser.add_argument(
25
+ "--random-action-fraction",
26
+ type=float,
27
+ default=0.0,
28
+ help="Optional evaluation-time action noise. Defaults to 0.0 for a clean planner evaluation.",
29
+ )
30
+
31
+ world_model_parser = subparsers.add_parser("world-model", help="Evaluate held-out world model metrics.")
32
+ world_model_parser.add_argument("--config", type=Path, required=True)
33
+ world_model_parser.add_argument("--checkpoint", type=Path, required=True)
34
+ world_model_parser.add_argument("--replay-dir", type=Path, required=True)
35
+ world_model_parser.add_argument("--split", type=str, default="val", choices=["train", "val", "test"])
36
+ return parser
37
+
38
+
39
+ def main() -> None:
40
+ parser = build_arg_parser()
41
+ args = parser.parse_args()
42
+ config = load_config(args.config)
43
+
44
+ if args.command == "random":
45
+ print(evaluate_random_policy(config))
46
+ return
47
+
48
+ env = make_env_from_config(config, seed=config.get("project", {}).get("seed", 0))
49
+ action_dim = env.action_space.n
50
+ env.close()
51
+ model, _, metadata = load_world_model_checkpoint(args.checkpoint, action_dim=action_dim, map_location="cpu")
52
+
53
+ if args.command == "planner":
54
+ print({
55
+ "metadata": metadata,
56
+ **evaluate_planner(
57
+ config,
58
+ model,
59
+ random_action_fraction=args.random_action_fraction,
60
+ ),
61
+ })
62
+ return
63
+
64
+ replay = ReplayBuffer.load(args.replay_dir)
65
+ print({"metadata": metadata, **evaluate_world_model(config, model, replay, split=args.split)})
66
+
67
+
68
+ if __name__ == "__main__":
69
+ main()
src/minidreamer/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ """MiniDreamer package."""
2
+
3
+ __all__ = [
4
+ "config",
5
+ ]
6
+
src/minidreamer/baselines/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ """Baseline agents."""
2
+
src/minidreamer/baselines/train_ppo.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ from pathlib import Path
5
+
6
+ import gymnasium as gym
7
+ import torch
8
+ from torch import nn
9
+
10
+ from minidreamer.config import ensure_run_dirs, load_config
11
+ from minidreamer.envs.make_env import make_env_from_config
12
+ from minidreamer.utils.common import seed_everything
13
+
14
+ try:
15
+ from stable_baselines3 import PPO
16
+ from stable_baselines3.common.evaluation import evaluate_policy
17
+ from stable_baselines3.common.monitor import Monitor
18
+ from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
19
+ from stable_baselines3.common.vec_env import DummyVecEnv
20
+ except ImportError as exc: # pragma: no cover - exercised only when dependency is missing.
21
+ PPO = None
22
+ IMPORT_ERROR = exc
23
+ else:
24
+ IMPORT_ERROR = None
25
+
26
+
27
+ class MiniGridCNNExtractor(BaseFeaturesExtractor):
28
+ def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 256) -> None:
29
+ super().__init__(observation_space, features_dim)
30
+ if len(observation_space.shape) != 3:
31
+ raise ValueError(f"Expected 3D image observations, got {observation_space.shape}.")
32
+ self.channel_first = observation_space.shape[0] in (1, 3)
33
+ channels = observation_space.shape[0] if self.channel_first else observation_space.shape[2]
34
+ self.cnn = nn.Sequential(
35
+ nn.Conv2d(channels, 32, kernel_size=4, stride=2, padding=1),
36
+ nn.ReLU(),
37
+ nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
38
+ nn.ReLU(),
39
+ nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
40
+ nn.ReLU(),
41
+ nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
42
+ nn.ReLU(),
43
+ nn.Flatten(),
44
+ )
45
+ with torch.no_grad():
46
+ sample = torch.as_tensor(observation_space.sample()[None]).float()
47
+ if not self.channel_first:
48
+ sample = sample.permute(0, 3, 1, 2)
49
+ flattened_dim = self.cnn(sample).shape[1]
50
+ self.linear = nn.Sequential(nn.Linear(flattened_dim, features_dim), nn.ReLU())
51
+
52
+ def forward(self, observations: torch.Tensor) -> torch.Tensor:
53
+ if observations.dim() == 4 and observations.shape[1] not in (1, 3):
54
+ observations = observations.permute(0, 3, 1, 2)
55
+ return self.linear(self.cnn(observations.float()))
56
+
57
+
58
+ def build_env(config: dict, seed: int, rank: int):
59
+ def _make():
60
+ env = make_env_from_config(config, seed=seed + rank)
61
+ return Monitor(env)
62
+
63
+ return _make
64
+
65
+
66
+ def train_ppo(config: dict, output_dir: str | Path) -> dict[str, float]:
67
+ if PPO is None:
68
+ raise ImportError(
69
+ "stable-baselines3 is required for PPO training."
70
+ ) from IMPORT_ERROR
71
+ ppo_cfg = config["ppo"]
72
+ seed = config.get("project", {}).get("seed", 0)
73
+ seed_everything(seed)
74
+ run_dirs = ensure_run_dirs(output_dir)
75
+
76
+ env_fns = [build_env(config, seed, rank) for rank in range(ppo_cfg.get("num_envs", 4))]
77
+ vec_env = DummyVecEnv(env_fns)
78
+ policy_kwargs = {
79
+ "features_extractor_class": MiniGridCNNExtractor,
80
+ "features_extractor_kwargs": {"features_dim": ppo_cfg.get("features_dim", 256)},
81
+ }
82
+ model = PPO(
83
+ "CnnPolicy",
84
+ vec_env,
85
+ policy_kwargs=policy_kwargs,
86
+ learning_rate=ppo_cfg.get("learning_rate", 3e-4),
87
+ n_steps=ppo_cfg.get("n_steps", 256),
88
+ batch_size=ppo_cfg.get("batch_size", 256),
89
+ n_epochs=ppo_cfg.get("n_epochs", 4),
90
+ gamma=ppo_cfg.get("gamma", 0.99),
91
+ gae_lambda=ppo_cfg.get("gae_lambda", 0.95),
92
+ clip_range=ppo_cfg.get("clip_range", 0.2),
93
+ ent_coef=ppo_cfg.get("ent_coef", 0.01),
94
+ vf_coef=ppo_cfg.get("vf_coef", 0.5),
95
+ seed=seed,
96
+ device=ppo_cfg.get("device", "auto"),
97
+ verbose=1,
98
+ )
99
+ model.learn(total_timesteps=ppo_cfg["total_timesteps"])
100
+ model.save(Path(run_dirs["checkpoints"]) / "ppo_latest")
101
+ mean_reward, std_reward = evaluate_policy(
102
+ model,
103
+ vec_env,
104
+ n_eval_episodes=config["evaluation"]["episodes"],
105
+ deterministic=True,
106
+ )
107
+ vec_env.close()
108
+ return {"mean_reward": float(mean_reward), "std_reward": float(std_reward)}
109
+
110
+
111
+ def build_arg_parser() -> argparse.ArgumentParser:
112
+ parser = argparse.ArgumentParser(description="Train a PPO baseline on MiniGrid pixels.")
113
+ parser.add_argument("--config", type=Path, required=True)
114
+ parser.add_argument("--output-dir", type=Path, required=True)
115
+ return parser
116
+
117
+
118
+ def main() -> None:
119
+ parser = build_arg_parser()
120
+ args = parser.parse_args()
121
+ config = load_config(args.config)
122
+ summary = train_ppo(config, args.output_dir)
123
+ print(summary)
124
+
125
+
126
+ if __name__ == "__main__":
127
+ main()
src/minidreamer/config.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import copy
4
+ from pathlib import Path
5
+ from typing import Any
6
+
7
+ import yaml
8
+
9
+ ConfigDict = dict[str, Any]
10
+
11
+
12
+ def load_config(path: str | Path) -> ConfigDict:
13
+ path = Path(path)
14
+ with path.open("r", encoding="utf-8") as handle:
15
+ config = yaml.safe_load(handle)
16
+ if not isinstance(config, dict):
17
+ raise ValueError(f"Config at {path} must be a mapping.")
18
+ return config
19
+
20
+
21
+ def merge_dicts(base: ConfigDict, overrides: ConfigDict) -> ConfigDict:
22
+ merged = copy.deepcopy(base)
23
+ for key, value in overrides.items():
24
+ if isinstance(value, dict) and isinstance(merged.get(key), dict):
25
+ merged[key] = merge_dicts(merged[key], value)
26
+ else:
27
+ merged[key] = value
28
+ return merged
29
+
30
+
31
+ def save_config(config: ConfigDict, path: str | Path) -> None:
32
+ path = Path(path)
33
+ path.parent.mkdir(parents=True, exist_ok=True)
34
+ with path.open("w", encoding="utf-8") as handle:
35
+ yaml.safe_dump(config, handle, sort_keys=False)
36
+
37
+
38
+ def ensure_run_dirs(base_dir: str | Path) -> dict[str, Path]:
39
+ base = Path(base_dir)
40
+ paths = {
41
+ "base": base,
42
+ "checkpoints": base / "checkpoints",
43
+ "metrics": base / "metrics",
44
+ "plots": base / "plots",
45
+ "replay": base / "replay",
46
+ }
47
+ for path in paths.values():
48
+ path.mkdir(parents=True, exist_ok=True)
49
+ return paths
50
+
51
+
52
+ def deep_get(config: ConfigDict, *keys: str, default: Any = None) -> Any:
53
+ current: Any = config
54
+ for key in keys:
55
+ if not isinstance(current, dict) or key not in current:
56
+ return default
57
+ current = current[key]
58
+ return current
59
+
src/minidreamer/envs/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ """Environment helpers."""
2
+
src/minidreamer/envs/make_env.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Iterable
5
+
6
+ import gymnasium as gym
7
+ import numpy as np
8
+ import torch
9
+ from gymnasium import spaces
10
+ from minigrid.wrappers import ImgObsWrapper, RGBImgPartialObsWrapper
11
+ from PIL import Image
12
+
13
+
14
+ @dataclass(frozen=True)
15
+ class EnvSpec:
16
+ env_id: str
17
+ resize: tuple[int, int] = (64, 64)
18
+ normalize_obs: bool = True
19
+ rgb_partial_obs: bool = True
20
+ image_only: bool = True
21
+
22
+
23
+ class ResizeNormalizeObservation(gym.ObservationWrapper):
24
+ def __init__(
25
+ self,
26
+ env: gym.Env,
27
+ resize: tuple[int, int] | None = (64, 64),
28
+ normalize: bool = True,
29
+ ) -> None:
30
+ super().__init__(env)
31
+ self.resize = resize
32
+ self.normalize = normalize
33
+ base_space = env.observation_space
34
+ if not isinstance(base_space, spaces.Box):
35
+ raise TypeError("MiniDreamer expects a Box observation space after wrappers.")
36
+ channels = base_space.shape[-1]
37
+ if resize is None:
38
+ height, width = base_space.shape[:2]
39
+ else:
40
+ height, width = resize
41
+ low, high = (0.0, 1.0) if normalize else (0, 255)
42
+ dtype = np.float32 if normalize else np.uint8
43
+ self.observation_space = spaces.Box(
44
+ low=low,
45
+ high=high,
46
+ shape=(height, width, channels),
47
+ dtype=dtype,
48
+ )
49
+
50
+ def observation(self, observation: np.ndarray) -> np.ndarray:
51
+ obs = observation
52
+ if self.resize is not None and tuple(obs.shape[:2]) != self.resize:
53
+ pil_image = Image.fromarray(obs.astype(np.uint8))
54
+ pil_image = pil_image.resize((self.resize[1], self.resize[0]), Image.Resampling.BILINEAR)
55
+ obs = np.asarray(pil_image)
56
+ if self.normalize:
57
+ return obs.astype(np.float32) / 255.0
58
+ return obs.astype(np.uint8)
59
+
60
+
61
+ def make_env(
62
+ env_id: str = "MiniGrid-FourRooms-v0",
63
+ seed: int | None = None,
64
+ resize: tuple[int, int] = (64, 64),
65
+ normalize_obs: bool = True,
66
+ rgb_partial_obs: bool = True,
67
+ image_only: bool = True,
68
+ render_mode: str | None = None,
69
+ ) -> gym.Env:
70
+ env = gym.make(env_id, render_mode=render_mode)
71
+ env = gym.wrappers.RecordEpisodeStatistics(env)
72
+ if rgb_partial_obs:
73
+ env = RGBImgPartialObsWrapper(env)
74
+ if image_only:
75
+ env = ImgObsWrapper(env)
76
+ env = ResizeNormalizeObservation(env, resize=resize, normalize=normalize_obs)
77
+ if seed is not None:
78
+ env.reset(seed=seed)
79
+ env.action_space.seed(seed)
80
+ return env
81
+
82
+
83
+ def make_env_from_config(config: dict, seed: int | None = None) -> gym.Env:
84
+ env_cfg = config["env"]
85
+ return make_env(
86
+ env_id=env_cfg["id"],
87
+ seed=seed,
88
+ resize=tuple(env_cfg.get("resize", (64, 64))),
89
+ normalize_obs=env_cfg.get("normalize_obs", True),
90
+ rgb_partial_obs=env_cfg.get("rgb_partial_obs", True),
91
+ image_only=env_cfg.get("image_only", True),
92
+ )
93
+
94
+
95
+ def observation_to_tensor(observation: np.ndarray, device: torch.device | None = None) -> torch.Tensor:
96
+ if observation.ndim != 3:
97
+ raise ValueError(f"Expected HWC observation, got shape {observation.shape}.")
98
+ tensor = torch.from_numpy(observation).permute(2, 0, 1).float()
99
+ return tensor.to(device) if device is not None else tensor
100
+
101
+
102
+ def batch_observations_to_tensor(
103
+ observations: np.ndarray,
104
+ device: torch.device | None = None,
105
+ ) -> torch.Tensor:
106
+ if observations.ndim != 5:
107
+ raise ValueError(f"Expected BT HWC observations, got shape {observations.shape}.")
108
+ tensor = torch.from_numpy(observations).permute(0, 1, 4, 2, 3).float()
109
+ return tensor.to(device) if device is not None else tensor
110
+
111
+
112
+ def action_subset(action_space_n: int, names: Iterable[str] | None = None) -> list[int]:
113
+ if names is None:
114
+ return list(range(action_space_n))
115
+ lookup = {
116
+ "left": 0,
117
+ "right": 1,
118
+ "forward": 2,
119
+ "pickup": 3,
120
+ "drop": 4,
121
+ "toggle": 5,
122
+ "done": 6,
123
+ }
124
+ return [lookup[name] for name in names]
125
+
src/minidreamer/evaluation.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from collections import defaultdict
4
+ from typing import Any
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ from minidreamer.data.replay_buffer import Episode, ReplayBuffer
11
+ from minidreamer.envs.make_env import make_env_from_config
12
+ from minidreamer.models.world_model import WorldModel
13
+
14
+
15
+ def evaluate_random_policy(config: dict, episodes: int | None = None, seed: int | None = None) -> dict[str, float]:
16
+ eval_cfg = config["evaluation"]
17
+ episodes = episodes or eval_cfg["episodes"]
18
+ seed = config.get("project", {}).get("seed", 0) if seed is None else seed
19
+ env = make_env_from_config(config, seed=seed)
20
+ rng = np.random.default_rng(seed)
21
+ returns = []
22
+ lengths = []
23
+ successes = []
24
+ for episode_idx in range(episodes):
25
+ obs, _ = env.reset(seed=seed + episode_idx)
26
+ total_return = 0.0
27
+ terminated = False
28
+ truncated = False
29
+ length = 0
30
+ while not (terminated or truncated):
31
+ action = int(rng.integers(0, env.action_space.n))
32
+ obs, reward, terminated, truncated, _ = env.step(action)
33
+ total_return += float(reward)
34
+ length += 1
35
+ returns.append(total_return)
36
+ lengths.append(length)
37
+ successes.append(float(terminated and total_return > 0.0))
38
+ env.close()
39
+ returns_array = np.asarray(returns, dtype=np.float32)
40
+ lengths_array = np.asarray(lengths, dtype=np.float32)
41
+ successes_array = np.asarray(successes, dtype=np.float32)
42
+ return {
43
+ "success_rate": float(successes_array.mean()),
44
+ "mean_return": float(returns_array.mean()),
45
+ "median_return": float(np.median(returns_array)),
46
+ "mean_episode_length": float(lengths_array.mean()),
47
+ }
48
+
49
+
50
+ def _episode_to_batch(episode: Episode, device: torch.device) -> dict[str, torch.Tensor]:
51
+ batch = {
52
+ "obs": torch.from_numpy(episode.obs[None]).permute(0, 1, 4, 2, 3).float().to(device),
53
+ "actions": torch.from_numpy(episode.actions[None]).long().to(device),
54
+ "rewards": torch.from_numpy(episode.rewards[None]).float().to(device),
55
+ "terminated": torch.from_numpy(episode.terminated[None]).float().to(device),
56
+ "truncated": torch.from_numpy(episode.truncated[None]).float().to(device),
57
+ "done": torch.from_numpy(episode.done[None]).float().to(device),
58
+ "mask": torch.ones((1, episode.length), dtype=torch.float32, device=device),
59
+ }
60
+ return batch
61
+
62
+
63
+ def _sequence_state(model: WorldModel, episode: Episode, start_idx: int) -> Any:
64
+ state = model.posterior_step(model.initial_state(1), None, episode.obs[0], sample=False)
65
+ for idx in range(start_idx):
66
+ state = model.posterior_step(state, int(episode.actions[idx]), episode.obs[idx + 1], sample=False)
67
+ return state
68
+
69
+
70
+ def _discounted_return(rewards: np.ndarray, done: np.ndarray, discount: float) -> float:
71
+ total = 0.0
72
+ alive = 1.0
73
+ for step, reward in enumerate(rewards):
74
+ total += (discount**step) * alive * float(reward)
75
+ alive *= 1.0 - float(done[step])
76
+ return total
77
+
78
+
79
+ def evaluate_world_model(
80
+ config: dict,
81
+ model: WorldModel,
82
+ replay: ReplayBuffer,
83
+ split: str = "val",
84
+ max_episodes: int | None = None,
85
+ ) -> dict[str, float]:
86
+ device = model.device
87
+ model.eval()
88
+ metrics: dict[str, list[float]] = defaultdict(list)
89
+ horizons = [1, 5, 10]
90
+ discount = float(config["planner"]["discount"])
91
+ episodes = replay.episode_ids(split)
92
+ if max_episodes is not None:
93
+ episodes = episodes[:max_episodes]
94
+
95
+ with torch.no_grad():
96
+ for episode_id in episodes:
97
+ episode = replay.episodes[episode_id]
98
+ batch = _episode_to_batch(episode, device)
99
+ outputs = model.observe_sequence(batch["obs"], batch["actions"], sample=False)
100
+ reward_mse = F.mse_loss(outputs.reward_pred, batch["rewards"], reduction="none").mean()
101
+ done_bce = F.binary_cross_entropy_with_logits(outputs.done_logits, batch["done"], reduction="none").mean()
102
+ kl = model.rssm.kl_divergence(
103
+ outputs.post_mean,
104
+ outputs.post_std,
105
+ outputs.prior_mean,
106
+ outputs.prior_std,
107
+ ).mean()
108
+ metrics["reward_mse"].append(float(reward_mse.cpu()))
109
+ metrics["done_bce"].append(float(done_bce.cpu()))
110
+ metrics["kl_loss"].append(float(kl.cpu()))
111
+ if outputs.reconstructions is not None:
112
+ recon_mse = F.mse_loss(outputs.reconstructions, batch["obs"][:, 1:], reduction="none").mean()
113
+ metrics["reconstruction_mse"].append(float(recon_mse.cpu()))
114
+
115
+ for horizon in horizons:
116
+ if episode.length < horizon:
117
+ continue
118
+ reward_errors = []
119
+ done_correct = []
120
+ imagined_returns = []
121
+ real_returns = []
122
+ for start_idx in range(episode.length - horizon + 1):
123
+ state = _sequence_state(model, episode, start_idx)
124
+ actions = torch.from_numpy(episode.actions[start_idx : start_idx + horizon]).long().to(device)
125
+ rollout = model.score_action_sequences(
126
+ state,
127
+ actions.unsqueeze(0),
128
+ discount=discount,
129
+ use_done_mask=True,
130
+ )
131
+ reward_pred = rollout["reward_pred"].squeeze(0).cpu().numpy()
132
+ done_prob = rollout["done_prob"].squeeze(0).cpu().numpy()
133
+ done_pred = (done_prob >= 0.5).astype(np.float32)
134
+ real_rewards = episode.rewards[start_idx : start_idx + horizon]
135
+ real_done = episode.done[start_idx : start_idx + horizon]
136
+ reward_errors.append(np.mean((reward_pred - real_rewards) ** 2))
137
+ done_correct.append(np.mean(done_pred == real_done))
138
+ imagined_returns.append(float(rollout["scores"].squeeze(0).cpu()))
139
+ real_returns.append(_discounted_return(real_rewards, real_done, discount))
140
+ metrics[f"open_loop_reward_error_h{horizon}"].append(float(np.mean(reward_errors)))
141
+ metrics[f"open_loop_done_accuracy_h{horizon}"].append(float(np.mean(done_correct)))
142
+ if len(imagined_returns) > 1 and np.std(real_returns) > 0.0 and np.std(imagined_returns) > 0.0:
143
+ correlation = np.corrcoef(imagined_returns, real_returns)[0, 1]
144
+ metrics["imagined_return_vs_real_return_correlation"].append(float(correlation))
145
+
146
+ return {name: float(np.mean(values)) for name, values in metrics.items() if values}
147
+
src/minidreamer/models/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ """Model components for MiniDreamer."""
2
+
src/minidreamer/models/decoder.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from torch import nn
4
+
5
+
6
+ class ConvDecoder(nn.Module):
7
+ def __init__(self, feature_dim: int, out_channels: int = 3) -> None:
8
+ super().__init__()
9
+ self.projection = nn.Sequential(
10
+ nn.Linear(feature_dim, 256 * 4 * 4),
11
+ nn.ReLU(),
12
+ )
13
+ self.decoder = nn.Sequential(
14
+ nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
15
+ nn.ReLU(),
16
+ nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
17
+ nn.ReLU(),
18
+ nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
19
+ nn.ReLU(),
20
+ nn.ConvTranspose2d(32, out_channels, kernel_size=4, stride=2, padding=1),
21
+ nn.Sigmoid(),
22
+ )
23
+
24
+ def forward(self, features):
25
+ hidden = self.projection(features).view(-1, 256, 4, 4)
26
+ return self.decoder(hidden)
27
+
src/minidreamer/models/encoder.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+
7
+ class ConvEncoder(nn.Module):
8
+ def __init__(self, in_channels: int = 3, embedding_dim: int = 256) -> None:
9
+ super().__init__()
10
+ self.conv = nn.Sequential(
11
+ nn.Conv2d(in_channels, 32, kernel_size=4, stride=2, padding=1),
12
+ nn.ReLU(),
13
+ nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
14
+ nn.ReLU(),
15
+ nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
16
+ nn.ReLU(),
17
+ nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
18
+ nn.ReLU(),
19
+ )
20
+ self.projection = nn.Sequential(
21
+ nn.Flatten(),
22
+ nn.Linear(256 * 4 * 4, embedding_dim),
23
+ nn.ReLU(),
24
+ )
25
+
26
+ def forward(self, obs: torch.Tensor) -> torch.Tensor:
27
+ hidden = self.conv(obs)
28
+ return self.projection(hidden)
29
+
src/minidreamer/models/heads.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from torch import nn
4
+
5
+
6
+ class MLPHead(nn.Module):
7
+ def __init__(self, in_dim: int, out_dim: int, hidden_dim: int = 256) -> None:
8
+ super().__init__()
9
+ self.net = nn.Sequential(
10
+ nn.Linear(in_dim, hidden_dim),
11
+ nn.ELU(),
12
+ nn.Linear(hidden_dim, hidden_dim),
13
+ nn.ELU(),
14
+ nn.Linear(hidden_dim, out_dim),
15
+ )
16
+
17
+ def forward(self, x):
18
+ return self.net(x)
19
+
20
+
21
+ class RewardHead(MLPHead):
22
+ def __init__(self, in_dim: int, hidden_dim: int = 256) -> None:
23
+ super().__init__(in_dim=in_dim, out_dim=1, hidden_dim=hidden_dim)
24
+
25
+
26
+ class DoneHead(MLPHead):
27
+ def __init__(self, in_dim: int, hidden_dim: int = 256) -> None:
28
+ super().__init__(in_dim=in_dim, out_dim=1, hidden_dim=hidden_dim)
29
+
src/minidreamer/models/rssm.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+
9
+
10
+ @dataclass
11
+ class RSSMState:
12
+ deter: torch.Tensor
13
+ stoch: torch.Tensor
14
+ mean: torch.Tensor
15
+ std: torch.Tensor
16
+
17
+ def features(self) -> torch.Tensor:
18
+ return torch.cat([self.deter, self.stoch], dim=-1)
19
+
20
+ def detach(self) -> "RSSMState":
21
+ return RSSMState(
22
+ deter=self.deter.detach(),
23
+ stoch=self.stoch.detach(),
24
+ mean=self.mean.detach(),
25
+ std=self.std.detach(),
26
+ )
27
+
28
+ def repeat(self, count: int) -> "RSSMState":
29
+ return RSSMState(
30
+ deter=self.deter.repeat(count, 1),
31
+ stoch=self.stoch.repeat(count, 1),
32
+ mean=self.mean.repeat(count, 1),
33
+ std=self.std.repeat(count, 1),
34
+ )
35
+
36
+
37
+ class RSSM(nn.Module):
38
+ def __init__(
39
+ self,
40
+ action_dim: int,
41
+ embedding_dim: int = 256,
42
+ deter_dim: int = 256,
43
+ stoch_dim: int = 32,
44
+ hidden_dim: int = 256,
45
+ min_std: float = 0.1,
46
+ ) -> None:
47
+ super().__init__()
48
+ self.action_dim = action_dim
49
+ self.embedding_dim = embedding_dim
50
+ self.deter_dim = deter_dim
51
+ self.stoch_dim = stoch_dim
52
+ self.hidden_dim = hidden_dim
53
+ self.min_std = min_std
54
+
55
+ self.input_net = nn.Sequential(
56
+ nn.Linear(stoch_dim + action_dim, hidden_dim),
57
+ nn.ELU(),
58
+ )
59
+ self.gru = nn.GRUCell(hidden_dim, deter_dim)
60
+ self.prior_net = nn.Sequential(
61
+ nn.Linear(deter_dim, hidden_dim),
62
+ nn.ELU(),
63
+ nn.Linear(hidden_dim, 2 * stoch_dim),
64
+ )
65
+ self.posterior_net = nn.Sequential(
66
+ nn.Linear(deter_dim + embedding_dim, hidden_dim),
67
+ nn.ELU(),
68
+ nn.Linear(hidden_dim, 2 * stoch_dim),
69
+ )
70
+
71
+ def initial(self, batch_size: int, device: torch.device) -> RSSMState:
72
+ zeros_deter = torch.zeros(batch_size, self.deter_dim, device=device)
73
+ zeros_stoch = torch.zeros(batch_size, self.stoch_dim, device=device)
74
+ return RSSMState(
75
+ deter=zeros_deter,
76
+ stoch=zeros_stoch,
77
+ mean=zeros_stoch,
78
+ std=torch.ones_like(zeros_stoch),
79
+ )
80
+
81
+ def _stats(self, tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
82
+ mean, raw_std = torch.chunk(tensor, 2, dim=-1)
83
+ std = F.softplus(raw_std) + self.min_std
84
+ return mean, std
85
+
86
+ def _action_one_hot(self, action: torch.Tensor) -> torch.Tensor:
87
+ action = action.long().view(-1)
88
+ return F.one_hot(action, num_classes=self.action_dim).float()
89
+
90
+ def _next_deter(self, prev_state: RSSMState, action: torch.Tensor) -> torch.Tensor:
91
+ action_one_hot = self._action_one_hot(action)
92
+ gru_input = self.input_net(torch.cat([prev_state.stoch, action_one_hot], dim=-1))
93
+ return self.gru(gru_input, prev_state.deter)
94
+
95
+ def prior(self, deter: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
96
+ return self._stats(self.prior_net(deter))
97
+
98
+ def posterior(self, deter: torch.Tensor, embed: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
99
+ return self._stats(self.posterior_net(torch.cat([deter, embed], dim=-1)))
100
+
101
+ @staticmethod
102
+ def sample(mean: torch.Tensor, std: torch.Tensor, sample: bool = True) -> torch.Tensor:
103
+ if sample:
104
+ return mean + torch.randn_like(std) * std
105
+ return mean
106
+
107
+ def observe(
108
+ self,
109
+ prev_state: RSSMState,
110
+ prev_action: torch.Tensor | None,
111
+ embed: torch.Tensor,
112
+ sample: bool = True,
113
+ ) -> tuple[RSSMState, tuple[torch.Tensor, torch.Tensor]]:
114
+ if prev_action is None:
115
+ deter = prev_state.deter
116
+ else:
117
+ deter = self._next_deter(prev_state, prev_action)
118
+ prior_mean, prior_std = self.prior(deter)
119
+ post_mean, post_std = self.posterior(deter, embed)
120
+ stoch = self.sample(post_mean, post_std, sample=sample)
121
+ state = RSSMState(deter=deter, stoch=stoch, mean=post_mean, std=post_std)
122
+ return state, (prior_mean, prior_std)
123
+
124
+ def imagine(
125
+ self,
126
+ prev_state: RSSMState,
127
+ action: torch.Tensor,
128
+ sample: bool = False,
129
+ ) -> tuple[RSSMState, tuple[torch.Tensor, torch.Tensor]]:
130
+ deter = self._next_deter(prev_state, action)
131
+ mean, std = self.prior(deter)
132
+ stoch = self.sample(mean, std, sample=sample)
133
+ state = RSSMState(deter=deter, stoch=stoch, mean=mean, std=std)
134
+ return state, (mean, std)
135
+
136
+ @staticmethod
137
+ def kl_divergence(
138
+ post_mean: torch.Tensor,
139
+ post_std: torch.Tensor,
140
+ prior_mean: torch.Tensor,
141
+ prior_std: torch.Tensor,
142
+ ) -> torch.Tensor:
143
+ log_var_ratio = 2.0 * (torch.log(prior_std) - torch.log(post_std))
144
+ var_ratio = (post_std / prior_std) ** 2
145
+ mean_term = ((post_mean - prior_mean) / prior_std) ** 2
146
+ return 0.5 * torch.sum(var_ratio + mean_term + log_var_ratio - 1.0, dim=-1)
147
+
src/minidreamer/models/world_model.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch import nn
10
+
11
+ from minidreamer.models.decoder import ConvDecoder
12
+ from minidreamer.models.encoder import ConvEncoder
13
+ from minidreamer.models.heads import DoneHead, RewardHead
14
+ from minidreamer.models.rssm import RSSM, RSSMState
15
+ from minidreamer.utils.common import masked_mean
16
+
17
+
18
+ @dataclass
19
+ class WorldModelOutputs:
20
+ states: list[RSSMState]
21
+ prior_mean: torch.Tensor
22
+ prior_std: torch.Tensor
23
+ post_mean: torch.Tensor
24
+ post_std: torch.Tensor
25
+ reward_pred: torch.Tensor
26
+ done_logits: torch.Tensor
27
+ reconstructions: torch.Tensor | None
28
+
29
+
30
+ class WorldModel(nn.Module):
31
+ def __init__(
32
+ self,
33
+ action_dim: int,
34
+ embedding_dim: int = 256,
35
+ deter_dim: int = 256,
36
+ stoch_dim: int = 32,
37
+ hidden_dim: int = 256,
38
+ use_decoder: bool = True,
39
+ min_std: float = 0.1,
40
+ obs_channels: int = 3,
41
+ ) -> None:
42
+ super().__init__()
43
+ self.action_dim = action_dim
44
+ self.embedding_dim = embedding_dim
45
+ self.deter_dim = deter_dim
46
+ self.stoch_dim = stoch_dim
47
+ self.hidden_dim = hidden_dim
48
+ self.use_decoder = use_decoder
49
+ self.obs_channels = obs_channels
50
+
51
+ self.encoder = ConvEncoder(in_channels=obs_channels, embedding_dim=embedding_dim)
52
+ self.rssm = RSSM(
53
+ action_dim=action_dim,
54
+ embedding_dim=embedding_dim,
55
+ deter_dim=deter_dim,
56
+ stoch_dim=stoch_dim,
57
+ hidden_dim=hidden_dim,
58
+ min_std=min_std,
59
+ )
60
+ feature_dim = deter_dim + stoch_dim
61
+ self.reward_head = RewardHead(feature_dim, hidden_dim=hidden_dim)
62
+ self.done_head = DoneHead(feature_dim, hidden_dim=hidden_dim)
63
+ self.decoder = ConvDecoder(feature_dim, out_channels=obs_channels) if use_decoder else None
64
+
65
+ @classmethod
66
+ def from_config(cls, config: dict, action_dim: int, obs_channels: int = 3) -> "WorldModel":
67
+ model_cfg = config["model"]
68
+ return cls(
69
+ action_dim=action_dim,
70
+ embedding_dim=model_cfg["embedding_dim"],
71
+ deter_dim=model_cfg["deter_dim"],
72
+ stoch_dim=model_cfg["stoch_dim"],
73
+ hidden_dim=model_cfg["hidden_dim"],
74
+ use_decoder=model_cfg.get("use_decoder", True),
75
+ min_std=model_cfg.get("min_std", 0.1),
76
+ obs_channels=obs_channels,
77
+ )
78
+
79
+ @property
80
+ def device(self) -> torch.device:
81
+ return next(self.parameters()).device
82
+
83
+ def initial_state(self, batch_size: int) -> RSSMState:
84
+ return self.rssm.initial(batch_size=batch_size, device=self.device)
85
+
86
+ def encode(self, obs: torch.Tensor) -> torch.Tensor:
87
+ if obs.dim() < 4:
88
+ raise ValueError(f"Expected at least 4 dims for observations, got {obs.shape}.")
89
+ leading_shape = obs.shape[:-3]
90
+ flat_obs = obs.reshape(-1, *obs.shape[-3:])
91
+ flat_embeddings = self.encoder(flat_obs)
92
+ return flat_embeddings.reshape(*leading_shape, -1)
93
+
94
+ def observe_sequence(
95
+ self,
96
+ obs: torch.Tensor,
97
+ actions: torch.Tensor,
98
+ sample: bool = True,
99
+ ) -> WorldModelOutputs:
100
+ if obs.dim() != 5:
101
+ raise ValueError(f"Expected obs shape [B, T+1, C, H, W], got {obs.shape}.")
102
+ if actions.dim() != 2:
103
+ raise ValueError(f"Expected actions shape [B, T], got {actions.shape}.")
104
+ batch_size, time_steps = actions.shape
105
+ embeddings = self.encode(obs)
106
+ state, _ = self.rssm.observe(self.initial_state(batch_size), None, embeddings[:, 0], sample=sample)
107
+
108
+ states = [state]
109
+ prior_means = []
110
+ prior_stds = []
111
+ post_means = []
112
+ post_stds = []
113
+ rewards = []
114
+ done_logits = []
115
+ reconstructions = []
116
+
117
+ for t in range(time_steps):
118
+ next_state, (prior_mean, prior_std) = self.rssm.observe(
119
+ state,
120
+ actions[:, t],
121
+ embeddings[:, t + 1],
122
+ sample=sample,
123
+ )
124
+ features = next_state.features()
125
+ prior_means.append(prior_mean)
126
+ prior_stds.append(prior_std)
127
+ post_means.append(next_state.mean)
128
+ post_stds.append(next_state.std)
129
+ rewards.append(self.reward_head(features).squeeze(-1))
130
+ done_logits.append(self.done_head(features).squeeze(-1))
131
+ if self.decoder is not None:
132
+ reconstructions.append(self.decoder(features))
133
+ states.append(next_state)
134
+ state = next_state
135
+
136
+ recon_tensor = torch.stack(reconstructions, dim=1) if reconstructions else None
137
+ return WorldModelOutputs(
138
+ states=states,
139
+ prior_mean=torch.stack(prior_means, dim=1),
140
+ prior_std=torch.stack(prior_stds, dim=1),
141
+ post_mean=torch.stack(post_means, dim=1),
142
+ post_std=torch.stack(post_stds, dim=1),
143
+ reward_pred=torch.stack(rewards, dim=1),
144
+ done_logits=torch.stack(done_logits, dim=1),
145
+ reconstructions=recon_tensor,
146
+ )
147
+
148
+ def compute_losses(self, batch: dict[str, torch.Tensor], config: dict[str, Any]) -> dict[str, torch.Tensor]:
149
+ outputs = self.observe_sequence(batch["obs"], batch["actions"], sample=True)
150
+ training_cfg = config["training"]
151
+ rewards = batch["rewards"]
152
+ done = batch["done"]
153
+ mask = batch["mask"]
154
+
155
+ reward_loss = masked_mean(F.mse_loss(outputs.reward_pred, rewards, reduction="none"), mask)
156
+ done_loss = masked_mean(
157
+ F.binary_cross_entropy_with_logits(outputs.done_logits, done, reduction="none"),
158
+ mask,
159
+ )
160
+ kl_per_step = self.rssm.kl_divergence(
161
+ outputs.post_mean,
162
+ outputs.post_std,
163
+ outputs.prior_mean,
164
+ outputs.prior_std,
165
+ )
166
+ free_nats = torch.full_like(kl_per_step, float(training_cfg.get("free_nats", 1.0)))
167
+ kl_loss = masked_mean(torch.maximum(kl_per_step, free_nats), mask)
168
+
169
+ if outputs.reconstructions is not None and training_cfg.get("beta_recon", 0.0) > 0.0:
170
+ recon_target = batch["obs"][:, 1:]
171
+ recon_error = F.mse_loss(outputs.reconstructions, recon_target, reduction="none").mean(dim=(2, 3, 4))
172
+ recon_loss = masked_mean(recon_error, mask)
173
+ else:
174
+ recon_loss = torch.zeros((), device=self.device)
175
+
176
+ total_loss = (
177
+ float(training_cfg.get("beta_reward", 1.0)) * reward_loss
178
+ + float(training_cfg.get("beta_done", 1.0)) * done_loss
179
+ + float(training_cfg.get("beta_kl", 1.0)) * kl_loss
180
+ + float(training_cfg.get("beta_recon", 0.0)) * recon_loss
181
+ )
182
+ return {
183
+ "loss": total_loss,
184
+ "reward_loss": reward_loss.detach(),
185
+ "done_loss": done_loss.detach(),
186
+ "kl_loss": kl_loss.detach(),
187
+ "recon_loss": recon_loss.detach(),
188
+ "reward_mse": F.mse_loss(outputs.reward_pred, rewards, reduction="none").mul(mask).sum() / mask.sum().clamp_min(1.0),
189
+ "done_bce": F.binary_cross_entropy_with_logits(outputs.done_logits, done, reduction="none").mul(mask).sum() / mask.sum().clamp_min(1.0),
190
+ }
191
+
192
+ def posterior_step(
193
+ self,
194
+ prev_state: RSSMState,
195
+ prev_action: int | torch.Tensor | None,
196
+ observation: np.ndarray | torch.Tensor,
197
+ sample: bool = False,
198
+ ) -> RSSMState:
199
+ obs_tensor = self._prepare_single_observation(observation)
200
+ embed = self.encode(obs_tensor)
201
+ if prev_action is None:
202
+ action_tensor = None
203
+ else:
204
+ action_tensor = torch.as_tensor(prev_action, device=self.device).view(1)
205
+ state, _ = self.rssm.observe(prev_state, action_tensor, embed, sample=sample)
206
+ return state
207
+
208
+ def imagine_rollout(
209
+ self,
210
+ start_state: RSSMState,
211
+ action_sequences: torch.Tensor,
212
+ sample: bool = False,
213
+ ) -> dict[str, torch.Tensor]:
214
+ if action_sequences.dim() == 1:
215
+ action_sequences = action_sequences.unsqueeze(0)
216
+ batch_size, horizon = action_sequences.shape
217
+ state = start_state if start_state.deter.shape[0] == batch_size else start_state.repeat(batch_size)
218
+ rewards = []
219
+ done_logits = []
220
+ states = []
221
+ for t in range(horizon):
222
+ state, _ = self.rssm.imagine(state, action_sequences[:, t], sample=sample)
223
+ features = state.features()
224
+ rewards.append(self.reward_head(features).squeeze(-1))
225
+ done_logits.append(self.done_head(features).squeeze(-1))
226
+ states.append(state)
227
+ return {
228
+ "states": states,
229
+ "reward_pred": torch.stack(rewards, dim=1),
230
+ "done_logits": torch.stack(done_logits, dim=1),
231
+ }
232
+
233
+ def score_action_sequences(
234
+ self,
235
+ start_state: RSSMState,
236
+ action_sequences: torch.Tensor,
237
+ discount: float = 0.99,
238
+ use_done_mask: bool = True,
239
+ ) -> dict[str, torch.Tensor]:
240
+ rollout = self.imagine_rollout(start_state, action_sequences, sample=False)
241
+ reward_pred = rollout["reward_pred"]
242
+ done_prob = torch.sigmoid(rollout["done_logits"])
243
+ alive = torch.ones(reward_pred.shape[0], device=self.device)
244
+ scores = torch.zeros_like(alive)
245
+ for t in range(reward_pred.shape[1]):
246
+ scores = scores + (discount**t) * alive * reward_pred[:, t]
247
+ if use_done_mask:
248
+ alive = alive * (1.0 - done_prob[:, t])
249
+ rollout["scores"] = scores
250
+ rollout["done_prob"] = done_prob
251
+ return rollout
252
+
253
+ def _prepare_single_observation(self, observation: np.ndarray | torch.Tensor) -> torch.Tensor:
254
+ if torch.is_tensor(observation):
255
+ obs_tensor = observation.to(self.device).float()
256
+ else:
257
+ obs_tensor = torch.as_tensor(observation, device=self.device).float()
258
+ if obs_tensor.dim() == 3 and obs_tensor.shape[-1] in (1, 3):
259
+ obs_tensor = obs_tensor.permute(2, 0, 1).unsqueeze(0)
260
+ elif obs_tensor.dim() == 3 and obs_tensor.shape[0] in (1, 3):
261
+ obs_tensor = obs_tensor.unsqueeze(0)
262
+ elif obs_tensor.dim() == 4:
263
+ pass
264
+ else:
265
+ raise ValueError(f"Unsupported observation shape {tuple(obs_tensor.shape)}.")
266
+ return obs_tensor
267
+
src/minidreamer/planning/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ """Planning utilities."""
2
+
src/minidreamer/planning/cem.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+
8
+ from minidreamer.models.rssm import RSSMState
9
+ from minidreamer.models.world_model import WorldModel
10
+
11
+
12
+ @dataclass
13
+ class PlannerOutput:
14
+ action: int
15
+ sequence: list[int]
16
+ score: float
17
+ policy: torch.Tensor
18
+ entropy: float
19
+
20
+
21
+ class DiscreteCEMPlanner:
22
+ def __init__(
23
+ self,
24
+ world_model: WorldModel,
25
+ action_dim: int,
26
+ horizon: int = 8,
27
+ candidates: int = 256,
28
+ elites: int = 32,
29
+ iterations: int = 4,
30
+ discount: float = 0.99,
31
+ use_done_mask: bool = True,
32
+ smoothing: float = 1e-3,
33
+ ) -> None:
34
+ self.world_model = world_model
35
+ self.action_dim = action_dim
36
+ self.horizon = horizon
37
+ self.candidates = candidates
38
+ self.elites = min(elites, candidates)
39
+ self.iterations = iterations
40
+ self.discount = discount
41
+ self.use_done_mask = use_done_mask
42
+ self.smoothing = smoothing
43
+
44
+ @classmethod
45
+ def from_config(cls, world_model: WorldModel, action_dim: int, config: dict) -> "DiscreteCEMPlanner":
46
+ planner_cfg = config["planner"]
47
+ return cls(
48
+ world_model=world_model,
49
+ action_dim=action_dim,
50
+ horizon=planner_cfg["horizon"],
51
+ candidates=planner_cfg["candidates"],
52
+ elites=planner_cfg["elites"],
53
+ iterations=planner_cfg["iterations"],
54
+ discount=planner_cfg["discount"],
55
+ use_done_mask=planner_cfg.get("use_done_mask", True),
56
+ )
57
+
58
+ def _sample_sequences(self, probs: torch.Tensor) -> torch.Tensor:
59
+ flat = probs.unsqueeze(0).expand(self.candidates, -1, -1).reshape(-1, self.action_dim)
60
+ sampled = torch.multinomial(flat, num_samples=1, replacement=True)
61
+ return sampled.view(self.candidates, self.horizon)
62
+
63
+ def plan(self, state: RSSMState) -> PlannerOutput:
64
+ device = self.world_model.device
65
+ probs = torch.full(
66
+ (self.horizon, self.action_dim),
67
+ fill_value=1.0 / self.action_dim,
68
+ device=device,
69
+ )
70
+ best_sequence = None
71
+ best_score = torch.tensor(float("-inf"), device=device)
72
+
73
+ for _ in range(self.iterations):
74
+ action_sequences = self._sample_sequences(probs)
75
+ scores = self.world_model.score_action_sequences(
76
+ state,
77
+ action_sequences,
78
+ discount=self.discount,
79
+ use_done_mask=self.use_done_mask,
80
+ )["scores"]
81
+ elite_indices = torch.topk(scores, k=self.elites, largest=True).indices
82
+ elites = action_sequences[elite_indices]
83
+ elite_freq = F.one_hot(elites, num_classes=self.action_dim).float().mean(dim=0)
84
+ probs = elite_freq + self.smoothing
85
+ probs = probs / probs.sum(dim=-1, keepdim=True)
86
+
87
+ iteration_best_idx = scores.argmax()
88
+ iteration_best_score = scores[iteration_best_idx]
89
+ if iteration_best_score > best_score:
90
+ best_score = iteration_best_score
91
+ best_sequence = action_sequences[iteration_best_idx]
92
+
93
+ if best_sequence is None:
94
+ raise RuntimeError("CEM planner failed to sample any action sequence.")
95
+ entropy = float((-(probs * probs.clamp_min(1e-8).log()).sum(dim=-1)).mean().detach().cpu())
96
+ return PlannerOutput(
97
+ action=int(best_sequence[0].item()),
98
+ sequence=[int(action.item()) for action in best_sequence],
99
+ score=float(best_score.detach().cpu()),
100
+ policy=probs.detach().cpu(),
101
+ entropy=entropy,
102
+ )
103
+
src/minidreamer/planning/evaluate_planner.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ import numpy as np
6
+ import torch
7
+
8
+ from minidreamer.envs.make_env import make_env_from_config
9
+ from minidreamer.planning.cem import DiscreteCEMPlanner
10
+
11
+
12
+ @dataclass
13
+ class PlannerEpisode:
14
+ success: bool
15
+ total_return: float
16
+ length: int
17
+ terminated: bool
18
+ truncated: bool
19
+ planner_entropy: float
20
+
21
+
22
+ def run_planner_episode(
23
+ env,
24
+ world_model,
25
+ planner: DiscreteCEMPlanner,
26
+ rng: np.random.Generator,
27
+ seed: int | None = None,
28
+ random_action_fraction: float = 0.0,
29
+ ) -> PlannerEpisode:
30
+ obs, _ = env.reset(seed=seed)
31
+ world_model.eval()
32
+ with torch.no_grad():
33
+ state = world_model.posterior_step(world_model.initial_state(1), None, obs, sample=False)
34
+ total_return = 0.0
35
+ length = 0
36
+ terminated = False
37
+ truncated = False
38
+ entropies: list[float] = []
39
+
40
+ while not (terminated or truncated):
41
+ if rng.random() < random_action_fraction:
42
+ action = int(env.action_space.sample())
43
+ else:
44
+ plan = planner.plan(state)
45
+ action = plan.action
46
+ entropies.append(plan.entropy)
47
+ obs, reward, terminated, truncated, _ = env.step(action)
48
+ total_return += float(reward)
49
+ length += 1
50
+ if not (terminated or truncated):
51
+ state = world_model.posterior_step(state, action, obs, sample=False)
52
+
53
+ return PlannerEpisode(
54
+ success=bool(terminated and total_return > 0.0),
55
+ total_return=total_return,
56
+ length=length,
57
+ terminated=bool(terminated),
58
+ truncated=bool(truncated),
59
+ planner_entropy=float(np.mean(entropies)) if entropies else float("nan"),
60
+ )
61
+
62
+
63
+ def evaluate_planner(
64
+ config: dict,
65
+ world_model,
66
+ episodes: int | None = None,
67
+ seed: int | None = None,
68
+ random_action_fraction: float = 0.0,
69
+ ) -> dict[str, float]:
70
+ eval_cfg = config["evaluation"]
71
+ episodes = episodes or eval_cfg["episodes"]
72
+ seed = config.get("project", {}).get("seed", 0) if seed is None else seed
73
+ env = make_env_from_config(config, seed=seed)
74
+ planner = DiscreteCEMPlanner.from_config(world_model, env.action_space.n, config)
75
+ rng = np.random.default_rng(seed)
76
+ results = [
77
+ run_planner_episode(
78
+ env,
79
+ world_model,
80
+ planner,
81
+ rng,
82
+ seed=seed + episode_idx,
83
+ random_action_fraction=random_action_fraction,
84
+ )
85
+ for episode_idx in range(episodes)
86
+ ]
87
+ env.close()
88
+
89
+ returns = np.asarray([result.total_return for result in results], dtype=np.float32)
90
+ lengths = np.asarray([result.length for result in results], dtype=np.float32)
91
+ successes = np.asarray([result.success for result in results], dtype=np.float32)
92
+ entropies = np.asarray([result.planner_entropy for result in results], dtype=np.float32)
93
+ return {
94
+ "success_rate": float(successes.mean()),
95
+ "mean_return": float(returns.mean()),
96
+ "median_return": float(np.median(returns)),
97
+ "mean_episode_length": float(lengths.mean()),
98
+ "planner_action_entropy": float(np.nanmean(entropies)),
99
+ }
src/minidreamer/serialization.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from typing import Any
5
+
6
+ import torch
7
+
8
+ from minidreamer.models.world_model import WorldModel
9
+
10
+
11
+ def save_world_model_checkpoint(
12
+ path: str | Path,
13
+ model: WorldModel,
14
+ config: dict[str, Any],
15
+ optimizer: torch.optim.Optimizer | None = None,
16
+ metadata: dict[str, Any] | None = None,
17
+ ) -> None:
18
+ path = Path(path)
19
+ path.parent.mkdir(parents=True, exist_ok=True)
20
+ payload = {
21
+ "model_state": model.state_dict(),
22
+ "config": config,
23
+ "metadata": metadata or {},
24
+ }
25
+ if optimizer is not None:
26
+ payload["optimizer_state"] = optimizer.state_dict()
27
+ torch.save(payload, path)
28
+
29
+
30
+ def load_world_model_checkpoint(
31
+ path: str | Path,
32
+ action_dim: int,
33
+ map_location: str | torch.device | None = None,
34
+ ) -> tuple[WorldModel, dict[str, Any], dict[str, Any]]:
35
+ payload = torch.load(path, map_location=map_location, weights_only=False)
36
+ config = payload["config"]
37
+ model = WorldModel.from_config(config, action_dim=action_dim)
38
+ model.load_state_dict(payload["model_state"])
39
+ return model, config, payload.get("metadata", {})
40
+
src/minidreamer/utils/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ """Utility helpers for MiniDreamer."""
2
+
src/minidreamer/utils/common.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import random
5
+ from pathlib import Path
6
+ from typing import Any
7
+
8
+ import numpy as np
9
+ import torch
10
+
11
+
12
+ def seed_everything(seed: int) -> None:
13
+ random.seed(seed)
14
+ np.random.seed(seed)
15
+ torch.manual_seed(seed)
16
+ if torch.cuda.is_available():
17
+ torch.cuda.manual_seed_all(seed)
18
+
19
+
20
+ def get_device(device: str | None = None) -> torch.device:
21
+ if device is not None:
22
+ return torch.device(device)
23
+ if torch.cuda.is_available():
24
+ return torch.device("cuda")
25
+ if torch.backends.mps.is_available():
26
+ return torch.device("mps")
27
+ return torch.device("cpu")
28
+
29
+
30
+ def to_numpy(value: Any) -> np.ndarray:
31
+ if isinstance(value, np.ndarray):
32
+ return value
33
+ if torch.is_tensor(value):
34
+ return value.detach().cpu().numpy()
35
+ return np.asarray(value)
36
+
37
+
38
+ def masked_mean(values: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
39
+ mask = mask.to(values.dtype)
40
+ denom = torch.clamp(mask.sum(), min=1.0)
41
+ return (values * mask).sum() / denom
42
+
43
+
44
+ def write_json(path: str | Path, payload: dict[str, Any]) -> None:
45
+ path = Path(path)
46
+ path.parent.mkdir(parents=True, exist_ok=True)
47
+ with path.open("w", encoding="utf-8") as handle:
48
+ json.dump(payload, handle, indent=2, sort_keys=True)
49
+
50
+
51
+ def write_jsonl(path: str | Path, rows: list[dict[str, Any]]) -> None:
52
+ path = Path(path)
53
+ path.parent.mkdir(parents=True, exist_ok=True)
54
+ with path.open("w", encoding="utf-8") as handle:
55
+ for row in rows:
56
+ handle.write(json.dumps(row, sort_keys=True))
57
+ handle.write("\n")
58
+
src/train_world_model.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ from pathlib import Path
5
+ from typing import Any
6
+
7
+ import numpy as np
8
+ import torch
9
+ from torch.nn.utils import clip_grad_norm_
10
+ from tqdm import trange
11
+
12
+ from minidreamer.config import ensure_run_dirs, load_config, merge_dicts, save_config
13
+ from minidreamer.data.collect_random import collect_bootstrap_dataset
14
+ from minidreamer.data.replay_buffer import ReplayBuffer
15
+ from minidreamer.evaluation import evaluate_random_policy, evaluate_world_model
16
+ from minidreamer.envs.make_env import make_env_from_config
17
+ from minidreamer.models.world_model import WorldModel
18
+ from minidreamer.planning.cem import DiscreteCEMPlanner
19
+ from minidreamer.planning.evaluate_planner import evaluate_planner
20
+ from minidreamer.serialization import save_world_model_checkpoint
21
+ from minidreamer.utils.common import get_device, seed_everything, write_json, write_jsonl
22
+
23
+
24
+ def train_world_model_updates(
25
+ model: WorldModel,
26
+ replay: ReplayBuffer,
27
+ optimizer: torch.optim.Optimizer,
28
+ config: dict[str, Any],
29
+ num_updates: int,
30
+ device: torch.device,
31
+ ) -> list[dict[str, float]]:
32
+ if num_updates <= 0:
33
+ return []
34
+ model.train()
35
+ logs: list[dict[str, float]] = []
36
+ progress = trange(num_updates, desc="world-model-updates", leave=False)
37
+ for _ in progress:
38
+ batch = ReplayBuffer.batch_to_torch(replay.sample_sequences(split="train"), device=device)
39
+ losses = model.compute_losses(batch, config)
40
+ optimizer.zero_grad(set_to_none=True)
41
+ losses["loss"].backward()
42
+ clip_grad_norm_(model.parameters(), float(config["training"].get("grad_clip_norm", 100.0)))
43
+ optimizer.step()
44
+ log_row = {
45
+ "loss": float(losses["loss"].detach().cpu()),
46
+ "reward_loss": float(losses["reward_loss"].cpu()),
47
+ "done_loss": float(losses["done_loss"].cpu()),
48
+ "kl_loss": float(losses["kl_loss"].cpu()),
49
+ "recon_loss": float(losses["recon_loss"].cpu()),
50
+ }
51
+ logs.append(log_row)
52
+ progress.set_postfix({key: f"{value:.3f}" for key, value in log_row.items()})
53
+ return logs
54
+
55
+
56
+ def optimizer_to_device(optimizer: torch.optim.Optimizer, device: torch.device) -> None:
57
+ for state in optimizer.state.values():
58
+ for key, value in state.items():
59
+ if torch.is_tensor(value):
60
+ state[key] = value.to(device)
61
+
62
+
63
+ def load_training_state(
64
+ checkpoint_path: str | Path,
65
+ config: dict[str, Any],
66
+ action_dim: int,
67
+ device: torch.device,
68
+ ) -> tuple[dict[str, Any], WorldModel, torch.optim.Optimizer, dict[str, Any]]:
69
+ payload = torch.load(checkpoint_path, map_location=device, weights_only=False)
70
+ resolved_config = merge_dicts(payload["config"], config)
71
+ model = WorldModel.from_config(resolved_config, action_dim=action_dim).to(device)
72
+ model.load_state_dict(payload["model_state"])
73
+ optimizer = torch.optim.Adam(model.parameters(), lr=float(resolved_config["training"]["lr"]))
74
+ optimizer_state = payload.get("optimizer_state")
75
+ if optimizer_state is not None:
76
+ optimizer.load_state_dict(optimizer_state)
77
+ optimizer_to_device(optimizer, device)
78
+ return resolved_config, model, optimizer, payload.get("metadata", {})
79
+
80
+
81
+ def find_existing_run_artifacts(base_dir: str | Path) -> list[Path]:
82
+ base = Path(base_dir)
83
+ if not base.exists():
84
+ return []
85
+
86
+ artifact_files = [
87
+ base / "metrics" / "run_summary.json",
88
+ base / "metrics" / "train_metrics.jsonl",
89
+ base / "metrics" / "eval_metrics.jsonl",
90
+ base / "checkpoints" / "world_model_latest.pt",
91
+ base / "replay" / "metadata.json",
92
+ ]
93
+ found = [path for path in artifact_files if path.exists()]
94
+ if found:
95
+ return found
96
+
97
+ for subdir_name in ("checkpoints", "metrics", "replay"):
98
+ subdir = base / subdir_name
99
+ if subdir.exists():
100
+ for child in subdir.iterdir():
101
+ found.append(child)
102
+ break
103
+ return found
104
+
105
+
106
+ def collect_planner_steps(
107
+ env,
108
+ replay: ReplayBuffer,
109
+ model: WorldModel,
110
+ planner: DiscreteCEMPlanner,
111
+ num_steps: int,
112
+ random_action_fraction: float,
113
+ rng: np.random.Generator,
114
+ ) -> dict[str, int]:
115
+ collected_steps = 0
116
+ episodes = 0
117
+ success_episodes = 0
118
+ model.eval()
119
+ while collected_steps < num_steps:
120
+ obs, _ = env.reset()
121
+ observations = [obs]
122
+ actions: list[int] = []
123
+ rewards: list[float] = []
124
+ terminated_flags: list[float] = []
125
+ truncated_flags: list[float] = []
126
+ done_flags: list[float] = []
127
+ terminated = False
128
+ truncated = False
129
+
130
+ with torch.no_grad():
131
+ state = model.posterior_step(model.initial_state(1), None, obs, sample=False)
132
+ while not (terminated or truncated):
133
+ if rng.random() < random_action_fraction:
134
+ action = int(env.action_space.sample())
135
+ else:
136
+ action = planner.plan(state).action
137
+ obs, reward, terminated, truncated, _ = env.step(action)
138
+ actions.append(action)
139
+ rewards.append(float(reward))
140
+ terminated_flags.append(float(terminated))
141
+ truncated_flags.append(float(truncated))
142
+ done_flags.append(float(terminated or truncated))
143
+ observations.append(obs)
144
+ collected_steps += 1
145
+ if terminated or truncated:
146
+ break
147
+ state = model.posterior_step(state, action, obs, sample=False)
148
+
149
+ replay.add_episode(
150
+ obs=np.asarray(observations, dtype=np.float32),
151
+ actions=np.asarray(actions, dtype=np.int64),
152
+ rewards=np.asarray(rewards, dtype=np.float32),
153
+ terminated=np.asarray(terminated_flags, dtype=np.float32),
154
+ truncated=np.asarray(truncated_flags, dtype=np.float32),
155
+ done=np.asarray(done_flags, dtype=np.float32),
156
+ )
157
+ episodes += 1
158
+ success_episodes += int(bool(terminated and np.sum(rewards) > 0.0))
159
+ return {
160
+ "env_steps": collected_steps,
161
+ "episodes": episodes,
162
+ "success_episodes": success_episodes,
163
+ }
164
+
165
+
166
+ def run_training(
167
+ config: dict[str, Any],
168
+ output_dir: str | Path,
169
+ replay_dir: str | Path | None = None,
170
+ resume_checkpoint: str | Path | None = None,
171
+ allow_overwrite_existing_output: bool = False,
172
+ ) -> dict[str, Any]:
173
+ seed = config.get("project", {}).get("seed", 0)
174
+ seed_everything(seed)
175
+ existing_artifacts = find_existing_run_artifacts(output_dir)
176
+ if existing_artifacts and resume_checkpoint is None and not allow_overwrite_existing_output:
177
+ preview = ", ".join(str(path) for path in existing_artifacts[:3])
178
+ raise FileExistsError(
179
+ f"Refusing to overwrite existing run directory '{output_dir}'. "
180
+ f"Found existing artifacts: {preview}. "
181
+ "Choose a new --output-dir, resume with --resume-checkpoint, "
182
+ "or pass --allow-overwrite-existing-output to overwrite intentionally."
183
+ )
184
+ run_dirs = ensure_run_dirs(output_dir)
185
+ device = get_device(config.get("training", {}).get("device"))
186
+
187
+ env = make_env_from_config(config, seed=seed)
188
+ action_dim = env.action_space.n
189
+ env.close()
190
+
191
+ if replay_dir is not None and Path(replay_dir).exists():
192
+ replay = ReplayBuffer.load(replay_dir)
193
+ collection_summary = {"replay_loaded": replay.summary()}
194
+ else:
195
+ replay, collection_summary = collect_bootstrap_dataset(config, output_dir=run_dirs["replay"], seed=seed)
196
+
197
+ resume_metadata: dict[str, Any] = {}
198
+ if resume_checkpoint is not None:
199
+ config, model, optimizer, resume_metadata = load_training_state(
200
+ checkpoint_path=resume_checkpoint,
201
+ config=config,
202
+ action_dim=action_dim,
203
+ device=device,
204
+ )
205
+ else:
206
+ model = WorldModel.from_config(config, action_dim=action_dim).to(device)
207
+ optimizer = torch.optim.Adam(model.parameters(), lr=float(config["training"]["lr"]))
208
+
209
+ save_config(config, run_dirs["base"] / "resolved_config.yaml")
210
+ training_logs: list[dict[str, float]] = []
211
+ evaluation_logs: list[dict[str, float]] = []
212
+
213
+ train_collect_ratio = float(config["collection"].get("train_collect_ratio", 1.0))
214
+ total_updates_budget = int(config["training"]["train_steps"])
215
+ if resume_checkpoint is not None:
216
+ updates_done = int(resume_metadata.get("updates_done", 0))
217
+ checkpoint_env_steps = int(resume_metadata.get("env_steps", 0))
218
+ if replay.env_steps > checkpoint_env_steps and updates_done < total_updates_budget:
219
+ collect_steps_per_iteration = max(1, int(config["collection"].get("collect_steps_per_iteration", 1)))
220
+ per_iteration_updates = int(
221
+ config["collection"].get(
222
+ "gradient_updates_per_iteration",
223
+ round(collect_steps_per_iteration * train_collect_ratio),
224
+ )
225
+ )
226
+ missed_iterations = max(0, round((replay.env_steps - checkpoint_env_steps) / collect_steps_per_iteration))
227
+ catch_up_updates = min(total_updates_budget - updates_done, per_iteration_updates * missed_iterations)
228
+ catch_up_logs = train_world_model_updates(model, replay, optimizer, config, catch_up_updates, device)
229
+ training_logs.extend(catch_up_logs)
230
+ updates_done += len(catch_up_logs)
231
+ else:
232
+ initial_updates = min(total_updates_budget, max(1, int(round(replay.env_steps * train_collect_ratio))))
233
+ training_logs.extend(train_world_model_updates(model, replay, optimizer, config, initial_updates, device))
234
+ updates_done = len(training_logs)
235
+
236
+ comparison_budgets = config.get("comparison", {}).get("env_steps", [replay.env_steps])
237
+ target_env_steps = int(max(comparison_budgets))
238
+ rng = np.random.default_rng(seed)
239
+ env = make_env_from_config(config, seed=seed)
240
+ planner = DiscreteCEMPlanner.from_config(model, env.action_space.n, config)
241
+ eval_every_steps = int(config["evaluation"].get("eval_every_env_steps", target_env_steps))
242
+ next_eval_step = replay.env_steps
243
+
244
+ while replay.env_steps < target_env_steps and updates_done < total_updates_budget:
245
+ collect_steps = min(
246
+ int(config["collection"]["collect_steps_per_iteration"]),
247
+ target_env_steps - replay.env_steps,
248
+ )
249
+ collection_row = collect_planner_steps(
250
+ env,
251
+ replay,
252
+ model,
253
+ planner,
254
+ num_steps=collect_steps,
255
+ random_action_fraction=float(config["collection"].get("random_action_fraction_after_planner", 0.0)),
256
+ rng=rng,
257
+ )
258
+ updates = int(config["collection"].get("gradient_updates_per_iteration", round(collection_row["env_steps"] * train_collect_ratio)))
259
+ updates = min(updates, total_updates_budget - updates_done)
260
+ training_logs.extend(train_world_model_updates(model, replay, optimizer, config, updates, device))
261
+ updates_done = len(training_logs)
262
+ replay.save(run_dirs["replay"])
263
+
264
+ if replay.env_steps >= next_eval_step:
265
+ world_model_metrics = evaluate_world_model(config, model, replay, split="val", max_episodes=10)
266
+ planner_metrics = evaluate_planner(config, model, episodes=min(10, config["evaluation"]["episodes"]), seed=seed)
267
+ random_metrics = evaluate_random_policy(config, episodes=min(10, config["evaluation"]["episodes"]), seed=seed)
268
+ eval_row = {
269
+ "env_steps": replay.env_steps,
270
+ "updates_done": updates_done,
271
+ **{f"world_model/{key}": value for key, value in world_model_metrics.items()},
272
+ **{f"planner/{key}": value for key, value in planner_metrics.items()},
273
+ **{f"random/{key}": value for key, value in random_metrics.items()},
274
+ }
275
+ evaluation_logs.append(eval_row)
276
+ next_eval_step += eval_every_steps
277
+ save_world_model_checkpoint(
278
+ run_dirs["checkpoints"] / f"world_model_env_steps_{replay.env_steps}.pt",
279
+ model,
280
+ config,
281
+ optimizer=optimizer,
282
+ metadata={"env_steps": replay.env_steps, "updates_done": updates_done},
283
+ )
284
+
285
+ env.close()
286
+ save_world_model_checkpoint(
287
+ run_dirs["checkpoints"] / "world_model_latest.pt",
288
+ model,
289
+ config,
290
+ optimizer=optimizer,
291
+ metadata={"env_steps": replay.env_steps, "updates_done": updates_done},
292
+ )
293
+ write_json(run_dirs["metrics"] / "collection_summary.json", collection_summary)
294
+ write_jsonl(run_dirs["metrics"] / "train_metrics.jsonl", training_logs)
295
+ write_jsonl(run_dirs["metrics"] / "eval_metrics.jsonl", evaluation_logs)
296
+ summary = {
297
+ "replay": replay.summary(),
298
+ "updates_done": updates_done,
299
+ "device": str(device),
300
+ }
301
+ write_json(run_dirs["metrics"] / "run_summary.json", summary)
302
+ return summary
303
+
304
+
305
+ def build_arg_parser() -> argparse.ArgumentParser:
306
+ parser = argparse.ArgumentParser(description="Train the MiniDreamer world model.")
307
+ parser.add_argument("--config", type=Path, required=True)
308
+ parser.add_argument("--output-dir", type=Path, required=True)
309
+ parser.add_argument("--replay-dir", type=Path, default=None, help="Optional existing replay directory.")
310
+ parser.add_argument("--resume-checkpoint", type=Path, default=None, help="Optional checkpoint to resume from.")
311
+ parser.add_argument(
312
+ "--allow-overwrite-existing-output",
313
+ action="store_true",
314
+ help="Allow overwriting an existing run directory when not resuming.",
315
+ )
316
+ return parser
317
+
318
+
319
+ def main() -> None:
320
+ parser = build_arg_parser()
321
+ args = parser.parse_args()
322
+ config = load_config(args.config)
323
+ summary = run_training(
324
+ config,
325
+ args.output_dir,
326
+ replay_dir=args.replay_dir,
327
+ resume_checkpoint=args.resume_checkpoint,
328
+ allow_overwrite_existing_output=args.allow_overwrite_existing_output,
329
+ )
330
+ print(summary)
331
+
332
+
333
+ if __name__ == "__main__":
334
+ main()
tests/test_cem_planner.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from minidreamer.planning.cem import DiscreteCEMPlanner
4
+
5
+
6
+ class DummyWorldModel:
7
+ def __init__(self):
8
+ self.device = torch.device("cpu")
9
+
10
+ def score_action_sequences(self, state, action_sequences, discount=0.99, use_done_mask=True):
11
+ target = torch.tensor([1, 2, 0, 1], device=action_sequences.device)
12
+ scores = -(action_sequences != target).float().sum(dim=-1)
13
+ return {"scores": scores}
14
+
15
+
16
+ def test_discrete_cem_planner_finds_high_scoring_sequence():
17
+ torch.manual_seed(0)
18
+ planner = DiscreteCEMPlanner(
19
+ world_model=DummyWorldModel(),
20
+ action_dim=3,
21
+ horizon=4,
22
+ candidates=512,
23
+ elites=64,
24
+ iterations=5,
25
+ discount=1.0,
26
+ use_done_mask=False,
27
+ )
28
+ output = planner.plan(state=object())
29
+ assert output.action == 1
30
+ assert output.sequence == [1, 2, 0, 1]
31
+
tests/test_env.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ from minidreamer.envs.make_env import make_env
4
+
5
+
6
+ def test_make_env_returns_normalized_rgb_observation():
7
+ env = make_env(seed=0)
8
+ obs, _ = env.reset()
9
+ assert obs.shape == (64, 64, 3)
10
+ assert obs.dtype == np.float32
11
+ assert 0.0 <= float(obs.min()) <= float(obs.max()) <= 1.0
12
+
13
+ next_obs, reward, terminated, truncated, _ = env.step(env.action_space.sample())
14
+ assert next_obs.shape == (64, 64, 3)
15
+ assert isinstance(float(reward), float)
16
+ assert isinstance(bool(terminated), bool)
17
+ assert isinstance(bool(truncated), bool)
18
+ env.close()
19
+
tests/test_replay_buffer.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+ from minidreamer.data.replay_buffer import ReplayBuffer
7
+
8
+
9
+ def make_episode(length: int, reward: float = 0.0):
10
+ obs = np.random.rand(length + 1, 64, 64, 3).astype(np.float32)
11
+ actions = np.arange(length, dtype=np.int64) % 7
12
+ rewards = np.full(length, reward, dtype=np.float32)
13
+ terminated = np.zeros(length, dtype=np.float32)
14
+ truncated = np.zeros(length, dtype=np.float32)
15
+ done = np.zeros(length, dtype=np.float32)
16
+ terminated[-1] = 1.0
17
+ done[-1] = 1.0
18
+ return obs, actions, rewards, terminated, truncated, done
19
+
20
+
21
+ def test_replay_buffer_sampling_and_padding(tmp_path: Path):
22
+ buffer = ReplayBuffer(capacity_episodes=10, sequence_length=8, batch_size=4)
23
+ for episode_id, length in enumerate((3, 5, 9)):
24
+ obs, actions, rewards, terminated, truncated, done = make_episode(length, reward=float(episode_id))
25
+ buffer.add_episode(obs, actions, rewards, terminated, truncated, done, episode_id=episode_id)
26
+
27
+ available_split = next(split for split in ("train", "val", "test") if buffer.episode_ids(split))
28
+ batch = buffer.sample_sequences(split=available_split, batch_size=2, rng=np.random.default_rng(0))
29
+ assert batch["obs"].shape == (2, 9, 64, 64, 3)
30
+ assert batch["actions"].shape == (2, 8)
31
+ assert batch["mask"].shape == (2, 8)
32
+ assert np.all(batch["mask"].sum(axis=1) >= 1)
33
+
34
+ save_dir = tmp_path / "replay"
35
+ buffer.save(save_dir)
36
+ loaded = ReplayBuffer.load(save_dir)
37
+ assert loaded.summary()["episodes"] == buffer.summary()["episodes"]
38
+ assert loaded.summary()["env_steps"] == buffer.summary()["env_steps"]
39
+
40
+
41
+ def test_replay_buffer_torch_batch_shapes():
42
+ buffer = ReplayBuffer(capacity_episodes=4, sequence_length=4, batch_size=2)
43
+ obs, actions, rewards, terminated, truncated, done = make_episode(5, reward=1.0)
44
+ buffer.add_episode(obs, actions, rewards, terminated, truncated, done)
45
+ available_split = next(split for split in ("train", "val", "test") if buffer.episode_ids(split))
46
+ batch = buffer.sample_sequences(split=available_split, batch_size=2, rng=np.random.default_rng(1))
47
+ tensor_batch = ReplayBuffer.batch_to_torch(batch)
48
+ assert tensor_batch["obs"].shape == (2, 5, 3, 64, 64)
49
+ assert tensor_batch["actions"].dtype == torch.int64
50
+
tests/test_rssm_shapes.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from minidreamer.models.world_model import WorldModel
4
+
5
+
6
+ def test_world_model_sequence_shapes_and_loss():
7
+ torch.manual_seed(0)
8
+ model = WorldModel(
9
+ action_dim=7,
10
+ embedding_dim=128,
11
+ deter_dim=128,
12
+ stoch_dim=16,
13
+ hidden_dim=128,
14
+ use_decoder=True,
15
+ )
16
+ obs = torch.rand(4, 33, 3, 64, 64)
17
+ actions = torch.randint(0, 7, (4, 32))
18
+ outputs = model.observe_sequence(obs, actions, sample=False)
19
+ assert outputs.reward_pred.shape == (4, 32)
20
+ assert outputs.done_logits.shape == (4, 32)
21
+ assert outputs.prior_mean.shape == (4, 32, 16)
22
+ assert outputs.reconstructions is not None
23
+ assert outputs.reconstructions.shape == (4, 32, 3, 64, 64)
24
+
25
+ batch = {
26
+ "obs": obs,
27
+ "actions": actions,
28
+ "rewards": torch.zeros(4, 32),
29
+ "done": torch.zeros(4, 32),
30
+ "mask": torch.ones(4, 32),
31
+ }
32
+ config = {
33
+ "training": {
34
+ "beta_reward": 1.0,
35
+ "beta_done": 1.0,
36
+ "beta_kl": 1.0,
37
+ "beta_recon": 1.0,
38
+ "free_nats": 1.0,
39
+ }
40
+ }
41
+ losses = model.compute_losses(batch, config)
42
+ assert torch.isfinite(losses["loss"])
43
+