Sync local repo state
#1
by PatrykT - opened
- .gitignore +16 -0
- LICENSE +21 -0
- README.md +87 -3
- configs/fourrooms_ppo.yaml +31 -0
- configs/fourrooms_world_model.yaml +68 -0
- docs/spec_clarifications.md +26 -0
- minidreamer_project_spec.md +911 -0
- notebooks/results_analysis.ipynb +25 -0
- notebooks/rollout_debug.ipynb +26 -0
- plots/.gitkeep +1 -0
- plots/learning_curves.png +0 -0
- plots/model_error_vs_rollout_horizon.png +0 -0
- plots/success_rate_vs_env_steps.png +0 -0
- pyproject.toml +46 -0
- results.md +127 -0
- scripts/collect_random.sh +9 -0
- scripts/eval_planner.sh +26 -0
- scripts/generate_results_plots.py +159 -0
- scripts/train_ppo.sh +9 -0
- scripts/train_world_model.sh +10 -0
- src/evaluate.py +69 -0
- src/minidreamer/__init__.py +6 -0
- src/minidreamer/baselines/__init__.py +2 -0
- src/minidreamer/baselines/train_ppo.py +127 -0
- src/minidreamer/config.py +59 -0
- src/minidreamer/envs/__init__.py +2 -0
- src/minidreamer/envs/make_env.py +125 -0
- src/minidreamer/evaluation.py +147 -0
- src/minidreamer/models/__init__.py +2 -0
- src/minidreamer/models/decoder.py +27 -0
- src/minidreamer/models/encoder.py +29 -0
- src/minidreamer/models/heads.py +29 -0
- src/minidreamer/models/rssm.py +147 -0
- src/minidreamer/models/world_model.py +267 -0
- src/minidreamer/planning/__init__.py +2 -0
- src/minidreamer/planning/cem.py +103 -0
- src/minidreamer/planning/evaluate_planner.py +99 -0
- src/minidreamer/serialization.py +40 -0
- src/minidreamer/utils/__init__.py +2 -0
- src/minidreamer/utils/common.py +58 -0
- src/train_world_model.py +334 -0
- tests/test_cem_planner.py +31 -0
- tests/test_env.py +19 -0
- tests/test_replay_buffer.py +50 -0
- tests/test_rssm_shapes.py +43 -0
.gitignore
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.pytest_cache/
|
| 2 |
+
.ruff_cache/
|
| 3 |
+
.venv/
|
| 4 |
+
__pycache__/
|
| 5 |
+
*.pyc
|
| 6 |
+
*.pyo
|
| 7 |
+
*.pyd
|
| 8 |
+
*.so
|
| 9 |
+
*.egg-info/
|
| 10 |
+
.DS_Store
|
| 11 |
+
|
| 12 |
+
artifacts/
|
| 13 |
+
checkpoints/
|
| 14 |
+
metrics/
|
| 15 |
+
data/
|
| 16 |
+
logs/
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2026 alpatrykos
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
CHANGED
|
@@ -1,3 +1,87 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MiniDreamer
|
| 2 |
+
|
| 3 |
+
MiniDreamer is a PlaNet-style world model project for `MiniGrid-FourRooms-v0`. It learns a recurrent latent dynamics model from partial RGB observations, predicts reward and episode termination, and uses discrete CEM planning in latent space.
|
| 4 |
+
|
| 5 |
+
The repository contains:
|
| 6 |
+
|
| 7 |
+
- MiniGrid RGB environment wrappers and bootstrap trajectory collection
|
| 8 |
+
- Episode-aware replay buffer with reproducible train/val/test splits
|
| 9 |
+
- CNN encoder, Gaussian RSSM, reward/done heads, optional decoder
|
| 10 |
+
- Discrete CEM planner with termination-aware return scoring
|
| 11 |
+
- PPO baseline entrypoint with a MiniGrid-compatible CNN feature extractor
|
| 12 |
+
- Evaluation code, configs, scripts, tests, and project documentation
|
| 13 |
+
|
| 14 |
+
A complete baseline training run has been executed. A summary is recorded in [results.md](/Users/patryktargosinski/minidreamer/results.md), while the frozen baseline artifacts remain gitignored under `artifacts/world_model/`.
|
| 15 |
+
|
| 16 |
+
## Layout
|
| 17 |
+
|
| 18 |
+
```text
|
| 19 |
+
configs/
|
| 20 |
+
docs/
|
| 21 |
+
notebooks/
|
| 22 |
+
scripts/
|
| 23 |
+
src/
|
| 24 |
+
tests/
|
| 25 |
+
```
|
| 26 |
+
|
| 27 |
+
Core code lives under `src/minidreamer/`, with CLI entrypoints at `src/train_world_model.py` and `src/evaluate.py`.
|
| 28 |
+
|
| 29 |
+
## Setup
|
| 30 |
+
|
| 31 |
+
Use Python 3.11 or 3.12. The project metadata is defined in [pyproject.toml](/Users/patryktargosinski/minidreamer/pyproject.toml).
|
| 32 |
+
|
| 33 |
+
```bash
|
| 34 |
+
python3.11 -m venv .venv
|
| 35 |
+
source .venv/bin/activate
|
| 36 |
+
pip install -e ".[dev]"
|
| 37 |
+
```
|
| 38 |
+
|
| 39 |
+
## Main Commands
|
| 40 |
+
|
| 41 |
+
Bootstrap replay collection:
|
| 42 |
+
|
| 43 |
+
```bash
|
| 44 |
+
./scripts/collect_random.sh
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
World-model pipeline:
|
| 48 |
+
|
| 49 |
+
```bash
|
| 50 |
+
./scripts/train_world_model.sh
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
By default, the script writes new experiments to `artifacts/world_model_experiment/`. To choose a different experiment directory without touching the frozen baseline, set `MINIDREAMER_OUTPUT_DIR`:
|
| 54 |
+
|
| 55 |
+
```bash
|
| 56 |
+
MINIDREAMER_OUTPUT_DIR=artifacts/world_model_restricted_actions ./scripts/train_world_model.sh
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
Resume an interrupted world-model run from a checkpoint:
|
| 60 |
+
|
| 61 |
+
```bash
|
| 62 |
+
python3.11 src/train_world_model.py \
|
| 63 |
+
--config configs/fourrooms_world_model.yaml \
|
| 64 |
+
--output-dir artifacts/world_model \
|
| 65 |
+
--replay-dir artifacts/world_model/replay \
|
| 66 |
+
--resume-checkpoint artifacts/world_model/checkpoints/world_model_env_steps_90021.pt
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
Planner evaluation from a checkpoint:
|
| 70 |
+
|
| 71 |
+
```bash
|
| 72 |
+
./scripts/eval_planner.sh /path/to/checkpoint.pt /path/to/replay
|
| 73 |
+
```
|
| 74 |
+
|
| 75 |
+
PPO baseline:
|
| 76 |
+
|
| 77 |
+
```bash
|
| 78 |
+
./scripts/train_ppo.sh
|
| 79 |
+
```
|
| 80 |
+
|
| 81 |
+
## Notes
|
| 82 |
+
|
| 83 |
+
- The latest completed run summary is in [results.md](/Users/patryktargosinski/minidreamer/results.md).
|
| 84 |
+
- The baseline run in `artifacts/world_model/` is intentionally frozen as the reference artifact.
|
| 85 |
+
- New world-model experiments should write to separate directories under `artifacts/`.
|
| 86 |
+
- The trainer refuses to overwrite an existing run directory unless you resume with `--resume-checkpoint` or explicitly pass `--allow-overwrite-existing-output`.
|
| 87 |
+
- Metrics, replay snapshots, and checkpoints are intentionally gitignored.
|
configs/fourrooms_ppo.yaml
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
project:
|
| 2 |
+
name: minidreamer-fourrooms-ppo
|
| 3 |
+
seed: 0
|
| 4 |
+
|
| 5 |
+
env:
|
| 6 |
+
id: MiniGrid-FourRooms-v0
|
| 7 |
+
rgb_partial_obs: true
|
| 8 |
+
image_only: true
|
| 9 |
+
resize: [64, 64]
|
| 10 |
+
normalize_obs: true
|
| 11 |
+
|
| 12 |
+
ppo:
|
| 13 |
+
total_timesteps: 100000
|
| 14 |
+
num_envs: 4
|
| 15 |
+
learning_rate: 0.0003
|
| 16 |
+
n_steps: 256
|
| 17 |
+
batch_size: 256
|
| 18 |
+
n_epochs: 4
|
| 19 |
+
gamma: 0.99
|
| 20 |
+
gae_lambda: 0.95
|
| 21 |
+
clip_range: 0.2
|
| 22 |
+
ent_coef: 0.01
|
| 23 |
+
vf_coef: 0.5
|
| 24 |
+
features_dim: 256
|
| 25 |
+
device: auto
|
| 26 |
+
|
| 27 |
+
evaluation:
|
| 28 |
+
episodes: 100
|
| 29 |
+
seeds: [0, 1, 2]
|
| 30 |
+
eval_every_env_steps: 10000
|
| 31 |
+
|
configs/fourrooms_world_model.yaml
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
project:
|
| 2 |
+
name: minidreamer-fourrooms
|
| 3 |
+
seed: 0
|
| 4 |
+
|
| 5 |
+
env:
|
| 6 |
+
id: MiniGrid-FourRooms-v0
|
| 7 |
+
rgb_partial_obs: true
|
| 8 |
+
image_only: true
|
| 9 |
+
resize: [64, 64]
|
| 10 |
+
normalize_obs: true
|
| 11 |
+
action_space: full
|
| 12 |
+
|
| 13 |
+
replay:
|
| 14 |
+
capacity_episodes: 5000
|
| 15 |
+
sequence_length: 32
|
| 16 |
+
batch_size: 32
|
| 17 |
+
train_fraction: 0.8
|
| 18 |
+
val_fraction: 0.1
|
| 19 |
+
test_fraction: 0.1
|
| 20 |
+
split_key: episode_id
|
| 21 |
+
|
| 22 |
+
collection:
|
| 23 |
+
bootstrap_env_steps: 5000
|
| 24 |
+
bootstrap_success_threshold: 20
|
| 25 |
+
bootstrap_fallback_policy: restricted_random_3_actions
|
| 26 |
+
bootstrap_env_step_cap: 20000
|
| 27 |
+
collect_steps_per_iteration: 1000
|
| 28 |
+
gradient_updates_per_iteration: 1000
|
| 29 |
+
train_collect_ratio: 1.0
|
| 30 |
+
random_action_fraction_after_planner: 0.3
|
| 31 |
+
|
| 32 |
+
model:
|
| 33 |
+
embedding_dim: 256
|
| 34 |
+
deter_dim: 256
|
| 35 |
+
stoch_dim: 32
|
| 36 |
+
hidden_dim: 256
|
| 37 |
+
use_decoder: true
|
| 38 |
+
min_std: 0.1
|
| 39 |
+
|
| 40 |
+
training:
|
| 41 |
+
optimizer: adam
|
| 42 |
+
lr: 0.0003
|
| 43 |
+
grad_clip_norm: 100.0
|
| 44 |
+
train_steps: 100000
|
| 45 |
+
beta_kl: 1.0
|
| 46 |
+
free_nats: 1.0
|
| 47 |
+
beta_recon: 1.0
|
| 48 |
+
beta_reward: 1.0
|
| 49 |
+
beta_done: 1.0
|
| 50 |
+
device: null
|
| 51 |
+
|
| 52 |
+
planner:
|
| 53 |
+
type: discrete_cem
|
| 54 |
+
horizon: 8
|
| 55 |
+
candidates: 256
|
| 56 |
+
elites: 32
|
| 57 |
+
iterations: 4
|
| 58 |
+
discount: 0.99
|
| 59 |
+
use_done_mask: true
|
| 60 |
+
|
| 61 |
+
evaluation:
|
| 62 |
+
episodes: 100
|
| 63 |
+
seeds: [0, 1, 2]
|
| 64 |
+
eval_every_env_steps: 10000
|
| 65 |
+
|
| 66 |
+
comparison:
|
| 67 |
+
env_steps: [10000, 25000, 50000, 100000]
|
| 68 |
+
|
docs/spec_clarifications.md
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Spec Clarifications
|
| 2 |
+
|
| 3 |
+
This document records the implementation choices the spec left open enough that code needed an explicit default.
|
| 4 |
+
|
| 5 |
+
## Implemented choices
|
| 6 |
+
|
| 7 |
+
1. The Python package lives under `src/minidreamer/`, with thin CLI entrypoints at `src/train_world_model.py` and `src/evaluate.py`. That keeps imports stable while still matching the spec's requested top-level scripts.
|
| 8 |
+
2. The encoder and decoder use `padding=1` in all `4x4, stride=2` convolutions so `64x64` inputs shrink cleanly to `4x4` and decode symmetrically back to `64x64`.
|
| 9 |
+
3. CEM planning uses the prior mean during latent imagination instead of sampling the stochastic latent. That reduces planner variance and makes candidate ranking deterministic given the current model parameters.
|
| 10 |
+
4. Replay sampling pads only at the tail of an in-episode chunk and applies a transition mask so padded steps do not contribute to loss terms.
|
| 11 |
+
5. Bootstrap and online training both default to `train_collect_ratio = 1.0`, so the initial bootstrap replay produces one gradient update per collected environment step unless an explicit `gradient_updates_per_iteration` override is set.
|
| 12 |
+
6. Evaluation computes one-step held-out metrics over full episodes and open-loop rollout metrics for horizons `1`, `5`, and `10` using actual held-out action sequences.
|
| 13 |
+
|
| 14 |
+
## Remaining optional extensions
|
| 15 |
+
|
| 16 |
+
These are intentionally not implemented in v1 because the spec marked them as later improvements or ablations:
|
| 17 |
+
|
| 18 |
+
- KL balancing beyond free nats
|
| 19 |
+
- uncertainty penalties or ensembles in the planner
|
| 20 |
+
- actor-critic imagination learning
|
| 21 |
+
- richer sparse-reward heads beyond scalar reward regression
|
| 22 |
+
|
| 23 |
+
## Budget semantics
|
| 24 |
+
|
| 25 |
+
Collection currently finishes complete episodes rather than cutting trajectories mid-episode. That keeps replay episodes semantically clean for recurrent training, but it means a run can land slightly above a requested step target when the last episode crosses the threshold. If exact step-matched checkpoints become mandatory for reporting, the next refinement should snapshot metrics at the first checkpoint at or above each budget and report the realized environment step count alongside the nominal target.
|
| 26 |
+
|
minidreamer_project_spec.md
ADDED
|
@@ -0,0 +1,911 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MiniDreamer: A PlaNet-style World Model for Pixel-based MiniGrid Planning
|
| 2 |
+
|
| 3 |
+
## 1. Project name
|
| 4 |
+
|
| 5 |
+
**MiniDreamer: A PlaNet-style world model for pixel-based MiniGrid planning**
|
| 6 |
+
|
| 7 |
+
## 2. Goal
|
| 8 |
+
|
| 9 |
+
Build a small research-grade world model agent that learns a latent dynamics model from pixel observations in `MiniGrid-FourRooms-v0`, then uses the learned model for short-horizon planning.
|
| 10 |
+
|
| 11 |
+
The project should answer:
|
| 12 |
+
|
| 13 |
+
> Can a compact latent dynamics model trained from partial RGB observations support useful planning in a sparse-reward gridworld?
|
| 14 |
+
|
| 15 |
+
MiniGrid is a good first target because it is designed as a simple, fast, customizable RL benchmark suite, with discrete actions and goal-oriented environments. FourRooms specifically asks an agent to navigate a four-room maze to reach a green goal; the registered environment is `MiniGrid-FourRooms-v0`.
|
| 16 |
+
|
| 17 |
+
Reference: [MiniGrid documentation](https://minigrid.farama.org/index.html)
|
| 18 |
+
|
| 19 |
+
---
|
| 20 |
+
|
| 21 |
+
## 3. Scope
|
| 22 |
+
|
| 23 |
+
### In scope
|
| 24 |
+
|
| 25 |
+
- Pixel-based observations using:
|
| 26 |
+
- `RGBImgPartialObsWrapper`
|
| 27 |
+
- `ImgObsWrapper`
|
| 28 |
+
- Environment:
|
| 29 |
+
- MVP: `MiniGrid-FourRooms-v0`
|
| 30 |
+
- Extension: `MiniGrid-DoorKey-*`, `MiniGrid-LockedRoom-v0`, or `MiniGrid-Dynamic-Obstacles-*`
|
| 31 |
+
- Learned latent dynamics model
|
| 32 |
+
- Reward and termination prediction
|
| 33 |
+
- MPC/CEM-style planning in latent space
|
| 34 |
+
- PPO baseline for comparison
|
| 35 |
+
- Ablations:
|
| 36 |
+
- with image reconstruction vs. without
|
| 37 |
+
- latent size
|
| 38 |
+
- planning horizon
|
| 39 |
+
- stochastic vs. deterministic latent state
|
| 40 |
+
|
| 41 |
+
### Out of scope for v1
|
| 42 |
+
|
| 43 |
+
- Full DreamerV3 reproduction
|
| 44 |
+
- Actor-critic learning inside imagination
|
| 45 |
+
- Minecraft/Crafter-scale environments
|
| 46 |
+
- Large video-generation-style world models
|
| 47 |
+
- Language-conditioned BabyAI tasks
|
| 48 |
+
|
| 49 |
+
---
|
| 50 |
+
|
| 51 |
+
## 4. Technical background
|
| 52 |
+
|
| 53 |
+
This should be a **PlaNet-lite** implementation. PlaNet learns a latent dynamics model from images and chooses actions through online planning in latent space, using both deterministic and stochastic transition components.
|
| 54 |
+
|
| 55 |
+
Reference: [PlaNet: Learning Latent Dynamics for Planning from Pixels](https://arxiv.org/abs/1811.04551)
|
| 56 |
+
|
| 57 |
+
DreamerV3 is the more mature descendant: it learns a world model and improves behavior by imagining future scenarios, but its full implementation is too much for a first project. Use DreamerV3 as conceptual inspiration, not as the implementation target.
|
| 58 |
+
|
| 59 |
+
Reference: [DreamerV3](https://arxiv.org/abs/2301.04104)
|
| 60 |
+
|
| 61 |
+
MiniGrid’s default observation is a compact symbolic encoding, not raw pixels. For this project, use `RGBImgPartialObsWrapper` to obtain RGB pixel observations, then `ImgObsWrapper` to keep only the image tensor.
|
| 62 |
+
|
| 63 |
+
Reference: [MiniGrid wrappers](https://minigrid.farama.org/api/wrapper/)
|
| 64 |
+
|
| 65 |
+
---
|
| 66 |
+
|
| 67 |
+
## 5. Core hypothesis
|
| 68 |
+
|
| 69 |
+
### Main hypothesis
|
| 70 |
+
|
| 71 |
+
A learned recurrent latent dynamics model can support better sample efficiency than a model-free baseline in early training, even if final performance is lower or less stable.
|
| 72 |
+
|
| 73 |
+
### Secondary hypotheses
|
| 74 |
+
|
| 75 |
+
1. **Reconstruction loss may help representation learning early**, but may hurt planning if it forces the latent state to model irrelevant pixels.
|
| 76 |
+
2. **Short planning horizons will work better than long ones**, because model error compounds over imagined rollouts.
|
| 77 |
+
3. **Stochastic latent states should outperform deterministic-only states** under partial observability.
|
| 78 |
+
4. **Reward prediction quality matters more than pixel reconstruction quality** for planning performance.
|
| 79 |
+
|
| 80 |
+
---
|
| 81 |
+
|
| 82 |
+
## 6. Environment specification
|
| 83 |
+
|
| 84 |
+
### MVP environment
|
| 85 |
+
|
| 86 |
+
```python
|
| 87 |
+
env_id = "MiniGrid-FourRooms-v0"
|
| 88 |
+
```
|
| 89 |
+
|
| 90 |
+
FourRooms has a discrete action space of size 7. The meaningful actions for this task are mostly `left`, `right`, and `forward`; `pickup`, `drop`, `toggle`, and `done` are listed but unused in FourRooms.
|
| 91 |
+
|
| 92 |
+
Reference: [MiniGrid FourRooms environment](https://minigrid.farama.org/environments/minigrid/FourRoomsEnv/)
|
| 93 |
+
|
| 94 |
+
### Wrappers
|
| 95 |
+
|
| 96 |
+
```python
|
| 97 |
+
import gymnasium as gym
|
| 98 |
+
from minigrid.wrappers import RGBImgPartialObsWrapper, ImgObsWrapper
|
| 99 |
+
|
| 100 |
+
env = gym.make("MiniGrid-FourRooms-v0")
|
| 101 |
+
env = RGBImgPartialObsWrapper(env)
|
| 102 |
+
env = ImgObsWrapper(env)
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
### Observation
|
| 106 |
+
|
| 107 |
+
Use RGB partial observation.
|
| 108 |
+
|
| 109 |
+
Normalize to:
|
| 110 |
+
|
| 111 |
+
```python
|
| 112 |
+
obs = obs.astype("float32") / 255.0
|
| 113 |
+
```
|
| 114 |
+
|
| 115 |
+
Recommended internal size:
|
| 116 |
+
|
| 117 |
+
```text
|
| 118 |
+
64 x 64 x 3
|
| 119 |
+
```
|
| 120 |
+
|
| 121 |
+
Resize if needed.
|
| 122 |
+
|
| 123 |
+
### Action space
|
| 124 |
+
|
| 125 |
+
Use the full 7-action space for baseline compatibility.
|
| 126 |
+
|
| 127 |
+
Optional ablation:
|
| 128 |
+
|
| 129 |
+
```text
|
| 130 |
+
restricted_action_space = {left, right, forward}
|
| 131 |
+
```
|
| 132 |
+
|
| 133 |
+
If bootstrap data collection yields too few successful episodes, it is acceptable to use the restricted 3-action random policy during bootstrap only. If that fallback is used, record it explicitly in the report and still count all bootstrap environment steps toward the world-model sample-efficiency budget.
|
| 134 |
+
|
| 135 |
+
### Reward
|
| 136 |
+
|
| 137 |
+
FourRooms gives a success reward based on step count and zero reward for failure; the documented reward is:
|
| 138 |
+
|
| 139 |
+
```text
|
| 140 |
+
1 - 0.9 * step_count / max_steps
|
| 141 |
+
```
|
| 142 |
+
|
| 143 |
+
For model training, predict:
|
| 144 |
+
|
| 145 |
+
```text
|
| 146 |
+
reward_t ∈ R
|
| 147 |
+
done_t ∈ {0, 1}
|
| 148 |
+
```
|
| 149 |
+
|
| 150 |
+
### Episode end semantics
|
| 151 |
+
|
| 152 |
+
Use the Gymnasium step API explicitly:
|
| 153 |
+
|
| 154 |
+
```python
|
| 155 |
+
obs_next, reward, terminated, truncated, info = env.step(action)
|
| 156 |
+
```
|
| 157 |
+
|
| 158 |
+
Store all three episode-end signals:
|
| 159 |
+
|
| 160 |
+
```text
|
| 161 |
+
terminated_t = goal reached / environment terminal condition
|
| 162 |
+
truncated_t = episode stopped by time limit
|
| 163 |
+
done_t = terminated_t or truncated_t
|
| 164 |
+
```
|
| 165 |
+
|
| 166 |
+
Training and planning semantics:
|
| 167 |
+
|
| 168 |
+
- `reward_t` is the scalar emitted after taking `action_t`
|
| 169 |
+
- `done_t` is the episode-end target used by the world model and planner
|
| 170 |
+
- `success_t = 1` iff `terminated_t == 1` and `reward_t > 0`
|
| 171 |
+
|
| 172 |
+
Do not collapse `terminated` and `truncated` in logged metrics; report them separately when debugging failures.
|
| 173 |
+
|
| 174 |
+
---
|
| 175 |
+
|
| 176 |
+
## 7. Model architecture
|
| 177 |
+
|
| 178 |
+
### 7.1 Overview
|
| 179 |
+
|
| 180 |
+
```text
|
| 181 |
+
obs_t ──► CNN encoder ──► embedding e_t
|
| 182 |
+
│
|
| 183 |
+
action_{t-1} ────────────┤
|
| 184 |
+
▼
|
| 185 |
+
RSSM latent model
|
| 186 |
+
│
|
| 187 |
+
┌──────────┼──────────┐
|
| 188 |
+
▼ ▼ ▼
|
| 189 |
+
reward head done head decoder? optional
|
| 190 |
+
```
|
| 191 |
+
|
| 192 |
+
### 7.2 Latent state
|
| 193 |
+
|
| 194 |
+
Use a simplified **RSSM-style** state:
|
| 195 |
+
|
| 196 |
+
```text
|
| 197 |
+
h_t = deterministic recurrent state
|
| 198 |
+
z_t = stochastic latent state
|
| 199 |
+
s_t = concat(h_t, z_t)
|
| 200 |
+
```
|
| 201 |
+
|
| 202 |
+
Recommended dimensions:
|
| 203 |
+
|
| 204 |
+
```yaml
|
| 205 |
+
embedding_dim: 256
|
| 206 |
+
deterministic_state_dim: 256
|
| 207 |
+
stochastic_state_dim: 32
|
| 208 |
+
num_stochastic_classes: null # use Gaussian latent for v1
|
| 209 |
+
```
|
| 210 |
+
|
| 211 |
+
### 7.3 Encoder
|
| 212 |
+
|
| 213 |
+
CNN encoder:
|
| 214 |
+
|
| 215 |
+
```text
|
| 216 |
+
Input: 64x64x3 RGB
|
| 217 |
+
Conv 4x4 stride 2, channels 32
|
| 218 |
+
Conv 4x4 stride 2, channels 64
|
| 219 |
+
Conv 4x4 stride 2, channels 128
|
| 220 |
+
Conv 4x4 stride 2, channels 256
|
| 221 |
+
Flatten
|
| 222 |
+
Linear -> embedding_dim
|
| 223 |
+
```
|
| 224 |
+
|
| 225 |
+
Use `LayerNorm` or `BatchNorm` only if training is unstable. Start simple.
|
| 226 |
+
|
| 227 |
+
### 7.4 Recurrent dynamics
|
| 228 |
+
|
| 229 |
+
Prior:
|
| 230 |
+
|
| 231 |
+
```text
|
| 232 |
+
p(z_t | h_t)
|
| 233 |
+
```
|
| 234 |
+
|
| 235 |
+
Posterior:
|
| 236 |
+
|
| 237 |
+
```text
|
| 238 |
+
q(z_t | h_t, e_t)
|
| 239 |
+
```
|
| 240 |
+
|
| 241 |
+
Transition:
|
| 242 |
+
|
| 243 |
+
```text
|
| 244 |
+
h_t = GRU(h_{t-1}, concat(z_{t-1}, one_hot(action_{t-1})))
|
| 245 |
+
```
|
| 246 |
+
|
| 247 |
+
Gaussian latent:
|
| 248 |
+
|
| 249 |
+
```text
|
| 250 |
+
prior_mean, prior_std = prior_net(h_t)
|
| 251 |
+
post_mean, post_std = posterior_net(h_t, e_t)
|
| 252 |
+
z_t ~ Normal(post_mean, post_std)
|
| 253 |
+
```
|
| 254 |
+
|
| 255 |
+
Use reparameterization.
|
| 256 |
+
|
| 257 |
+
Clamp std:
|
| 258 |
+
|
| 259 |
+
```python
|
| 260 |
+
std = softplus(raw_std) + 0.1
|
| 261 |
+
```
|
| 262 |
+
|
| 263 |
+
### 7.4.1 Sequence indexing and initialization
|
| 264 |
+
|
| 265 |
+
Use transition tuples with explicit alignment:
|
| 266 |
+
|
| 267 |
+
```text
|
| 268 |
+
(obs_t, action_t, reward_t, terminated_t, truncated_t, done_t, obs_{t+1})
|
| 269 |
+
```
|
| 270 |
+
|
| 271 |
+
Training-time convention:
|
| 272 |
+
|
| 273 |
+
1. Infer posterior `s_t` from `obs_t` and the previous recurrent state.
|
| 274 |
+
2. Apply `action_t` in the transition model to form the prior for step `t+1`.
|
| 275 |
+
3. Condition on `obs_{t+1}` to infer posterior `s_{t+1}`.
|
| 276 |
+
4. Predict `reward_t`, `done_t`, and optional reconstruction of `obs_{t+1}` from `s_{t+1}`.
|
| 277 |
+
|
| 278 |
+
Sequence-start convention:
|
| 279 |
+
|
| 280 |
+
- initialize `h_0` and `z_0` to zeros
|
| 281 |
+
- use a zero action embedding in place of `action_{-1}`
|
| 282 |
+
- never sample a training chunk that crosses an episode boundary
|
| 283 |
+
- if padding is needed for batching, pad only within an episode and apply a loss mask so padded steps do not contribute to any loss
|
| 284 |
+
|
| 285 |
+
### 7.5 Prediction heads
|
| 286 |
+
|
| 287 |
+
Reward head:
|
| 288 |
+
|
| 289 |
+
```text
|
| 290 |
+
MLP(s_{t+1}) -> scalar reward_t
|
| 291 |
+
```
|
| 292 |
+
|
| 293 |
+
Done head:
|
| 294 |
+
|
| 295 |
+
```text
|
| 296 |
+
MLP(s_{t+1}) -> Bernoulli logit for done_t
|
| 297 |
+
```
|
| 298 |
+
|
| 299 |
+
Optional decoder:
|
| 300 |
+
|
| 301 |
+
```text
|
| 302 |
+
MLP/ConvTranspose(s_{t+1}) -> reconstructed RGB observation obs_{t+1}
|
| 303 |
+
```
|
| 304 |
+
|
| 305 |
+
For v1, include the decoder behind a config flag:
|
| 306 |
+
|
| 307 |
+
```yaml
|
| 308 |
+
use_decoder: true | false
|
| 309 |
+
```
|
| 310 |
+
|
| 311 |
+
---
|
| 312 |
+
|
| 313 |
+
## 8. Loss function
|
| 314 |
+
|
| 315 |
+
Total loss:
|
| 316 |
+
|
| 317 |
+
```text
|
| 318 |
+
L = β_reward * L_reward + β_done * L_done + β_kl * max(L_kl, free_nats) + β_recon * L_recon
|
| 319 |
+
```
|
| 320 |
+
|
| 321 |
+
### Reward loss
|
| 322 |
+
|
| 323 |
+
```text
|
| 324 |
+
L_reward = MSE(pred_reward_t, reward_t)
|
| 325 |
+
```
|
| 326 |
+
|
| 327 |
+
Alternative for sparse rewards:
|
| 328 |
+
|
| 329 |
+
```text
|
| 330 |
+
two_head_reward:
|
| 331 |
+
reward_occurrence: BCE
|
| 332 |
+
reward_value_given_success: MSE
|
| 333 |
+
```
|
| 334 |
+
|
| 335 |
+
Do not start there unless MSE fails.
|
| 336 |
+
|
| 337 |
+
### Done loss
|
| 338 |
+
|
| 339 |
+
```text
|
| 340 |
+
L_done = BCEWithLogits(done_logit_t, done_t)
|
| 341 |
+
```
|
| 342 |
+
|
| 343 |
+
### KL loss
|
| 344 |
+
|
| 345 |
+
```text
|
| 346 |
+
L_kl = KL(q(z_t | h_t, e_t) || p(z_t | h_t))
|
| 347 |
+
```
|
| 348 |
+
|
| 349 |
+
Use free-nats from the start in v1. KL balancing is a later stabilization improvement.
|
| 350 |
+
|
| 351 |
+
Initial config:
|
| 352 |
+
|
| 353 |
+
```yaml
|
| 354 |
+
beta_reward: 1.0
|
| 355 |
+
beta_done: 1.0
|
| 356 |
+
beta_kl: 1.0
|
| 357 |
+
free_nats: 1.0
|
| 358 |
+
```
|
| 359 |
+
|
| 360 |
+
### Reconstruction loss
|
| 361 |
+
|
| 362 |
+
If decoder is enabled:
|
| 363 |
+
|
| 364 |
+
```text
|
| 365 |
+
L_recon = MSE(reconstructed_obs_{t+1}, obs_{t+1})
|
| 366 |
+
```
|
| 367 |
+
|
| 368 |
+
If decoder is disabled, set:
|
| 369 |
+
|
| 370 |
+
```yaml
|
| 371 |
+
beta_recon: 0.0
|
| 372 |
+
```
|
| 373 |
+
|
| 374 |
+
Initial config:
|
| 375 |
+
|
| 376 |
+
```yaml
|
| 377 |
+
beta_recon: 1.0
|
| 378 |
+
```
|
| 379 |
+
|
| 380 |
+
Ablate with:
|
| 381 |
+
|
| 382 |
+
```yaml
|
| 383 |
+
beta_recon: 0.0
|
| 384 |
+
```
|
| 385 |
+
|
| 386 |
+
---
|
| 387 |
+
|
| 388 |
+
## 9. Data collection
|
| 389 |
+
|
| 390 |
+
### Phase 1: bootstrap dataset
|
| 391 |
+
|
| 392 |
+
Collect:
|
| 393 |
+
|
| 394 |
+
```yaml
|
| 395 |
+
bootstrap_env_steps: 5000
|
| 396 |
+
policy: random_full_action_space
|
| 397 |
+
fallback_if_successes_too_low: random_restricted_action_space
|
| 398 |
+
min_success_episodes_before_planning: 20
|
| 399 |
+
bootstrap_env_step_cap: 20000
|
| 400 |
+
```
|
| 401 |
+
|
| 402 |
+
Use environment-step budgets, not fixed episode counts, so the early comparison points against PPO remain meaningful.
|
| 403 |
+
|
| 404 |
+
Bootstrap rule:
|
| 405 |
+
|
| 406 |
+
- start with full-action random data collection
|
| 407 |
+
- if the replay buffer contains fewer than `20` successful episodes after `5000` steps, continue collecting with restricted-action random `{left, right, forward}`
|
| 408 |
+
- do not make sample-efficiency claims for planner performance before the replay buffer contains at least `20` successful episodes
|
| 409 |
+
|
| 410 |
+
Store transitions:
|
| 411 |
+
|
| 412 |
+
```text
|
| 413 |
+
obs_t
|
| 414 |
+
action_t
|
| 415 |
+
reward_t
|
| 416 |
+
terminated_t
|
| 417 |
+
truncated_t
|
| 418 |
+
done_t
|
| 419 |
+
obs_{t+1}
|
| 420 |
+
episode_id
|
| 421 |
+
step_id
|
| 422 |
+
```
|
| 423 |
+
|
| 424 |
+
### Phase 2: mixed dataset
|
| 425 |
+
|
| 426 |
+
After bootstrap, alternate fixed collection/training rounds:
|
| 427 |
+
|
| 428 |
+
```yaml
|
| 429 |
+
collect_steps_per_iteration: 1000
|
| 430 |
+
gradient_updates_per_iteration: 1000
|
| 431 |
+
```
|
| 432 |
+
|
| 433 |
+
Default collection mix:
|
| 434 |
+
|
| 435 |
+
```text
|
| 436 |
+
70% planner policy
|
| 437 |
+
30% random exploration
|
| 438 |
+
```
|
| 439 |
+
|
| 440 |
+
This prevents early model errors from collapsing exploration. `train_collect_ratio = 1.0` means one gradient update per newly collected environment step.
|
| 441 |
+
|
| 442 |
+
### Replay format and splits
|
| 443 |
+
|
| 444 |
+
Use chunked trajectory sampling:
|
| 445 |
+
|
| 446 |
+
```yaml
|
| 447 |
+
sequence_length: 32
|
| 448 |
+
batch_size: 32
|
| 449 |
+
```
|
| 450 |
+
|
| 451 |
+
Sample contiguous sequences, not independent transitions, because the recurrent model needs temporal structure.
|
| 452 |
+
|
| 453 |
+
Maintain episode-level dataset splits:
|
| 454 |
+
|
| 455 |
+
```yaml
|
| 456 |
+
train_fraction: 0.8
|
| 457 |
+
val_fraction: 0.1
|
| 458 |
+
test_fraction: 0.1
|
| 459 |
+
split_key: episode_id
|
| 460 |
+
```
|
| 461 |
+
|
| 462 |
+
Assign splits by episode id, not by transition, and keep the split fixed as more episodes are collected. Use train episodes for optimization, val episodes for model selection and debugging, and held-out test episodes only for final world-model reporting.
|
| 463 |
+
|
| 464 |
+
---
|
| 465 |
+
|
| 466 |
+
## 10. Planner
|
| 467 |
+
|
| 468 |
+
### MVP planner: CEM over discrete actions
|
| 469 |
+
|
| 470 |
+
At each real environment step:
|
| 471 |
+
|
| 472 |
+
1. Encode current observation.
|
| 473 |
+
2. Update posterior latent state.
|
| 474 |
+
3. Sample candidate action sequences.
|
| 475 |
+
4. Roll them forward using the learned prior model.
|
| 476 |
+
5. Score imagined returns.
|
| 477 |
+
6. Execute the first action from the best sequence.
|
| 478 |
+
7. Replan at the next step.
|
| 479 |
+
|
| 480 |
+
### Planning config
|
| 481 |
+
|
| 482 |
+
```yaml
|
| 483 |
+
planning_horizon: 8
|
| 484 |
+
num_candidates: 256
|
| 485 |
+
num_elites: 32
|
| 486 |
+
cem_iterations: 4
|
| 487 |
+
discount: 0.99
|
| 488 |
+
```
|
| 489 |
+
|
| 490 |
+
Because actions are discrete, maintain categorical probabilities over actions per timestep:
|
| 491 |
+
|
| 492 |
+
```text
|
| 493 |
+
π_h ∈ R[horizon, num_actions]
|
| 494 |
+
```
|
| 495 |
+
|
| 496 |
+
CEM loop:
|
| 497 |
+
|
| 498 |
+
```text
|
| 499 |
+
sample action sequences from π
|
| 500 |
+
roll out latent dynamics
|
| 501 |
+
score by termination-aware discounted reward
|
| 502 |
+
select top-k elites
|
| 503 |
+
update π toward elite action frequencies
|
| 504 |
+
```
|
| 505 |
+
|
| 506 |
+
### Score function
|
| 507 |
+
|
| 508 |
+
```text
|
| 509 |
+
done_prob_t = sigmoid(pred_done_logit_t)
|
| 510 |
+
alive_0 = 1
|
| 511 |
+
alive_{t+1} = alive_t * (1 - done_prob_t)
|
| 512 |
+
score = Σ_t γ^t * alive_t * predicted_reward_t
|
| 513 |
+
```
|
| 514 |
+
|
| 515 |
+
This makes the planner termination-aware and prevents CEM from exploiting trajectories that unrealistically continue after predicted episode end.
|
| 516 |
+
|
| 517 |
+
Optional later refinement:
|
| 518 |
+
|
| 519 |
+
```text
|
| 520 |
+
score += λ * predicted_success_probability
|
| 521 |
+
score -= μ * predicted_done_probability_before_reward
|
| 522 |
+
```
|
| 523 |
+
|
| 524 |
+
Keep v1 simple: use only discounted reward with the alive-mask above.
|
| 525 |
+
|
| 526 |
+
---
|
| 527 |
+
|
| 528 |
+
## 11. Baselines
|
| 529 |
+
|
| 530 |
+
### Baseline 1: random policy
|
| 531 |
+
|
| 532 |
+
Measure:
|
| 533 |
+
|
| 534 |
+
```text
|
| 535 |
+
success_rate
|
| 536 |
+
mean_return
|
| 537 |
+
mean_episode_length
|
| 538 |
+
```
|
| 539 |
+
|
| 540 |
+
### Baseline 2: PPO
|
| 541 |
+
|
| 542 |
+
Use Stable-Baselines3 PPO with a custom CNN feature extractor. MiniGrid’s own training tutorial shows PPO training with Stable-Baselines3 and notes that a custom feature extractor is needed because the default CNN architecture does not directly support MiniGrid’s observation space.
|
| 543 |
+
|
| 544 |
+
Reference: [MiniGrid training tutorial](https://minigrid.farama.org/content/training/)
|
| 545 |
+
|
| 546 |
+
Compare PPO and world model under equal environment steps:
|
| 547 |
+
|
| 548 |
+
```yaml
|
| 549 |
+
env_steps:
|
| 550 |
+
- 10_000
|
| 551 |
+
- 25_000
|
| 552 |
+
- 50_000
|
| 553 |
+
- 100_000
|
| 554 |
+
```
|
| 555 |
+
|
| 556 |
+
### Comparison protocol
|
| 557 |
+
|
| 558 |
+
Use the same wrapped observation setup for PPO and the world-model agent unless an experiment explicitly states otherwise.
|
| 559 |
+
|
| 560 |
+
Budget accounting:
|
| 561 |
+
|
| 562 |
+
- count every environment interaction used for world-model data collection toward the world-model budget, including bootstrap random data and later planner-driven data
|
| 563 |
+
- do not count offline gradient updates toward environment-step budgets
|
| 564 |
+
- do not count evaluation episodes toward training budgets
|
| 565 |
+
|
| 566 |
+
Checkpointing protocol at budget `B`:
|
| 567 |
+
|
| 568 |
+
1. PPO is trained for exactly `B` environment steps, then evaluated.
|
| 569 |
+
2. The world-model pipeline is allowed to collect exactly `B` environment steps in total, updating the model online according to Section 9, then evaluated.
|
| 570 |
+
3. Report the latest checkpoint at budget `B`, not the best checkpoint seen so far.
|
| 571 |
+
4. Use the same seeds and the same number of evaluation episodes for both methods.
|
| 572 |
+
|
| 573 |
+
---
|
| 574 |
+
|
| 575 |
+
## 12. Evaluation metrics
|
| 576 |
+
|
| 577 |
+
### Agent metrics
|
| 578 |
+
|
| 579 |
+
```yaml
|
| 580 |
+
success_rate: fraction of episodes reaching goal
|
| 581 |
+
mean_return: average episodic return
|
| 582 |
+
median_return: robust episodic return
|
| 583 |
+
mean_episode_length: average steps per episode
|
| 584 |
+
env_steps_to_80_percent_success: sample efficiency metric
|
| 585 |
+
```
|
| 586 |
+
|
| 587 |
+
### World model metrics
|
| 588 |
+
|
| 589 |
+
```yaml
|
| 590 |
+
reward_mse: reward prediction error
|
| 591 |
+
done_bce: termination prediction error
|
| 592 |
+
kl_loss: posterior-prior divergence
|
| 593 |
+
reconstruction_mse: if decoder enabled
|
| 594 |
+
open_loop_reward_error_h1_h5_h10: reward rollout error over horizons
|
| 595 |
+
open_loop_done_accuracy_h1_h5_h10: done prediction accuracy over horizons
|
| 596 |
+
```
|
| 597 |
+
|
| 598 |
+
World-model metrics must be reported on held-out validation or test episodes, never on the training replay used for optimization. For final tables and plots, prefer the held-out test split.
|
| 599 |
+
|
| 600 |
+
### Planning metrics
|
| 601 |
+
|
| 602 |
+
```yaml
|
| 603 |
+
imagined_return_vs_real_return_correlation
|
| 604 |
+
planner_action_entropy
|
| 605 |
+
model_rollout_horizon_sensitivity
|
| 606 |
+
```
|
| 607 |
+
|
| 608 |
+
---
|
| 609 |
+
|
| 610 |
+
## 13. Experiment matrix
|
| 611 |
+
|
| 612 |
+
### MVP experiments
|
| 613 |
+
|
| 614 |
+
| ID | Model | Decoder | Latent | Horizon | Actions | Purpose |
|
| 615 |
+
|---:|---|---:|---:|---:|---|---|
|
| 616 |
+
| E0 | Random | n/a | n/a | n/a | 7 | lower bound |
|
| 617 |
+
| E1 | PPO | n/a | n/a | n/a | 7 | model-free baseline |
|
| 618 |
+
| E2 | World model | yes | 32 | 8 | 7 | default model |
|
| 619 |
+
| E3 | World model | no | 32 | 8 | 7 | test reconstruction value |
|
| 620 |
+
| E4 | World model | yes | 16 | 8 | 7 | latent bottleneck |
|
| 621 |
+
| E5 | World model | yes | 64 | 8 | 7 | larger latent |
|
| 622 |
+
| E6 | World model | yes | 32 | 4 | 7 | short planning |
|
| 623 |
+
| E7 | World model | yes | 32 | 16 | 7 | long planning |
|
| 624 |
+
|
| 625 |
+
### Extension experiments
|
| 626 |
+
|
| 627 |
+
| ID | Environment | Purpose |
|
| 628 |
+
|---:|---|---|
|
| 629 |
+
| X1 | DoorKey | test object interaction |
|
| 630 |
+
| X2 | LockedRoom | test longer-horizon partial observability |
|
| 631 |
+
| X3 | DynamicObstacles | test non-stationary local hazards |
|
| 632 |
+
| X4 | FourRooms larger/custom | test layout generalization |
|
| 633 |
+
|
| 634 |
+
---
|
| 635 |
+
|
| 636 |
+
## 14. Expected deliverables
|
| 637 |
+
|
| 638 |
+
### Code deliverables
|
| 639 |
+
|
| 640 |
+
```text
|
| 641 |
+
minidreamer/
|
| 642 |
+
configs/
|
| 643 |
+
fourrooms_world_model.yaml
|
| 644 |
+
fourrooms_ppo.yaml
|
| 645 |
+
src/
|
| 646 |
+
envs/
|
| 647 |
+
make_env.py
|
| 648 |
+
data/
|
| 649 |
+
replay_buffer.py
|
| 650 |
+
collect_random.py
|
| 651 |
+
models/
|
| 652 |
+
encoder.py
|
| 653 |
+
rssm.py
|
| 654 |
+
heads.py
|
| 655 |
+
decoder.py
|
| 656 |
+
world_model.py
|
| 657 |
+
planning/
|
| 658 |
+
cem.py
|
| 659 |
+
evaluate_planner.py
|
| 660 |
+
baselines/
|
| 661 |
+
train_ppo.py
|
| 662 |
+
train_world_model.py
|
| 663 |
+
evaluate.py
|
| 664 |
+
scripts/
|
| 665 |
+
collect_random.sh
|
| 666 |
+
train_world_model.sh
|
| 667 |
+
eval_planner.sh
|
| 668 |
+
train_ppo.sh
|
| 669 |
+
notebooks/
|
| 670 |
+
rollout_debug.ipynb
|
| 671 |
+
results_analysis.ipynb
|
| 672 |
+
tests/
|
| 673 |
+
test_env.py
|
| 674 |
+
test_replay_buffer.py
|
| 675 |
+
test_rssm_shapes.py
|
| 676 |
+
test_cem_planner.py
|
| 677 |
+
```
|
| 678 |
+
|
| 679 |
+
### Research deliverables
|
| 680 |
+
|
| 681 |
+
```text
|
| 682 |
+
README.md
|
| 683 |
+
results.md
|
| 684 |
+
plots/
|
| 685 |
+
learning_curves.png
|
| 686 |
+
success_rate_vs_env_steps.png
|
| 687 |
+
model_error_vs_rollout_horizon.png
|
| 688 |
+
reconstruction_examples.png
|
| 689 |
+
imagined_vs_real_rollouts.png
|
| 690 |
+
```
|
| 691 |
+
|
| 692 |
+
### Minimum publishable artifact
|
| 693 |
+
|
| 694 |
+
A short report with:
|
| 695 |
+
|
| 696 |
+
1. Problem statement
|
| 697 |
+
2. Method
|
| 698 |
+
3. Environment setup
|
| 699 |
+
4. Baselines
|
| 700 |
+
5. Main learning curves
|
| 701 |
+
6. Ablations
|
| 702 |
+
7. Failure cases
|
| 703 |
+
8. Next steps
|
| 704 |
+
|
| 705 |
+
---
|
| 706 |
+
|
| 707 |
+
## 15. Milestones
|
| 708 |
+
|
| 709 |
+
### Milestone 1: environment and data pipeline
|
| 710 |
+
|
| 711 |
+
Acceptance criteria:
|
| 712 |
+
|
| 713 |
+
- Can create `MiniGrid-FourRooms-v0`
|
| 714 |
+
- Can wrap it into RGB-only observation mode
|
| 715 |
+
- Can collect and save bootstrap trajectories with explicit `terminated`, `truncated`, and `done` flags
|
| 716 |
+
- Can reload trajectory chunks for sequence training
|
| 717 |
+
- Train/val/test episode splits are reproducible
|
| 718 |
+
- Shape tests pass
|
| 719 |
+
|
| 720 |
+
### Milestone 2: world model training
|
| 721 |
+
|
| 722 |
+
Acceptance criteria:
|
| 723 |
+
|
| 724 |
+
- RSSM forward pass works on sequence batches
|
| 725 |
+
- Reward, done, KL, and optional reconstruction losses train without NaNs
|
| 726 |
+
- One-step reward/done prediction beats trivial constant predictor on held-out validation episodes
|
| 727 |
+
- Model checkpointing works
|
| 728 |
+
|
| 729 |
+
### Milestone 3: open-loop model evaluation
|
| 730 |
+
|
| 731 |
+
Acceptance criteria:
|
| 732 |
+
|
| 733 |
+
- Can visualize real vs. reconstructed observations
|
| 734 |
+
- Can roll model forward for 1, 5, and 10 imagined steps
|
| 735 |
+
- Can report reward and done prediction error by horizon on held-out episodes
|
| 736 |
+
|
| 737 |
+
### Milestone 4: CEM planner
|
| 738 |
+
|
| 739 |
+
Acceptance criteria:
|
| 740 |
+
|
| 741 |
+
- Planner can choose actions from latent state
|
| 742 |
+
- Planner runs online in the environment
|
| 743 |
+
- Planner scoring uses termination-aware reward masking
|
| 744 |
+
- Success rate beats random policy over 100 evaluation episodes
|
| 745 |
+
|
| 746 |
+
### Milestone 5: baseline comparison
|
| 747 |
+
|
| 748 |
+
Acceptance criteria:
|
| 749 |
+
|
| 750 |
+
- PPO baseline runs on same observation setup
|
| 751 |
+
- World model and PPO are compared at fixed environment-step budgets, with bootstrap data counted toward the world-model budget
|
| 752 |
+
- At least 3 seeds per method
|
| 753 |
+
|
| 754 |
+
### Milestone 6: ablations and report
|
| 755 |
+
|
| 756 |
+
Acceptance criteria:
|
| 757 |
+
|
| 758 |
+
- Run decoder/no-decoder ablation
|
| 759 |
+
- Run at least 2 planning horizons
|
| 760 |
+
- Run at least 2 latent sizes
|
| 761 |
+
- Produce final plots and written analysis
|
| 762 |
+
|
| 763 |
+
---
|
| 764 |
+
|
| 765 |
+
## 16. Suggested default config
|
| 766 |
+
|
| 767 |
+
```yaml
|
| 768 |
+
project:
|
| 769 |
+
name: minidreamer-fourrooms
|
| 770 |
+
seed: 0
|
| 771 |
+
|
| 772 |
+
env:
|
| 773 |
+
id: MiniGrid-FourRooms-v0
|
| 774 |
+
rgb_partial_obs: true
|
| 775 |
+
image_only: true
|
| 776 |
+
resize: [64, 64]
|
| 777 |
+
normalize_obs: true
|
| 778 |
+
action_space: full
|
| 779 |
+
|
| 780 |
+
replay:
|
| 781 |
+
capacity_episodes: 5000
|
| 782 |
+
sequence_length: 32
|
| 783 |
+
batch_size: 32
|
| 784 |
+
train_fraction: 0.8
|
| 785 |
+
val_fraction: 0.1
|
| 786 |
+
test_fraction: 0.1
|
| 787 |
+
split_key: episode_id
|
| 788 |
+
|
| 789 |
+
collection:
|
| 790 |
+
bootstrap_env_steps: 5000
|
| 791 |
+
bootstrap_success_threshold: 20
|
| 792 |
+
bootstrap_fallback_policy: restricted_random_3_actions
|
| 793 |
+
bootstrap_env_step_cap: 20000
|
| 794 |
+
collect_steps_per_iteration: 1000
|
| 795 |
+
train_collect_ratio: 1.0
|
| 796 |
+
random_action_fraction_after_planner: 0.3
|
| 797 |
+
|
| 798 |
+
model:
|
| 799 |
+
embedding_dim: 256
|
| 800 |
+
deter_dim: 256
|
| 801 |
+
stoch_dim: 32
|
| 802 |
+
hidden_dim: 256
|
| 803 |
+
use_decoder: true
|
| 804 |
+
min_std: 0.1
|
| 805 |
+
|
| 806 |
+
training:
|
| 807 |
+
optimizer: adam
|
| 808 |
+
lr: 0.0003
|
| 809 |
+
grad_clip_norm: 100.0
|
| 810 |
+
train_steps: 100000
|
| 811 |
+
beta_kl: 1.0
|
| 812 |
+
free_nats: 1.0
|
| 813 |
+
beta_recon: 1.0
|
| 814 |
+
beta_reward: 1.0
|
| 815 |
+
beta_done: 1.0
|
| 816 |
+
|
| 817 |
+
planner:
|
| 818 |
+
type: discrete_cem
|
| 819 |
+
horizon: 8
|
| 820 |
+
candidates: 256
|
| 821 |
+
elites: 32
|
| 822 |
+
iterations: 4
|
| 823 |
+
discount: 0.99
|
| 824 |
+
use_done_mask: true
|
| 825 |
+
|
| 826 |
+
evaluation:
|
| 827 |
+
episodes: 100
|
| 828 |
+
seeds: [0, 1, 2]
|
| 829 |
+
eval_every_env_steps: 10000
|
| 830 |
+
```
|
| 831 |
+
|
| 832 |
+
---
|
| 833 |
+
|
| 834 |
+
## 17. Acceptance criteria for v1
|
| 835 |
+
|
| 836 |
+
The project is successful if:
|
| 837 |
+
|
| 838 |
+
1. The world model trains stably for at least 3 seeds.
|
| 839 |
+
2. The CEM planner beats random policy on `MiniGrid-FourRooms-v0`.
|
| 840 |
+
3. The report shows learning curves for world-model planner, PPO, and random.
|
| 841 |
+
4. The report includes at least one ablation showing whether reconstruction helps.
|
| 842 |
+
5. The repo can reproduce the main result from a clean config file and fixed environment-step comparison protocol.
|
| 843 |
+
|
| 844 |
+
A strong v1 result would be:
|
| 845 |
+
|
| 846 |
+
```text
|
| 847 |
+
world_model_success_rate > random_success_rate
|
| 848 |
+
```
|
| 849 |
+
|
| 850 |
+
An excellent v1 result would be:
|
| 851 |
+
|
| 852 |
+
```text
|
| 853 |
+
world_model reaches useful success rate with fewer env steps than PPO
|
| 854 |
+
```
|
| 855 |
+
|
| 856 |
+
Do not require the world model to beat PPO asymptotically. For a first project, sample efficiency and interpretable failure analysis are more important.
|
| 857 |
+
|
| 858 |
+
---
|
| 859 |
+
|
| 860 |
+
## 18. Main risks
|
| 861 |
+
|
| 862 |
+
### Risk 1: sparse reward makes reward model hard to train
|
| 863 |
+
|
| 864 |
+
Mitigations:
|
| 865 |
+
|
| 866 |
+
- start with random + heuristic exploratory data
|
| 867 |
+
- add goal-reaching trajectories if random rarely succeeds
|
| 868 |
+
- optionally train first on `MiniGrid-Empty-*`
|
| 869 |
+
- use success classification in addition to reward regression
|
| 870 |
+
|
| 871 |
+
### Risk 2: model learns visual reconstruction but not controllable dynamics
|
| 872 |
+
|
| 873 |
+
Mitigations:
|
| 874 |
+
|
| 875 |
+
- ablate decoder
|
| 876 |
+
- track reward/done prediction separately from reconstruction
|
| 877 |
+
- evaluate imagined-vs-real return correlation
|
| 878 |
+
|
| 879 |
+
### Risk 3: CEM planner exploits model errors
|
| 880 |
+
|
| 881 |
+
Mitigations:
|
| 882 |
+
|
| 883 |
+
- short planning horizon
|
| 884 |
+
- replan every step
|
| 885 |
+
- penalize high predicted uncertainty if using ensemble later
|
| 886 |
+
- mix random exploration into data collection
|
| 887 |
+
|
| 888 |
+
### Risk 4: partial observability hurts Markov assumptions
|
| 889 |
+
|
| 890 |
+
Mitigations:
|
| 891 |
+
|
| 892 |
+
- keep recurrent state
|
| 893 |
+
- compare deterministic-only vs. stochastic recurrent latent
|
| 894 |
+
- optionally add frame stacking as a diagnostic, not as the main solution
|
| 895 |
+
|
| 896 |
+
---
|
| 897 |
+
|
| 898 |
+
## 19. Recommended implementation order
|
| 899 |
+
|
| 900 |
+
1. Implement environment wrapper and trajectory collection.
|
| 901 |
+
2. Implement replay buffer with sequence sampling.
|
| 902 |
+
3. Implement encoder + RSSM + reward/done heads.
|
| 903 |
+
4. Train model on random trajectories.
|
| 904 |
+
5. Add decoder and visualization.
|
| 905 |
+
6. Add open-loop rollout diagnostics.
|
| 906 |
+
7. Implement discrete CEM planner.
|
| 907 |
+
8. Evaluate against random.
|
| 908 |
+
9. Add PPO baseline.
|
| 909 |
+
10. Run ablations.
|
| 910 |
+
|
| 911 |
+
The key design constraint: **do not optimize for final score first**. Optimize for observability: diagnostics, rollout plots, prediction errors, and failure cases. That will make this a research project instead of just another RL training script.
|
notebooks/results_analysis.ipynb
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"# Results Analysis\n",
|
| 8 |
+
"\n",
|
| 9 |
+
"Use this notebook to load `metrics/*.jsonl`, aggregate seed-level runs, and generate the plots referenced in `results.md`."
|
| 10 |
+
]
|
| 11 |
+
}
|
| 12 |
+
],
|
| 13 |
+
"metadata": {
|
| 14 |
+
"kernelspec": {
|
| 15 |
+
"display_name": "Python 3",
|
| 16 |
+
"language": "python",
|
| 17 |
+
"name": "python3"
|
| 18 |
+
},
|
| 19 |
+
"language_info": {
|
| 20 |
+
"name": "python"
|
| 21 |
+
}
|
| 22 |
+
},
|
| 23 |
+
"nbformat": 4,
|
| 24 |
+
"nbformat_minor": 5
|
| 25 |
+
}
|
notebooks/rollout_debug.ipynb
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"# Rollout Debug\n",
|
| 8 |
+
"\n",
|
| 9 |
+
"Use this notebook to inspect replay sequences, reconstructions, and imagined-vs-real reward rollouts after training has been run."
|
| 10 |
+
]
|
| 11 |
+
}
|
| 12 |
+
],
|
| 13 |
+
"metadata": {
|
| 14 |
+
"kernelspec": {
|
| 15 |
+
"display_name": "Python 3",
|
| 16 |
+
"language": "python",
|
| 17 |
+
"name": "python3"
|
| 18 |
+
},
|
| 19 |
+
"language_info": {
|
| 20 |
+
"name": "python"
|
| 21 |
+
}
|
| 22 |
+
},
|
| 23 |
+
"nbformat": 4,
|
| 24 |
+
"nbformat_minor": 5
|
| 25 |
+
}
|
| 26 |
+
|
plots/.gitkeep
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
plots/learning_curves.png
ADDED
|
plots/model_error_vs_rollout_horizon.png
ADDED
|
plots/success_rate_vs_env_steps.png
ADDED
|
pyproject.toml
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=69", "wheel"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "minidreamer"
|
| 7 |
+
version = "0.1.0"
|
| 8 |
+
description = "PlaNet-style world model planning for pixel-based MiniGrid"
|
| 9 |
+
readme = "README.md"
|
| 10 |
+
requires-python = ">=3.11,<3.13"
|
| 11 |
+
license = { file = "LICENSE" }
|
| 12 |
+
authors = [{ name = "OpenAI Codex" }]
|
| 13 |
+
dependencies = [
|
| 14 |
+
"gymnasium>=0.29,<1.1",
|
| 15 |
+
"minigrid>=2.3.1",
|
| 16 |
+
"numpy>=1.26,<3",
|
| 17 |
+
"pillow>=10.0",
|
| 18 |
+
"pyyaml>=6.0",
|
| 19 |
+
"torch>=2.3,<3",
|
| 20 |
+
"tqdm>=4.66",
|
| 21 |
+
"matplotlib>=3.8",
|
| 22 |
+
"stable-baselines3>=2.3.0",
|
| 23 |
+
]
|
| 24 |
+
|
| 25 |
+
[project.optional-dependencies]
|
| 26 |
+
dev = [
|
| 27 |
+
"pytest>=8.2",
|
| 28 |
+
]
|
| 29 |
+
|
| 30 |
+
[project.scripts]
|
| 31 |
+
minidreamer-collect-random = "minidreamer.data.collect_random:main"
|
| 32 |
+
minidreamer-train-world-model = "train_world_model:main"
|
| 33 |
+
minidreamer-evaluate = "evaluate:main"
|
| 34 |
+
minidreamer-train-ppo = "minidreamer.baselines.train_ppo:main"
|
| 35 |
+
|
| 36 |
+
[tool.setuptools]
|
| 37 |
+
package-dir = { "" = "src" }
|
| 38 |
+
py-modules = ["train_world_model", "evaluate"]
|
| 39 |
+
|
| 40 |
+
[tool.setuptools.packages.find]
|
| 41 |
+
where = ["src"]
|
| 42 |
+
include = ["minidreamer*"]
|
| 43 |
+
|
| 44 |
+
[tool.pytest.ini_options]
|
| 45 |
+
testpaths = ["tests"]
|
| 46 |
+
pythonpath = ["src"]
|
results.md
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Results
|
| 2 |
+
|
| 3 |
+
## Status
|
| 4 |
+
|
| 5 |
+
A baseline world-model training run completed on `2026-04-21` on Apple Silicon using the `mps` backend. This run is now frozen as the reference artifact under `artifacts/world_model/`. Final world-model summaries now come from `artifacts/world_model/metrics/run_summary.json`, `artifacts/world_model/metrics/final_eval_latest.json`, and `artifacts/world_model/metrics/planner_eval_latest_clean.json`, the PPO baseline summary comes from `artifacts/ppo/metrics/run_summary.json`, and generated figures live under `plots/`.
|
| 6 |
+
|
| 7 |
+
## Problem Statement
|
| 8 |
+
|
| 9 |
+
Train a PlaNet-style world model for `MiniGrid-FourRooms-v0` from partial RGB observations, then evaluate a latent-space discrete CEM planner against a random baseline.
|
| 10 |
+
|
| 11 |
+
## Method
|
| 12 |
+
|
| 13 |
+
- CNN encoder + Gaussian RSSM world model with reward, done, and reconstruction heads.
|
| 14 |
+
- Discrete CEM planning in latent space.
|
| 15 |
+
- Replay-buffer training with episode-aware train/val/test splits.
|
| 16 |
+
- Config: `configs/fourrooms_world_model.yaml`.
|
| 17 |
+
|
| 18 |
+
## Environment Setup
|
| 19 |
+
|
| 20 |
+
- Device: `mps`
|
| 21 |
+
- Target environment steps: `100000`
|
| 22 |
+
- Final realized environment steps: `100004`
|
| 23 |
+
- Total gradient updates completed: `10000`
|
| 24 |
+
- Final replay summary:
|
| 25 |
+
- Episodes: `1022`
|
| 26 |
+
- Success episodes: `42`
|
| 27 |
+
- Train/val/test episodes: `819 / 116 / 87`
|
| 28 |
+
|
| 29 |
+
## Baselines
|
| 30 |
+
|
| 31 |
+
Persisted comparison metrics were recorded at `97124` env steps:
|
| 32 |
+
|
| 33 |
+
- Random baseline success rate: `0.0`
|
| 34 |
+
- Random baseline mean return: `0.0`
|
| 35 |
+
- Random baseline mean episode length: `100.0`
|
| 36 |
+
|
| 37 |
+
Completed PPO baseline after `100000` training env steps:
|
| 38 |
+
|
| 39 |
+
- PPO success rate: `0.14`
|
| 40 |
+
- PPO mean return: `0.1336`
|
| 41 |
+
- PPO median return: `0.0`
|
| 42 |
+
- PPO return std: `0.3313`
|
| 43 |
+
- PPO mean episode length: `86.71`
|
| 44 |
+
|
| 45 |
+
## Main Metrics
|
| 46 |
+
|
| 47 |
+
Persisted evaluation metrics at `97124` env steps:
|
| 48 |
+
|
| 49 |
+
- Planner success rate: `0.10`
|
| 50 |
+
- Planner mean return: `0.0568`
|
| 51 |
+
- Planner median return: `0.0`
|
| 52 |
+
- Planner mean episode length: `94.8`
|
| 53 |
+
- Planner action entropy: `0.6569`
|
| 54 |
+
|
| 55 |
+
Held-out world-model metrics at the same checkpoint:
|
| 56 |
+
|
| 57 |
+
- Reward MSE: `1.98e-6`
|
| 58 |
+
- Done BCE: `0.1808`
|
| 59 |
+
- KL loss: `0.8996`
|
| 60 |
+
- Reconstruction MSE: `0.0109`
|
| 61 |
+
|
| 62 |
+
Open-loop rollout quality:
|
| 63 |
+
|
| 64 |
+
- Done accuracy @1/@5/@10: `0.9880 / 0.9960 / 0.9973`
|
| 65 |
+
- Reward error @1/@5/@10: `1.98e-6 / 1.83e-6 / 1.69e-6`
|
| 66 |
+
|
| 67 |
+
Canonical clean final planner evaluation at `world_model_latest.pt`:
|
| 68 |
+
|
| 69 |
+
- Evaluation budget: `100` planner episodes and `100` random episodes
|
| 70 |
+
- Planner success rate: `0.04`
|
| 71 |
+
- Planner mean return: `0.03415`
|
| 72 |
+
- Planner median return: `0.0`
|
| 73 |
+
- Planner mean episode length: `96.65`
|
| 74 |
+
- Planner action entropy: `0.9142`
|
| 75 |
+
- Final random baseline success rate: `0.03`
|
| 76 |
+
- Final random baseline mean return: `0.02073`
|
| 77 |
+
- Final random baseline mean episode length: `98.03`
|
| 78 |
+
|
| 79 |
+
Small-sample final world-model diagnostics at `world_model_latest.pt`:
|
| 80 |
+
|
| 81 |
+
- Evaluation budget: `5` held-out world-model episodes
|
| 82 |
+
- Final reward MSE: `4.95e-6`
|
| 83 |
+
- Final done BCE: `0.1513`
|
| 84 |
+
- Final KL loss: `0.8833`
|
| 85 |
+
- Final reconstruction MSE: `0.0096`
|
| 86 |
+
- Final done accuracy @1/@5/@10: `0.9880 / 0.9971 / 0.9978`
|
| 87 |
+
- Final reward error @1/@5/@10: `4.98e-6 / 5.09e-6 / 5.12e-6`
|
| 88 |
+
|
| 89 |
+
## Comparison
|
| 90 |
+
|
| 91 |
+
At roughly matched data budgets, the PPO baseline produced stronger direct control performance than the world-model planner:
|
| 92 |
+
|
| 93 |
+
- PPO at `100000` env steps: `0.14` success rate, `0.1336` mean return, `86.71` mean episode length over `100` evaluation episodes.
|
| 94 |
+
- Final clean planner eval at `100004` env steps: `0.04` success rate, `0.03415` mean return, `96.65` mean episode length over `100` evaluation episodes.
|
| 95 |
+
- Final clean random eval at `100004` env steps: `0.03` success rate, `0.02073` mean return, `98.03` mean episode length over `100` evaluation episodes.
|
| 96 |
+
|
| 97 |
+
The predictive model stayed numerically strong through the end of training, but that did not translate into a robust planner win on FourRooms at this training budget. After removing evaluation-time action noise, the final planner is slightly better than random, but still clearly below PPO.
|
| 98 |
+
|
| 99 |
+
## Ablations
|
| 100 |
+
|
| 101 |
+
No ablation runs have been recorded yet.
|
| 102 |
+
|
| 103 |
+
## Failure Cases / Operational Notes
|
| 104 |
+
|
| 105 |
+
- The initial long run stopped near `90021` env steps when the local machine hit severe disk pressure.
|
| 106 |
+
- Training was resumed successfully from `artifacts/world_model/checkpoints/world_model_env_steps_90021.pt` after adding checkpoint-resume support to the trainer.
|
| 107 |
+
- Only one scheduled planner evaluation row is persisted in `eval_metrics.jsonl`. The resumed segment completed without crossing another configured evaluation boundary before the final save, so the end-of-run planner metrics are stored separately in `final_eval_latest.json` and `planner_eval_latest_clean.json`.
|
| 108 |
+
- The scheduled planner row in `eval_metrics.jsonl` and the planner section inside `final_eval_latest.json` were recorded before the evaluation-noise fix. They are kept for provenance, but `planner_eval_latest_clean.json` is the canonical final planner result.
|
| 109 |
+
|
| 110 |
+
## Visualizations
|
| 111 |
+
|
| 112 |
+
- Generated from `artifacts/world_model/metrics/train_metrics.jsonl`, `artifacts/world_model/metrics/eval_metrics.jsonl`, `artifacts/world_model/metrics/final_eval_latest.json`, and `artifacts/world_model/metrics/planner_eval_latest_clean.json` using `scripts/generate_results_plots.py`.
|
| 113 |
+
- The success-rate chart uses the clean final planner evaluation for the last point.
|
| 114 |
+
- `plots/learning_curves.png`
|
| 115 |
+
- `plots/success_rate_vs_env_steps.png`
|
| 116 |
+
- `plots/model_error_vs_rollout_horizon.png`
|
| 117 |
+
|
| 118 |
+
## Artifact Locations
|
| 119 |
+
|
| 120 |
+
- Metrics: `artifacts/world_model/metrics/`
|
| 121 |
+
- Replay: `artifacts/world_model/replay/`
|
| 122 |
+
- Checkpoints: `artifacts/world_model/checkpoints/`
|
| 123 |
+
- Final checkpoint: `artifacts/world_model/checkpoints/world_model_latest.pt`
|
| 124 |
+
- Canonical clean final planner eval: `artifacts/world_model/metrics/planner_eval_latest_clean.json`
|
| 125 |
+
- PPO metrics: `artifacts/ppo/metrics/`
|
| 126 |
+
- PPO checkpoint: `artifacts/ppo/checkpoints/ppo_latest.zip`
|
| 127 |
+
- Generated plots: `plots/`
|
scripts/collect_random.sh
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
export PYTHONPATH="${PYTHONPATH:-}:$(pwd)/src"
|
| 5 |
+
python3.11 src/minidreamer/data/collect_random.py \
|
| 6 |
+
--config configs/fourrooms_world_model.yaml \
|
| 7 |
+
--output-dir artifacts/bootstrap_replay \
|
| 8 |
+
"$@"
|
| 9 |
+
|
scripts/eval_planner.sh
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
if [ "$#" -lt 2 ]; then
|
| 5 |
+
echo "Usage: $0 CHECKPOINT REPLAY_DIR [extra evaluate args]" >&2
|
| 6 |
+
exit 1
|
| 7 |
+
fi
|
| 8 |
+
|
| 9 |
+
checkpoint="$1"
|
| 10 |
+
replay_dir="$2"
|
| 11 |
+
shift 2
|
| 12 |
+
|
| 13 |
+
export PYTHONPATH="${PYTHONPATH:-}:$(pwd)/src"
|
| 14 |
+
python3.11 src/evaluate.py \
|
| 15 |
+
planner \
|
| 16 |
+
--config configs/fourrooms_world_model.yaml \
|
| 17 |
+
--checkpoint "${checkpoint}" \
|
| 18 |
+
"$@"
|
| 19 |
+
|
| 20 |
+
python3.11 src/evaluate.py \
|
| 21 |
+
world-model \
|
| 22 |
+
--config configs/fourrooms_world_model.yaml \
|
| 23 |
+
--checkpoint "${checkpoint}" \
|
| 24 |
+
--replay-dir "${replay_dir}" \
|
| 25 |
+
"$@"
|
| 26 |
+
|
scripts/generate_results_plots.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import matplotlib.pyplot as plt
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
ROOT = Path(__file__).resolve().parents[1]
|
| 10 |
+
METRICS_DIR = ROOT / "artifacts" / "world_model" / "metrics"
|
| 11 |
+
PLOTS_DIR = ROOT / "plots"
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def load_jsonl(path: Path) -> list[dict]:
|
| 15 |
+
rows = []
|
| 16 |
+
if not path.exists():
|
| 17 |
+
return rows
|
| 18 |
+
for line in path.read_text(encoding="utf-8").splitlines():
|
| 19 |
+
line = line.strip()
|
| 20 |
+
if line:
|
| 21 |
+
rows.append(json.loads(line))
|
| 22 |
+
return rows
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def load_json(path: Path) -> dict | None:
|
| 26 |
+
if not path.exists():
|
| 27 |
+
return None
|
| 28 |
+
return json.loads(path.read_text(encoding="utf-8"))
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def rolling_mean(values: list[float], window: int) -> list[float]:
|
| 32 |
+
if not values:
|
| 33 |
+
return []
|
| 34 |
+
out: list[float] = []
|
| 35 |
+
running_sum = 0.0
|
| 36 |
+
for idx, value in enumerate(values):
|
| 37 |
+
running_sum += value
|
| 38 |
+
if idx >= window:
|
| 39 |
+
running_sum -= values[idx - window]
|
| 40 |
+
out.append(running_sum / min(idx + 1, window))
|
| 41 |
+
return out
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def generate_learning_curves(train_rows: list[dict]) -> None:
|
| 45 |
+
if not train_rows:
|
| 46 |
+
return
|
| 47 |
+
steps = list(range(1, len(train_rows) + 1))
|
| 48 |
+
window = min(250, len(train_rows))
|
| 49 |
+
loss = rolling_mean([row["loss"] for row in train_rows], window)
|
| 50 |
+
kl = rolling_mean([row["kl_loss"] for row in train_rows], window)
|
| 51 |
+
done = rolling_mean([row["done_loss"] for row in train_rows], window)
|
| 52 |
+
recon = rolling_mean([row["recon_loss"] for row in train_rows], window)
|
| 53 |
+
|
| 54 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
| 55 |
+
ax.plot(steps, loss, label="total loss", linewidth=2.0)
|
| 56 |
+
ax.plot(steps, kl, label="kl loss", linewidth=1.5)
|
| 57 |
+
ax.plot(steps, done, label="done loss", linewidth=1.5)
|
| 58 |
+
ax.plot(steps, recon, label="recon loss", linewidth=1.5)
|
| 59 |
+
ax.set_title("World Model Training Curves")
|
| 60 |
+
ax.set_xlabel("Gradient update")
|
| 61 |
+
ax.set_ylabel("Smoothed loss")
|
| 62 |
+
ax.legend()
|
| 63 |
+
ax.grid(alpha=0.25)
|
| 64 |
+
fig.tight_layout()
|
| 65 |
+
fig.savefig(PLOTS_DIR / "learning_curves.png", dpi=200)
|
| 66 |
+
plt.close(fig)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def generate_success_plot(
|
| 70 |
+
eval_rows: list[dict],
|
| 71 |
+
final_eval: dict | None,
|
| 72 |
+
clean_planner_eval: dict | None,
|
| 73 |
+
) -> None:
|
| 74 |
+
steps: list[int] = []
|
| 75 |
+
planner_success: list[float] = []
|
| 76 |
+
random_success: list[float] = []
|
| 77 |
+
|
| 78 |
+
for row in eval_rows:
|
| 79 |
+
steps.append(int(row["env_steps"]))
|
| 80 |
+
planner_success.append(float(row["planner/success_rate"]))
|
| 81 |
+
random_success.append(float(row["random/success_rate"]))
|
| 82 |
+
|
| 83 |
+
if clean_planner_eval is not None:
|
| 84 |
+
steps.append(int(clean_planner_eval["metadata"]["env_steps"]))
|
| 85 |
+
planner_success.append(float(clean_planner_eval["planner_clean"]["success_rate"]))
|
| 86 |
+
random_success.append(float(clean_planner_eval["random"]["success_rate"]))
|
| 87 |
+
elif final_eval is not None:
|
| 88 |
+
steps.append(int(final_eval["metadata"]["env_steps"]))
|
| 89 |
+
planner_success.append(float(final_eval["planner"]["success_rate"]))
|
| 90 |
+
random_success.append(float(final_eval["random"]["success_rate"]))
|
| 91 |
+
|
| 92 |
+
if not steps:
|
| 93 |
+
return
|
| 94 |
+
|
| 95 |
+
paired = sorted(zip(steps, planner_success, random_success), key=lambda item: item[0])
|
| 96 |
+
steps = [item[0] for item in paired]
|
| 97 |
+
planner_success = [item[1] for item in paired]
|
| 98 |
+
random_success = [item[2] for item in paired]
|
| 99 |
+
|
| 100 |
+
fig, ax = plt.subplots(figsize=(9, 5))
|
| 101 |
+
ax.plot(steps, planner_success, marker="o", linewidth=2.0, label="planner success rate")
|
| 102 |
+
ax.plot(steps, random_success, marker="o", linewidth=2.0, label="random success rate")
|
| 103 |
+
ax.set_title("Success Rate vs Environment Steps")
|
| 104 |
+
ax.set_xlabel("Environment steps")
|
| 105 |
+
ax.set_ylabel("Success rate")
|
| 106 |
+
ax.set_ylim(-0.02, 1.02)
|
| 107 |
+
ax.legend()
|
| 108 |
+
ax.grid(alpha=0.25)
|
| 109 |
+
fig.tight_layout()
|
| 110 |
+
fig.savefig(PLOTS_DIR / "success_rate_vs_env_steps.png", dpi=200)
|
| 111 |
+
plt.close(fig)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def generate_rollout_error_plot(final_eval: dict | None) -> None:
|
| 115 |
+
if final_eval is None:
|
| 116 |
+
return
|
| 117 |
+
horizons = [1, 5, 10]
|
| 118 |
+
reward_errors = [
|
| 119 |
+
float(final_eval["world_model"][f"open_loop_reward_error_h{h}"])
|
| 120 |
+
for h in horizons
|
| 121 |
+
]
|
| 122 |
+
done_accuracy = [
|
| 123 |
+
float(final_eval["world_model"][f"open_loop_done_accuracy_h{h}"])
|
| 124 |
+
for h in horizons
|
| 125 |
+
]
|
| 126 |
+
|
| 127 |
+
fig, ax1 = plt.subplots(figsize=(9, 5))
|
| 128 |
+
ax1.bar([str(h) for h in horizons], reward_errors, color="#2b6cb0", alpha=0.8)
|
| 129 |
+
ax1.set_xlabel("Open-loop horizon")
|
| 130 |
+
ax1.set_ylabel("Reward MSE", color="#2b6cb0")
|
| 131 |
+
ax1.tick_params(axis="y", labelcolor="#2b6cb0")
|
| 132 |
+
ax1.set_title("Model Error vs Rollout Horizon")
|
| 133 |
+
ax1.grid(alpha=0.2, axis="y")
|
| 134 |
+
|
| 135 |
+
ax2 = ax1.twinx()
|
| 136 |
+
ax2.plot([str(h) for h in horizons], done_accuracy, color="#c05621", marker="o", linewidth=2.0)
|
| 137 |
+
ax2.set_ylabel("Done accuracy", color="#c05621")
|
| 138 |
+
ax2.tick_params(axis="y", labelcolor="#c05621")
|
| 139 |
+
ax2.set_ylim(0.95, 1.001)
|
| 140 |
+
|
| 141 |
+
fig.tight_layout()
|
| 142 |
+
fig.savefig(PLOTS_DIR / "model_error_vs_rollout_horizon.png", dpi=200)
|
| 143 |
+
plt.close(fig)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def main() -> None:
|
| 147 |
+
PLOTS_DIR.mkdir(parents=True, exist_ok=True)
|
| 148 |
+
train_rows = load_jsonl(METRICS_DIR / "train_metrics.jsonl")
|
| 149 |
+
eval_rows = load_jsonl(METRICS_DIR / "eval_metrics.jsonl")
|
| 150 |
+
final_eval = load_json(METRICS_DIR / "final_eval_latest.json")
|
| 151 |
+
clean_planner_eval = load_json(METRICS_DIR / "planner_eval_latest_clean.json")
|
| 152 |
+
|
| 153 |
+
generate_learning_curves(train_rows)
|
| 154 |
+
generate_success_plot(eval_rows, final_eval, clean_planner_eval)
|
| 155 |
+
generate_rollout_error_plot(final_eval)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
if __name__ == "__main__":
|
| 159 |
+
main()
|
scripts/train_ppo.sh
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
export PYTHONPATH="${PYTHONPATH:-}:$(pwd)/src"
|
| 5 |
+
python3.11 src/minidreamer/baselines/train_ppo.py \
|
| 6 |
+
--config configs/fourrooms_ppo.yaml \
|
| 7 |
+
--output-dir artifacts/ppo \
|
| 8 |
+
"$@"
|
| 9 |
+
|
scripts/train_world_model.sh
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
output_dir="${MINIDREAMER_OUTPUT_DIR:-artifacts/world_model_experiment}"
|
| 5 |
+
|
| 6 |
+
export PYTHONPATH="${PYTHONPATH:-}:$(pwd)/src"
|
| 7 |
+
python3.11 src/train_world_model.py \
|
| 8 |
+
--config configs/fourrooms_world_model.yaml \
|
| 9 |
+
--output-dir "${output_dir}" \
|
| 10 |
+
"$@"
|
src/evaluate.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
from minidreamer.config import load_config
|
| 7 |
+
from minidreamer.data.replay_buffer import ReplayBuffer
|
| 8 |
+
from minidreamer.evaluation import evaluate_random_policy, evaluate_world_model
|
| 9 |
+
from minidreamer.envs.make_env import make_env_from_config
|
| 10 |
+
from minidreamer.planning.evaluate_planner import evaluate_planner
|
| 11 |
+
from minidreamer.serialization import load_world_model_checkpoint
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def build_arg_parser() -> argparse.ArgumentParser:
|
| 15 |
+
parser = argparse.ArgumentParser(description="Evaluate MiniDreamer components.")
|
| 16 |
+
subparsers = parser.add_subparsers(dest="command", required=True)
|
| 17 |
+
|
| 18 |
+
random_parser = subparsers.add_parser("random", help="Evaluate a random policy.")
|
| 19 |
+
random_parser.add_argument("--config", type=Path, required=True)
|
| 20 |
+
|
| 21 |
+
planner_parser = subparsers.add_parser("planner", help="Evaluate a trained planner.")
|
| 22 |
+
planner_parser.add_argument("--config", type=Path, required=True)
|
| 23 |
+
planner_parser.add_argument("--checkpoint", type=Path, required=True)
|
| 24 |
+
planner_parser.add_argument(
|
| 25 |
+
"--random-action-fraction",
|
| 26 |
+
type=float,
|
| 27 |
+
default=0.0,
|
| 28 |
+
help="Optional evaluation-time action noise. Defaults to 0.0 for a clean planner evaluation.",
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
world_model_parser = subparsers.add_parser("world-model", help="Evaluate held-out world model metrics.")
|
| 32 |
+
world_model_parser.add_argument("--config", type=Path, required=True)
|
| 33 |
+
world_model_parser.add_argument("--checkpoint", type=Path, required=True)
|
| 34 |
+
world_model_parser.add_argument("--replay-dir", type=Path, required=True)
|
| 35 |
+
world_model_parser.add_argument("--split", type=str, default="val", choices=["train", "val", "test"])
|
| 36 |
+
return parser
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def main() -> None:
|
| 40 |
+
parser = build_arg_parser()
|
| 41 |
+
args = parser.parse_args()
|
| 42 |
+
config = load_config(args.config)
|
| 43 |
+
|
| 44 |
+
if args.command == "random":
|
| 45 |
+
print(evaluate_random_policy(config))
|
| 46 |
+
return
|
| 47 |
+
|
| 48 |
+
env = make_env_from_config(config, seed=config.get("project", {}).get("seed", 0))
|
| 49 |
+
action_dim = env.action_space.n
|
| 50 |
+
env.close()
|
| 51 |
+
model, _, metadata = load_world_model_checkpoint(args.checkpoint, action_dim=action_dim, map_location="cpu")
|
| 52 |
+
|
| 53 |
+
if args.command == "planner":
|
| 54 |
+
print({
|
| 55 |
+
"metadata": metadata,
|
| 56 |
+
**evaluate_planner(
|
| 57 |
+
config,
|
| 58 |
+
model,
|
| 59 |
+
random_action_fraction=args.random_action_fraction,
|
| 60 |
+
),
|
| 61 |
+
})
|
| 62 |
+
return
|
| 63 |
+
|
| 64 |
+
replay = ReplayBuffer.load(args.replay_dir)
|
| 65 |
+
print({"metadata": metadata, **evaluate_world_model(config, model, replay, split=args.split)})
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
if __name__ == "__main__":
|
| 69 |
+
main()
|
src/minidreamer/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""MiniDreamer package."""
|
| 2 |
+
|
| 3 |
+
__all__ = [
|
| 4 |
+
"config",
|
| 5 |
+
]
|
| 6 |
+
|
src/minidreamer/baselines/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Baseline agents."""
|
| 2 |
+
|
src/minidreamer/baselines/train_ppo.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import gymnasium as gym
|
| 7 |
+
import torch
|
| 8 |
+
from torch import nn
|
| 9 |
+
|
| 10 |
+
from minidreamer.config import ensure_run_dirs, load_config
|
| 11 |
+
from minidreamer.envs.make_env import make_env_from_config
|
| 12 |
+
from minidreamer.utils.common import seed_everything
|
| 13 |
+
|
| 14 |
+
try:
|
| 15 |
+
from stable_baselines3 import PPO
|
| 16 |
+
from stable_baselines3.common.evaluation import evaluate_policy
|
| 17 |
+
from stable_baselines3.common.monitor import Monitor
|
| 18 |
+
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
|
| 19 |
+
from stable_baselines3.common.vec_env import DummyVecEnv
|
| 20 |
+
except ImportError as exc: # pragma: no cover - exercised only when dependency is missing.
|
| 21 |
+
PPO = None
|
| 22 |
+
IMPORT_ERROR = exc
|
| 23 |
+
else:
|
| 24 |
+
IMPORT_ERROR = None
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class MiniGridCNNExtractor(BaseFeaturesExtractor):
|
| 28 |
+
def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 256) -> None:
|
| 29 |
+
super().__init__(observation_space, features_dim)
|
| 30 |
+
if len(observation_space.shape) != 3:
|
| 31 |
+
raise ValueError(f"Expected 3D image observations, got {observation_space.shape}.")
|
| 32 |
+
self.channel_first = observation_space.shape[0] in (1, 3)
|
| 33 |
+
channels = observation_space.shape[0] if self.channel_first else observation_space.shape[2]
|
| 34 |
+
self.cnn = nn.Sequential(
|
| 35 |
+
nn.Conv2d(channels, 32, kernel_size=4, stride=2, padding=1),
|
| 36 |
+
nn.ReLU(),
|
| 37 |
+
nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
|
| 38 |
+
nn.ReLU(),
|
| 39 |
+
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
|
| 40 |
+
nn.ReLU(),
|
| 41 |
+
nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
|
| 42 |
+
nn.ReLU(),
|
| 43 |
+
nn.Flatten(),
|
| 44 |
+
)
|
| 45 |
+
with torch.no_grad():
|
| 46 |
+
sample = torch.as_tensor(observation_space.sample()[None]).float()
|
| 47 |
+
if not self.channel_first:
|
| 48 |
+
sample = sample.permute(0, 3, 1, 2)
|
| 49 |
+
flattened_dim = self.cnn(sample).shape[1]
|
| 50 |
+
self.linear = nn.Sequential(nn.Linear(flattened_dim, features_dim), nn.ReLU())
|
| 51 |
+
|
| 52 |
+
def forward(self, observations: torch.Tensor) -> torch.Tensor:
|
| 53 |
+
if observations.dim() == 4 and observations.shape[1] not in (1, 3):
|
| 54 |
+
observations = observations.permute(0, 3, 1, 2)
|
| 55 |
+
return self.linear(self.cnn(observations.float()))
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def build_env(config: dict, seed: int, rank: int):
|
| 59 |
+
def _make():
|
| 60 |
+
env = make_env_from_config(config, seed=seed + rank)
|
| 61 |
+
return Monitor(env)
|
| 62 |
+
|
| 63 |
+
return _make
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def train_ppo(config: dict, output_dir: str | Path) -> dict[str, float]:
|
| 67 |
+
if PPO is None:
|
| 68 |
+
raise ImportError(
|
| 69 |
+
"stable-baselines3 is required for PPO training."
|
| 70 |
+
) from IMPORT_ERROR
|
| 71 |
+
ppo_cfg = config["ppo"]
|
| 72 |
+
seed = config.get("project", {}).get("seed", 0)
|
| 73 |
+
seed_everything(seed)
|
| 74 |
+
run_dirs = ensure_run_dirs(output_dir)
|
| 75 |
+
|
| 76 |
+
env_fns = [build_env(config, seed, rank) for rank in range(ppo_cfg.get("num_envs", 4))]
|
| 77 |
+
vec_env = DummyVecEnv(env_fns)
|
| 78 |
+
policy_kwargs = {
|
| 79 |
+
"features_extractor_class": MiniGridCNNExtractor,
|
| 80 |
+
"features_extractor_kwargs": {"features_dim": ppo_cfg.get("features_dim", 256)},
|
| 81 |
+
}
|
| 82 |
+
model = PPO(
|
| 83 |
+
"CnnPolicy",
|
| 84 |
+
vec_env,
|
| 85 |
+
policy_kwargs=policy_kwargs,
|
| 86 |
+
learning_rate=ppo_cfg.get("learning_rate", 3e-4),
|
| 87 |
+
n_steps=ppo_cfg.get("n_steps", 256),
|
| 88 |
+
batch_size=ppo_cfg.get("batch_size", 256),
|
| 89 |
+
n_epochs=ppo_cfg.get("n_epochs", 4),
|
| 90 |
+
gamma=ppo_cfg.get("gamma", 0.99),
|
| 91 |
+
gae_lambda=ppo_cfg.get("gae_lambda", 0.95),
|
| 92 |
+
clip_range=ppo_cfg.get("clip_range", 0.2),
|
| 93 |
+
ent_coef=ppo_cfg.get("ent_coef", 0.01),
|
| 94 |
+
vf_coef=ppo_cfg.get("vf_coef", 0.5),
|
| 95 |
+
seed=seed,
|
| 96 |
+
device=ppo_cfg.get("device", "auto"),
|
| 97 |
+
verbose=1,
|
| 98 |
+
)
|
| 99 |
+
model.learn(total_timesteps=ppo_cfg["total_timesteps"])
|
| 100 |
+
model.save(Path(run_dirs["checkpoints"]) / "ppo_latest")
|
| 101 |
+
mean_reward, std_reward = evaluate_policy(
|
| 102 |
+
model,
|
| 103 |
+
vec_env,
|
| 104 |
+
n_eval_episodes=config["evaluation"]["episodes"],
|
| 105 |
+
deterministic=True,
|
| 106 |
+
)
|
| 107 |
+
vec_env.close()
|
| 108 |
+
return {"mean_reward": float(mean_reward), "std_reward": float(std_reward)}
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def build_arg_parser() -> argparse.ArgumentParser:
|
| 112 |
+
parser = argparse.ArgumentParser(description="Train a PPO baseline on MiniGrid pixels.")
|
| 113 |
+
parser.add_argument("--config", type=Path, required=True)
|
| 114 |
+
parser.add_argument("--output-dir", type=Path, required=True)
|
| 115 |
+
return parser
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def main() -> None:
|
| 119 |
+
parser = build_arg_parser()
|
| 120 |
+
args = parser.parse_args()
|
| 121 |
+
config = load_config(args.config)
|
| 122 |
+
summary = train_ppo(config, args.output_dir)
|
| 123 |
+
print(summary)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
if __name__ == "__main__":
|
| 127 |
+
main()
|
src/minidreamer/config.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import copy
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Any
|
| 6 |
+
|
| 7 |
+
import yaml
|
| 8 |
+
|
| 9 |
+
ConfigDict = dict[str, Any]
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def load_config(path: str | Path) -> ConfigDict:
|
| 13 |
+
path = Path(path)
|
| 14 |
+
with path.open("r", encoding="utf-8") as handle:
|
| 15 |
+
config = yaml.safe_load(handle)
|
| 16 |
+
if not isinstance(config, dict):
|
| 17 |
+
raise ValueError(f"Config at {path} must be a mapping.")
|
| 18 |
+
return config
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def merge_dicts(base: ConfigDict, overrides: ConfigDict) -> ConfigDict:
|
| 22 |
+
merged = copy.deepcopy(base)
|
| 23 |
+
for key, value in overrides.items():
|
| 24 |
+
if isinstance(value, dict) and isinstance(merged.get(key), dict):
|
| 25 |
+
merged[key] = merge_dicts(merged[key], value)
|
| 26 |
+
else:
|
| 27 |
+
merged[key] = value
|
| 28 |
+
return merged
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def save_config(config: ConfigDict, path: str | Path) -> None:
|
| 32 |
+
path = Path(path)
|
| 33 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 34 |
+
with path.open("w", encoding="utf-8") as handle:
|
| 35 |
+
yaml.safe_dump(config, handle, sort_keys=False)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def ensure_run_dirs(base_dir: str | Path) -> dict[str, Path]:
|
| 39 |
+
base = Path(base_dir)
|
| 40 |
+
paths = {
|
| 41 |
+
"base": base,
|
| 42 |
+
"checkpoints": base / "checkpoints",
|
| 43 |
+
"metrics": base / "metrics",
|
| 44 |
+
"plots": base / "plots",
|
| 45 |
+
"replay": base / "replay",
|
| 46 |
+
}
|
| 47 |
+
for path in paths.values():
|
| 48 |
+
path.mkdir(parents=True, exist_ok=True)
|
| 49 |
+
return paths
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def deep_get(config: ConfigDict, *keys: str, default: Any = None) -> Any:
|
| 53 |
+
current: Any = config
|
| 54 |
+
for key in keys:
|
| 55 |
+
if not isinstance(current, dict) or key not in current:
|
| 56 |
+
return default
|
| 57 |
+
current = current[key]
|
| 58 |
+
return current
|
| 59 |
+
|
src/minidreamer/envs/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Environment helpers."""
|
| 2 |
+
|
src/minidreamer/envs/make_env.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Iterable
|
| 5 |
+
|
| 6 |
+
import gymnasium as gym
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
from gymnasium import spaces
|
| 10 |
+
from minigrid.wrappers import ImgObsWrapper, RGBImgPartialObsWrapper
|
| 11 |
+
from PIL import Image
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclass(frozen=True)
|
| 15 |
+
class EnvSpec:
|
| 16 |
+
env_id: str
|
| 17 |
+
resize: tuple[int, int] = (64, 64)
|
| 18 |
+
normalize_obs: bool = True
|
| 19 |
+
rgb_partial_obs: bool = True
|
| 20 |
+
image_only: bool = True
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class ResizeNormalizeObservation(gym.ObservationWrapper):
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
env: gym.Env,
|
| 27 |
+
resize: tuple[int, int] | None = (64, 64),
|
| 28 |
+
normalize: bool = True,
|
| 29 |
+
) -> None:
|
| 30 |
+
super().__init__(env)
|
| 31 |
+
self.resize = resize
|
| 32 |
+
self.normalize = normalize
|
| 33 |
+
base_space = env.observation_space
|
| 34 |
+
if not isinstance(base_space, spaces.Box):
|
| 35 |
+
raise TypeError("MiniDreamer expects a Box observation space after wrappers.")
|
| 36 |
+
channels = base_space.shape[-1]
|
| 37 |
+
if resize is None:
|
| 38 |
+
height, width = base_space.shape[:2]
|
| 39 |
+
else:
|
| 40 |
+
height, width = resize
|
| 41 |
+
low, high = (0.0, 1.0) if normalize else (0, 255)
|
| 42 |
+
dtype = np.float32 if normalize else np.uint8
|
| 43 |
+
self.observation_space = spaces.Box(
|
| 44 |
+
low=low,
|
| 45 |
+
high=high,
|
| 46 |
+
shape=(height, width, channels),
|
| 47 |
+
dtype=dtype,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
def observation(self, observation: np.ndarray) -> np.ndarray:
|
| 51 |
+
obs = observation
|
| 52 |
+
if self.resize is not None and tuple(obs.shape[:2]) != self.resize:
|
| 53 |
+
pil_image = Image.fromarray(obs.astype(np.uint8))
|
| 54 |
+
pil_image = pil_image.resize((self.resize[1], self.resize[0]), Image.Resampling.BILINEAR)
|
| 55 |
+
obs = np.asarray(pil_image)
|
| 56 |
+
if self.normalize:
|
| 57 |
+
return obs.astype(np.float32) / 255.0
|
| 58 |
+
return obs.astype(np.uint8)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def make_env(
|
| 62 |
+
env_id: str = "MiniGrid-FourRooms-v0",
|
| 63 |
+
seed: int | None = None,
|
| 64 |
+
resize: tuple[int, int] = (64, 64),
|
| 65 |
+
normalize_obs: bool = True,
|
| 66 |
+
rgb_partial_obs: bool = True,
|
| 67 |
+
image_only: bool = True,
|
| 68 |
+
render_mode: str | None = None,
|
| 69 |
+
) -> gym.Env:
|
| 70 |
+
env = gym.make(env_id, render_mode=render_mode)
|
| 71 |
+
env = gym.wrappers.RecordEpisodeStatistics(env)
|
| 72 |
+
if rgb_partial_obs:
|
| 73 |
+
env = RGBImgPartialObsWrapper(env)
|
| 74 |
+
if image_only:
|
| 75 |
+
env = ImgObsWrapper(env)
|
| 76 |
+
env = ResizeNormalizeObservation(env, resize=resize, normalize=normalize_obs)
|
| 77 |
+
if seed is not None:
|
| 78 |
+
env.reset(seed=seed)
|
| 79 |
+
env.action_space.seed(seed)
|
| 80 |
+
return env
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def make_env_from_config(config: dict, seed: int | None = None) -> gym.Env:
|
| 84 |
+
env_cfg = config["env"]
|
| 85 |
+
return make_env(
|
| 86 |
+
env_id=env_cfg["id"],
|
| 87 |
+
seed=seed,
|
| 88 |
+
resize=tuple(env_cfg.get("resize", (64, 64))),
|
| 89 |
+
normalize_obs=env_cfg.get("normalize_obs", True),
|
| 90 |
+
rgb_partial_obs=env_cfg.get("rgb_partial_obs", True),
|
| 91 |
+
image_only=env_cfg.get("image_only", True),
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def observation_to_tensor(observation: np.ndarray, device: torch.device | None = None) -> torch.Tensor:
|
| 96 |
+
if observation.ndim != 3:
|
| 97 |
+
raise ValueError(f"Expected HWC observation, got shape {observation.shape}.")
|
| 98 |
+
tensor = torch.from_numpy(observation).permute(2, 0, 1).float()
|
| 99 |
+
return tensor.to(device) if device is not None else tensor
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def batch_observations_to_tensor(
|
| 103 |
+
observations: np.ndarray,
|
| 104 |
+
device: torch.device | None = None,
|
| 105 |
+
) -> torch.Tensor:
|
| 106 |
+
if observations.ndim != 5:
|
| 107 |
+
raise ValueError(f"Expected BT HWC observations, got shape {observations.shape}.")
|
| 108 |
+
tensor = torch.from_numpy(observations).permute(0, 1, 4, 2, 3).float()
|
| 109 |
+
return tensor.to(device) if device is not None else tensor
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def action_subset(action_space_n: int, names: Iterable[str] | None = None) -> list[int]:
|
| 113 |
+
if names is None:
|
| 114 |
+
return list(range(action_space_n))
|
| 115 |
+
lookup = {
|
| 116 |
+
"left": 0,
|
| 117 |
+
"right": 1,
|
| 118 |
+
"forward": 2,
|
| 119 |
+
"pickup": 3,
|
| 120 |
+
"drop": 4,
|
| 121 |
+
"toggle": 5,
|
| 122 |
+
"done": 6,
|
| 123 |
+
}
|
| 124 |
+
return [lookup[name] for name in names]
|
| 125 |
+
|
src/minidreamer/evaluation.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from collections import defaultdict
|
| 4 |
+
from typing import Any
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
from minidreamer.data.replay_buffer import Episode, ReplayBuffer
|
| 11 |
+
from minidreamer.envs.make_env import make_env_from_config
|
| 12 |
+
from minidreamer.models.world_model import WorldModel
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def evaluate_random_policy(config: dict, episodes: int | None = None, seed: int | None = None) -> dict[str, float]:
|
| 16 |
+
eval_cfg = config["evaluation"]
|
| 17 |
+
episodes = episodes or eval_cfg["episodes"]
|
| 18 |
+
seed = config.get("project", {}).get("seed", 0) if seed is None else seed
|
| 19 |
+
env = make_env_from_config(config, seed=seed)
|
| 20 |
+
rng = np.random.default_rng(seed)
|
| 21 |
+
returns = []
|
| 22 |
+
lengths = []
|
| 23 |
+
successes = []
|
| 24 |
+
for episode_idx in range(episodes):
|
| 25 |
+
obs, _ = env.reset(seed=seed + episode_idx)
|
| 26 |
+
total_return = 0.0
|
| 27 |
+
terminated = False
|
| 28 |
+
truncated = False
|
| 29 |
+
length = 0
|
| 30 |
+
while not (terminated or truncated):
|
| 31 |
+
action = int(rng.integers(0, env.action_space.n))
|
| 32 |
+
obs, reward, terminated, truncated, _ = env.step(action)
|
| 33 |
+
total_return += float(reward)
|
| 34 |
+
length += 1
|
| 35 |
+
returns.append(total_return)
|
| 36 |
+
lengths.append(length)
|
| 37 |
+
successes.append(float(terminated and total_return > 0.0))
|
| 38 |
+
env.close()
|
| 39 |
+
returns_array = np.asarray(returns, dtype=np.float32)
|
| 40 |
+
lengths_array = np.asarray(lengths, dtype=np.float32)
|
| 41 |
+
successes_array = np.asarray(successes, dtype=np.float32)
|
| 42 |
+
return {
|
| 43 |
+
"success_rate": float(successes_array.mean()),
|
| 44 |
+
"mean_return": float(returns_array.mean()),
|
| 45 |
+
"median_return": float(np.median(returns_array)),
|
| 46 |
+
"mean_episode_length": float(lengths_array.mean()),
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _episode_to_batch(episode: Episode, device: torch.device) -> dict[str, torch.Tensor]:
|
| 51 |
+
batch = {
|
| 52 |
+
"obs": torch.from_numpy(episode.obs[None]).permute(0, 1, 4, 2, 3).float().to(device),
|
| 53 |
+
"actions": torch.from_numpy(episode.actions[None]).long().to(device),
|
| 54 |
+
"rewards": torch.from_numpy(episode.rewards[None]).float().to(device),
|
| 55 |
+
"terminated": torch.from_numpy(episode.terminated[None]).float().to(device),
|
| 56 |
+
"truncated": torch.from_numpy(episode.truncated[None]).float().to(device),
|
| 57 |
+
"done": torch.from_numpy(episode.done[None]).float().to(device),
|
| 58 |
+
"mask": torch.ones((1, episode.length), dtype=torch.float32, device=device),
|
| 59 |
+
}
|
| 60 |
+
return batch
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def _sequence_state(model: WorldModel, episode: Episode, start_idx: int) -> Any:
|
| 64 |
+
state = model.posterior_step(model.initial_state(1), None, episode.obs[0], sample=False)
|
| 65 |
+
for idx in range(start_idx):
|
| 66 |
+
state = model.posterior_step(state, int(episode.actions[idx]), episode.obs[idx + 1], sample=False)
|
| 67 |
+
return state
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def _discounted_return(rewards: np.ndarray, done: np.ndarray, discount: float) -> float:
|
| 71 |
+
total = 0.0
|
| 72 |
+
alive = 1.0
|
| 73 |
+
for step, reward in enumerate(rewards):
|
| 74 |
+
total += (discount**step) * alive * float(reward)
|
| 75 |
+
alive *= 1.0 - float(done[step])
|
| 76 |
+
return total
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def evaluate_world_model(
|
| 80 |
+
config: dict,
|
| 81 |
+
model: WorldModel,
|
| 82 |
+
replay: ReplayBuffer,
|
| 83 |
+
split: str = "val",
|
| 84 |
+
max_episodes: int | None = None,
|
| 85 |
+
) -> dict[str, float]:
|
| 86 |
+
device = model.device
|
| 87 |
+
model.eval()
|
| 88 |
+
metrics: dict[str, list[float]] = defaultdict(list)
|
| 89 |
+
horizons = [1, 5, 10]
|
| 90 |
+
discount = float(config["planner"]["discount"])
|
| 91 |
+
episodes = replay.episode_ids(split)
|
| 92 |
+
if max_episodes is not None:
|
| 93 |
+
episodes = episodes[:max_episodes]
|
| 94 |
+
|
| 95 |
+
with torch.no_grad():
|
| 96 |
+
for episode_id in episodes:
|
| 97 |
+
episode = replay.episodes[episode_id]
|
| 98 |
+
batch = _episode_to_batch(episode, device)
|
| 99 |
+
outputs = model.observe_sequence(batch["obs"], batch["actions"], sample=False)
|
| 100 |
+
reward_mse = F.mse_loss(outputs.reward_pred, batch["rewards"], reduction="none").mean()
|
| 101 |
+
done_bce = F.binary_cross_entropy_with_logits(outputs.done_logits, batch["done"], reduction="none").mean()
|
| 102 |
+
kl = model.rssm.kl_divergence(
|
| 103 |
+
outputs.post_mean,
|
| 104 |
+
outputs.post_std,
|
| 105 |
+
outputs.prior_mean,
|
| 106 |
+
outputs.prior_std,
|
| 107 |
+
).mean()
|
| 108 |
+
metrics["reward_mse"].append(float(reward_mse.cpu()))
|
| 109 |
+
metrics["done_bce"].append(float(done_bce.cpu()))
|
| 110 |
+
metrics["kl_loss"].append(float(kl.cpu()))
|
| 111 |
+
if outputs.reconstructions is not None:
|
| 112 |
+
recon_mse = F.mse_loss(outputs.reconstructions, batch["obs"][:, 1:], reduction="none").mean()
|
| 113 |
+
metrics["reconstruction_mse"].append(float(recon_mse.cpu()))
|
| 114 |
+
|
| 115 |
+
for horizon in horizons:
|
| 116 |
+
if episode.length < horizon:
|
| 117 |
+
continue
|
| 118 |
+
reward_errors = []
|
| 119 |
+
done_correct = []
|
| 120 |
+
imagined_returns = []
|
| 121 |
+
real_returns = []
|
| 122 |
+
for start_idx in range(episode.length - horizon + 1):
|
| 123 |
+
state = _sequence_state(model, episode, start_idx)
|
| 124 |
+
actions = torch.from_numpy(episode.actions[start_idx : start_idx + horizon]).long().to(device)
|
| 125 |
+
rollout = model.score_action_sequences(
|
| 126 |
+
state,
|
| 127 |
+
actions.unsqueeze(0),
|
| 128 |
+
discount=discount,
|
| 129 |
+
use_done_mask=True,
|
| 130 |
+
)
|
| 131 |
+
reward_pred = rollout["reward_pred"].squeeze(0).cpu().numpy()
|
| 132 |
+
done_prob = rollout["done_prob"].squeeze(0).cpu().numpy()
|
| 133 |
+
done_pred = (done_prob >= 0.5).astype(np.float32)
|
| 134 |
+
real_rewards = episode.rewards[start_idx : start_idx + horizon]
|
| 135 |
+
real_done = episode.done[start_idx : start_idx + horizon]
|
| 136 |
+
reward_errors.append(np.mean((reward_pred - real_rewards) ** 2))
|
| 137 |
+
done_correct.append(np.mean(done_pred == real_done))
|
| 138 |
+
imagined_returns.append(float(rollout["scores"].squeeze(0).cpu()))
|
| 139 |
+
real_returns.append(_discounted_return(real_rewards, real_done, discount))
|
| 140 |
+
metrics[f"open_loop_reward_error_h{horizon}"].append(float(np.mean(reward_errors)))
|
| 141 |
+
metrics[f"open_loop_done_accuracy_h{horizon}"].append(float(np.mean(done_correct)))
|
| 142 |
+
if len(imagined_returns) > 1 and np.std(real_returns) > 0.0 and np.std(imagined_returns) > 0.0:
|
| 143 |
+
correlation = np.corrcoef(imagined_returns, real_returns)[0, 1]
|
| 144 |
+
metrics["imagined_return_vs_real_return_correlation"].append(float(correlation))
|
| 145 |
+
|
| 146 |
+
return {name: float(np.mean(values)) for name, values in metrics.items() if values}
|
| 147 |
+
|
src/minidreamer/models/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Model components for MiniDreamer."""
|
| 2 |
+
|
src/minidreamer/models/decoder.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from torch import nn
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class ConvDecoder(nn.Module):
|
| 7 |
+
def __init__(self, feature_dim: int, out_channels: int = 3) -> None:
|
| 8 |
+
super().__init__()
|
| 9 |
+
self.projection = nn.Sequential(
|
| 10 |
+
nn.Linear(feature_dim, 256 * 4 * 4),
|
| 11 |
+
nn.ReLU(),
|
| 12 |
+
)
|
| 13 |
+
self.decoder = nn.Sequential(
|
| 14 |
+
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
|
| 15 |
+
nn.ReLU(),
|
| 16 |
+
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
|
| 17 |
+
nn.ReLU(),
|
| 18 |
+
nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
|
| 19 |
+
nn.ReLU(),
|
| 20 |
+
nn.ConvTranspose2d(32, out_channels, kernel_size=4, stride=2, padding=1),
|
| 21 |
+
nn.Sigmoid(),
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
def forward(self, features):
|
| 25 |
+
hidden = self.projection(features).view(-1, 256, 4, 4)
|
| 26 |
+
return self.decoder(hidden)
|
| 27 |
+
|
src/minidreamer/models/encoder.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class ConvEncoder(nn.Module):
|
| 8 |
+
def __init__(self, in_channels: int = 3, embedding_dim: int = 256) -> None:
|
| 9 |
+
super().__init__()
|
| 10 |
+
self.conv = nn.Sequential(
|
| 11 |
+
nn.Conv2d(in_channels, 32, kernel_size=4, stride=2, padding=1),
|
| 12 |
+
nn.ReLU(),
|
| 13 |
+
nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
|
| 14 |
+
nn.ReLU(),
|
| 15 |
+
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
|
| 16 |
+
nn.ReLU(),
|
| 17 |
+
nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
|
| 18 |
+
nn.ReLU(),
|
| 19 |
+
)
|
| 20 |
+
self.projection = nn.Sequential(
|
| 21 |
+
nn.Flatten(),
|
| 22 |
+
nn.Linear(256 * 4 * 4, embedding_dim),
|
| 23 |
+
nn.ReLU(),
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
def forward(self, obs: torch.Tensor) -> torch.Tensor:
|
| 27 |
+
hidden = self.conv(obs)
|
| 28 |
+
return self.projection(hidden)
|
| 29 |
+
|
src/minidreamer/models/heads.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from torch import nn
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class MLPHead(nn.Module):
|
| 7 |
+
def __init__(self, in_dim: int, out_dim: int, hidden_dim: int = 256) -> None:
|
| 8 |
+
super().__init__()
|
| 9 |
+
self.net = nn.Sequential(
|
| 10 |
+
nn.Linear(in_dim, hidden_dim),
|
| 11 |
+
nn.ELU(),
|
| 12 |
+
nn.Linear(hidden_dim, hidden_dim),
|
| 13 |
+
nn.ELU(),
|
| 14 |
+
nn.Linear(hidden_dim, out_dim),
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
def forward(self, x):
|
| 18 |
+
return self.net(x)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class RewardHead(MLPHead):
|
| 22 |
+
def __init__(self, in_dim: int, hidden_dim: int = 256) -> None:
|
| 23 |
+
super().__init__(in_dim=in_dim, out_dim=1, hidden_dim=hidden_dim)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class DoneHead(MLPHead):
|
| 27 |
+
def __init__(self, in_dim: int, hidden_dim: int = 256) -> None:
|
| 28 |
+
super().__init__(in_dim=in_dim, out_dim=1, hidden_dim=hidden_dim)
|
| 29 |
+
|
src/minidreamer/models/rssm.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from torch import nn
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@dataclass
|
| 11 |
+
class RSSMState:
|
| 12 |
+
deter: torch.Tensor
|
| 13 |
+
stoch: torch.Tensor
|
| 14 |
+
mean: torch.Tensor
|
| 15 |
+
std: torch.Tensor
|
| 16 |
+
|
| 17 |
+
def features(self) -> torch.Tensor:
|
| 18 |
+
return torch.cat([self.deter, self.stoch], dim=-1)
|
| 19 |
+
|
| 20 |
+
def detach(self) -> "RSSMState":
|
| 21 |
+
return RSSMState(
|
| 22 |
+
deter=self.deter.detach(),
|
| 23 |
+
stoch=self.stoch.detach(),
|
| 24 |
+
mean=self.mean.detach(),
|
| 25 |
+
std=self.std.detach(),
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
def repeat(self, count: int) -> "RSSMState":
|
| 29 |
+
return RSSMState(
|
| 30 |
+
deter=self.deter.repeat(count, 1),
|
| 31 |
+
stoch=self.stoch.repeat(count, 1),
|
| 32 |
+
mean=self.mean.repeat(count, 1),
|
| 33 |
+
std=self.std.repeat(count, 1),
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class RSSM(nn.Module):
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
action_dim: int,
|
| 41 |
+
embedding_dim: int = 256,
|
| 42 |
+
deter_dim: int = 256,
|
| 43 |
+
stoch_dim: int = 32,
|
| 44 |
+
hidden_dim: int = 256,
|
| 45 |
+
min_std: float = 0.1,
|
| 46 |
+
) -> None:
|
| 47 |
+
super().__init__()
|
| 48 |
+
self.action_dim = action_dim
|
| 49 |
+
self.embedding_dim = embedding_dim
|
| 50 |
+
self.deter_dim = deter_dim
|
| 51 |
+
self.stoch_dim = stoch_dim
|
| 52 |
+
self.hidden_dim = hidden_dim
|
| 53 |
+
self.min_std = min_std
|
| 54 |
+
|
| 55 |
+
self.input_net = nn.Sequential(
|
| 56 |
+
nn.Linear(stoch_dim + action_dim, hidden_dim),
|
| 57 |
+
nn.ELU(),
|
| 58 |
+
)
|
| 59 |
+
self.gru = nn.GRUCell(hidden_dim, deter_dim)
|
| 60 |
+
self.prior_net = nn.Sequential(
|
| 61 |
+
nn.Linear(deter_dim, hidden_dim),
|
| 62 |
+
nn.ELU(),
|
| 63 |
+
nn.Linear(hidden_dim, 2 * stoch_dim),
|
| 64 |
+
)
|
| 65 |
+
self.posterior_net = nn.Sequential(
|
| 66 |
+
nn.Linear(deter_dim + embedding_dim, hidden_dim),
|
| 67 |
+
nn.ELU(),
|
| 68 |
+
nn.Linear(hidden_dim, 2 * stoch_dim),
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
def initial(self, batch_size: int, device: torch.device) -> RSSMState:
|
| 72 |
+
zeros_deter = torch.zeros(batch_size, self.deter_dim, device=device)
|
| 73 |
+
zeros_stoch = torch.zeros(batch_size, self.stoch_dim, device=device)
|
| 74 |
+
return RSSMState(
|
| 75 |
+
deter=zeros_deter,
|
| 76 |
+
stoch=zeros_stoch,
|
| 77 |
+
mean=zeros_stoch,
|
| 78 |
+
std=torch.ones_like(zeros_stoch),
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
def _stats(self, tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
| 82 |
+
mean, raw_std = torch.chunk(tensor, 2, dim=-1)
|
| 83 |
+
std = F.softplus(raw_std) + self.min_std
|
| 84 |
+
return mean, std
|
| 85 |
+
|
| 86 |
+
def _action_one_hot(self, action: torch.Tensor) -> torch.Tensor:
|
| 87 |
+
action = action.long().view(-1)
|
| 88 |
+
return F.one_hot(action, num_classes=self.action_dim).float()
|
| 89 |
+
|
| 90 |
+
def _next_deter(self, prev_state: RSSMState, action: torch.Tensor) -> torch.Tensor:
|
| 91 |
+
action_one_hot = self._action_one_hot(action)
|
| 92 |
+
gru_input = self.input_net(torch.cat([prev_state.stoch, action_one_hot], dim=-1))
|
| 93 |
+
return self.gru(gru_input, prev_state.deter)
|
| 94 |
+
|
| 95 |
+
def prior(self, deter: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
| 96 |
+
return self._stats(self.prior_net(deter))
|
| 97 |
+
|
| 98 |
+
def posterior(self, deter: torch.Tensor, embed: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
| 99 |
+
return self._stats(self.posterior_net(torch.cat([deter, embed], dim=-1)))
|
| 100 |
+
|
| 101 |
+
@staticmethod
|
| 102 |
+
def sample(mean: torch.Tensor, std: torch.Tensor, sample: bool = True) -> torch.Tensor:
|
| 103 |
+
if sample:
|
| 104 |
+
return mean + torch.randn_like(std) * std
|
| 105 |
+
return mean
|
| 106 |
+
|
| 107 |
+
def observe(
|
| 108 |
+
self,
|
| 109 |
+
prev_state: RSSMState,
|
| 110 |
+
prev_action: torch.Tensor | None,
|
| 111 |
+
embed: torch.Tensor,
|
| 112 |
+
sample: bool = True,
|
| 113 |
+
) -> tuple[RSSMState, tuple[torch.Tensor, torch.Tensor]]:
|
| 114 |
+
if prev_action is None:
|
| 115 |
+
deter = prev_state.deter
|
| 116 |
+
else:
|
| 117 |
+
deter = self._next_deter(prev_state, prev_action)
|
| 118 |
+
prior_mean, prior_std = self.prior(deter)
|
| 119 |
+
post_mean, post_std = self.posterior(deter, embed)
|
| 120 |
+
stoch = self.sample(post_mean, post_std, sample=sample)
|
| 121 |
+
state = RSSMState(deter=deter, stoch=stoch, mean=post_mean, std=post_std)
|
| 122 |
+
return state, (prior_mean, prior_std)
|
| 123 |
+
|
| 124 |
+
def imagine(
|
| 125 |
+
self,
|
| 126 |
+
prev_state: RSSMState,
|
| 127 |
+
action: torch.Tensor,
|
| 128 |
+
sample: bool = False,
|
| 129 |
+
) -> tuple[RSSMState, tuple[torch.Tensor, torch.Tensor]]:
|
| 130 |
+
deter = self._next_deter(prev_state, action)
|
| 131 |
+
mean, std = self.prior(deter)
|
| 132 |
+
stoch = self.sample(mean, std, sample=sample)
|
| 133 |
+
state = RSSMState(deter=deter, stoch=stoch, mean=mean, std=std)
|
| 134 |
+
return state, (mean, std)
|
| 135 |
+
|
| 136 |
+
@staticmethod
|
| 137 |
+
def kl_divergence(
|
| 138 |
+
post_mean: torch.Tensor,
|
| 139 |
+
post_std: torch.Tensor,
|
| 140 |
+
prior_mean: torch.Tensor,
|
| 141 |
+
prior_std: torch.Tensor,
|
| 142 |
+
) -> torch.Tensor:
|
| 143 |
+
log_var_ratio = 2.0 * (torch.log(prior_std) - torch.log(post_std))
|
| 144 |
+
var_ratio = (post_std / prior_std) ** 2
|
| 145 |
+
mean_term = ((post_mean - prior_mean) / prior_std) ** 2
|
| 146 |
+
return 0.5 * torch.sum(var_ratio + mean_term + log_var_ratio - 1.0, dim=-1)
|
| 147 |
+
|
src/minidreamer/models/world_model.py
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Any
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from torch import nn
|
| 10 |
+
|
| 11 |
+
from minidreamer.models.decoder import ConvDecoder
|
| 12 |
+
from minidreamer.models.encoder import ConvEncoder
|
| 13 |
+
from minidreamer.models.heads import DoneHead, RewardHead
|
| 14 |
+
from minidreamer.models.rssm import RSSM, RSSMState
|
| 15 |
+
from minidreamer.utils.common import masked_mean
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@dataclass
|
| 19 |
+
class WorldModelOutputs:
|
| 20 |
+
states: list[RSSMState]
|
| 21 |
+
prior_mean: torch.Tensor
|
| 22 |
+
prior_std: torch.Tensor
|
| 23 |
+
post_mean: torch.Tensor
|
| 24 |
+
post_std: torch.Tensor
|
| 25 |
+
reward_pred: torch.Tensor
|
| 26 |
+
done_logits: torch.Tensor
|
| 27 |
+
reconstructions: torch.Tensor | None
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class WorldModel(nn.Module):
|
| 31 |
+
def __init__(
|
| 32 |
+
self,
|
| 33 |
+
action_dim: int,
|
| 34 |
+
embedding_dim: int = 256,
|
| 35 |
+
deter_dim: int = 256,
|
| 36 |
+
stoch_dim: int = 32,
|
| 37 |
+
hidden_dim: int = 256,
|
| 38 |
+
use_decoder: bool = True,
|
| 39 |
+
min_std: float = 0.1,
|
| 40 |
+
obs_channels: int = 3,
|
| 41 |
+
) -> None:
|
| 42 |
+
super().__init__()
|
| 43 |
+
self.action_dim = action_dim
|
| 44 |
+
self.embedding_dim = embedding_dim
|
| 45 |
+
self.deter_dim = deter_dim
|
| 46 |
+
self.stoch_dim = stoch_dim
|
| 47 |
+
self.hidden_dim = hidden_dim
|
| 48 |
+
self.use_decoder = use_decoder
|
| 49 |
+
self.obs_channels = obs_channels
|
| 50 |
+
|
| 51 |
+
self.encoder = ConvEncoder(in_channels=obs_channels, embedding_dim=embedding_dim)
|
| 52 |
+
self.rssm = RSSM(
|
| 53 |
+
action_dim=action_dim,
|
| 54 |
+
embedding_dim=embedding_dim,
|
| 55 |
+
deter_dim=deter_dim,
|
| 56 |
+
stoch_dim=stoch_dim,
|
| 57 |
+
hidden_dim=hidden_dim,
|
| 58 |
+
min_std=min_std,
|
| 59 |
+
)
|
| 60 |
+
feature_dim = deter_dim + stoch_dim
|
| 61 |
+
self.reward_head = RewardHead(feature_dim, hidden_dim=hidden_dim)
|
| 62 |
+
self.done_head = DoneHead(feature_dim, hidden_dim=hidden_dim)
|
| 63 |
+
self.decoder = ConvDecoder(feature_dim, out_channels=obs_channels) if use_decoder else None
|
| 64 |
+
|
| 65 |
+
@classmethod
|
| 66 |
+
def from_config(cls, config: dict, action_dim: int, obs_channels: int = 3) -> "WorldModel":
|
| 67 |
+
model_cfg = config["model"]
|
| 68 |
+
return cls(
|
| 69 |
+
action_dim=action_dim,
|
| 70 |
+
embedding_dim=model_cfg["embedding_dim"],
|
| 71 |
+
deter_dim=model_cfg["deter_dim"],
|
| 72 |
+
stoch_dim=model_cfg["stoch_dim"],
|
| 73 |
+
hidden_dim=model_cfg["hidden_dim"],
|
| 74 |
+
use_decoder=model_cfg.get("use_decoder", True),
|
| 75 |
+
min_std=model_cfg.get("min_std", 0.1),
|
| 76 |
+
obs_channels=obs_channels,
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
@property
|
| 80 |
+
def device(self) -> torch.device:
|
| 81 |
+
return next(self.parameters()).device
|
| 82 |
+
|
| 83 |
+
def initial_state(self, batch_size: int) -> RSSMState:
|
| 84 |
+
return self.rssm.initial(batch_size=batch_size, device=self.device)
|
| 85 |
+
|
| 86 |
+
def encode(self, obs: torch.Tensor) -> torch.Tensor:
|
| 87 |
+
if obs.dim() < 4:
|
| 88 |
+
raise ValueError(f"Expected at least 4 dims for observations, got {obs.shape}.")
|
| 89 |
+
leading_shape = obs.shape[:-3]
|
| 90 |
+
flat_obs = obs.reshape(-1, *obs.shape[-3:])
|
| 91 |
+
flat_embeddings = self.encoder(flat_obs)
|
| 92 |
+
return flat_embeddings.reshape(*leading_shape, -1)
|
| 93 |
+
|
| 94 |
+
def observe_sequence(
|
| 95 |
+
self,
|
| 96 |
+
obs: torch.Tensor,
|
| 97 |
+
actions: torch.Tensor,
|
| 98 |
+
sample: bool = True,
|
| 99 |
+
) -> WorldModelOutputs:
|
| 100 |
+
if obs.dim() != 5:
|
| 101 |
+
raise ValueError(f"Expected obs shape [B, T+1, C, H, W], got {obs.shape}.")
|
| 102 |
+
if actions.dim() != 2:
|
| 103 |
+
raise ValueError(f"Expected actions shape [B, T], got {actions.shape}.")
|
| 104 |
+
batch_size, time_steps = actions.shape
|
| 105 |
+
embeddings = self.encode(obs)
|
| 106 |
+
state, _ = self.rssm.observe(self.initial_state(batch_size), None, embeddings[:, 0], sample=sample)
|
| 107 |
+
|
| 108 |
+
states = [state]
|
| 109 |
+
prior_means = []
|
| 110 |
+
prior_stds = []
|
| 111 |
+
post_means = []
|
| 112 |
+
post_stds = []
|
| 113 |
+
rewards = []
|
| 114 |
+
done_logits = []
|
| 115 |
+
reconstructions = []
|
| 116 |
+
|
| 117 |
+
for t in range(time_steps):
|
| 118 |
+
next_state, (prior_mean, prior_std) = self.rssm.observe(
|
| 119 |
+
state,
|
| 120 |
+
actions[:, t],
|
| 121 |
+
embeddings[:, t + 1],
|
| 122 |
+
sample=sample,
|
| 123 |
+
)
|
| 124 |
+
features = next_state.features()
|
| 125 |
+
prior_means.append(prior_mean)
|
| 126 |
+
prior_stds.append(prior_std)
|
| 127 |
+
post_means.append(next_state.mean)
|
| 128 |
+
post_stds.append(next_state.std)
|
| 129 |
+
rewards.append(self.reward_head(features).squeeze(-1))
|
| 130 |
+
done_logits.append(self.done_head(features).squeeze(-1))
|
| 131 |
+
if self.decoder is not None:
|
| 132 |
+
reconstructions.append(self.decoder(features))
|
| 133 |
+
states.append(next_state)
|
| 134 |
+
state = next_state
|
| 135 |
+
|
| 136 |
+
recon_tensor = torch.stack(reconstructions, dim=1) if reconstructions else None
|
| 137 |
+
return WorldModelOutputs(
|
| 138 |
+
states=states,
|
| 139 |
+
prior_mean=torch.stack(prior_means, dim=1),
|
| 140 |
+
prior_std=torch.stack(prior_stds, dim=1),
|
| 141 |
+
post_mean=torch.stack(post_means, dim=1),
|
| 142 |
+
post_std=torch.stack(post_stds, dim=1),
|
| 143 |
+
reward_pred=torch.stack(rewards, dim=1),
|
| 144 |
+
done_logits=torch.stack(done_logits, dim=1),
|
| 145 |
+
reconstructions=recon_tensor,
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
def compute_losses(self, batch: dict[str, torch.Tensor], config: dict[str, Any]) -> dict[str, torch.Tensor]:
|
| 149 |
+
outputs = self.observe_sequence(batch["obs"], batch["actions"], sample=True)
|
| 150 |
+
training_cfg = config["training"]
|
| 151 |
+
rewards = batch["rewards"]
|
| 152 |
+
done = batch["done"]
|
| 153 |
+
mask = batch["mask"]
|
| 154 |
+
|
| 155 |
+
reward_loss = masked_mean(F.mse_loss(outputs.reward_pred, rewards, reduction="none"), mask)
|
| 156 |
+
done_loss = masked_mean(
|
| 157 |
+
F.binary_cross_entropy_with_logits(outputs.done_logits, done, reduction="none"),
|
| 158 |
+
mask,
|
| 159 |
+
)
|
| 160 |
+
kl_per_step = self.rssm.kl_divergence(
|
| 161 |
+
outputs.post_mean,
|
| 162 |
+
outputs.post_std,
|
| 163 |
+
outputs.prior_mean,
|
| 164 |
+
outputs.prior_std,
|
| 165 |
+
)
|
| 166 |
+
free_nats = torch.full_like(kl_per_step, float(training_cfg.get("free_nats", 1.0)))
|
| 167 |
+
kl_loss = masked_mean(torch.maximum(kl_per_step, free_nats), mask)
|
| 168 |
+
|
| 169 |
+
if outputs.reconstructions is not None and training_cfg.get("beta_recon", 0.0) > 0.0:
|
| 170 |
+
recon_target = batch["obs"][:, 1:]
|
| 171 |
+
recon_error = F.mse_loss(outputs.reconstructions, recon_target, reduction="none").mean(dim=(2, 3, 4))
|
| 172 |
+
recon_loss = masked_mean(recon_error, mask)
|
| 173 |
+
else:
|
| 174 |
+
recon_loss = torch.zeros((), device=self.device)
|
| 175 |
+
|
| 176 |
+
total_loss = (
|
| 177 |
+
float(training_cfg.get("beta_reward", 1.0)) * reward_loss
|
| 178 |
+
+ float(training_cfg.get("beta_done", 1.0)) * done_loss
|
| 179 |
+
+ float(training_cfg.get("beta_kl", 1.0)) * kl_loss
|
| 180 |
+
+ float(training_cfg.get("beta_recon", 0.0)) * recon_loss
|
| 181 |
+
)
|
| 182 |
+
return {
|
| 183 |
+
"loss": total_loss,
|
| 184 |
+
"reward_loss": reward_loss.detach(),
|
| 185 |
+
"done_loss": done_loss.detach(),
|
| 186 |
+
"kl_loss": kl_loss.detach(),
|
| 187 |
+
"recon_loss": recon_loss.detach(),
|
| 188 |
+
"reward_mse": F.mse_loss(outputs.reward_pred, rewards, reduction="none").mul(mask).sum() / mask.sum().clamp_min(1.0),
|
| 189 |
+
"done_bce": F.binary_cross_entropy_with_logits(outputs.done_logits, done, reduction="none").mul(mask).sum() / mask.sum().clamp_min(1.0),
|
| 190 |
+
}
|
| 191 |
+
|
| 192 |
+
def posterior_step(
|
| 193 |
+
self,
|
| 194 |
+
prev_state: RSSMState,
|
| 195 |
+
prev_action: int | torch.Tensor | None,
|
| 196 |
+
observation: np.ndarray | torch.Tensor,
|
| 197 |
+
sample: bool = False,
|
| 198 |
+
) -> RSSMState:
|
| 199 |
+
obs_tensor = self._prepare_single_observation(observation)
|
| 200 |
+
embed = self.encode(obs_tensor)
|
| 201 |
+
if prev_action is None:
|
| 202 |
+
action_tensor = None
|
| 203 |
+
else:
|
| 204 |
+
action_tensor = torch.as_tensor(prev_action, device=self.device).view(1)
|
| 205 |
+
state, _ = self.rssm.observe(prev_state, action_tensor, embed, sample=sample)
|
| 206 |
+
return state
|
| 207 |
+
|
| 208 |
+
def imagine_rollout(
|
| 209 |
+
self,
|
| 210 |
+
start_state: RSSMState,
|
| 211 |
+
action_sequences: torch.Tensor,
|
| 212 |
+
sample: bool = False,
|
| 213 |
+
) -> dict[str, torch.Tensor]:
|
| 214 |
+
if action_sequences.dim() == 1:
|
| 215 |
+
action_sequences = action_sequences.unsqueeze(0)
|
| 216 |
+
batch_size, horizon = action_sequences.shape
|
| 217 |
+
state = start_state if start_state.deter.shape[0] == batch_size else start_state.repeat(batch_size)
|
| 218 |
+
rewards = []
|
| 219 |
+
done_logits = []
|
| 220 |
+
states = []
|
| 221 |
+
for t in range(horizon):
|
| 222 |
+
state, _ = self.rssm.imagine(state, action_sequences[:, t], sample=sample)
|
| 223 |
+
features = state.features()
|
| 224 |
+
rewards.append(self.reward_head(features).squeeze(-1))
|
| 225 |
+
done_logits.append(self.done_head(features).squeeze(-1))
|
| 226 |
+
states.append(state)
|
| 227 |
+
return {
|
| 228 |
+
"states": states,
|
| 229 |
+
"reward_pred": torch.stack(rewards, dim=1),
|
| 230 |
+
"done_logits": torch.stack(done_logits, dim=1),
|
| 231 |
+
}
|
| 232 |
+
|
| 233 |
+
def score_action_sequences(
|
| 234 |
+
self,
|
| 235 |
+
start_state: RSSMState,
|
| 236 |
+
action_sequences: torch.Tensor,
|
| 237 |
+
discount: float = 0.99,
|
| 238 |
+
use_done_mask: bool = True,
|
| 239 |
+
) -> dict[str, torch.Tensor]:
|
| 240 |
+
rollout = self.imagine_rollout(start_state, action_sequences, sample=False)
|
| 241 |
+
reward_pred = rollout["reward_pred"]
|
| 242 |
+
done_prob = torch.sigmoid(rollout["done_logits"])
|
| 243 |
+
alive = torch.ones(reward_pred.shape[0], device=self.device)
|
| 244 |
+
scores = torch.zeros_like(alive)
|
| 245 |
+
for t in range(reward_pred.shape[1]):
|
| 246 |
+
scores = scores + (discount**t) * alive * reward_pred[:, t]
|
| 247 |
+
if use_done_mask:
|
| 248 |
+
alive = alive * (1.0 - done_prob[:, t])
|
| 249 |
+
rollout["scores"] = scores
|
| 250 |
+
rollout["done_prob"] = done_prob
|
| 251 |
+
return rollout
|
| 252 |
+
|
| 253 |
+
def _prepare_single_observation(self, observation: np.ndarray | torch.Tensor) -> torch.Tensor:
|
| 254 |
+
if torch.is_tensor(observation):
|
| 255 |
+
obs_tensor = observation.to(self.device).float()
|
| 256 |
+
else:
|
| 257 |
+
obs_tensor = torch.as_tensor(observation, device=self.device).float()
|
| 258 |
+
if obs_tensor.dim() == 3 and obs_tensor.shape[-1] in (1, 3):
|
| 259 |
+
obs_tensor = obs_tensor.permute(2, 0, 1).unsqueeze(0)
|
| 260 |
+
elif obs_tensor.dim() == 3 and obs_tensor.shape[0] in (1, 3):
|
| 261 |
+
obs_tensor = obs_tensor.unsqueeze(0)
|
| 262 |
+
elif obs_tensor.dim() == 4:
|
| 263 |
+
pass
|
| 264 |
+
else:
|
| 265 |
+
raise ValueError(f"Unsupported observation shape {tuple(obs_tensor.shape)}.")
|
| 266 |
+
return obs_tensor
|
| 267 |
+
|
src/minidreamer/planning/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Planning utilities."""
|
| 2 |
+
|
src/minidreamer/planning/cem.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
from minidreamer.models.rssm import RSSMState
|
| 9 |
+
from minidreamer.models.world_model import WorldModel
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass
|
| 13 |
+
class PlannerOutput:
|
| 14 |
+
action: int
|
| 15 |
+
sequence: list[int]
|
| 16 |
+
score: float
|
| 17 |
+
policy: torch.Tensor
|
| 18 |
+
entropy: float
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class DiscreteCEMPlanner:
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
world_model: WorldModel,
|
| 25 |
+
action_dim: int,
|
| 26 |
+
horizon: int = 8,
|
| 27 |
+
candidates: int = 256,
|
| 28 |
+
elites: int = 32,
|
| 29 |
+
iterations: int = 4,
|
| 30 |
+
discount: float = 0.99,
|
| 31 |
+
use_done_mask: bool = True,
|
| 32 |
+
smoothing: float = 1e-3,
|
| 33 |
+
) -> None:
|
| 34 |
+
self.world_model = world_model
|
| 35 |
+
self.action_dim = action_dim
|
| 36 |
+
self.horizon = horizon
|
| 37 |
+
self.candidates = candidates
|
| 38 |
+
self.elites = min(elites, candidates)
|
| 39 |
+
self.iterations = iterations
|
| 40 |
+
self.discount = discount
|
| 41 |
+
self.use_done_mask = use_done_mask
|
| 42 |
+
self.smoothing = smoothing
|
| 43 |
+
|
| 44 |
+
@classmethod
|
| 45 |
+
def from_config(cls, world_model: WorldModel, action_dim: int, config: dict) -> "DiscreteCEMPlanner":
|
| 46 |
+
planner_cfg = config["planner"]
|
| 47 |
+
return cls(
|
| 48 |
+
world_model=world_model,
|
| 49 |
+
action_dim=action_dim,
|
| 50 |
+
horizon=planner_cfg["horizon"],
|
| 51 |
+
candidates=planner_cfg["candidates"],
|
| 52 |
+
elites=planner_cfg["elites"],
|
| 53 |
+
iterations=planner_cfg["iterations"],
|
| 54 |
+
discount=planner_cfg["discount"],
|
| 55 |
+
use_done_mask=planner_cfg.get("use_done_mask", True),
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
def _sample_sequences(self, probs: torch.Tensor) -> torch.Tensor:
|
| 59 |
+
flat = probs.unsqueeze(0).expand(self.candidates, -1, -1).reshape(-1, self.action_dim)
|
| 60 |
+
sampled = torch.multinomial(flat, num_samples=1, replacement=True)
|
| 61 |
+
return sampled.view(self.candidates, self.horizon)
|
| 62 |
+
|
| 63 |
+
def plan(self, state: RSSMState) -> PlannerOutput:
|
| 64 |
+
device = self.world_model.device
|
| 65 |
+
probs = torch.full(
|
| 66 |
+
(self.horizon, self.action_dim),
|
| 67 |
+
fill_value=1.0 / self.action_dim,
|
| 68 |
+
device=device,
|
| 69 |
+
)
|
| 70 |
+
best_sequence = None
|
| 71 |
+
best_score = torch.tensor(float("-inf"), device=device)
|
| 72 |
+
|
| 73 |
+
for _ in range(self.iterations):
|
| 74 |
+
action_sequences = self._sample_sequences(probs)
|
| 75 |
+
scores = self.world_model.score_action_sequences(
|
| 76 |
+
state,
|
| 77 |
+
action_sequences,
|
| 78 |
+
discount=self.discount,
|
| 79 |
+
use_done_mask=self.use_done_mask,
|
| 80 |
+
)["scores"]
|
| 81 |
+
elite_indices = torch.topk(scores, k=self.elites, largest=True).indices
|
| 82 |
+
elites = action_sequences[elite_indices]
|
| 83 |
+
elite_freq = F.one_hot(elites, num_classes=self.action_dim).float().mean(dim=0)
|
| 84 |
+
probs = elite_freq + self.smoothing
|
| 85 |
+
probs = probs / probs.sum(dim=-1, keepdim=True)
|
| 86 |
+
|
| 87 |
+
iteration_best_idx = scores.argmax()
|
| 88 |
+
iteration_best_score = scores[iteration_best_idx]
|
| 89 |
+
if iteration_best_score > best_score:
|
| 90 |
+
best_score = iteration_best_score
|
| 91 |
+
best_sequence = action_sequences[iteration_best_idx]
|
| 92 |
+
|
| 93 |
+
if best_sequence is None:
|
| 94 |
+
raise RuntimeError("CEM planner failed to sample any action sequence.")
|
| 95 |
+
entropy = float((-(probs * probs.clamp_min(1e-8).log()).sum(dim=-1)).mean().detach().cpu())
|
| 96 |
+
return PlannerOutput(
|
| 97 |
+
action=int(best_sequence[0].item()),
|
| 98 |
+
sequence=[int(action.item()) for action in best_sequence],
|
| 99 |
+
score=float(best_score.detach().cpu()),
|
| 100 |
+
policy=probs.detach().cpu(),
|
| 101 |
+
entropy=entropy,
|
| 102 |
+
)
|
| 103 |
+
|
src/minidreamer/planning/evaluate_planner.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from minidreamer.envs.make_env import make_env_from_config
|
| 9 |
+
from minidreamer.planning.cem import DiscreteCEMPlanner
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass
|
| 13 |
+
class PlannerEpisode:
|
| 14 |
+
success: bool
|
| 15 |
+
total_return: float
|
| 16 |
+
length: int
|
| 17 |
+
terminated: bool
|
| 18 |
+
truncated: bool
|
| 19 |
+
planner_entropy: float
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def run_planner_episode(
|
| 23 |
+
env,
|
| 24 |
+
world_model,
|
| 25 |
+
planner: DiscreteCEMPlanner,
|
| 26 |
+
rng: np.random.Generator,
|
| 27 |
+
seed: int | None = None,
|
| 28 |
+
random_action_fraction: float = 0.0,
|
| 29 |
+
) -> PlannerEpisode:
|
| 30 |
+
obs, _ = env.reset(seed=seed)
|
| 31 |
+
world_model.eval()
|
| 32 |
+
with torch.no_grad():
|
| 33 |
+
state = world_model.posterior_step(world_model.initial_state(1), None, obs, sample=False)
|
| 34 |
+
total_return = 0.0
|
| 35 |
+
length = 0
|
| 36 |
+
terminated = False
|
| 37 |
+
truncated = False
|
| 38 |
+
entropies: list[float] = []
|
| 39 |
+
|
| 40 |
+
while not (terminated or truncated):
|
| 41 |
+
if rng.random() < random_action_fraction:
|
| 42 |
+
action = int(env.action_space.sample())
|
| 43 |
+
else:
|
| 44 |
+
plan = planner.plan(state)
|
| 45 |
+
action = plan.action
|
| 46 |
+
entropies.append(plan.entropy)
|
| 47 |
+
obs, reward, terminated, truncated, _ = env.step(action)
|
| 48 |
+
total_return += float(reward)
|
| 49 |
+
length += 1
|
| 50 |
+
if not (terminated or truncated):
|
| 51 |
+
state = world_model.posterior_step(state, action, obs, sample=False)
|
| 52 |
+
|
| 53 |
+
return PlannerEpisode(
|
| 54 |
+
success=bool(terminated and total_return > 0.0),
|
| 55 |
+
total_return=total_return,
|
| 56 |
+
length=length,
|
| 57 |
+
terminated=bool(terminated),
|
| 58 |
+
truncated=bool(truncated),
|
| 59 |
+
planner_entropy=float(np.mean(entropies)) if entropies else float("nan"),
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def evaluate_planner(
|
| 64 |
+
config: dict,
|
| 65 |
+
world_model,
|
| 66 |
+
episodes: int | None = None,
|
| 67 |
+
seed: int | None = None,
|
| 68 |
+
random_action_fraction: float = 0.0,
|
| 69 |
+
) -> dict[str, float]:
|
| 70 |
+
eval_cfg = config["evaluation"]
|
| 71 |
+
episodes = episodes or eval_cfg["episodes"]
|
| 72 |
+
seed = config.get("project", {}).get("seed", 0) if seed is None else seed
|
| 73 |
+
env = make_env_from_config(config, seed=seed)
|
| 74 |
+
planner = DiscreteCEMPlanner.from_config(world_model, env.action_space.n, config)
|
| 75 |
+
rng = np.random.default_rng(seed)
|
| 76 |
+
results = [
|
| 77 |
+
run_planner_episode(
|
| 78 |
+
env,
|
| 79 |
+
world_model,
|
| 80 |
+
planner,
|
| 81 |
+
rng,
|
| 82 |
+
seed=seed + episode_idx,
|
| 83 |
+
random_action_fraction=random_action_fraction,
|
| 84 |
+
)
|
| 85 |
+
for episode_idx in range(episodes)
|
| 86 |
+
]
|
| 87 |
+
env.close()
|
| 88 |
+
|
| 89 |
+
returns = np.asarray([result.total_return for result in results], dtype=np.float32)
|
| 90 |
+
lengths = np.asarray([result.length for result in results], dtype=np.float32)
|
| 91 |
+
successes = np.asarray([result.success for result in results], dtype=np.float32)
|
| 92 |
+
entropies = np.asarray([result.planner_entropy for result in results], dtype=np.float32)
|
| 93 |
+
return {
|
| 94 |
+
"success_rate": float(successes.mean()),
|
| 95 |
+
"mean_return": float(returns.mean()),
|
| 96 |
+
"median_return": float(np.median(returns)),
|
| 97 |
+
"mean_episode_length": float(lengths.mean()),
|
| 98 |
+
"planner_action_entropy": float(np.nanmean(entropies)),
|
| 99 |
+
}
|
src/minidreamer/serialization.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Any
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from minidreamer.models.world_model import WorldModel
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def save_world_model_checkpoint(
|
| 12 |
+
path: str | Path,
|
| 13 |
+
model: WorldModel,
|
| 14 |
+
config: dict[str, Any],
|
| 15 |
+
optimizer: torch.optim.Optimizer | None = None,
|
| 16 |
+
metadata: dict[str, Any] | None = None,
|
| 17 |
+
) -> None:
|
| 18 |
+
path = Path(path)
|
| 19 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 20 |
+
payload = {
|
| 21 |
+
"model_state": model.state_dict(),
|
| 22 |
+
"config": config,
|
| 23 |
+
"metadata": metadata or {},
|
| 24 |
+
}
|
| 25 |
+
if optimizer is not None:
|
| 26 |
+
payload["optimizer_state"] = optimizer.state_dict()
|
| 27 |
+
torch.save(payload, path)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def load_world_model_checkpoint(
|
| 31 |
+
path: str | Path,
|
| 32 |
+
action_dim: int,
|
| 33 |
+
map_location: str | torch.device | None = None,
|
| 34 |
+
) -> tuple[WorldModel, dict[str, Any], dict[str, Any]]:
|
| 35 |
+
payload = torch.load(path, map_location=map_location, weights_only=False)
|
| 36 |
+
config = payload["config"]
|
| 37 |
+
model = WorldModel.from_config(config, action_dim=action_dim)
|
| 38 |
+
model.load_state_dict(payload["model_state"])
|
| 39 |
+
return model, config, payload.get("metadata", {})
|
| 40 |
+
|
src/minidreamer/utils/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Utility helpers for MiniDreamer."""
|
| 2 |
+
|
src/minidreamer/utils/common.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import random
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Any
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def seed_everything(seed: int) -> None:
|
| 13 |
+
random.seed(seed)
|
| 14 |
+
np.random.seed(seed)
|
| 15 |
+
torch.manual_seed(seed)
|
| 16 |
+
if torch.cuda.is_available():
|
| 17 |
+
torch.cuda.manual_seed_all(seed)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def get_device(device: str | None = None) -> torch.device:
|
| 21 |
+
if device is not None:
|
| 22 |
+
return torch.device(device)
|
| 23 |
+
if torch.cuda.is_available():
|
| 24 |
+
return torch.device("cuda")
|
| 25 |
+
if torch.backends.mps.is_available():
|
| 26 |
+
return torch.device("mps")
|
| 27 |
+
return torch.device("cpu")
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def to_numpy(value: Any) -> np.ndarray:
|
| 31 |
+
if isinstance(value, np.ndarray):
|
| 32 |
+
return value
|
| 33 |
+
if torch.is_tensor(value):
|
| 34 |
+
return value.detach().cpu().numpy()
|
| 35 |
+
return np.asarray(value)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def masked_mean(values: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
| 39 |
+
mask = mask.to(values.dtype)
|
| 40 |
+
denom = torch.clamp(mask.sum(), min=1.0)
|
| 41 |
+
return (values * mask).sum() / denom
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def write_json(path: str | Path, payload: dict[str, Any]) -> None:
|
| 45 |
+
path = Path(path)
|
| 46 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 47 |
+
with path.open("w", encoding="utf-8") as handle:
|
| 48 |
+
json.dump(payload, handle, indent=2, sort_keys=True)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def write_jsonl(path: str | Path, rows: list[dict[str, Any]]) -> None:
|
| 52 |
+
path = Path(path)
|
| 53 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 54 |
+
with path.open("w", encoding="utf-8") as handle:
|
| 55 |
+
for row in rows:
|
| 56 |
+
handle.write(json.dumps(row, sort_keys=True))
|
| 57 |
+
handle.write("\n")
|
| 58 |
+
|
src/train_world_model.py
ADDED
|
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Any
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
from torch.nn.utils import clip_grad_norm_
|
| 10 |
+
from tqdm import trange
|
| 11 |
+
|
| 12 |
+
from minidreamer.config import ensure_run_dirs, load_config, merge_dicts, save_config
|
| 13 |
+
from minidreamer.data.collect_random import collect_bootstrap_dataset
|
| 14 |
+
from minidreamer.data.replay_buffer import ReplayBuffer
|
| 15 |
+
from minidreamer.evaluation import evaluate_random_policy, evaluate_world_model
|
| 16 |
+
from minidreamer.envs.make_env import make_env_from_config
|
| 17 |
+
from minidreamer.models.world_model import WorldModel
|
| 18 |
+
from minidreamer.planning.cem import DiscreteCEMPlanner
|
| 19 |
+
from minidreamer.planning.evaluate_planner import evaluate_planner
|
| 20 |
+
from minidreamer.serialization import save_world_model_checkpoint
|
| 21 |
+
from minidreamer.utils.common import get_device, seed_everything, write_json, write_jsonl
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def train_world_model_updates(
|
| 25 |
+
model: WorldModel,
|
| 26 |
+
replay: ReplayBuffer,
|
| 27 |
+
optimizer: torch.optim.Optimizer,
|
| 28 |
+
config: dict[str, Any],
|
| 29 |
+
num_updates: int,
|
| 30 |
+
device: torch.device,
|
| 31 |
+
) -> list[dict[str, float]]:
|
| 32 |
+
if num_updates <= 0:
|
| 33 |
+
return []
|
| 34 |
+
model.train()
|
| 35 |
+
logs: list[dict[str, float]] = []
|
| 36 |
+
progress = trange(num_updates, desc="world-model-updates", leave=False)
|
| 37 |
+
for _ in progress:
|
| 38 |
+
batch = ReplayBuffer.batch_to_torch(replay.sample_sequences(split="train"), device=device)
|
| 39 |
+
losses = model.compute_losses(batch, config)
|
| 40 |
+
optimizer.zero_grad(set_to_none=True)
|
| 41 |
+
losses["loss"].backward()
|
| 42 |
+
clip_grad_norm_(model.parameters(), float(config["training"].get("grad_clip_norm", 100.0)))
|
| 43 |
+
optimizer.step()
|
| 44 |
+
log_row = {
|
| 45 |
+
"loss": float(losses["loss"].detach().cpu()),
|
| 46 |
+
"reward_loss": float(losses["reward_loss"].cpu()),
|
| 47 |
+
"done_loss": float(losses["done_loss"].cpu()),
|
| 48 |
+
"kl_loss": float(losses["kl_loss"].cpu()),
|
| 49 |
+
"recon_loss": float(losses["recon_loss"].cpu()),
|
| 50 |
+
}
|
| 51 |
+
logs.append(log_row)
|
| 52 |
+
progress.set_postfix({key: f"{value:.3f}" for key, value in log_row.items()})
|
| 53 |
+
return logs
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def optimizer_to_device(optimizer: torch.optim.Optimizer, device: torch.device) -> None:
|
| 57 |
+
for state in optimizer.state.values():
|
| 58 |
+
for key, value in state.items():
|
| 59 |
+
if torch.is_tensor(value):
|
| 60 |
+
state[key] = value.to(device)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def load_training_state(
|
| 64 |
+
checkpoint_path: str | Path,
|
| 65 |
+
config: dict[str, Any],
|
| 66 |
+
action_dim: int,
|
| 67 |
+
device: torch.device,
|
| 68 |
+
) -> tuple[dict[str, Any], WorldModel, torch.optim.Optimizer, dict[str, Any]]:
|
| 69 |
+
payload = torch.load(checkpoint_path, map_location=device, weights_only=False)
|
| 70 |
+
resolved_config = merge_dicts(payload["config"], config)
|
| 71 |
+
model = WorldModel.from_config(resolved_config, action_dim=action_dim).to(device)
|
| 72 |
+
model.load_state_dict(payload["model_state"])
|
| 73 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=float(resolved_config["training"]["lr"]))
|
| 74 |
+
optimizer_state = payload.get("optimizer_state")
|
| 75 |
+
if optimizer_state is not None:
|
| 76 |
+
optimizer.load_state_dict(optimizer_state)
|
| 77 |
+
optimizer_to_device(optimizer, device)
|
| 78 |
+
return resolved_config, model, optimizer, payload.get("metadata", {})
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def find_existing_run_artifacts(base_dir: str | Path) -> list[Path]:
|
| 82 |
+
base = Path(base_dir)
|
| 83 |
+
if not base.exists():
|
| 84 |
+
return []
|
| 85 |
+
|
| 86 |
+
artifact_files = [
|
| 87 |
+
base / "metrics" / "run_summary.json",
|
| 88 |
+
base / "metrics" / "train_metrics.jsonl",
|
| 89 |
+
base / "metrics" / "eval_metrics.jsonl",
|
| 90 |
+
base / "checkpoints" / "world_model_latest.pt",
|
| 91 |
+
base / "replay" / "metadata.json",
|
| 92 |
+
]
|
| 93 |
+
found = [path for path in artifact_files if path.exists()]
|
| 94 |
+
if found:
|
| 95 |
+
return found
|
| 96 |
+
|
| 97 |
+
for subdir_name in ("checkpoints", "metrics", "replay"):
|
| 98 |
+
subdir = base / subdir_name
|
| 99 |
+
if subdir.exists():
|
| 100 |
+
for child in subdir.iterdir():
|
| 101 |
+
found.append(child)
|
| 102 |
+
break
|
| 103 |
+
return found
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def collect_planner_steps(
|
| 107 |
+
env,
|
| 108 |
+
replay: ReplayBuffer,
|
| 109 |
+
model: WorldModel,
|
| 110 |
+
planner: DiscreteCEMPlanner,
|
| 111 |
+
num_steps: int,
|
| 112 |
+
random_action_fraction: float,
|
| 113 |
+
rng: np.random.Generator,
|
| 114 |
+
) -> dict[str, int]:
|
| 115 |
+
collected_steps = 0
|
| 116 |
+
episodes = 0
|
| 117 |
+
success_episodes = 0
|
| 118 |
+
model.eval()
|
| 119 |
+
while collected_steps < num_steps:
|
| 120 |
+
obs, _ = env.reset()
|
| 121 |
+
observations = [obs]
|
| 122 |
+
actions: list[int] = []
|
| 123 |
+
rewards: list[float] = []
|
| 124 |
+
terminated_flags: list[float] = []
|
| 125 |
+
truncated_flags: list[float] = []
|
| 126 |
+
done_flags: list[float] = []
|
| 127 |
+
terminated = False
|
| 128 |
+
truncated = False
|
| 129 |
+
|
| 130 |
+
with torch.no_grad():
|
| 131 |
+
state = model.posterior_step(model.initial_state(1), None, obs, sample=False)
|
| 132 |
+
while not (terminated or truncated):
|
| 133 |
+
if rng.random() < random_action_fraction:
|
| 134 |
+
action = int(env.action_space.sample())
|
| 135 |
+
else:
|
| 136 |
+
action = planner.plan(state).action
|
| 137 |
+
obs, reward, terminated, truncated, _ = env.step(action)
|
| 138 |
+
actions.append(action)
|
| 139 |
+
rewards.append(float(reward))
|
| 140 |
+
terminated_flags.append(float(terminated))
|
| 141 |
+
truncated_flags.append(float(truncated))
|
| 142 |
+
done_flags.append(float(terminated or truncated))
|
| 143 |
+
observations.append(obs)
|
| 144 |
+
collected_steps += 1
|
| 145 |
+
if terminated or truncated:
|
| 146 |
+
break
|
| 147 |
+
state = model.posterior_step(state, action, obs, sample=False)
|
| 148 |
+
|
| 149 |
+
replay.add_episode(
|
| 150 |
+
obs=np.asarray(observations, dtype=np.float32),
|
| 151 |
+
actions=np.asarray(actions, dtype=np.int64),
|
| 152 |
+
rewards=np.asarray(rewards, dtype=np.float32),
|
| 153 |
+
terminated=np.asarray(terminated_flags, dtype=np.float32),
|
| 154 |
+
truncated=np.asarray(truncated_flags, dtype=np.float32),
|
| 155 |
+
done=np.asarray(done_flags, dtype=np.float32),
|
| 156 |
+
)
|
| 157 |
+
episodes += 1
|
| 158 |
+
success_episodes += int(bool(terminated and np.sum(rewards) > 0.0))
|
| 159 |
+
return {
|
| 160 |
+
"env_steps": collected_steps,
|
| 161 |
+
"episodes": episodes,
|
| 162 |
+
"success_episodes": success_episodes,
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def run_training(
|
| 167 |
+
config: dict[str, Any],
|
| 168 |
+
output_dir: str | Path,
|
| 169 |
+
replay_dir: str | Path | None = None,
|
| 170 |
+
resume_checkpoint: str | Path | None = None,
|
| 171 |
+
allow_overwrite_existing_output: bool = False,
|
| 172 |
+
) -> dict[str, Any]:
|
| 173 |
+
seed = config.get("project", {}).get("seed", 0)
|
| 174 |
+
seed_everything(seed)
|
| 175 |
+
existing_artifacts = find_existing_run_artifacts(output_dir)
|
| 176 |
+
if existing_artifacts and resume_checkpoint is None and not allow_overwrite_existing_output:
|
| 177 |
+
preview = ", ".join(str(path) for path in existing_artifacts[:3])
|
| 178 |
+
raise FileExistsError(
|
| 179 |
+
f"Refusing to overwrite existing run directory '{output_dir}'. "
|
| 180 |
+
f"Found existing artifacts: {preview}. "
|
| 181 |
+
"Choose a new --output-dir, resume with --resume-checkpoint, "
|
| 182 |
+
"or pass --allow-overwrite-existing-output to overwrite intentionally."
|
| 183 |
+
)
|
| 184 |
+
run_dirs = ensure_run_dirs(output_dir)
|
| 185 |
+
device = get_device(config.get("training", {}).get("device"))
|
| 186 |
+
|
| 187 |
+
env = make_env_from_config(config, seed=seed)
|
| 188 |
+
action_dim = env.action_space.n
|
| 189 |
+
env.close()
|
| 190 |
+
|
| 191 |
+
if replay_dir is not None and Path(replay_dir).exists():
|
| 192 |
+
replay = ReplayBuffer.load(replay_dir)
|
| 193 |
+
collection_summary = {"replay_loaded": replay.summary()}
|
| 194 |
+
else:
|
| 195 |
+
replay, collection_summary = collect_bootstrap_dataset(config, output_dir=run_dirs["replay"], seed=seed)
|
| 196 |
+
|
| 197 |
+
resume_metadata: dict[str, Any] = {}
|
| 198 |
+
if resume_checkpoint is not None:
|
| 199 |
+
config, model, optimizer, resume_metadata = load_training_state(
|
| 200 |
+
checkpoint_path=resume_checkpoint,
|
| 201 |
+
config=config,
|
| 202 |
+
action_dim=action_dim,
|
| 203 |
+
device=device,
|
| 204 |
+
)
|
| 205 |
+
else:
|
| 206 |
+
model = WorldModel.from_config(config, action_dim=action_dim).to(device)
|
| 207 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=float(config["training"]["lr"]))
|
| 208 |
+
|
| 209 |
+
save_config(config, run_dirs["base"] / "resolved_config.yaml")
|
| 210 |
+
training_logs: list[dict[str, float]] = []
|
| 211 |
+
evaluation_logs: list[dict[str, float]] = []
|
| 212 |
+
|
| 213 |
+
train_collect_ratio = float(config["collection"].get("train_collect_ratio", 1.0))
|
| 214 |
+
total_updates_budget = int(config["training"]["train_steps"])
|
| 215 |
+
if resume_checkpoint is not None:
|
| 216 |
+
updates_done = int(resume_metadata.get("updates_done", 0))
|
| 217 |
+
checkpoint_env_steps = int(resume_metadata.get("env_steps", 0))
|
| 218 |
+
if replay.env_steps > checkpoint_env_steps and updates_done < total_updates_budget:
|
| 219 |
+
collect_steps_per_iteration = max(1, int(config["collection"].get("collect_steps_per_iteration", 1)))
|
| 220 |
+
per_iteration_updates = int(
|
| 221 |
+
config["collection"].get(
|
| 222 |
+
"gradient_updates_per_iteration",
|
| 223 |
+
round(collect_steps_per_iteration * train_collect_ratio),
|
| 224 |
+
)
|
| 225 |
+
)
|
| 226 |
+
missed_iterations = max(0, round((replay.env_steps - checkpoint_env_steps) / collect_steps_per_iteration))
|
| 227 |
+
catch_up_updates = min(total_updates_budget - updates_done, per_iteration_updates * missed_iterations)
|
| 228 |
+
catch_up_logs = train_world_model_updates(model, replay, optimizer, config, catch_up_updates, device)
|
| 229 |
+
training_logs.extend(catch_up_logs)
|
| 230 |
+
updates_done += len(catch_up_logs)
|
| 231 |
+
else:
|
| 232 |
+
initial_updates = min(total_updates_budget, max(1, int(round(replay.env_steps * train_collect_ratio))))
|
| 233 |
+
training_logs.extend(train_world_model_updates(model, replay, optimizer, config, initial_updates, device))
|
| 234 |
+
updates_done = len(training_logs)
|
| 235 |
+
|
| 236 |
+
comparison_budgets = config.get("comparison", {}).get("env_steps", [replay.env_steps])
|
| 237 |
+
target_env_steps = int(max(comparison_budgets))
|
| 238 |
+
rng = np.random.default_rng(seed)
|
| 239 |
+
env = make_env_from_config(config, seed=seed)
|
| 240 |
+
planner = DiscreteCEMPlanner.from_config(model, env.action_space.n, config)
|
| 241 |
+
eval_every_steps = int(config["evaluation"].get("eval_every_env_steps", target_env_steps))
|
| 242 |
+
next_eval_step = replay.env_steps
|
| 243 |
+
|
| 244 |
+
while replay.env_steps < target_env_steps and updates_done < total_updates_budget:
|
| 245 |
+
collect_steps = min(
|
| 246 |
+
int(config["collection"]["collect_steps_per_iteration"]),
|
| 247 |
+
target_env_steps - replay.env_steps,
|
| 248 |
+
)
|
| 249 |
+
collection_row = collect_planner_steps(
|
| 250 |
+
env,
|
| 251 |
+
replay,
|
| 252 |
+
model,
|
| 253 |
+
planner,
|
| 254 |
+
num_steps=collect_steps,
|
| 255 |
+
random_action_fraction=float(config["collection"].get("random_action_fraction_after_planner", 0.0)),
|
| 256 |
+
rng=rng,
|
| 257 |
+
)
|
| 258 |
+
updates = int(config["collection"].get("gradient_updates_per_iteration", round(collection_row["env_steps"] * train_collect_ratio)))
|
| 259 |
+
updates = min(updates, total_updates_budget - updates_done)
|
| 260 |
+
training_logs.extend(train_world_model_updates(model, replay, optimizer, config, updates, device))
|
| 261 |
+
updates_done = len(training_logs)
|
| 262 |
+
replay.save(run_dirs["replay"])
|
| 263 |
+
|
| 264 |
+
if replay.env_steps >= next_eval_step:
|
| 265 |
+
world_model_metrics = evaluate_world_model(config, model, replay, split="val", max_episodes=10)
|
| 266 |
+
planner_metrics = evaluate_planner(config, model, episodes=min(10, config["evaluation"]["episodes"]), seed=seed)
|
| 267 |
+
random_metrics = evaluate_random_policy(config, episodes=min(10, config["evaluation"]["episodes"]), seed=seed)
|
| 268 |
+
eval_row = {
|
| 269 |
+
"env_steps": replay.env_steps,
|
| 270 |
+
"updates_done": updates_done,
|
| 271 |
+
**{f"world_model/{key}": value for key, value in world_model_metrics.items()},
|
| 272 |
+
**{f"planner/{key}": value for key, value in planner_metrics.items()},
|
| 273 |
+
**{f"random/{key}": value for key, value in random_metrics.items()},
|
| 274 |
+
}
|
| 275 |
+
evaluation_logs.append(eval_row)
|
| 276 |
+
next_eval_step += eval_every_steps
|
| 277 |
+
save_world_model_checkpoint(
|
| 278 |
+
run_dirs["checkpoints"] / f"world_model_env_steps_{replay.env_steps}.pt",
|
| 279 |
+
model,
|
| 280 |
+
config,
|
| 281 |
+
optimizer=optimizer,
|
| 282 |
+
metadata={"env_steps": replay.env_steps, "updates_done": updates_done},
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
env.close()
|
| 286 |
+
save_world_model_checkpoint(
|
| 287 |
+
run_dirs["checkpoints"] / "world_model_latest.pt",
|
| 288 |
+
model,
|
| 289 |
+
config,
|
| 290 |
+
optimizer=optimizer,
|
| 291 |
+
metadata={"env_steps": replay.env_steps, "updates_done": updates_done},
|
| 292 |
+
)
|
| 293 |
+
write_json(run_dirs["metrics"] / "collection_summary.json", collection_summary)
|
| 294 |
+
write_jsonl(run_dirs["metrics"] / "train_metrics.jsonl", training_logs)
|
| 295 |
+
write_jsonl(run_dirs["metrics"] / "eval_metrics.jsonl", evaluation_logs)
|
| 296 |
+
summary = {
|
| 297 |
+
"replay": replay.summary(),
|
| 298 |
+
"updates_done": updates_done,
|
| 299 |
+
"device": str(device),
|
| 300 |
+
}
|
| 301 |
+
write_json(run_dirs["metrics"] / "run_summary.json", summary)
|
| 302 |
+
return summary
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
def build_arg_parser() -> argparse.ArgumentParser:
|
| 306 |
+
parser = argparse.ArgumentParser(description="Train the MiniDreamer world model.")
|
| 307 |
+
parser.add_argument("--config", type=Path, required=True)
|
| 308 |
+
parser.add_argument("--output-dir", type=Path, required=True)
|
| 309 |
+
parser.add_argument("--replay-dir", type=Path, default=None, help="Optional existing replay directory.")
|
| 310 |
+
parser.add_argument("--resume-checkpoint", type=Path, default=None, help="Optional checkpoint to resume from.")
|
| 311 |
+
parser.add_argument(
|
| 312 |
+
"--allow-overwrite-existing-output",
|
| 313 |
+
action="store_true",
|
| 314 |
+
help="Allow overwriting an existing run directory when not resuming.",
|
| 315 |
+
)
|
| 316 |
+
return parser
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def main() -> None:
|
| 320 |
+
parser = build_arg_parser()
|
| 321 |
+
args = parser.parse_args()
|
| 322 |
+
config = load_config(args.config)
|
| 323 |
+
summary = run_training(
|
| 324 |
+
config,
|
| 325 |
+
args.output_dir,
|
| 326 |
+
replay_dir=args.replay_dir,
|
| 327 |
+
resume_checkpoint=args.resume_checkpoint,
|
| 328 |
+
allow_overwrite_existing_output=args.allow_overwrite_existing_output,
|
| 329 |
+
)
|
| 330 |
+
print(summary)
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
if __name__ == "__main__":
|
| 334 |
+
main()
|
tests/test_cem_planner.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from minidreamer.planning.cem import DiscreteCEMPlanner
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class DummyWorldModel:
|
| 7 |
+
def __init__(self):
|
| 8 |
+
self.device = torch.device("cpu")
|
| 9 |
+
|
| 10 |
+
def score_action_sequences(self, state, action_sequences, discount=0.99, use_done_mask=True):
|
| 11 |
+
target = torch.tensor([1, 2, 0, 1], device=action_sequences.device)
|
| 12 |
+
scores = -(action_sequences != target).float().sum(dim=-1)
|
| 13 |
+
return {"scores": scores}
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def test_discrete_cem_planner_finds_high_scoring_sequence():
|
| 17 |
+
torch.manual_seed(0)
|
| 18 |
+
planner = DiscreteCEMPlanner(
|
| 19 |
+
world_model=DummyWorldModel(),
|
| 20 |
+
action_dim=3,
|
| 21 |
+
horizon=4,
|
| 22 |
+
candidates=512,
|
| 23 |
+
elites=64,
|
| 24 |
+
iterations=5,
|
| 25 |
+
discount=1.0,
|
| 26 |
+
use_done_mask=False,
|
| 27 |
+
)
|
| 28 |
+
output = planner.plan(state=object())
|
| 29 |
+
assert output.action == 1
|
| 30 |
+
assert output.sequence == [1, 2, 0, 1]
|
| 31 |
+
|
tests/test_env.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
from minidreamer.envs.make_env import make_env
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def test_make_env_returns_normalized_rgb_observation():
|
| 7 |
+
env = make_env(seed=0)
|
| 8 |
+
obs, _ = env.reset()
|
| 9 |
+
assert obs.shape == (64, 64, 3)
|
| 10 |
+
assert obs.dtype == np.float32
|
| 11 |
+
assert 0.0 <= float(obs.min()) <= float(obs.max()) <= 1.0
|
| 12 |
+
|
| 13 |
+
next_obs, reward, terminated, truncated, _ = env.step(env.action_space.sample())
|
| 14 |
+
assert next_obs.shape == (64, 64, 3)
|
| 15 |
+
assert isinstance(float(reward), float)
|
| 16 |
+
assert isinstance(bool(terminated), bool)
|
| 17 |
+
assert isinstance(bool(truncated), bool)
|
| 18 |
+
env.close()
|
| 19 |
+
|
tests/test_replay_buffer.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from minidreamer.data.replay_buffer import ReplayBuffer
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def make_episode(length: int, reward: float = 0.0):
|
| 10 |
+
obs = np.random.rand(length + 1, 64, 64, 3).astype(np.float32)
|
| 11 |
+
actions = np.arange(length, dtype=np.int64) % 7
|
| 12 |
+
rewards = np.full(length, reward, dtype=np.float32)
|
| 13 |
+
terminated = np.zeros(length, dtype=np.float32)
|
| 14 |
+
truncated = np.zeros(length, dtype=np.float32)
|
| 15 |
+
done = np.zeros(length, dtype=np.float32)
|
| 16 |
+
terminated[-1] = 1.0
|
| 17 |
+
done[-1] = 1.0
|
| 18 |
+
return obs, actions, rewards, terminated, truncated, done
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def test_replay_buffer_sampling_and_padding(tmp_path: Path):
|
| 22 |
+
buffer = ReplayBuffer(capacity_episodes=10, sequence_length=8, batch_size=4)
|
| 23 |
+
for episode_id, length in enumerate((3, 5, 9)):
|
| 24 |
+
obs, actions, rewards, terminated, truncated, done = make_episode(length, reward=float(episode_id))
|
| 25 |
+
buffer.add_episode(obs, actions, rewards, terminated, truncated, done, episode_id=episode_id)
|
| 26 |
+
|
| 27 |
+
available_split = next(split for split in ("train", "val", "test") if buffer.episode_ids(split))
|
| 28 |
+
batch = buffer.sample_sequences(split=available_split, batch_size=2, rng=np.random.default_rng(0))
|
| 29 |
+
assert batch["obs"].shape == (2, 9, 64, 64, 3)
|
| 30 |
+
assert batch["actions"].shape == (2, 8)
|
| 31 |
+
assert batch["mask"].shape == (2, 8)
|
| 32 |
+
assert np.all(batch["mask"].sum(axis=1) >= 1)
|
| 33 |
+
|
| 34 |
+
save_dir = tmp_path / "replay"
|
| 35 |
+
buffer.save(save_dir)
|
| 36 |
+
loaded = ReplayBuffer.load(save_dir)
|
| 37 |
+
assert loaded.summary()["episodes"] == buffer.summary()["episodes"]
|
| 38 |
+
assert loaded.summary()["env_steps"] == buffer.summary()["env_steps"]
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def test_replay_buffer_torch_batch_shapes():
|
| 42 |
+
buffer = ReplayBuffer(capacity_episodes=4, sequence_length=4, batch_size=2)
|
| 43 |
+
obs, actions, rewards, terminated, truncated, done = make_episode(5, reward=1.0)
|
| 44 |
+
buffer.add_episode(obs, actions, rewards, terminated, truncated, done)
|
| 45 |
+
available_split = next(split for split in ("train", "val", "test") if buffer.episode_ids(split))
|
| 46 |
+
batch = buffer.sample_sequences(split=available_split, batch_size=2, rng=np.random.default_rng(1))
|
| 47 |
+
tensor_batch = ReplayBuffer.batch_to_torch(batch)
|
| 48 |
+
assert tensor_batch["obs"].shape == (2, 5, 3, 64, 64)
|
| 49 |
+
assert tensor_batch["actions"].dtype == torch.int64
|
| 50 |
+
|
tests/test_rssm_shapes.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from minidreamer.models.world_model import WorldModel
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def test_world_model_sequence_shapes_and_loss():
|
| 7 |
+
torch.manual_seed(0)
|
| 8 |
+
model = WorldModel(
|
| 9 |
+
action_dim=7,
|
| 10 |
+
embedding_dim=128,
|
| 11 |
+
deter_dim=128,
|
| 12 |
+
stoch_dim=16,
|
| 13 |
+
hidden_dim=128,
|
| 14 |
+
use_decoder=True,
|
| 15 |
+
)
|
| 16 |
+
obs = torch.rand(4, 33, 3, 64, 64)
|
| 17 |
+
actions = torch.randint(0, 7, (4, 32))
|
| 18 |
+
outputs = model.observe_sequence(obs, actions, sample=False)
|
| 19 |
+
assert outputs.reward_pred.shape == (4, 32)
|
| 20 |
+
assert outputs.done_logits.shape == (4, 32)
|
| 21 |
+
assert outputs.prior_mean.shape == (4, 32, 16)
|
| 22 |
+
assert outputs.reconstructions is not None
|
| 23 |
+
assert outputs.reconstructions.shape == (4, 32, 3, 64, 64)
|
| 24 |
+
|
| 25 |
+
batch = {
|
| 26 |
+
"obs": obs,
|
| 27 |
+
"actions": actions,
|
| 28 |
+
"rewards": torch.zeros(4, 32),
|
| 29 |
+
"done": torch.zeros(4, 32),
|
| 30 |
+
"mask": torch.ones(4, 32),
|
| 31 |
+
}
|
| 32 |
+
config = {
|
| 33 |
+
"training": {
|
| 34 |
+
"beta_reward": 1.0,
|
| 35 |
+
"beta_done": 1.0,
|
| 36 |
+
"beta_kl": 1.0,
|
| 37 |
+
"beta_recon": 1.0,
|
| 38 |
+
"free_nats": 1.0,
|
| 39 |
+
}
|
| 40 |
+
}
|
| 41 |
+
losses = model.compute_losses(batch, config)
|
| 42 |
+
assert torch.isfinite(losses["loss"])
|
| 43 |
+
|