MathisW78 commited on
Commit
f748552
·
verified ·
1 Parent(s): c540401

Demo notebook payload (source + checkpoint + assets)

Browse files
Files changed (48) hide show
  1. .gitattributes +3 -0
  2. README.md +739 -0
  3. ablation_assets/diagnosis_decision_tree.png +0 -0
  4. ablation_assets/grad_alignment.png +3 -0
  5. ablation_assets/gradient_conflict_map.png +0 -0
  6. ablation_assets/group_comparison.png +0 -0
  7. ablation_assets/group_summary.csv +5 -0
  8. ablation_assets/hypothesis_verdicts.csv +22 -0
  9. ablation_assets/main_results.csv +22 -0
  10. ablation_assets/per_env_delta.png +3 -0
  11. ablation_assets/per_env_win_rates.csv +22 -0
  12. ablation_assets/repr_drift.png +0 -0
  13. ablation_assets/results.json +0 -0
  14. ablation_assets/score_comparison.png +3 -0
  15. ablation_assets/score_delta.png +0 -0
  16. checkpoint_inference.pth +3 -0
  17. configs/defaults.yaml +242 -0
  18. configs/final_qmul_gpu.yaml +176 -0
  19. configs/final_ucl_gpu.yaml +158 -0
  20. configs/smoke.yaml +16 -0
  21. configs/ucl_gpu_bigger_model.yaml +103 -0
  22. configs/ucl_gpu_learning_behaviour.yaml +103 -0
  23. environments/.gitkeep +0 -0
  24. main.py +255 -0
  25. pyproject.toml +22 -0
  26. src/__init__.py +0 -0
  27. src/buffer.py +268 -0
  28. src/config.py +164 -0
  29. src/curriculum.py +143 -0
  30. src/diffusion/__init__.py +0 -0
  31. src/diffusion/forward.py +50 -0
  32. src/diffusion/loss.py +162 -0
  33. src/diffusion/sampling.py +398 -0
  34. src/diffusion/schedules.py +88 -0
  35. src/envs/__init__.py +0 -0
  36. src/envs/discovery.py +166 -0
  37. src/envs/minihack_env.py +454 -0
  38. src/models/__init__.py +0 -0
  39. src/models/denoiser.py +415 -0
  40. src/planners/__init__.py +0 -0
  41. src/planners/baselines.py +1247 -0
  42. src/planners/collect.py +588 -0
  43. src/planners/collect_oracle.py +185 -0
  44. src/planners/inference.py +360 -0
  45. src/planners/logging.py +291 -0
  46. src/planners/offline.py +727 -0
  47. src/planners/online.py +721 -0
  48. src/planners/smoke.py +63 -0
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ ablation_assets/grad_alignment.png filter=lfs diff=lfs merge=lfs -text
37
+ ablation_assets/per_env_delta.png filter=lfs diff=lfs merge=lfs -text
38
+ ablation_assets/score_comparison.png filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,739 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ReMDM Planner for MiniHack
2
+
3
+ PyTorch implementation of **ReMDM** (Remasking Discrete Diffusion Model) for action-sequence planning in [MiniHack](https://github.com/facebookresearch/minihack) navigation environments. A dual-stream transformer generates 64-step action plans by iteratively denoising masked token sequences, conditioned on a 9x9 local crop and the full 21x79 dungeon map.
4
+
5
+ > The primary training method is **DAgger** with BFS oracle supervision: the model is trained from scratch, with the buffer seeded by pure expert trajectories on the first iteration. A standalone **offline BC** mode is also available as an independent baseline trained on pre-collected datasets. The paper compares both methods head-to-head; neither depends on the other. An offline BC checkpoint can optionally warm-start DAgger, but this is not used in the paper. Generalises **zero-shot** from 4 in-distribution environments to 3 out-of-distribution environments.
6
+
7
+ ---
8
+
9
+ ## Pipeline
10
+
11
+ ```
12
+ [Primary] DAgger online training main.py --mode dagger
13
+ | (seed buffer with oracle demos on iter 0,
14
+ | collect with model, label with oracle,
15
+ | efficiency filter, curriculum sampling)
16
+ v checkpoint
17
+ [Evaluate] ID + OOD evaluation main.py --mode inference --checkpoint iter8000.pth
18
+ ```
19
+
20
+ ```
21
+ **Other modes:**
22
+
23
+ [Collect] Collect oracle demonstrations main.py --mode collect
24
+ [Offline BC] Train on pre-collected data main.py --mode offline --data dataset.pt
25
+ [Smoke test] Quick end-to-end check main.py --mode smoke
26
+
27
+ DAgger trains from scratch and is the recommended pipeline. Offline BC (`--mode collect` + `--mode offline`) is an independent training method compared against DAgger in the paper. An offline BC checkpoint can optionally warm-start DAgger via `--checkpoint`, but this was not used in the paper results.
28
+ ```
29
+
30
+ ---
31
+
32
+ ## Environments
33
+
34
+ **In-distribution (training):**
35
+
36
+ | Environment | Description |
37
+ |---|---|
38
+ | `MiniHack-Room-Random-5x5-v0` | Small random room |
39
+ | `MiniHack-Room-Random-15x15-v0` | Large random room |
40
+ | `MiniHack-Corridor-R2-v0` | Two-room corridor |
41
+ | `MiniHack-MazeWalk-9x9-v0` | Small maze |
42
+
43
+ **Out-of-distribution (zero-shot evaluation):**
44
+
45
+ | Environment | Description |
46
+ |---|---|
47
+ | `MiniHack-Room-Dark-15x15-v0` | Dark room (limited visibility) |
48
+ | `MiniHack-Corridor-R5-v0` | Five-room corridor |
49
+ | `MiniHack-MazeWalk-45x19-v0` | Large maze |
50
+
51
+ ---
52
+
53
+ ## Installation
54
+
55
+ ### Prerequisites
56
+
57
+ **Python 3.12+** is required.
58
+
59
+ **macOS (arm64):** Install cmake via Homebrew (needed to compile `nle` from source):
60
+
61
+ ```bash
62
+ brew install cmake
63
+ ```
64
+
65
+ **Linux (x86_64):** Pre-built wheels are available, but if building from source:
66
+
67
+ ```bash
68
+ sudo apt-get install build-essential cmake bison flex libbz2-dev
69
+ ```
70
+
71
+ ### Setup
72
+
73
+ ```bash
74
+ uv sync
75
+ ```
76
+
77
+ This installs all dependencies from the lockfile, including `nle>=1.2.0` (from the maintained [NetHack-LE](https://github.com/NetHack-LE/nle) fork), `minihack`, `torch>=2.11.0`, `wandb`, `polars`, `orjson`, and `scipy`.
78
+
79
+ ### GPU support (optional)
80
+
81
+ By default PyTorch runs on CPU. For NVIDIA CUDA 12:
82
+
83
+ ```bash
84
+ uv pip install torch --index-url https://download.pytorch.org/whl/cu121
85
+ ```
86
+
87
+ Verify GPU is detected:
88
+
89
+ ```bash
90
+ uv run python -c "import torch; print(torch.cuda.is_available())"
91
+ ```
92
+
93
+ ---
94
+
95
+ ## Usage
96
+
97
+ All modes share a single entry point. Defaults load from `configs/defaults.yaml`; any value can be overridden via `key=value` pairs.
98
+
99
+ ```bash
100
+ python main.py --mode <MODE> [--config PATH] [key=value ...]
101
+ ```
102
+
103
+ ### Smoke test
104
+
105
+ Collects a few oracle trajectories, trains under a tiny 5k env-step budget, and prints ID evaluation results.
106
+
107
+ ```bash
108
+ python main.py --mode smoke
109
+ ```
110
+
111
+ ### Collect oracle demonstrations
112
+
113
+ Run the BFS oracle across all 4 ID environments and save the trajectories as a `.pt` dataset for offline BC training. Uses multiprocessing for parallelism.
114
+
115
+ ```bash
116
+ # Default: 5000 episodes per env, output to data/dataset.pt
117
+ python main.py --mode collect
118
+
119
+ # Custom episode count and output
120
+ python main.py --mode collect collect_episodes_per_env=2000 \
121
+ collect_output=data/small_dataset.pt
122
+
123
+ # Fewer workers (default: 8)
124
+ python main.py --mode collect collect_num_workers=4
125
+
126
+ # Reproducible with fixed seed
127
+ python main.py --mode collect seed=42
128
+ ```
129
+
130
+ The output `.pt` file is directly consumable by `--mode offline`:
131
+
132
+ ```bash
133
+ python main.py --mode collect
134
+ python main.py --mode offline --data data/dataset.pt
135
+ ```
136
+
137
+ ### Offline BC (optional)
138
+
139
+ Train the diffusion model on pre-collected oracle demonstrations. The run length
140
+ is controlled by `total_timesteps` — each env-step of the unified budget
141
+ corresponds to one dataset sample, so total gradient steps =
142
+ `total_timesteps // offline_batch_size`.
143
+
144
+ Periodic ID + OOD evaluation runs during training on the cadence defined by
145
+ `id_eval_every_timesteps` / `ood_eval_every_timesteps` (env-step units,
146
+ converted internally to grad-step deltas via `// offline_batch_size`),
147
+ mirroring the DAgger eval pattern. Results are logged to `eval_id/` and
148
+ `eval_ood/` W&B namespaces.
149
+
150
+ ```bash
151
+ python main.py --mode offline --data path/to/dataset.pt
152
+
153
+ # Shorter / longer run (the same knob the DAgger and SB3 baselines use):
154
+ python main.py --mode offline --data dataset.pt total_timesteps=500000
155
+
156
+ # Resume from a step-level checkpoint (restores optimizer, scheduler,
157
+ # step counter, and W&B run)
158
+ python main.py --mode offline --data path/to/dataset.pt \
159
+ --checkpoint checkpoints/offline_step2000.pth
160
+ ```
161
+
162
+ Step-level checkpoints are written every `checkpoint_every_timesteps` env-step
163
+ equivalents (converted internally to `/ offline_batch_size` grad steps).
164
+ Set to `0` to disable:
165
+
166
+ ```bash
167
+ python main.py --mode offline --data dataset.pt checkpoint_every_timesteps=0
168
+ ```
169
+
170
+ #### Compute-match overrides (paper-fair BC vs DAgger)
171
+
172
+ For research comparisons against a specific DAgger checkpoint, four optional
173
+ offline-only overrides bypass the env-step budget derivation. The
174
+ sample-to-grad-step ratio between the two modes (~50×) makes a single shared
175
+ `total_timesteps` budget unfair to one side; these knobs pin offline metrics
176
+ in grad-step units instead. All default to `null` (backwards compatible).
177
+
178
+ | Key | Purpose |
179
+ |---|---|
180
+ | `offline_total_grad_steps` | Pin gradient budget. Overrides `total_timesteps // offline_batch_size`. Use to match a DAgger iteration count (e.g. `60000` = 600 iters × 100 grad_steps_per_iter). |
181
+ | `offline_eval_every_grad_steps` | ID/OOD eval cadence in grad-step units. Without this, env-step cadence applied to BC's dense per-sample budget yields hundreds of evals. |
182
+ | `offline_checkpoint_every_grad_steps` | Checkpoint cadence in grad-step units. Same motivation. |
183
+ | `offline_buffer_capacity` | Distinct from `buffer_capacity` (sized for DAgger's small FIFO). The full BC dataset has ~500k–1M sliding windows; using DAgger's cap silently truncates. |
184
+
185
+ Example: train a fair offline BC baseline matched to DAgger@iter600
186
+ (60k AdamW updates × 2048 batch):
187
+
188
+ ```bash
189
+ python main.py --mode offline --data data/oracle_bc_qmul.pt \
190
+ --config configs/final_qmul_gpu.yaml
191
+ ```
192
+
193
+ The `final_qmul_gpu.yaml` and `final_ucl_gpu.yaml` configs both ship with
194
+ these overrides pre-set and with cross-cluster-identical training
195
+ hyperparameters (only collection-worker counts and output paths differ).
196
+
197
+ ### DAgger online training
198
+
199
+ Full DAgger loop: seed buffer with oracle data, collect with model, label with BFS oracle, filter by efficiency, train on buffer.
200
+
201
+ ```bash
202
+ # From scratch (seeds buffer with oracle data automatically)
203
+ python main.py --mode dagger
204
+
205
+ # Resume from local checkpoint
206
+ python main.py --mode dagger --checkpoint checkpoints/iter3000.pth
207
+
208
+ # Resume from a W&B artifact
209
+ python main.py --mode dagger \
210
+ --wandb-artifact entity/project/checkpoint-iter3000:latest
211
+
212
+ # Skip warm-start from checkpoint (reinitialise model, keep config)
213
+ python main.py --mode dagger --checkpoint checkpoints/iter3000.pth --no-warm-start
214
+
215
+ # Override hyperparameters (total_timesteps is the unified run-length knob)
216
+ python main.py --mode dagger total_timesteps=1000000 dagger_lr=0.0001
217
+
218
+ # Use a GPU-optimised config (paper run, QMUL H200)
219
+ python main.py --mode dagger --config configs/final_qmul_gpu.yaml
220
+ ```
221
+
222
+ ### Inference
223
+
224
+ Evaluate a checkpoint on specified environments. Accepts either `--checkpoint` (local path) or `--wandb-artifact` (W&B artifact reference).
225
+
226
+ ```bash
227
+ # All ID + OOD environments
228
+ python main.py --mode inference --checkpoint checkpoints/iter8000.pth
229
+
230
+ # From a W&B artifact
231
+ python main.py --mode inference \
232
+ --wandb-artifact entity/project/checkpoint-iter8000:latest
233
+
234
+ # Specific environments, save JSON
235
+ python main.py --mode inference \
236
+ --checkpoint checkpoints/iter8000.pth \
237
+ --envs MiniHack-Room-Random-5x5-v0 MiniHack-MazeWalk-45x19-v0 \
238
+ --episodes 100 \
239
+ --output results.json
240
+
241
+ # Custom .des scenario files
242
+ python main.py --mode inference \
243
+ --checkpoint checkpoints/iter8000.pth \
244
+ --des environments/custom_level.des
245
+
246
+ # Local-only ablation (zero out global map)
247
+ python main.py --mode inference \
248
+ --checkpoint checkpoints/iter8000.pth --blind-global
249
+
250
+ # Use training weights instead of EMA
251
+ python main.py --mode inference --checkpoint iter8000.pth --no-ema
252
+ ```
253
+
254
+ ### Baselines (SB3 + Decision Transformer)
255
+
256
+ Train and evaluate the head-to-head baselines used in the paper comparison.
257
+ Six algorithms are wired in: standard discrete-action RL via Stable-Baselines3
258
+ (`ppo`, `a2c`, `dqn`, `ppo-rnn`), Behavioural Cloning (`bc`) on oracle
259
+ demonstrations, and a causal Decision Transformer (`dt`) with target-return
260
+ conditioning. All six share the unified `cfg.total_timesteps` budget so the
261
+ numbers are directly comparable to DAgger and offline BC.
262
+
263
+ Hyperparameters live under the `baselines_*` namespace in `configs/defaults.yaml`
264
+ (BC epochs / batch / LR, DT context length / depth / width, oracle episodes per
265
+ env, eval cadence, DQN replay buffer, parallel SubprocVecEnv count, etc.). The
266
+ runner writes per-seed checkpoints, SB3 logs, and an aggregated results JSON
267
+ under `cfg.baselines_output_dir` (default `outputs/baselines/`); W&B runs land
268
+ in a separate project (`cfg.baselines_wandb_project`, default `remdm-baselines`)
269
+ so they don't pollute the main training leaderboards.
270
+
271
+ ```bash
272
+ # PPO on the 4 ID maps for the unified env-step budget, 1 seed
273
+ python main.py --mode baselines --algo ppo
274
+
275
+ # DQN with a custom budget and 3 seeds
276
+ python main.py --mode baselines --algo dqn \
277
+ --seeds 0 1 2 \
278
+ total_timesteps=1000000
279
+
280
+ # Behavioural Cloning baseline (oracle demos -> SB3 ActorCriticPolicy)
281
+ python main.py --mode baselines --algo bc --n-seeds 3
282
+
283
+ # Decision Transformer (causal R/s/a transformer with target-return)
284
+ python main.py --mode baselines --algo dt --seeds 0 1 2
285
+
286
+ # Override the aggregated-results JSON destination
287
+ python main.py --mode baselines --algo ppo --output results/ppo_smoke.json
288
+
289
+ # Paper-fair comparison against the ReMDM online budget (~5.65M env-steps)
290
+ python main.py --mode baselines --algo ppo total_timesteps=5650000
291
+ ```
292
+
293
+ The BC and DT defaults (50 epochs, 5000 oracle trajectories per ID env, 64-token
294
+ DT context, 256-D DT embedding) are tuned to match the data and compute scale of
295
+ the offline BC and ReMDM runs reported in the paper.
296
+
297
+ ### CLI flags
298
+
299
+ | Flag | Description |
300
+ |---|---|
301
+ | `--mode` | Required. One of `smoke`, `collect`, `offline`, `dagger`, `inference`, `baselines` |
302
+ | `--config PATH` | Config file (default: `configs/defaults.yaml`) |
303
+ | `--algo NAME` | Baseline algorithm (`ppo`, `a2c`, `dqn`, `ppo-rnn`, `bc`, `dt`); required with `--mode baselines` |
304
+ | `--seeds N [N ...]` | Explicit seed list for `--mode baselines` |
305
+ | `--n-seeds N` | Number of seeds starting from 0 (alternative to `--seeds`) |
306
+ | `--data PATH` | Dataset `.pt` file (offline mode) |
307
+ | `--checkpoint PATH` | Checkpoint `.pth` file |
308
+ | `--wandb-artifact REF` | W&B artifact reference (e.g. `entity/project/name:latest`) |
309
+ | `--no-warm-start` | Skip model warm-start from checkpoint (DAgger) |
310
+ | `--no-ema` | Use training weights instead of EMA for inference |
311
+ | `--envs ENV [ENV ...]` | Override evaluation environments |
312
+ | `--des PATH [PATH ...]` | Custom `.des` scenario files for evaluation |
313
+ | `--episodes N` | Episodes per environment (default: 50) |
314
+ | `--output PATH` | Save evaluation results / aggregated baselines JSON |
315
+ | `--blind-global` | Zero out global map observations (local-only ablation) |
316
+
317
+ ---
318
+
319
+ ## Architecture
320
+
321
+ **`LocalDiffusionPlannerWithGlobal`** (~5.2M parameters):
322
+
323
+ ```
324
+ Local stream: 9x9 glyphs -> Embedding(6000,64) -> CNN(64->32->64) -> Linear -> 1 token
325
+ Global stream: 21x79 glyphs -> Embedding(6000,32) -> CNN(32->32->64) -> Pool(2,4) -> 8 tokens
326
+ Goal head: mean(global) -> MLP -> [B,2] staircase coords (aux loss)
327
+ Gate: sigmoid(learnable scalar, init=-3.0) * global_tokens
328
+ Action stream: Embedding(14, 256) + timestep_emb(100, 256) + position_emb(64, 256)
329
+ Transformer: concat [1 + 8 + 64 = 73 tokens] -> 4-layer encoder (256D, 4 heads, pre-norm)
330
+ Output head: last 64 tokens -> Linear(256, 12) -> action logits
331
+ ```
332
+
333
+ The model takes `(local_obs, global_obs, noisy_action_seq, t_discrete)` and returns `{"actions": [B,64,12], "goal_pred": [B,2]}`.
334
+
335
+ A `LocalDiffusionPlanner` variant (no global stream, no goal head) is also available for ablation studies.
336
+
337
+ ---
338
+
339
+ ## Diffusion
340
+
341
+ **Forward process (MDLM):** Each action token is independently replaced with `MASK` (token 12) with probability `1 - alpha(t)`, where `alpha(t)` follows a linear or cosine schedule. PAD tokens (13) are never masked.
342
+
343
+ **Loss:** Cross-entropy on masked positions only, averaged globally across the batch. By default uses a flat average (matching the reference implementation). Optional SUBS importance weighting `w(t) = -alpha'(t) / (1 - alpha(t))`, clipped to `[0, 1000]`, can be enabled via `use_importance_weighting: true`. Optional label smoothing via `label_smoothing` (default 0.0).
344
+
345
+ **Reverse sampling (ReMDM):** Over `K` denoising steps (default 10):
346
+ 1. Model predicts logits; apply temperature scaling and top-K filtering.
347
+ 2. Sample predictions; compute per-token confidence.
348
+ 3. **MaskGIT unmask:** commit the `n_unmask` highest-confidence masked positions.
349
+ 4. **ReMDM remask:** stochastically re-mask committed positions to allow refinement.
350
+ 5. Final step: commit all remaining positions.
351
+
352
+ **Greedy sampling:** Used during DAgger data collection for deterministic rollouts. Same MaskGIT progressive unmasking loop but with argmax decoding (no temperature, no top-K, no remasking). Uses fewer denoising steps (`diffusion_steps_collect: 5`) for faster collection.
353
+
354
+ ### Remasking strategies
355
+
356
+ | Strategy | Formula | Description |
357
+ |---|---|---|
358
+ | `rescale` | `p = eta * sigma_max` | Proportional to noise level |
359
+ | `cap` | `p = min(eta, sigma_max)` | Fixed upper bound |
360
+ | `conf` | `p = eta * sigma_max * (1 - confidence)` | Low-confidence tokens remasked more |
361
+
362
+ ---
363
+
364
+ ## Configuration
365
+
366
+ ### Key hyperparameters
367
+
368
+ **Model**
369
+
370
+ | Parameter | Default | Description |
371
+ |---|---|---|
372
+ | `n_embd` | 256 | Transformer hidden dimension |
373
+ | `n_head` | 4 | Attention heads |
374
+ | `n_layer` | 4 | Transformer blocks |
375
+ | `n_global_tokens` | 8 | Global stream context tokens |
376
+ | `seq_len` | 64 | Action plan length |
377
+ | `dropout` | 0.0 | Transformer dropout (0.0 -- forward masking regularises) |
378
+ | `ema_decay` | 0.999 | EMA smoothing for inference weights |
379
+ | `global_gate_init` | -3.0 | Initial value for global gate logit |
380
+
381
+ **Diffusion**
382
+
383
+ | Parameter | Default | Description |
384
+ |---|---|---|
385
+ | `noise_schedule` | `linear` | `linear` or `cosine` |
386
+ | `num_diffusion_steps` | 100 | Discrete timestep resolution |
387
+ | `diffusion_steps_eval` | 10 | Denoising iterations at inference |
388
+ | `diffusion_steps_collect` | 5 | Denoising iterations during DAgger collection |
389
+ | `remask_strategy` | `conf` | `rescale`, `cap`, or `conf` |
390
+ | `eta` | 0.15 | Remasking strength |
391
+ | `temperature` | 0.5 | Sampling temperature |
392
+ | `top_k` | 4 | Top-K filtering |
393
+ | `replan_every` | 16 | Env steps before replanning |
394
+ | `loss_weight_clip` | 1000.0 | SUBS importance weight clip bound |
395
+ | `label_smoothing` | 0.0 | Label smoothing for cross-entropy |
396
+ | `use_importance_weighting` | false | SUBS w(t) in loss (off = flat average) |
397
+ | `physics_aware_sampling` | false | Penalise hazardous actions at inference |
398
+
399
+ **Training budget (unified)**
400
+
401
+ Offline BC, DAgger, and the SB3 baselines all share a single env-step budget
402
+ expressed in `total_timesteps` (matching the SB3 convention). This is the only
403
+ knob that should change to scale a run up or down.
404
+
405
+ | Parameter | Default | Description |
406
+ |---|---|---|
407
+ | `total_timesteps` | 2,000,000 | Env-step budget shared across offline / DAgger / SB3 |
408
+ | `id_eval_every_timesteps` | 25,000 | ID eval cadence (env-steps) |
409
+ | `ood_eval_every_timesteps` | 25,000 | OOD eval cadence (env-steps) |
410
+ | `checkpoint_every_timesteps` | 125,000 | Checkpoint cadence (env-steps) |
411
+
412
+ - **Offline BC:** each dataset sample is one env.step() equivalent, so total
413
+ gradient steps = `total_timesteps // offline_batch_size`. The cosine LR
414
+ schedule's `T_max` derives from the same quantity, so runs of different
415
+ lengths still decay to the 10% floor at their end.
416
+ - **DAgger:** the training loop tracks cumulative `env.step()` calls (model +
417
+ oracle rollouts combined) and halts when the running total reaches
418
+ `total_timesteps`. `episodes_per_iteration` and `grad_steps_per_iteration`
419
+ control the collect/train ratio but **must not** scale with the budget.
420
+ - **Fairness caveat — `ema_decay`:** this is an absolute-update-count constant
421
+ (half-life ~ `1 / (1 − decay)` steps). If `total_timesteps` shifts by more
422
+ than ~2× from the default, the fraction of training covered by the EMA
423
+ window changes. For very short or very long runs, consider setting a
424
+ matching decay manually.
425
+
426
+ **Training**
427
+
428
+ | Parameter | Default | Description |
429
+ |---|---|---|
430
+ | `offline_lr` | 0.0003 | BC learning rate (cosine-decayed to 10% over `total_grad_steps`) |
431
+ | `dagger_lr` | 0.00003 | DAgger learning rate (constant) |
432
+ | `offline_batch_size` | 3584 | Offline BC batch size |
433
+ | `dagger_batch_size` | 3584 | DAgger batch size |
434
+ | `offline_grad_clip` | 1.0 | Gradient norm clip (offline) |
435
+ | `dagger_grad_clip` | 1.0 | Gradient norm clip (DAgger) |
436
+ | `weight_decay` | 0.0001 | AdamW weight decay (both optimizers) |
437
+ | `grad_steps_per_iteration` | 100 | Gradient steps per DAgger iteration |
438
+ | `episodes_per_iteration` | 30 | Episodes collected per DAgger iteration |
439
+ | `aux_loss_weight` | 0.5 | Weight for auxiliary goal loss |
440
+ | `buffer_capacity` | 10000 | Replay buffer size (windows) |
441
+ | `efficiency_multiplier` | 1.5 | DAgger efficiency filter threshold |
442
+ | `curriculum_preseed` | true | Pre-seed curriculum with 50/50 prior |
443
+ | `curriculum_queue_size` | 100 | Curriculum window size per environment |
444
+
445
+ **Data Collection**
446
+
447
+ | Parameter | Default | Description |
448
+ |---|---|---|
449
+ | `collect_episodes_per_env` | 5000 | Oracle episodes per ID environment |
450
+ | `collect_num_workers` | 8 | Parallel process workers for collection |
451
+ | `collect_output` | `data/dataset.pt` | Output path for collected dataset |
452
+
453
+ **Evaluation**
454
+
455
+ | Parameter | Default | Description |
456
+ |---|---|---|
457
+ | `eval_episodes_per_env` | 50 | Episodes per environment at eval time |
458
+ | `checkpoint_eval_episodes` | 50 | Episodes per env at checkpoint eval |
459
+
460
+ (Eval and checkpoint *cadences* are expressed in env-steps under
461
+ **Training budget (unified)** above.)
462
+
463
+ **Performance**
464
+
465
+ | Parameter | Default | Description |
466
+ |---|---|---|
467
+ | `use_amp` | false | Mixed-precision (FP16) training via `torch.amp` |
468
+ | `torch_compile` | false | `torch.compile` the model for fused kernels |
469
+ | `num_collection_workers` | 8 | Parallel workers for DAgger episode collection |
470
+
471
+ **Logging**
472
+
473
+ | Parameter | Default | Description |
474
+ |---|---|---|
475
+ | `use_wandb` | true | Enable W&B logging |
476
+ | `wandb_project` | `remdm-minihack` | W&B project name |
477
+ | `wandb_resume_id` | null | W&B run ID for resumption |
478
+ | `offline_log_every` | 10 | Stdout/W&B log frequency (offline steps) |
479
+ | `seed` | null | RNG seed (null = random) |
480
+
481
+ ### Config presets
482
+
483
+ | File | Purpose |
484
+ |---|---|
485
+ | `configs/defaults.yaml` | Base defaults for all modes |
486
+ | `configs/smoke.yaml` | Fast smoke test (`total_timesteps=5000`, small buffer, W&B off) |
487
+ | `configs/ucl_gpu_bigger_model.yaml` | UCL GPU exploration with a larger model (384D, 6 heads) |
488
+ | `configs/ucl_gpu_learning_behaviour.yaml` | UCL GPU learning-behaviour study (eta=0.18, B=6144) |
489
+ | `configs/final_qmul_gpu.yaml` | **Paper run, QMUL H200.** Drives both `--mode dagger` (reproduces the iter600 checkpoint) and `--mode offline` (compute-matched fair BC baseline: 60k grad steps × B=2048). AMP + torch.compile + 32 collection workers. |
490
+ | `configs/final_ucl_gpu.yaml` | **Paper run, UCL 3090 Ti 24 GB.** Identical training hyperparams to the QMUL config for cross-cluster fairness; only `num_collection_workers` (8 instead of 32) and output paths differ. |
491
+
492
+ ---
493
+
494
+ ## DAgger Training Loop
495
+
496
+ Each DAgger iteration:
497
+
498
+ 1. **Curriculum sampling:** Select an environment weighted by difficulty (low win-rate environments sampled more).
499
+ 2. **Model rollout:** Generate plans with the EMA model using greedy sampling; execute with replanning every 16 steps. Collects `episodes_per_iteration` (default 30) episodes per iteration.
500
+ 3. **Oracle rollout:** Run the BFS oracle on the **same seed** for comparison.
501
+ 4. **Efficiency filter:** Add the oracle trajectory to the buffer if the model failed or took >1.5x the oracle's steps.
502
+ 5. **Budget accounting:** Advance `env_steps_total += model_steps + oracle_steps`. The training loop halts when the running total reaches `total_timesteps`.
503
+ 6. **Training:** Sample from the replay buffer; run `grad_steps_per_iteration` gradient steps, updating EMA weights after each gradient step.
504
+
505
+ Collection uses GPU-batched rollouts when on CUDA with `episodes_per_iteration > 1`, falling back to threaded CPU collection or sequential collection as appropriate.
506
+
507
+ The BFS oracle uses a 5-tier priority: (1) kick adjacent doors, (2) BFS to staircase, (3) BFS to frontier, (4) BFS to farthest tile, (5) random cardinal.
508
+
509
+ ---
510
+
511
+ ## Reward Shaping
512
+
513
+ The environment wrapper applies shaped rewards to guide learning:
514
+
515
+ | Component | Value | Condition |
516
+ |---|---|---|
517
+ | Win bonus | +20.0 | Episode won |
518
+ | BFS progress | +0.5 * (prev_dist - curr_dist) | Closer to staircase |
519
+ | Exploration | +0.05 | New tile visited |
520
+ | Step penalty | -0.01 | Every step |
521
+
522
+ ---
523
+
524
+ ## Project Structure
525
+
526
+ ```
527
+ minihack-ReMDM-planner/
528
+ ├── configs/
529
+ │ ├── defaults.yaml Base hyperparameters
530
+ │ ├── smoke.yaml Smoke test overrides
531
+ │ ├── ucl_gpu_bigger_model.yaml UCL GPU (larger model: 384D, 6 heads)
532
+ │ ├── ucl_gpu_learning_behaviour.yaml UCL GPU learning-behaviour study
533
+ │ ├── final_qmul_gpu.yaml Paper run: DAgger + fair offline BC (QMUL H200)
534
+ │ └── final_ucl_gpu.yaml Paper run: DAgger + fair offline BC (UCL 3090 Ti)
535
+ ├── environments/ Custom .des scenario files
536
+ ├── src/
537
+ │ ├── config.py YAML config loader with CLI overrides
538
+ │ ├── buffer.py ReplayBuffer with offline-protected FIFO
539
+ │ ├── curriculum.py DynamicCurriculum + efficiency_filter
540
+ │ ├── diffusion/
541
+ │ │ ├── schedules.py Linear and cosine noise schedules
542
+ │ │ ├── forward.py Forward masking process q(z_t | x_0)
543
+ │ │ ├── loss.py MDLM ELBO + auxiliary goal loss
544
+ │ │ └── sampling.py ReMDM reverse sampling with remasking
545
+ │ ├── models/
546
+ │ │ └── denoiser.py LocalDiffusionPlannerWithGlobal + ModelEMA
547
+ │ ├── envs/
548
+ │ │ ├── minihack_env.py AdvancedObservationEnv + BFS oracle
549
+ │ │ └── discovery.py Env registry scanner + inference benchmark
550
+ │ └── planners/
551
+ │ ├── collect.py run_model_episode + DataCollector
552
+ │ ├── collect_oracle.py Standalone oracle data collection
553
+ │ ├── offline.py Offline BC trainer
554
+ │ ├── online.py DAgger Trainer + checkpointing
555
+ │ ├── inference.py Evaluator + result formatting
556
+ │ ├── baselines.py SB3 + Decision Transformer baselines
557
+ │ ├── smoke.py Smoke-test runner
558
+ │ └── logging.py Centralised W&B + stdout logging
559
+ ├── experiments/
560
+ │ └── rl_finetuning/ RL fine-tuning ablation suite
561
+ │ ├── run_ablations.py CLI entry point
562
+ │ ├── configs/ Ablation config files
563
+ │ ├── ablations/ Loss, optimizer, registry, training
564
+ │ ├── diagnostics/ Gradient, representation, timestep metrics
565
+ │ └── analysis/ Plots, tables, reports
566
+ ├── scripts/
567
+ │ ├── hf_upload.py HuggingFace Hub upload utility
568
+ │ └── profile_dagger.py DAgger iteration profiler
569
+ ├── main.py CLI entry point (smoke/collect/offline/dagger/inference/baselines)
570
+ ├── pyproject.toml PEP 621 project metadata + dependencies
571
+ ├── uv.lock Deterministic lockfile
572
+ └── README.md
573
+ ```
574
+
575
+ ---
576
+
577
+ ## W&B Metric Namespaces
578
+
579
+ | Namespace | Contents |
580
+ |---|---|
581
+ | `diffusion/` | `loss`, `loss_diff`, `loss_aux` |
582
+ | `train/` | `buffer_size`, `buffer_online_frac`, `model_won`, `added_to_buffer`, `episodes_collected`, `model_steps`, `oracle_steps`, `efficiency_ratio`, `lr`, `grad_norm`, `global_gate`, `env_steps`, `progress` |
583
+ | `speed/` | `iter_time_sec`, `collect_time_sec`, `train_step_time_sec`, `samples_per_sec`, `env_steps_per_sec`, `gpu_memory_mb` |
584
+ | `perf/` | `iter_time_s`, `collect_time_s`, `train_time_s`, `grad_steps_per_sec` (legacy compat) |
585
+ | `model/` | `param_norm`, `param_drift_from_init`, `ema_gate_value` (every 10 iters) |
586
+ | `eval_id/{env}/` | Per-environment win rate, avg steps, avg reward (in-distribution) |
587
+ | `eval_ood/{env}/` | Per-environment win rate, avg steps, avg reward (out-of-distribution) |
588
+ | `eval_id/` | `mean_win_rate` |
589
+ | `eval_ood/` | `mean_win_rate` |
590
+ | `curriculum/{env}/` | `win_rate` per training environment |
591
+ | `ckpt_eval_id/`, `ckpt_eval_ood/` | Per-env metrics at checkpoint time |
592
+ | `ckpt_eval/` | `id_winrate`, `ood_winrate` |
593
+ | `offline/` | `final_loss`, `total_steps`, `total_timesteps` (summary only) |
594
+
595
+ Both DAgger and offline BC emit to `eval_id/` and `eval_ood/` namespaces.
596
+ Offline mode reuses the same `Evaluator` and EMA-weight evaluation path as
597
+ DAgger, so curves are directly comparable across modes.
598
+
599
+ ---
600
+
601
+ ## Checkpoint Format
602
+
603
+ **DAgger checkpoint:**
604
+
605
+ ```python
606
+ {
607
+ "model_state_dict": ...,
608
+ "ema_state_dict": ...,
609
+ "optimizer_state_dict": ...,
610
+ "scheduler_state_dict": ...,
611
+ "curriculum_state": {...},
612
+ "iteration": int,
613
+ "env_steps": int, # cumulative env.step() calls so far
614
+ "wandb_run_id": str | None,
615
+ "rng_states": {"torch", "numpy", "python"},
616
+ }
617
+ ```
618
+
619
+ **Offline BC checkpoint** (step-level, file `offline_step{N}.pth`, saved when
620
+ `checkpoint_every_timesteps > 0`):
621
+
622
+ ```python
623
+ {
624
+ "model_state_dict": ...,
625
+ "ema_state_dict": ...,
626
+ "optimizer_state_dict": ...,
627
+ "scheduler_state_dict": ...,
628
+ "step": int,
629
+ "env_steps": int, # step * offline_batch_size
630
+ "wandb_run_id": str | None,
631
+ }
632
+ ```
633
+
634
+ **Offline final checkpoint** (saved at the end of offline training):
635
+
636
+ ```python
637
+ {
638
+ "model_state_dict": ...,
639
+ "ema_state_dict": ...,
640
+ "wandb_run_id": str | None,
641
+ }
642
+ ```
643
+
644
+ Inference uses EMA weights by default. Pass `--no-ema` to use training weights.
645
+
646
+ ### W&B Artifacts
647
+
648
+ Checkpoints are automatically uploaded as versioned W&B artifacts (type `"model"`) at each checkpoint save. Each artifact contains the `.pth` weights and a `config.yaml` snapshot of all hyperparameters used.
649
+
650
+ To resume from an artifact:
651
+
652
+ ```bash
653
+ # DAgger resume
654
+ python main.py --mode dagger \
655
+ --wandb-artifact entity/project/checkpoint-iter3000:latest
656
+
657
+ # Inference
658
+ python main.py --mode inference \
659
+ --wandb-artifact entity/project/checkpoint-iter8000:v2
660
+ ```
661
+
662
+ The artifact reference format is `entity/project/artifact-name:version` where version is `latest`, `v0`, `v1`, etc.
663
+
664
+ ### W&B Run Resumption
665
+
666
+ All training loops save the W&B run ID in their checkpoints. When resuming from a checkpoint, the run ID is automatically extracted and passed to `wandb.init(resume="must")`, so metrics continue on the same W&B curves with no gaps.
667
+
668
+ ```bash
669
+ # DAgger: automatic -- run ID is read from the checkpoint
670
+ python main.py --mode dagger --checkpoint checkpoints/iter2000.pth
671
+
672
+ # Offline BC: automatic
673
+ python main.py --mode offline --data dataset.pt \
674
+ --checkpoint checkpoints/offline_step2000.pth
675
+
676
+ # Manual override (e.g. checkpoint saved before this feature was added):
677
+ python main.py --mode dagger --checkpoint old_checkpoint.pth \
678
+ wandb_resume_id=abc123xyz
679
+
680
+ # Ablation suite:
681
+ python experiments/rl_finetuning/run_ablations.py \
682
+ --checkpoint path/to/ckpt.pth --all --use_wandb \
683
+ --wandb_resume_id abc123xyz
684
+ ```
685
+
686
+ The run ID is visible in the W&B dashboard URL: `wandb.ai/.../runs/<run-id>`.
687
+
688
+ ---
689
+
690
+ ## Performance Tuning
691
+
692
+ Three config keys control performance optimisations. Defaults are set for GPU training; override for CPU or different hardware.
693
+
694
+ ### Mixed precision (`use_amp: true`)
695
+
696
+ Wraps training forward/backward in `torch.amp.autocast("cuda")` with `GradScaler`. Active in both offline BC and DAgger training.
697
+
698
+ - **Measured speedup:** 2.2x on gradient steps, 1.7x on full smoke test wall-clock
699
+ - **Memory:** peak GPU stays ~16 GB at B=3584 (same as FP32 due to embedding-heavy model)
700
+ - **Correctness:** loss trajectory and win rates statistically equivalent to FP32
701
+ - **When to use:** always on GPU. No effect on CPU (autocast is a no-op)
702
+ - **Default:** `false` in `defaults.yaml`; enabled in GPU-specific configs
703
+
704
+ ### torch.compile (`torch_compile: true`)
705
+
706
+ Applies `torch.compile(model, mode="default")` before training. Falls back gracefully if no C compiler is found (common on managed GPU nodes).
707
+
708
+ - **Measured speedup:** none beyond AMP alone. Not recommended for primary training.
709
+ - **Default:** `false` in `defaults.yaml`; opt in via the `final_*_gpu.yaml` configs.
710
+ - **When to use:** experimental only. May help on future PyTorch versions with better dynamic shape support.
711
+
712
+ ### Parallel collection (`num_collection_workers: N`)
713
+
714
+ DAgger episode collection supports three strategies (auto-selected):
715
+ 1. **GPU-batched** (default on CUDA with `episodes_per_iteration > 1`): all envs in lockstep
716
+ 2. **Threaded CPU** (fallback when `num_collection_workers > 0`): `ThreadPoolExecutor` with CPU model copies
717
+ 3. **Sequential** (reference behaviour): one episode at a time
718
+
719
+ - **Default:** `8` workers in `defaults.yaml`
720
+ - **When to use:** GPU-batched is preferred; workers primarily affect the CPU fallback path
721
+
722
+ ### Profiling
723
+
724
+ Run `python scripts/profile_dagger.py [key=value ...]` to profile DAgger iteration components. Supports all config overrides (e.g., `use_amp=true`).
725
+
726
+ ---
727
+
728
+ ## Implementation Notes
729
+
730
+ - **MDLM loss** returns `0.0` (not NaN) when no masked positions exist in the batch. Uses global averaging by default; SUBS importance weighting is opt-in via `use_importance_weighting: true`.
731
+ - **PAD tokens** are never masked during the forward process and are excluded from the loss.
732
+ - **Sampling paths:** Evaluation uses stochastic ReMDM sampling (temperature, top-K, remasking) with `diffusion_steps_eval` (default 10) steps. DAgger collection uses greedy argmax sampling (deterministic, no remasking) with `diffusion_steps_collect` (default 5) steps for faster rollouts.
733
+ - **`remdm_sample`** guarantees a fully committed output (no MASK tokens) via a final-step commit and an assertion check. A min-keep 10% safety net prevents degenerate all-masked states.
734
+ - **EMA** shadow weights are updated after every gradient step (not per iteration). The `DataCollector` syncs the latest EMA weights before each rollout.
735
+ - **Curriculum** initialises with a 50/50 prior per environment (configurable via `curriculum_preseed`) and uses bucket-based weights over the rolling win-rate: low `[0, 0.15)` → 0.2, medium `[0.15, 0.85)` → 1.0, high `[0.85, 1.0]` → 0.1.
736
+ - **Replay buffer** pins offline data at the front; only online samples are FIFO-evicted. Returns `None` on empty buffer (callers handle gracefully).
737
+ - **Global gate** initialises at `sigmoid(-3.0) ~ 0.047`, starting nearly closed to prevent the global stream from destabilising early training.
738
+ - **Dropout** is set to 0.0 by default. The discrete diffusion forward masking already regularises; dropout on top is redundant.
739
+ - **DAgger warm-start:** On iteration 0, the buffer is seeded with 3 oracle trajectories per ID environment (12 total), giving the curriculum and training loop data to work with immediately.
ablation_assets/diagnosis_decision_tree.png ADDED
ablation_assets/grad_alignment.png ADDED

Git LFS Details

  • SHA256: 4f1186a56694e83030a800c6e098f302228018159c170f4bb4fc3203a9e23e6e
  • Pointer size: 131 Bytes
  • Size of remote file: 320 kB
ablation_assets/gradient_conflict_map.png ADDED
ablation_assets/group_comparison.png ADDED
ablation_assets/group_summary.csv ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ Group,N,Mean,Best,Worst,StdDev
2
+ Baseline,1,0.5625,0.5625,0.5625,0.0
3
+ A,6,0.6021,0.6667,0.5583,0.0358
4
+ B,7,0.4988,0.6542,0.0625,0.1834
5
+ C,7,0.6125,0.6458,0.5833,0.0184
ablation_assets/hypothesis_verdicts.csv ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Method,Group,Score,Delta_Baseline,Verdict,Hypothesis
2
+ advantage_clip,B,0.4958,-0.0667,NEUTRAL,If clipping helps: large advantage magnitudes destabilise training
3
+ attention_only,C,0.6167,0.0542,IMPROVEMENT,"If attention-only works: model needs routing updates, not feature updates"
4
+ baseline_rl,Baseline,0.5625,0.0,NEUTRAL,Diagnoses whether the RL signal alone causes collapse
5
+ bc_wins,B,0.5708,0.0083,NEUTRAL,If BC on wins helps: the return weighting is the specific cause
6
+ entropy_bonus,B,0.5708,0.0083,NEUTRAL,If entropy bonus helps: collapse is mode-collapse; not a gradient problem
7
+ ewc,A,0.6667,0.1042,IMPROVEMENT,If EWC helps: forgetting pretrained representations is the proximate cause
8
+ ffn_only,C,0.6083,0.0458,NEUTRAL,If FFN-only works: stored knowledge (FFN as memory) needs updating; not attention
9
+ frozen_backbone,C,0.6167,0.0542,IMPROVEMENT,If frozen backbone helps: deep gradient flow into backbone causes collapse
10
+ gradient_surgery,B,0.6542,0.0917,IMPROVEMENT,If PCGrad helps: gradients are conflicting and resolvable by projection
11
+ head_only,C,0.5958,0.0333,NEUTRAL,If head-only works: backbone representations are fine; only decision boundary needs updating
12
+ kl_penalty,A,0.5583,-0.0042,NEUTRAL,If this helps: catastrophic forgetting is the primary cause; soft regularisation suffices
13
+ layer_ablation_top1,C,0.6208,0.0583,IMPROVEMENT,Minimal unfrozen depth needed; collapse depth correlates with gradient flow depth
14
+ layer_ablation_top2,C,0.6458,0.0833,IMPROVEMENT,Minimal unfrozen depth needed; collapse depth correlates with gradient flow depth
15
+ layer_ablation_top3,C,0.5833,0.0208,NEUTRAL,Minimal unfrozen depth needed; collapse depth correlates with gradient flow depth
16
+ llrd,A,0.625,0.0625,IMPROVEMENT,If LLRD helps: deep gradient flow into early layers corrupts representations
17
+ lora,A,0.6042,0.0417,NEUTRAL,If LoRA works: too many unconstrained degrees of freedom cause collapse
18
+ low_t,B,0.55,-0.0125,NEUTRAL,If low-t helps: high-t (coarse-structure) gradients are biased
19
+ mixed_replay,A,0.5833,0.0208,NEUTRAL,If mixed replay helps: online data distribution alone is too corrupted
20
+ normalized_adv,B,0.0625,-0.5,COLLAPSE,If std normalisation helps: simple mean normalisation is too loose
21
+ t_curriculum,B,0.5875,0.025,NEUTRAL,If curriculum helps: ordering of learning signals matters
22
+ trust_region_kl,A,0.575,0.0125,NEUTRAL,If hard constraint helps: soft KL is insufficient -- a hard boundary is needed
ablation_assets/main_results.csv ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Method,Group,Score,Delta_Pretrained,Delta_Baseline,Verdict
2
+ ewc,A,0.6667,0.0792,0.1042,IMPROVEMENT
3
+ gradient_surgery,B,0.6542,0.0667,0.0917,IMPROVEMENT
4
+ layer_ablation_top2,C,0.6458,0.0583,0.0833,IMPROVEMENT
5
+ llrd,A,0.625,0.0375,0.0625,IMPROVEMENT
6
+ layer_ablation_top1,C,0.6208,0.0333,0.0583,IMPROVEMENT
7
+ frozen_backbone,C,0.6167,0.0292,0.0542,IMPROVEMENT
8
+ attention_only,C,0.6167,0.0292,0.0542,IMPROVEMENT
9
+ ffn_only,C,0.6083,0.0208,0.0458,NEUTRAL
10
+ lora,A,0.6042,0.0167,0.0417,NEUTRAL
11
+ head_only,C,0.5958,0.0083,0.0333,NEUTRAL
12
+ t_curriculum,B,0.5875,-0.0,0.025,NEUTRAL
13
+ mixed_replay,A,0.5833,-0.0042,0.0208,NEUTRAL
14
+ layer_ablation_top3,C,0.5833,-0.0042,0.0208,NEUTRAL
15
+ trust_region_kl,A,0.575,-0.0125,0.0125,NEUTRAL
16
+ entropy_bonus,B,0.5708,-0.0167,0.0083,NEUTRAL
17
+ bc_wins,B,0.5708,-0.0167,0.0083,NEUTRAL
18
+ baseline_rl,Baseline,0.5625,-0.025,0.0,NEUTRAL
19
+ kl_penalty,A,0.5583,-0.0292,-0.0042,NEUTRAL
20
+ low_t,B,0.55,-0.0375,-0.0125,NEUTRAL
21
+ advantage_clip,B,0.4958,-0.0917,-0.0667,NEUTRAL
22
+ normalized_adv,B,0.0625,-0.525,-0.5,COLLAPSE
ablation_assets/per_env_delta.png ADDED

Git LFS Details

  • SHA256: 8a205ec378a19466551c384d990d43f644008b711ecf43c0c014c97d4517b78d
  • Pointer size: 131 Bytes
  • Size of remote file: 120 kB
ablation_assets/per_env_win_rates.csv ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Method,MiniHack-Room-Random-5x5-v0,MiniHack-Room-Random-15x15-v0,MiniHack-Corridor-R2-v0,MiniHack-MazeWalk-9x9-v0
2
+ advantage_clip,0.9,0.95,0.25,0.15
3
+ attention_only,1.0,0.9,0.5,0.35
4
+ baseline_rl,1.0,0.7,0.3,0.1
5
+ bc_wins,0.9,0.7,0.4,0.1
6
+ entropy_bonus,0.9,0.4,0.45,0.15
7
+ ewc,1.0,0.85,0.6,0.3
8
+ ffn_only,1.0,1.0,0.35,0.3
9
+ frozen_backbone,0.95,0.9,0.5,0.4
10
+ gradient_surgery,1.0,0.9,0.45,0.25
11
+ head_only,1.0,0.8,0.3,0.3
12
+ kl_penalty,0.9,1.0,0.2,0.45
13
+ layer_ablation_top1,0.9,0.75,0.2,0.2
14
+ layer_ablation_top2,0.95,0.9,0.35,0.25
15
+ layer_ablation_top3,0.95,0.75,0.4,0.45
16
+ llrd,0.8,0.9,0.4,0.25
17
+ lora,1.0,0.75,0.2,0.2
18
+ low_t,1.0,0.6,0.45,0.15
19
+ mixed_replay,0.95,0.75,0.45,0.2
20
+ normalized_adv,0.1,0.0,0.1,0.1
21
+ t_curriculum,1.0,0.8,0.3,0.15
22
+ trust_region_kl,0.95,0.75,0.45,0.25
ablation_assets/repr_drift.png ADDED
ablation_assets/results.json ADDED
The diff for this file is too large to render. See raw diff
 
ablation_assets/score_comparison.png ADDED

Git LFS Details

  • SHA256: 93ee968f4869c03806b3243b8a95db5dab7f48c963151b5186693bd37330cd02
  • Pointer size: 131 Bytes
  • Size of remote file: 116 kB
ablation_assets/score_delta.png ADDED
checkpoint_inference.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4619870dd5fcdb2f1575c4a458e128f3da31f9a75a73562d9d316f60f288df20
3
+ size 20991233
configs/defaults.yaml ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ── Environments ──────────────────────────────────────────────────────
2
+ id_envs:
3
+ - MiniHack-Room-Random-5x5-v0
4
+ - MiniHack-Room-Random-15x15-v0
5
+ - MiniHack-Corridor-R2-v0
6
+ - MiniHack-MazeWalk-9x9-v0
7
+ ood_envs:
8
+ - MiniHack-Room-Dark-15x15-v0
9
+ - MiniHack-Corridor-R5-v0
10
+ - MiniHack-MazeWalk-45x19-v0
11
+ crop_size: 9
12
+ map_h: 21
13
+ map_w: 79
14
+ action_dim: 12
15
+ mask_token: 12
16
+ pad_token: 13
17
+
18
+ # ── Model ─────────────────────────────────────────────────────────────
19
+ n_embd: 256
20
+ n_head: 4
21
+ n_layer: 4
22
+ n_global_tokens: 8
23
+ seq_len: 64
24
+ global_gate_init: -3.0
25
+ # Transformer dropout. 0.0 is deliberate — discrete diffusion forward masking
26
+ # already regularises; dropout on top is redundant.
27
+ dropout: 0.0
28
+ ema_decay: 0.999
29
+
30
+ # ── Diffusion (MDLM) ─────────────────────────────────────────────────
31
+ noise_schedule: linear
32
+ num_diffusion_steps: 100
33
+ loss_weight_clip: 1000.0
34
+ label_smoothing: 0.0
35
+ # Use SUBS importance weighting w(t) in loss. Off by default (flat average
36
+ # matching reference). Enable for MDLM ELBO experiments.
37
+ use_importance_weighting: false
38
+ # ReMDM stochastic remask base fraction
39
+ eta: 0.15
40
+ # Remasking strategy: rescale | cap | conf
41
+ remask_strategy: conf
42
+
43
+ # ── Inference ─────────────────────────────────────────────────────────
44
+ # Number of reverse denoising steps at inference.
45
+ # Reference uses 5 (aggressive). Higher = better quality, slower.
46
+ diffusion_steps_eval: 10
47
+ # Denoising steps during DAgger collection. Fewer than eval since
48
+ # collection only needs "good enough" plans for efficiency comparison.
49
+ diffusion_steps_collect: 5
50
+ temperature: 0.5
51
+ top_k: 4
52
+ replan_every: 16
53
+ # Soft-penalise hazardous cardinal actions during stochastic sampling.
54
+ # Not active in the reference evaluation pipeline; off by default.
55
+ physics_aware_sampling: false
56
+
57
+ # ── Training budget (unified) ────────────────────────────────────────
58
+ # Total environment-step budget for training. Matches the SB3
59
+ # `total_timesteps` convention so runs can be compared apples-to-apples
60
+ # across offline BC, DAgger, and SB3 baselines.
61
+ #
62
+ # • DAgger: cumulative env.step() calls across model + oracle rollouts.
63
+ # Training stops once this budget is exhausted. `episodes_per_iteration`
64
+ # and `grad_steps_per_iteration` control the collect/train ratio;
65
+ # they do NOT change the total compute.
66
+ # • Offline BC: each dataset sample corresponds to one env.step() that
67
+ # collected it. Total gradient steps = total_timesteps // batch_size,
68
+ # i.e. the training consumes exactly `total_timesteps` samples.
69
+ #
70
+ # Fairness invariant — parameters that scale AUTOMATICALLY with this
71
+ # budget:
72
+ # * offline LR cosine T_max (= total_timesteps / offline_batch_size)
73
+ # * id_eval_every_timesteps / ood_eval_every_timesteps (env-step cadence)
74
+ # * checkpoint_every_timesteps (env-step cadence; offline converts via
75
+ # / offline_batch_size)
76
+ #
77
+ # Parameters held FIXED across different budgets (tuning knobs, not
78
+ # fairness knobs):
79
+ # * offline_batch_size, dagger_batch_size — per-step SNR
80
+ # * offline_lr, dagger_lr — peak learning rate
81
+ # * weight_decay, *_grad_clip, efficiency_multiplier, aux_loss_weight,
82
+ # loss_weight_clip, label_smoothing — optimisation regularisers
83
+ # * episodes_per_iteration, grad_steps_per_iteration — the collect/train
84
+ # ratio is itself a design choice; scaling these would confound
85
+ # collection coverage with update density
86
+ # * curriculum_queue_size, buffer_capacity — in absolute units by design
87
+ #
88
+ # Fairness caveat — `ema_decay` is an absolute-update-count constant
89
+ # (half-life ≈ 1 / (1 − decay) steps). If total_timesteps shifts by more
90
+ # than ~2x from the default, the fraction of training covered by the EMA
91
+ # window changes. For very short or very long runs, consider manually
92
+ # setting a matching decay (shorter run → lower decay, longer → higher).
93
+ total_timesteps: 2000000
94
+
95
+ # Evaluation + checkpoint cadence, in env-step units. These scale with
96
+ # total_timesteps so every run gets ~N eval points and ~M checkpoints
97
+ # regardless of budget. For offline BC, the cadence is converted to
98
+ # gradient-step intervals via `/ offline_batch_size`.
99
+ id_eval_every_timesteps: 25000
100
+ ood_eval_every_timesteps: 25000
101
+ checkpoint_every_timesteps: 125000
102
+
103
+ # ── Offline BC ────────────────────────────────────────────────────────
104
+ offline_lr: 0.0003
105
+ offline_batch_size: 3584
106
+ offline_grad_clip: 1.0
107
+ aux_loss_weight: 0.5
108
+ # ── Offline BC compute-match overrides (all opt-in, default null) ───
109
+ # These exist solely to support paper-fair comparisons against a
110
+ # specific DAgger iteration count, where the env-step / grad-step
111
+ # ratio between the two modes (~50x) makes a single shared
112
+ # `total_timesteps` budget unfair to one side. When null, offline
113
+ # falls back to the env-step-derived defaults.
114
+ #
115
+ # offline_total_grad_steps: pin gradient budget (e.g. 60000 to match
116
+ # 600 DAgger iters × 100 grad_steps_per_iter).
117
+ # offline_eval_every_grad_steps: ID/OOD eval cadence in grad-step
118
+ # units. Without this, dense env-step cadence yields ~500 evals.
119
+ # offline_checkpoint_every_grad_steps: checkpoint cadence in grad-step
120
+ # units. Same motivation as eval cadence.
121
+ # offline_buffer_capacity: distinct from `buffer_capacity` (which is
122
+ # sized for DAgger's small FIFO buffer). The full BC dataset has
123
+ # ~500k–1M sliding windows; using DAgger's cap silently truncates.
124
+ offline_total_grad_steps: null
125
+ offline_eval_every_grad_steps: null
126
+ offline_checkpoint_every_grad_steps: null
127
+ offline_buffer_capacity: null
128
+
129
+ # ── DAgger ────────────────────────────────────────────────────────────
130
+ dagger_lr: 0.00003
131
+ dagger_batch_size: 3584
132
+ dagger_grad_clip: 1.0
133
+ weight_decay: 0.0001
134
+ buffer_capacity: 10000
135
+ episodes_per_iteration: 30
136
+ grad_steps_per_iteration: 100
137
+ efficiency_multiplier: 1.5
138
+ curriculum_queue_size: 100
139
+ # Pre-seed curriculum queues with 50/50 prior for uniform early sampling.
140
+ curriculum_preseed: true
141
+ eval_episodes_per_env: 50
142
+ checkpoint_eval_episodes: 50
143
+
144
+ # ── Performance ──────────────────────────────────────────────────────
145
+ # Mixed-precision (FP16) training via torch.cuda.amp.
146
+ # Speeds up forward/backward ~1.5-2x on GPU. No effect on CPU.
147
+ use_amp: false
148
+ # torch.compile the model for fused kernels (experimental).
149
+ # May cause slow first iteration due to compilation. No effect on CPU.
150
+ torch_compile: false
151
+ # Number of parallel workers for DAgger episode collection.
152
+ # 0 = sequential (reference behaviour). Recommended: 4-8 on multi-core.
153
+ num_collection_workers: 8
154
+
155
+ # ── Data Collection ─────────────────────────────────────────────────
156
+ # Oracle episodes per ID environment for --mode collect.
157
+ collect_episodes_per_env: 5000
158
+ # Parallel environment workers for collection.
159
+ collect_num_workers: 8
160
+ # Output path for collected dataset.
161
+ collect_output: "data/dataset.pt"
162
+
163
+ # ── Checkpointing & Logging ──────────────────────────────────────────
164
+ checkpoint_dir: checkpoints
165
+ save_policy: true
166
+ hub_run_id: null
167
+ hub_repo_id: null
168
+ use_wandb: true
169
+ wandb_project: remdm-minihack
170
+ wandb_entity: "mathis-weil-university-college-london-ucl-"
171
+ wandb_run_name: null
172
+ wandb_resume_id: null
173
+ offline_log_every: 10
174
+ seed: null
175
+
176
+ # ── SB3 / DT baselines ───────────────────────────────────────────────
177
+ # Baselines compared head-to-head against the diffusion planner.
178
+ # Entry point:
179
+ # python main.py --mode baselines --algo {ppo,dqn,a2c,ppo-rnn,bc,dt}
180
+ #
181
+ # Algorithm families:
182
+ # * SB3 RL (ppo, a2c, dqn, ppo-rnn): consume `cfg.total_timesteps` as
183
+ # the env-step training budget — same convention as DAgger / offline
184
+ # BC. Use a custom MiniHack CNN feature extractor over the dict
185
+ # observation {"local": (1,9,9), "global": (1,21,79)}.
186
+ # * Behavioural Cloning (bc): collects oracle trajectories, trains an
187
+ # SB3 ActorCriticPolicy with a native PyTorch CE loop, evaluates on
188
+ # ID + OOD environments.
189
+ # * Decision Transformer (dt): collects oracle trajectories with
190
+ # return-to-go labels, trains a small causal transformer over
191
+ # interleaved (R, s, a) tokens, evaluates with target-return
192
+ # conditioning on ID + OOD environments.
193
+ #
194
+ # Number of parallel SB3 SubprocVecEnv workers per ID environment.
195
+ # Effective n_envs = baselines_n_envs_per_id * len(id_envs). Default = 2
196
+ # → 8 parallel envs over the 4 ID maps.
197
+ baselines_n_envs_per_id: 2
198
+ # DQN replay buffer capacity (transitions). Used only for --algo dqn.
199
+ baselines_dqn_buffer_size: 100000
200
+ # SB3 EvalCallback cadence in env-steps. Independent from
201
+ # id/ood_eval_every_timesteps because SB3's eval pipeline is per
202
+ # vector-env tick, not shared with the diffusion planner's evaluator.
203
+ baselines_eval_freq_env_steps: 10000
204
+ # Episodes per env at every eval trigger AND at the final BC / DT
205
+ # manual evaluation pass. Falls back to eval_episodes_per_env (50) when
206
+ # null so the comparison stays apples-to-apples with DAgger evals.
207
+ baselines_eval_episodes_per_env: null
208
+
209
+ # ── BC baseline ──────────────────────────────────────────────────────
210
+ # Oracle trajectories collected per ID environment (seeds 0..N-1).
211
+ # 5000 matches the offline BC dataset scale used by ReMDM.
212
+ baselines_bc_oracle_episodes_per_env: 5000
213
+ baselines_bc_epochs: 50
214
+ baselines_bc_batch_size: 256
215
+ baselines_bc_lr: 0.0003
216
+
217
+ # ── Decision Transformer baseline ────────────────────────────────────
218
+ # 5000 trajectories per ID env to match the BC / ReMDM data scale.
219
+ baselines_dt_oracle_episodes_per_env: 5000
220
+ baselines_dt_epochs: 50
221
+ baselines_dt_context_len: 64
222
+ baselines_dt_embed_dim: 256
223
+ baselines_dt_n_layers: 4
224
+ baselines_dt_n_heads: 4
225
+ baselines_dt_lr: 0.0003
226
+ baselines_dt_batch_size: 256
227
+ # Maximum episode length covered by DT positional embeddings. MUST be
228
+ # >= the longest oracle trajectory observed during data collection.
229
+ # Aligned with baselines_dt_eval_max_steps so positional embeddings
230
+ # cover the full eval-cap horizon.
231
+ baselines_dt_max_ep_len: 200
232
+ # DT eval rollout cap (steps before truncating an episode as a loss).
233
+ baselines_dt_eval_max_steps: 200
234
+
235
+ # ── Output / W&B ─────────────────────────────────────────────────────
236
+ # Separate W&B project for baselines (kept distinct from the main
237
+ # remdm-minihack project so baseline runs don't pollute training
238
+ # leaderboards). Set to null to fall back to wandb_project.
239
+ baselines_wandb_project: remdm-baselines
240
+ # Where per-seed checkpoints, SB3 logs, and aggregated results JSON
241
+ # are written. Resolved relative to the project root unless absolute.
242
+ baselines_output_dir: outputs/baselines
configs/final_qmul_gpu.yaml ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # QMUL H200 GPU — final paper run config
3
+ # =============================================================================
4
+ #
5
+ # This single config drives BOTH the final DAgger run that produced
6
+ # `checkpoint_final/online/final.pth` AND the compute-matched offline
7
+ # BC baseline used for the paper comparison.
8
+ #
9
+ # --mode dagger → reproduces the iter600 DAgger checkpoint recipe
10
+ # --mode offline → trains a fair offline BC baseline against it
11
+ #
12
+ # ── Fairness analysis ───────────────────────────────────────────────
13
+ #
14
+ # DAgger compute at iter600 (the checkpointed model):
15
+ # 600 iters × 100 grad_steps_per_iter × 2048 batch_size
16
+ # = 60,000 AdamW updates
17
+ # = 122,880,000 sample-equivalents
18
+ #
19
+ # The fair offline BC baseline matches this exactly:
20
+ # offline_total_grad_steps = 60,000 (override; pinned)
21
+ # offline_batch_size = 2048 (matches DAgger; same SNR)
22
+ # weight_decay, grad_clip, aux_loss_weight, model arch, diffusion
23
+ # params: all matched. Model is identical between modes.
24
+ #
25
+ # LR strategy follows "best-of-each-method" rather than identical
26
+ # optimisers — DAgger's 3e-5 constant is tuned for online refinement,
27
+ # offline's 3e-4 cosine→3e-5 is BC standard from-scratch. Both
28
+ # converge to the same effective late-training LR.
29
+ #
30
+ # Eval/checkpoint cadence is matched in *count* across modes (12 evals,
31
+ # 6 checkpoints per run) via the offline_*_every_grad_steps overrides,
32
+ # because the env-step→grad-step ratio differs by ~50× between modes.
33
+ #
34
+ # ── Hardware ─────────────────────────────────────────────────────────
35
+ #
36
+ # QMUL H200 (constrained VRAM allocation). The DAgger checkpoint was
37
+ # produced on this hardware, so batch_size and AMP settings must
38
+ # stay identical to the original run. AMP + torch.compile + 32-worker
39
+ # collection are the original perf settings.
40
+
41
+ # ── Environments ─────────────────────────────────────────────────────
42
+ id_envs:
43
+ - MiniHack-Room-Random-5x5-v0
44
+ - MiniHack-Room-Random-15x15-v0
45
+ - MiniHack-Corridor-R2-v0
46
+ - MiniHack-MazeWalk-9x9-v0
47
+ ood_envs:
48
+ - MiniHack-Room-Dark-15x15-v0
49
+ - MiniHack-Corridor-R5-v0
50
+ - MiniHack-MazeWalk-45x19-v0
51
+ crop_size: 9
52
+ map_h: 21
53
+ map_w: 79
54
+ action_dim: 12
55
+ mask_token: 12
56
+ pad_token: 13
57
+
58
+ # ── Model (matches checkpoint) ───────────────────────────────────────
59
+ n_embd: 256
60
+ n_head: 4
61
+ n_layer: 4
62
+ n_global_tokens: 8
63
+ seq_len: 64
64
+ global_gate_init: -3.0
65
+ dropout: 0.0
66
+ ema_decay: 0.999
67
+
68
+ # ── Diffusion (MDLM) — matches checkpoint ────────────────────────────
69
+ noise_schedule: linear
70
+ num_diffusion_steps: 100
71
+ loss_weight_clip: 1000.0
72
+ label_smoothing: 0.0
73
+ use_importance_weighting: false
74
+ eta: 0.15
75
+ remask_strategy: conf
76
+
77
+ # ── Inference / sampling — matches checkpoint ────────────────────────
78
+ diffusion_steps_eval: 10
79
+ diffusion_steps_collect: 5
80
+ temperature: 0.5
81
+ top_k: 4
82
+ replan_every: 16
83
+ physics_aware_sampling: false
84
+
85
+ # ── Shared training budget (DAgger only) ─────────────────────────────
86
+ # 5.65M env-steps reproduces the env-step budget consumed at iter600
87
+ # of the original DAgger run. This figure is calibrated against a real
88
+ # DAgger run with the same recipe (`p7wfp67q`, episodes_per_iteration=30,
89
+ # grad_steps_per_iteration=100): summing the per-iter env steps over
90
+ # the first 600 iterations gives 30 × 600 × mean(model_steps + oracle_steps)
91
+ # ≈ 30 × 600 × (198 + 116) ≈ 5.65 M real env.step() calls.
92
+ # (The earlier 3M figure was based on the buggy single-episode env-step
93
+ # accounting in `online.py:155-169` — fixed in the same commit as this
94
+ # config bump.) Used by `--mode dagger` only. Offline mode bypasses
95
+ # this via `offline_total_grad_steps` below — the unified env-step
96
+ # budget is fundamentally unfair when the sample-to-grad-step ratio
97
+ # differs by ~50× between modes.
98
+ total_timesteps: 5650000
99
+
100
+ # Eval/checkpoint cadence in env-step units (DAgger mode).
101
+ # Scaled with the corrected total_timesteps so the run still produces
102
+ # ~12 ID/OOD evals and ~6 checkpoints over its full duration.
103
+ # 470k → ~12 evals; 940k → ~6 checkpoints.
104
+ id_eval_every_timesteps: 470000
105
+ ood_eval_every_timesteps: 470000
106
+ checkpoint_every_timesteps: 940000
107
+
108
+ # Final-eval episode count (used by both ID/OOD eval triggers and
109
+ # checkpoint-time evals; matches the original DAgger run).
110
+ eval_episodes_per_env: 50
111
+ checkpoint_eval_episodes: 50
112
+ weight_decay: 0.0001
113
+ aux_loss_weight: 0.5
114
+
115
+ # ── DAgger (matches checkpoint_final/online/config_iter600.yaml) ─────
116
+ dagger_lr: 0.00003
117
+ dagger_batch_size: 2048
118
+ dagger_grad_clip: 1.0
119
+ buffer_capacity: 10000
120
+ episodes_per_iteration: 30
121
+ grad_steps_per_iteration: 100
122
+ efficiency_multiplier: 1.5
123
+ curriculum_queue_size: 100
124
+ curriculum_preseed: true
125
+
126
+ # ── Offline BC (compute-matched fair baseline) ───────────────────────
127
+ # Per the fairness analysis above:
128
+ # * Same gradient compute as DAgger (60k AdamW updates × 2048 batch)
129
+ # * Same model, diffusion, weight_decay, grad_clip, aux_loss
130
+ # * BC-tuned LR + cosine schedule (best practice from-scratch)
131
+ # * Eval/checkpoint counts matched to DAgger via grad-step overrides
132
+ offline_lr: 0.0003
133
+ offline_batch_size: 2048
134
+ offline_grad_clip: 1.0
135
+
136
+ # Compute pin: 60,000 AdamW updates = exactly DAgger@iter600.
137
+ offline_total_grad_steps: 60000
138
+
139
+ # Eval cadence: 5,000 grad steps → 12 evals (matches DAgger eval count).
140
+ offline_eval_every_grad_steps: 5000
141
+
142
+ # Checkpoint cadence: 10,000 grad steps → 6 checkpoints (matches DAgger).
143
+ offline_checkpoint_every_grad_steps: 10000
144
+
145
+ # Buffer cap for offline mode only — must hold the full pre-collected
146
+ # dataset (~1M sliding windows from 20k oracle trajectories). DAgger's
147
+ # `buffer_capacity: 10000` would silently FIFO-evict 99% of the data.
148
+ offline_buffer_capacity: 1500000
149
+
150
+ # ── Performance (cluster-tuned, matches original DAgger run) ─────────
151
+ use_amp: true
152
+ torch_compile: true
153
+ num_collection_workers: 32
154
+
155
+ # ── Data collection (for offline BC dataset) ─────────────────────────
156
+ # 5000 eps × 4 ID envs = 20k oracle trajectories. Strictly more than
157
+ # the ~7k unique trajectories DAgger had in its filtered buffer at
158
+ # iter600 — offline always gets a richer pre-collected pool, which is
159
+ # the standard fairness asymmetry in BC vs DAgger comparisons.
160
+ collect_episodes_per_env: 5000
161
+ collect_num_workers: 32
162
+ collect_output: data/oracle_bc_qmul.pt
163
+
164
+ # ── Checkpointing & Logging ──────────────────────────────────────────
165
+ checkpoint_dir: checkpoints_qmul
166
+ save_policy: true
167
+ hub_run_id: null
168
+ hub_repo_id: null
169
+ use_wandb: true
170
+ wandb_project: remdm-minihack
171
+ wandb_entity: "mathis-weil-university-college-london-ucl-"
172
+ wandb_run_name: null
173
+ # wandb_resume_id intentionally omitted — fresh runs by default.
174
+ # Override on the CLI (`wandb_resume_id=...`) to continue an existing run.
175
+ offline_log_every: 50
176
+ seed: null
configs/final_ucl_gpu.yaml ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # UCL 3090 Ti GPU — final paper run config
3
+ # =============================================================================
4
+ #
5
+ # This single config drives BOTH the final DAgger run and the
6
+ # compute-matched offline BC baseline used for the paper comparison.
7
+ #
8
+ # --mode dagger → reproduces the iter600 DAgger checkpoint recipe
9
+ # --mode offline → trains a fair offline BC baseline against it
10
+ #
11
+ # All training hyperparameters are IDENTICAL to `final_qmul_gpu.yaml`
12
+ # so cross-cluster runs produce directly comparable results. The only
13
+ # differences are hardware-specific perf knobs (collection workers).
14
+ # See the QMUL config header for the full fairness analysis.
15
+ #
16
+ # ── Hardware ─────────────────────────────────────────────────────────
17
+ #
18
+ # UCL 3090 Ti — 24 GB VRAM. The 4-layer × 256-dim model with
19
+ # batch=2048 and AMP fits with comfortable headroom (~6-8 GB peak).
20
+ # Lower core count than the QMUL cluster, so collection workers
21
+ # capped at 8.
22
+
23
+ # ── Environments ─────────────────────────────────────────────────────
24
+ id_envs:
25
+ - MiniHack-Room-Random-5x5-v0
26
+ - MiniHack-Room-Random-15x15-v0
27
+ - MiniHack-Corridor-R2-v0
28
+ - MiniHack-MazeWalk-9x9-v0
29
+ ood_envs:
30
+ - MiniHack-Room-Dark-15x15-v0
31
+ - MiniHack-Corridor-R5-v0
32
+ - MiniHack-MazeWalk-45x19-v0
33
+ crop_size: 9
34
+ map_h: 21
35
+ map_w: 79
36
+ action_dim: 12
37
+ mask_token: 12
38
+ pad_token: 13
39
+
40
+ # ── Model (matches checkpoint) ───────────────────────────────────────
41
+ n_embd: 256
42
+ n_head: 4
43
+ n_layer: 4
44
+ n_global_tokens: 8
45
+ seq_len: 64
46
+ global_gate_init: -3.0
47
+ dropout: 0.0
48
+ ema_decay: 0.999
49
+
50
+ # ── Diffusion (MDLM) — matches checkpoint ────────────────────────────
51
+ noise_schedule: linear
52
+ num_diffusion_steps: 100
53
+ loss_weight_clip: 1000.0
54
+ label_smoothing: 0.0
55
+ use_importance_weighting: false
56
+ eta: 0.15
57
+ remask_strategy: conf
58
+
59
+ # ── Inference / sampling — matches checkpoint ────────────────────────
60
+ diffusion_steps_eval: 10
61
+ diffusion_steps_collect: 5
62
+ temperature: 0.5
63
+ top_k: 4
64
+ replan_every: 16
65
+ physics_aware_sampling: false
66
+
67
+ # ── Shared training budget (DAgger only) ─────────────────────────────
68
+ # 5.65M env-steps reproduces the env-step budget consumed at iter600
69
+ # of the original DAgger run. Calibrated against a real DAgger run
70
+ # with the same recipe (see QMUL config header for the full derivation).
71
+ # The earlier 3M figure was based on the buggy single-episode env-step
72
+ # accounting in `online.py:155-169` — fixed in the same commit as this
73
+ # config bump. Used by `--mode dagger` only. Offline mode bypasses
74
+ # this via `offline_total_grad_steps` below.
75
+ total_timesteps: 5650000
76
+
77
+ # Eval/checkpoint cadence in env-step units (DAgger mode).
78
+ # Scaled with the corrected total_timesteps so the run still produces
79
+ # ~12 ID/OOD evals and ~6 checkpoints over its full duration.
80
+ id_eval_every_timesteps: 470000
81
+ ood_eval_every_timesteps: 470000
82
+ checkpoint_every_timesteps: 940000
83
+
84
+ # Final-eval episode count (used by both ID/OOD eval triggers and
85
+ # checkpoint-time evals; matches the original DAgger run).
86
+ eval_episodes_per_env: 50
87
+ checkpoint_eval_episodes: 50
88
+ weight_decay: 0.0001
89
+ aux_loss_weight: 0.5
90
+
91
+ # ── DAgger (matches checkpoint_final/online/config_iter600.yaml) ─────
92
+ dagger_lr: 0.00003
93
+ dagger_batch_size: 2048
94
+ dagger_grad_clip: 1.0
95
+ buffer_capacity: 10000
96
+ episodes_per_iteration: 30
97
+ grad_steps_per_iteration: 100
98
+ efficiency_multiplier: 1.5
99
+ curriculum_queue_size: 100
100
+ curriculum_preseed: true
101
+
102
+ # ── Offline BC (compute-matched fair baseline) ───────────────────────
103
+ # Per the fairness analysis (see QMUL config header):
104
+ # * Same gradient compute as DAgger (60k AdamW updates × 2048 batch)
105
+ # * Same model, diffusion, weight_decay, grad_clip, aux_loss
106
+ # * BC-tuned LR + cosine schedule (best practice from-scratch)
107
+ # * Eval/checkpoint counts matched to DAgger via grad-step overrides
108
+ #
109
+ # `offline_batch_size: 2048` is matched to DAgger (NOT the 4096 the
110
+ # previous UCL config used) so per-update SNR is identical between
111
+ # modes — this is the cleanest apples-to-apples optimisation
112
+ # comparison. The 24 GB VRAM can hold a larger batch but using one
113
+ # would confound the comparison.
114
+ offline_lr: 0.0003
115
+ offline_batch_size: 2048
116
+ offline_grad_clip: 1.0
117
+
118
+ # Compute pin: 60,000 AdamW updates = exactly DAgger@iter600.
119
+ offline_total_grad_steps: 60000
120
+
121
+ # Eval cadence: 5,000 grad steps → 12 evals (matches DAgger eval count).
122
+ offline_eval_every_grad_steps: 5000
123
+
124
+ # Checkpoint cadence: 10,000 grad steps → 6 checkpoints (matches DAgger).
125
+ offline_checkpoint_every_grad_steps: 10000
126
+
127
+ # Buffer cap for offline mode only — must hold the full pre-collected
128
+ # dataset (~1M sliding windows from 20k oracle trajectories). DAgger's
129
+ # `buffer_capacity: 10000` would silently FIFO-evict 99% of the data.
130
+ offline_buffer_capacity: 1500000
131
+
132
+ # ── Performance (cluster-tuned for 3090 Ti) ──────────────────────────
133
+ use_amp: true
134
+ torch_compile: true
135
+ num_collection_workers: 8
136
+
137
+ # ── Data collection (for offline BC dataset) ─────────────────────────
138
+ # 5000 eps × 4 ID envs = 20k oracle trajectories. Strictly more than
139
+ # the ~7k unique trajectories DAgger had in its filtered buffer at
140
+ # iter600 — offline always gets a richer pre-collected pool, which is
141
+ # the standard fairness asymmetry in BC vs DAgger comparisons.
142
+ collect_episodes_per_env: 5000
143
+ collect_num_workers: 8
144
+ collect_output: data/oracle_bc_ucl.pt
145
+
146
+ # ── Checkpointing & Logging ──────────────────────────────────────────
147
+ checkpoint_dir: checkpoints_ucl
148
+ save_policy: true
149
+ hub_run_id: null
150
+ hub_repo_id: null
151
+ use_wandb: true
152
+ wandb_project: remdm-minihack
153
+ wandb_entity: "mathis-weil-university-college-london-ucl-"
154
+ wandb_run_name: null
155
+ # wandb_resume_id intentionally omitted — fresh runs by default.
156
+ # Override on the CLI (`wandb_resume_id=...`) to continue an existing run.
157
+ offline_log_every: 50
158
+ seed: null
configs/smoke.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Smoke test overrides — fast end-to-end sanity check on CPU.
2
+ # With total_timesteps=5000 and ~2 eps × ~30 avg steps × 2 (model+oracle)
3
+ # = ~120 env steps/iter → ~40 iters → a few seconds per iter on CPU.
4
+ buffer_capacity: 50
5
+ dagger_batch_size: 256
6
+ offline_batch_size: 256
7
+ total_timesteps: 5000
8
+ id_eval_every_timesteps: 2500
9
+ ood_eval_every_timesteps: 2500
10
+ checkpoint_every_timesteps: 2500
11
+ episodes_per_iteration: 2
12
+ grad_steps_per_iteration: 5
13
+ eval_episodes_per_env: 2
14
+ checkpoint_eval_episodes: 2
15
+ num_collection_workers: 0
16
+ use_wandb: false
configs/ucl_gpu_bigger_model.yaml ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ── Environments ──────────────────────────────────────────────────────
2
+ id_envs:
3
+ - MiniHack-Room-Random-5x5-v0
4
+ - MiniHack-Room-Random-15x15-v0
5
+ - MiniHack-Corridor-R2-v0
6
+ - MiniHack-MazeWalk-9x9-v0
7
+ ood_envs:
8
+ - MiniHack-Room-Dark-15x15-v0
9
+ - MiniHack-Corridor-R5-v0
10
+ - MiniHack-MazeWalk-45x19-v0
11
+ crop_size: 9
12
+ map_h: 21
13
+ map_w: 79
14
+ action_dim: 12
15
+ mask_token: 12
16
+ pad_token: 13
17
+
18
+ # ── Model ─────────────────────────────────────────────────────────────
19
+ n_embd: 384
20
+ n_head: 6
21
+ n_layer: 4
22
+ n_global_tokens: 8
23
+ seq_len: 64
24
+ global_gate_init: -3.0
25
+ # Transformer dropout. 0.0 is deliberate — discrete diffusion forward masking
26
+ # already regularises; dropout on top is redundant.
27
+ dropout: 0.0
28
+ ema_decay: 0.999
29
+
30
+ # ── Diffusion (MDLM) ─────────────────────────────────────────────────
31
+ noise_schedule: linear
32
+ num_diffusion_steps: 100
33
+ loss_weight_clip: 1000.0
34
+ label_smoothing: 0.0
35
+ # Use SUBS importance weighting w(t) in loss. Off by default (flat average
36
+ # matching reference). Enable for MDLM ELBO experiments.
37
+ use_importance_weighting: false
38
+ # ReMDM stochastic remask base fraction
39
+ eta: 0.15
40
+ # Remasking strategy: rescale | cap | conf
41
+ remask_strategy: conf
42
+
43
+ # ── Inference ─────────────────────────────────────────────────────────
44
+ # Number of reverse denoising steps at inference.
45
+ # Reference uses 5 (aggressive). Higher = better quality, slower.
46
+ diffusion_steps_eval: 10
47
+ diffusion_steps_collect: 5
48
+ temperature: 0.5
49
+ top_k: 4
50
+ replan_every: 16
51
+ # Soft-penalise hazardous cardinal actions during stochastic sampling.
52
+ # Not active in the reference evaluation pipeline; off by default.
53
+ physics_aware_sampling: false
54
+
55
+ # ── Training budget (unified) ────────────────────────────────────────
56
+ total_timesteps: 20000000
57
+ id_eval_every_timesteps: 250000
58
+ ood_eval_every_timesteps: 250000
59
+ checkpoint_every_timesteps: 1250000
60
+
61
+ # ── Offline BC ────────────────────────────────────────────────────────
62
+ offline_lr: 0.0003
63
+ offline_batch_size: 4608
64
+ offline_grad_clip: 1.0
65
+ aux_loss_weight: 0.5
66
+
67
+ # ── DAgger ────────────────────────────────────────────────────────────
68
+ dagger_lr: 0.00003
69
+ dagger_batch_size: 4608
70
+ dagger_grad_clip: 1.0
71
+ weight_decay: 0.0001
72
+ buffer_capacity: 10000
73
+ episodes_per_iteration: 30
74
+ grad_steps_per_iteration: 100
75
+ efficiency_multiplier: 1.5
76
+ curriculum_queue_size: 100
77
+ # Pre-seed curriculum queues with 50/50 prior for uniform early sampling.
78
+ curriculum_preseed: true
79
+ eval_episodes_per_env: 50
80
+ checkpoint_eval_episodes: 50
81
+
82
+ # ── Performance ──────────────────────────────────────────────────────
83
+ # Mixed-precision (FP16) training via torch.cuda.amp.
84
+ # Speeds up forward/backward ~1.5-2x on GPU. No effect on CPU.
85
+ use_amp: true
86
+ # torch.compile the model for fused kernels (experimental).
87
+ # May cause slow first iteration due to compilation. No effect on CPU.
88
+ torch_compile: true
89
+ # Number of parallel workers for DAgger episode collection.
90
+ # 0 = sequential (reference behaviour). Recommended: 4-8 on multi-core.
91
+ num_collection_workers: 8
92
+
93
+ # ── Checkpointing & Logging ──────────────────────────────────────────
94
+ checkpoint_dir: checkpoints_ucl_bigger_model
95
+ save_policy: true
96
+ hub_run_id: null
97
+ hub_repo_id: null
98
+ use_wandb: true
99
+ wandb_project: remdm-minihack
100
+ wandb_entity: "mathis-weil-university-college-london-ucl-"
101
+ wandb_run_name: null
102
+ offline_log_every: 10
103
+ seed: null
configs/ucl_gpu_learning_behaviour.yaml ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ── Environments ──────────────────────────────────────────────────────
2
+ id_envs:
3
+ - MiniHack-Room-Random-5x5-v0
4
+ - MiniHack-Room-Random-15x15-v0
5
+ - MiniHack-Corridor-R2-v0
6
+ - MiniHack-MazeWalk-9x9-v0
7
+ ood_envs:
8
+ - MiniHack-Room-Dark-15x15-v0
9
+ - MiniHack-Corridor-R5-v0
10
+ - MiniHack-MazeWalk-45x19-v0
11
+ crop_size: 9
12
+ map_h: 21
13
+ map_w: 79
14
+ action_dim: 12
15
+ mask_token: 12
16
+ pad_token: 13
17
+
18
+ # ── Model ─────────────────────────────────────────────────────────────
19
+ n_embd: 256
20
+ n_head: 4
21
+ n_layer: 4
22
+ n_global_tokens: 8
23
+ seq_len: 64
24
+ global_gate_init: -3.0
25
+ # Transformer dropout. 0.0 is deliberate — discrete diffusion forward masking
26
+ # already regularises; dropout on top is redundant.
27
+ dropout: 0.0
28
+ ema_decay: 0.999
29
+
30
+ # ── Diffusion (MDLM) ─────────────────────────────────────────────────
31
+ noise_schedule: linear
32
+ num_diffusion_steps: 100
33
+ loss_weight_clip: 1000.0
34
+ label_smoothing: 0.0
35
+ # Use SUBS importance weighting w(t) in loss. Off by default (flat average
36
+ # matching reference). Enable for MDLM ELBO experiments.
37
+ use_importance_weighting: false
38
+ # ReMDM stochastic remask base fraction
39
+ eta: 0.18
40
+ # Remasking strategy: rescale | cap | conf
41
+ remask_strategy: conf
42
+
43
+ # ── Inference ─────────────────────────────────────────────────────────
44
+ # Number of reverse denoising steps at inference.
45
+ # Reference uses 5 (aggressive). Higher = better quality, slower.
46
+ diffusion_steps_eval: 10
47
+ diffusion_steps_collect: 5
48
+ temperature: 0.5
49
+ top_k: 4
50
+ replan_every: 16
51
+ # Soft-penalise hazardous cardinal actions during stochastic sampling.
52
+ # Not active in the reference evaluation pipeline; off by default.
53
+ physics_aware_sampling: false
54
+
55
+ # ── Training budget (unified) ────────────────────────────────────────
56
+ total_timesteps: 20000000
57
+ id_eval_every_timesteps: 250000
58
+ ood_eval_every_timesteps: 250000
59
+ checkpoint_every_timesteps: 1250000
60
+
61
+ # ── Offline BC ────────────────────────────────────────────────────────
62
+ offline_lr: 0.0003
63
+ offline_batch_size: 6144
64
+ offline_grad_clip: 1.0
65
+ aux_loss_weight: 0.5
66
+
67
+ # ── DAgger ────────────────────────────────────────────────────────────
68
+ dagger_lr: 0.00003
69
+ dagger_batch_size: 6144
70
+ dagger_grad_clip: 1.0
71
+ weight_decay: 0.0001
72
+ buffer_capacity: 10000
73
+ episodes_per_iteration: 30
74
+ grad_steps_per_iteration: 100
75
+ efficiency_multiplier: 1.5
76
+ curriculum_queue_size: 100
77
+ # Pre-seed curriculum queues with 50/50 prior for uniform early sampling.
78
+ curriculum_preseed: true
79
+ eval_episodes_per_env: 50
80
+ checkpoint_eval_episodes: 50
81
+
82
+ # ── Performance ──────────────────────────────────────────────────────
83
+ # Mixed-precision (FP16) training via torch.cuda.amp.
84
+ # Speeds up forward/backward ~1.5-2x on GPU. No effect on CPU.
85
+ use_amp: true
86
+ # torch.compile the model for fused kernels (experimental).
87
+ # May cause slow first iteration due to compilation. No effect on CPU.
88
+ torch_compile: true
89
+ # Number of parallel workers for DAgger episode collection.
90
+ # 0 = sequential (reference behaviour). Recommended: 4-8 on multi-core.
91
+ num_collection_workers: 8
92
+
93
+ # ── Checkpointing & Logging ──────────────────────────────────────────
94
+ checkpoint_dir: checkpoints_ucl_learning_behaviour
95
+ save_policy: true
96
+ hub_run_id: null
97
+ hub_repo_id: null
98
+ use_wandb: true
99
+ wandb_project: remdm-minihack
100
+ wandb_entity: "mathis-weil-university-college-london-ucl-"
101
+ wandb_run_name: null
102
+ offline_log_every: 10
103
+ seed: null
environments/.gitkeep ADDED
File without changes
main.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import logging
5
+ import random
6
+ from pathlib import Path
7
+ from typing import Any
8
+
9
+ import numpy as np
10
+ import torch
11
+
12
+ from src.config import load_config
13
+ from src.planners.baselines import ALL_BASELINE_ALGOS, run_baselines
14
+ from src.planners.logging import Logger
15
+ from src.planners.offline import run_offline
16
+ from src.planners.online import run_dagger
17
+ from src.planners.inference import run_inference
18
+ from src.planners.collect_oracle import run_collect
19
+ from src.planners.smoke import run_smoke
20
+
21
+
22
+ # =============================================================================
23
+ # Logging
24
+ # =============================================================================
25
+
26
+ logging.basicConfig(
27
+ level=logging.INFO,
28
+ format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
29
+ )
30
+ logger = logging.getLogger(__name__)
31
+
32
+
33
+ # =============================================================================
34
+ # Utils
35
+ # =============================================================================
36
+
37
+ def _parse_overrides(extras: list[str]) -> dict[str, Any]:
38
+ return {
39
+ k.lstrip("-"): v
40
+ for item in extras if "=" in item
41
+ for k, v in [item.split("=", 1)]
42
+ }
43
+
44
+
45
+ def _set_seed(seed: int | None) -> int:
46
+ if seed is None:
47
+ seed = random.randint(0, 2**31 - 1)
48
+
49
+ random.seed(seed)
50
+ np.random.seed(seed)
51
+ torch.manual_seed(seed)
52
+ if torch.cuda.is_available():
53
+ torch.cuda.manual_seed_all(seed)
54
+
55
+ return seed
56
+
57
+
58
+ # =============================================================================
59
+ # CLI
60
+ # =============================================================================
61
+
62
+ def parse_args() -> tuple[argparse.Namespace, list[str]]:
63
+ parser = argparse.ArgumentParser(
64
+ description="ReMDM-MiniHack: Masked Diffusion Planner",
65
+ )
66
+
67
+ parser.add_argument(
68
+ "--mode",
69
+ required=True,
70
+ choices=[
71
+ "smoke", "offline", "dagger", "inference", "collect", "baselines",
72
+ ],
73
+ )
74
+ parser.add_argument("--config", default="configs/defaults.yaml")
75
+ parser.add_argument(
76
+ "--algo", default=None, choices=list(ALL_BASELINE_ALGOS),
77
+ help="Baseline algorithm (required for --mode baselines)",
78
+ )
79
+ parser.add_argument(
80
+ "--seeds", type=int, nargs="+", default=None,
81
+ help=(
82
+ "Explicit list of seeds for --mode baselines "
83
+ "(e.g. --seeds 0 1 2)."
84
+ ),
85
+ )
86
+ parser.add_argument(
87
+ "--n-seeds", type=int, default=None,
88
+ help=(
89
+ "Number of seeds starting from 0 (alternative to --seeds; "
90
+ "only used by --mode baselines)."
91
+ ),
92
+ )
93
+
94
+ parser.add_argument("--data", default=None)
95
+ parser.add_argument("--checkpoint", default=None)
96
+ parser.add_argument(
97
+ "--wandb-artifact", default=None,
98
+ help=(
99
+ "W&B artifact reference to download as checkpoint, e.g. "
100
+ "'entity/project/checkpoint-iter1000:latest'"
101
+ ),
102
+ )
103
+ parser.add_argument("--no-warm-start", action="store_true")
104
+ parser.add_argument("--no-ema", action="store_true")
105
+
106
+ parser.add_argument("--envs", nargs="+", default=None)
107
+ parser.add_argument(
108
+ "--des", nargs="+", default=None,
109
+ help="Paths to .des scenario files for custom environment evaluation",
110
+ )
111
+ parser.add_argument("--episodes", type=int, default=50)
112
+ parser.add_argument("--output", default=None)
113
+ parser.add_argument(
114
+ "--blind-global", action="store_true",
115
+ help="Zero out global map observations (local-only ablation)",
116
+ )
117
+
118
+ return parser.parse_known_args()
119
+
120
+
121
+ # =============================================================================
122
+ # Config
123
+ # =============================================================================
124
+
125
+ def build_config(args, extras):
126
+ config_path = args.config
127
+ if args.mode == "smoke" and config_path == "configs/defaults.yaml":
128
+ config_path = "configs/smoke.yaml"
129
+
130
+ cfg = load_config(config_path, _parse_overrides(extras))
131
+
132
+ seed = _set_seed(cfg.seed)
133
+ logger.info(f"Seed: {seed}")
134
+
135
+ return cfg
136
+
137
+
138
+ # =============================================================================
139
+ # Validation
140
+ # =============================================================================
141
+
142
+ def validate(args) -> None:
143
+ if args.mode == "inference" and not args.checkpoint and not args.wandb_artifact:
144
+ raise ValueError(
145
+ "--checkpoint or --wandb-artifact required for inference mode"
146
+ )
147
+ if args.mode == "baselines" and args.algo is None:
148
+ raise ValueError(
149
+ "--algo is required for --mode baselines "
150
+ f"(choose one of {list(ALL_BASELINE_ALGOS)})"
151
+ )
152
+
153
+
154
+ def _resolve_seeds(args, cfg) -> list[int]:
155
+ """Build the seed list for --mode baselines."""
156
+ if args.seeds is not None:
157
+ return list(args.seeds)
158
+ if args.n_seeds is not None:
159
+ return list(range(int(args.n_seeds)))
160
+ return [cfg.seed if cfg.seed is not None else 0]
161
+
162
+
163
+ # =============================================================================
164
+ # Dispatch (no lambdas, cleaner)
165
+ # =============================================================================
166
+
167
+ def _resolve_path(p: str | None) -> str | None:
168
+ """Resolve a user-provided path to absolute, or return None."""
169
+ if p is None:
170
+ return None
171
+ return str(Path(p).resolve())
172
+
173
+
174
+ def _resolve_checkpoint(args, cfg) -> str | None:
175
+ """Return a local checkpoint path from --checkpoint or --wandb-artifact."""
176
+ if args.checkpoint:
177
+ return _resolve_path(args.checkpoint)
178
+ artifact_ref = args.wandb_artifact
179
+ if artifact_ref:
180
+ from src.planners.logging import download_artifact
181
+ path = download_artifact(artifact_ref)
182
+ if path is None:
183
+ raise RuntimeError(
184
+ f"Failed to download W&B artifact: {artifact_ref}"
185
+ )
186
+ return path
187
+ return None
188
+
189
+
190
+ def run_mode(mode: str, cfg, args) -> None:
191
+ data_path = _resolve_path(args.data)
192
+ output_path = _resolve_path(args.output)
193
+ des_files = (
194
+ [str(Path(d).resolve()) for d in args.des]
195
+ if args.des else None
196
+ )
197
+
198
+ if mode == "smoke":
199
+ run_smoke(cfg)
200
+
201
+ elif mode == "offline":
202
+ ckpt = _resolve_checkpoint(args, cfg)
203
+ run_offline(cfg, data_path, checkpoint_path=ckpt)
204
+
205
+ elif mode == "dagger":
206
+ ckpt = _resolve_checkpoint(args, cfg)
207
+ run_dagger(cfg, ckpt, args.no_warm_start)
208
+
209
+ elif mode == "collect":
210
+ run_collect(cfg)
211
+
212
+ elif mode == "baselines":
213
+ run_baselines(
214
+ cfg,
215
+ algo=args.algo,
216
+ seeds=_resolve_seeds(args, cfg),
217
+ output_path=output_path,
218
+ )
219
+
220
+ elif mode == "inference":
221
+ ckpt = _resolve_checkpoint(args, cfg)
222
+ if ckpt is None:
223
+ raise ValueError(
224
+ "--checkpoint or --wandb-artifact required for inference"
225
+ )
226
+ log = Logger(cfg)
227
+ run_inference(
228
+ cfg,
229
+ ckpt,
230
+ args.envs,
231
+ args.episodes,
232
+ output_path,
233
+ not args.no_ema,
234
+ log=log,
235
+ des_files=des_files,
236
+ blind_global=args.blind_global,
237
+ )
238
+ log.finish()
239
+
240
+
241
+ # =============================================================================
242
+ # Entry point
243
+ # =============================================================================
244
+
245
+ def main() -> None:
246
+ args, extras = parse_args()
247
+ validate(args)
248
+ cfg = build_config(args, extras)
249
+ if torch.cuda.is_available():
250
+ torch.set_float32_matmul_precision("high")
251
+ run_mode(args.mode, cfg, args)
252
+
253
+
254
+ if __name__ == "__main__":
255
+ main()
pyproject.toml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "minihack-remdm-planner"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = ">=3.12"
7
+ dependencies = [
8
+ "huggingface-hub>=1.8.0",
9
+ "ipython>=9.12.0",
10
+ "matplotlib>=3.10.8",
11
+ "minihack>=1.0.2",
12
+ "nle>=1.2.0",
13
+ "numpy>=2.4.4",
14
+ "orjson>=3.11.8",
15
+ "polars>=1.39.3",
16
+ "pyyaml>=6.0.3",
17
+ "sb3-contrib>=2.8.0",
18
+ "scipy>=1.17.1",
19
+ "stable-baselines3>=2.8.0",
20
+ "torch>=2.11.0",
21
+ "wandb>=0.25.1",
22
+ ]
src/__init__.py ADDED
File without changes
src/buffer.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Replay buffer with offline-protected FIFO eviction.
2
+
3
+ Ported from minihack_reference/src/buffer.py. Stores observation-action
4
+ windows of fixed length ``seq_len``. Offline data is pinned at the front
5
+ and never evicted; online samples use FIFO.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import numpy as np
11
+
12
+
13
+ class ReplayBuffer:
14
+ """Fixed-capacity buffer with offline-protected FIFO eviction.
15
+
16
+ Offline samples (loaded once via ``load_offline_data``) are pinned
17
+ and never evicted. Online samples added via ``add`` are FIFO-evicted
18
+ when the total count exceeds ``capacity``.
19
+
20
+ Args:
21
+ capacity: Maximum total number of windows.
22
+ seq_len: Action-sequence window length.
23
+ pad_token: Token used to pad short sequences.
24
+ """
25
+
26
+ def __init__(
27
+ self, capacity: int, seq_len: int, pad_token: int,
28
+ ) -> None:
29
+ self._capacity = capacity
30
+ self._seq_len = seq_len
31
+ self._pad_token = pad_token
32
+
33
+ # Each element: (local [9,9], global [21,79], actions [seq_len])
34
+ self._offline: list[tuple[np.ndarray, np.ndarray, np.ndarray]] = []
35
+ self._online: list[tuple[np.ndarray, np.ndarray, np.ndarray]] = []
36
+
37
+ # Stacked array cache for fast sampling
38
+ self._cache_valid = False
39
+ self._cached_local: np.ndarray | None = None
40
+ self._cached_global: np.ndarray | None = None
41
+ self._cached_actions: np.ndarray | None = None
42
+
43
+ # ── Offline data ─────────────────────────────────────────────
44
+
45
+ def load_offline_data(
46
+ self,
47
+ data: dict | list,
48
+ allowed_envs: list[str],
49
+ metadata: dict | None = None,
50
+ ) -> None:
51
+ """Load pre-collected trajectories and slice into windows.
52
+
53
+ Supports two dataset formats:
54
+
55
+ **New format** (dict): ``{"trajectories": [...]}`` where each entry
56
+ is a dict with ``"local"``, ``"global"``, ``"actions"``, ``"env_id"``.
57
+
58
+ **Legacy format** (list): Flat list of ``((local, global), action_seq)``
59
+ tuples produced by the reference pipeline (pre-windowed, already
60
+ ``seq_len``-length). Env filtering uses an optional *metadata* dict
61
+ with a ``"samples_per_env"`` key mapping env IDs to sample counts.
62
+
63
+ Args:
64
+ data: Dataset in new dict format or legacy list format.
65
+ allowed_envs: Only samples from these env IDs are kept.
66
+ metadata: Optional sidecar metadata for legacy format env
67
+ filtering. Ignored for the new format.
68
+ """
69
+ if isinstance(data, list):
70
+ self._load_legacy_offline_data(data, allowed_envs, metadata)
71
+ return
72
+
73
+ trajectories = data.get("trajectories", [data])
74
+ for traj in trajectories:
75
+ if traj.get("env_id", "") not in allowed_envs:
76
+ continue
77
+ windows = self._slice_trajectory(traj)
78
+ self._offline.extend(windows)
79
+ # Truncate to capacity
80
+ if len(self._offline) > self._capacity:
81
+ self._offline = self._offline[: self._capacity]
82
+ self._invalidate_cache()
83
+
84
+ def _load_legacy_offline_data(
85
+ self,
86
+ data: list,
87
+ allowed_envs: list[str],
88
+ metadata: dict | None = None,
89
+ ) -> None:
90
+ """Load reference-format datasets (pre-windowed tuples).
91
+
92
+ Args:
93
+ data: List of ``((local_crop, global_map), action_seq)`` tuples.
94
+ ``local_crop`` is ``[9, 9]``, ``global_map`` is ``[21, 79]``,
95
+ ``action_seq`` is a sequence of length ``seq_len``.
96
+ allowed_envs: Env IDs to retain.
97
+ metadata: Optional dict with ``"samples_per_env"`` key mapping
98
+ env IDs to per-env sample counts for precise filtering.
99
+ """
100
+ allowed = set(allowed_envs)
101
+
102
+ if metadata and "samples_per_env" in metadata:
103
+ # Build a per-sample env_id index from the metadata ordering
104
+ sample_to_env: list[str] = []
105
+ for env_id in sorted(metadata["samples_per_env"].keys()):
106
+ count = metadata["samples_per_env"][env_id]
107
+ sample_to_env.extend([env_id] * count)
108
+
109
+ for i, sample in enumerate(data):
110
+ env_id = (
111
+ sample_to_env[i] if i < len(sample_to_env) else None
112
+ )
113
+ if env_id is None or env_id in allowed:
114
+ self._offline.append(self._unpack_legacy_sample(sample))
115
+ else:
116
+ # No metadata — keep all samples (caller is responsible for
117
+ # pre-filtering)
118
+ for sample in data:
119
+ self._offline.append(self._unpack_legacy_sample(sample))
120
+
121
+ if len(self._offline) > self._capacity:
122
+ self._offline = self._offline[: self._capacity]
123
+ self._invalidate_cache()
124
+
125
+ @staticmethod
126
+ def _unpack_legacy_sample(
127
+ sample: tuple,
128
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
129
+ """Convert a legacy ``((local, global), action_seq)`` sample.
130
+
131
+ Args:
132
+ sample: Tuple of ``(state, action_seq)`` where state is
133
+ ``(local_crop, global_map)``.
134
+
135
+ Returns:
136
+ ``(local [9,9], global [21,79], actions [seq_len])`` as
137
+ numpy int16/int64 arrays.
138
+ """
139
+ (local, glb), action_seq = sample
140
+ return (
141
+ np.asarray(local, dtype=np.int16),
142
+ np.asarray(glb, dtype=np.int16),
143
+ np.asarray(action_seq, dtype=np.int64),
144
+ )
145
+
146
+ # ── Online data ──────────���───────────────────────────────────
147
+
148
+ def _invalidate_cache(self) -> None:
149
+ """Mark the stacked array cache as stale."""
150
+ self._cache_valid = False
151
+
152
+ def _ensure_cache(self) -> None:
153
+ """Rebuild stacked arrays from offline + online windows."""
154
+ if self._cache_valid:
155
+ return
156
+ combined = self._offline + self._online
157
+ if not combined:
158
+ return
159
+ n = len(combined)
160
+ l0, g0, a0 = combined[0]
161
+ self._cached_local = np.empty(
162
+ (n, *l0.shape), dtype=l0.dtype,
163
+ )
164
+ self._cached_global = np.empty(
165
+ (n, *g0.shape), dtype=g0.dtype,
166
+ )
167
+ self._cached_actions = np.empty(
168
+ (n, *a0.shape), dtype=a0.dtype,
169
+ )
170
+ for i, (l, g, a) in enumerate(combined):
171
+ self._cached_local[i] = l
172
+ self._cached_global[i] = g
173
+ self._cached_actions[i] = a
174
+ self._cache_valid = True
175
+
176
+ def add(self, trajectory: dict) -> None:
177
+ """Add a trajectory, sliced into overlapping windows.
178
+
179
+ FIFO-evicts oldest online samples when over capacity.
180
+
181
+ Args:
182
+ trajectory: Dict with ``"local"`` ``[T,9,9]``,
183
+ ``"global"`` ``[T,21,79]``, ``"actions"`` ``[T]``.
184
+ """
185
+ windows = self._slice_trajectory(trajectory)
186
+ self._online.extend(windows)
187
+ max_online = self._capacity - len(self._offline)
188
+ if len(self._online) > max_online:
189
+ excess = len(self._online) - max_online
190
+ self._online = self._online[excess:]
191
+ self._invalidate_cache()
192
+
193
+ # ── Sampling ─────────────────────────────────────────────────
194
+
195
+ def sample(
196
+ self, batch_size: int,
197
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray] | None:
198
+ """Random sample from offline + online combined.
199
+
200
+ Args:
201
+ batch_size: Number of windows to sample.
202
+
203
+ Returns:
204
+ ``(local [B,9,9], global [B,21,79], actions [B,seq_len])``
205
+ as numpy arrays, or ``None`` if the buffer is empty.
206
+ """
207
+ if len(self) == 0:
208
+ return None
209
+ self._ensure_cache()
210
+ if self._cached_local is None:
211
+ return None
212
+ indices = np.random.randint(0, len(self), size=batch_size)
213
+ return (
214
+ self._cached_local[indices],
215
+ self._cached_global[indices],
216
+ self._cached_actions[indices],
217
+ )
218
+
219
+ # ── Properties ─────────��─────────────────────────────────────
220
+
221
+ def __len__(self) -> int:
222
+ """Total number of windows (offline + online)."""
223
+ return len(self._offline) + len(self._online)
224
+
225
+ @property
226
+ def n_offline(self) -> int:
227
+ """Number of pinned offline windows."""
228
+ return len(self._offline)
229
+
230
+ @property
231
+ def offline_size(self) -> int:
232
+ """Number of pinned offline windows (alias)."""
233
+ return len(self._offline)
234
+
235
+ # ── Internals ───────────────────────────────────────────���────
236
+
237
+ def _slice_trajectory(
238
+ self, traj: dict,
239
+ ) -> list[tuple[np.ndarray, np.ndarray, np.ndarray]]:
240
+ """Slice a trajectory into overlapping seq_len windows.
241
+
242
+ Args:
243
+ traj: Trajectory dict with ``"local"``, ``"global"``,
244
+ ``"actions"`` arrays.
245
+
246
+ Returns:
247
+ List of ``(local, global, actions)`` tuples.
248
+ """
249
+ local_arr = np.asarray(traj["local"])
250
+ global_arr = np.asarray(traj["global"])
251
+ actions_arr = np.asarray(traj["actions"])
252
+ T = len(actions_arr)
253
+ windows: list[tuple[np.ndarray, np.ndarray, np.ndarray]] = []
254
+
255
+ for start in range(T):
256
+ end = start + self._seq_len
257
+ if end <= T:
258
+ a = actions_arr[start:end]
259
+ else:
260
+ a = np.full(self._seq_len, self._pad_token, dtype=np.int64)
261
+ a[: T - start] = actions_arr[start:]
262
+
263
+ # Use the observation at the window start
264
+ l = local_arr[min(start, len(local_arr) - 1)]
265
+ g = global_arr[min(start, len(global_arr) - 1)]
266
+ windows.append((l.copy(), g.copy(), a))
267
+
268
+ return windows
src/config.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Configuration loader for ReMDM-MiniHack.
2
+
3
+ Loads YAML configs with deep-merge and CLI override support,
4
+ following the Craftax config pattern.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import logging
10
+ import os
11
+ import secrets
12
+ from datetime import datetime, timezone
13
+ from pathlib import Path
14
+ from types import SimpleNamespace
15
+
16
+ import yaml
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ _PROJECT_ROOT = Path(__file__).resolve().parent.parent
22
+
23
+
24
+ def _deep_merge(base: dict, override: dict) -> dict:
25
+ """Recursively merge *override* into *base* (mutates *base*).
26
+
27
+ Args:
28
+ base: Base dictionary to merge into.
29
+ override: Dictionary whose values take precedence.
30
+
31
+ Returns:
32
+ The merged dictionary (same object as *base*).
33
+ """
34
+ for key, value in override.items():
35
+ if (
36
+ key in base
37
+ and isinstance(base[key], dict)
38
+ and isinstance(value, dict)
39
+ ):
40
+ _deep_merge(base[key], value)
41
+ else:
42
+ base[key] = value
43
+ return base
44
+
45
+
46
+ def _cast_value(value: str) -> int | float | bool | str | None:
47
+ """Best-effort cast of a CLI string to a Python scalar.
48
+
49
+ Args:
50
+ value: Raw string from the command line.
51
+
52
+ Returns:
53
+ Parsed Python value (int, float, bool, str, or None).
54
+ """
55
+ if value.lower() in ("true", "yes"):
56
+ return True
57
+ if value.lower() in ("false", "no"):
58
+ return False
59
+ if value.lower() == "null":
60
+ return None
61
+ try:
62
+ return int(value)
63
+ except ValueError:
64
+ pass
65
+ try:
66
+ return float(value)
67
+ except ValueError:
68
+ pass
69
+ return value
70
+
71
+
72
+ def load_config(
73
+ config_path: str | None = None,
74
+ cli_overrides: dict | None = None,
75
+ ) -> SimpleNamespace:
76
+ """Load configuration from YAML with optional overrides.
77
+
78
+ 1. Load ``configs/defaults.yaml``.
79
+ 2. Deep-merge *config_path* on top (if provided and different from defaults).
80
+ 3. Apply *cli_overrides* key=value pairs.
81
+ 4. Auto-select device (``cuda`` if available, else ``cpu``; honour
82
+ ``DEVICE`` env-var).
83
+ 5. Validate invariants.
84
+
85
+ Args:
86
+ config_path: Path to a YAML file merged on top of defaults.
87
+ ``None`` uses defaults only.
88
+ cli_overrides: ``{key: value}`` pairs applied last.
89
+
90
+ Returns:
91
+ A ``SimpleNamespace`` containing all hyperparameters.
92
+
93
+ Raises:
94
+ AssertionError: If ``mask_token != action_dim`` or
95
+ ``pad_token != action_dim + 1``.
96
+ """
97
+ if cli_overrides is None:
98
+ cli_overrides = {}
99
+
100
+ defaults_path = _PROJECT_ROOT / "configs" / "defaults.yaml"
101
+ with open(defaults_path, "r") as fh:
102
+ cfg = yaml.safe_load(fh)
103
+
104
+ if config_path is not None:
105
+ config_path_resolved = Path(config_path)
106
+ if not config_path_resolved.is_absolute():
107
+ config_path_resolved = _PROJECT_ROOT / config_path_resolved
108
+ if config_path_resolved.resolve() != defaults_path.resolve():
109
+ with open(config_path_resolved, "r") as fh:
110
+ overrides = yaml.safe_load(fh) or {}
111
+ _deep_merge(cfg, overrides)
112
+
113
+ for key, value in cli_overrides.items():
114
+ if isinstance(value, str):
115
+ value = _cast_value(value)
116
+ cfg[key] = value
117
+
118
+ # Device selection
119
+ env_device = os.environ.get("DEVICE")
120
+ if env_device:
121
+ cfg["device"] = env_device
122
+ elif "device" not in cfg:
123
+ try:
124
+ import torch
125
+ cfg["device"] = "cuda" if torch.cuda.is_available() else "cpu"
126
+ except ImportError:
127
+ cfg["device"] = "cpu"
128
+
129
+ ns = SimpleNamespace(**cfg)
130
+
131
+ # Validation
132
+ assert ns.mask_token == ns.action_dim, (
133
+ f"mask_token ({ns.mask_token}) must equal action_dim ({ns.action_dim})"
134
+ )
135
+ assert ns.pad_token == ns.action_dim + 1, (
136
+ f"pad_token ({ns.pad_token}) must equal action_dim + 1 "
137
+ f"({ns.action_dim + 1})"
138
+ )
139
+
140
+ return ns
141
+
142
+
143
+ def make_run_dir(cfg: SimpleNamespace, tag: str = "run") -> Path:
144
+ """Create a unique run subdirectory under ``cfg.checkpoint_dir``.
145
+
146
+ Generates a directory named ``{tag}_{YYYYMMDD}_{HHMMSS}_{hex4}``
147
+ to prevent concurrent runs from overwriting each other's
148
+ checkpoints. Updates ``cfg.checkpoint_dir`` in place.
149
+
150
+ Args:
151
+ cfg: Config namespace (``checkpoint_dir`` is mutated).
152
+ tag: Prefix for the directory name (e.g. ``"dagger"``,
153
+ ``"offline"``).
154
+
155
+ Returns:
156
+ The created directory path.
157
+ """
158
+ ts = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
159
+ suffix = secrets.token_hex(2)
160
+ run_dir = Path(cfg.checkpoint_dir).resolve() / f"{tag}_{ts}_{suffix}"
161
+ run_dir.mkdir(parents=True, exist_ok=True)
162
+ cfg.checkpoint_dir = str(run_dir)
163
+ logger.info("Checkpoint directory: %s", run_dir)
164
+ return run_dir
src/curriculum.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Dynamic environment curriculum and efficiency filter.
2
+
3
+ Ported from minihack_reference/src/curriculum.py. Tracks per-environment
4
+ win rates in a rolling window and uses bucket-based sampling weights to
5
+ focus training on environments where the model is struggling.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import random
11
+ from collections import deque
12
+
13
+
14
+ class DynamicCurriculum:
15
+ """Rolling-window curriculum with bucket-based sampling weights.
16
+
17
+ Each environment maintains a deque of recent win/loss outcomes.
18
+ Sampling probability is inversely proportional to performance:
19
+ environments with low win rates are sampled more often.
20
+
21
+ Args:
22
+ env_ids: List of environment IDs to track.
23
+ queue_size: Rolling window size per environment.
24
+ """
25
+
26
+ # Bucket thresholds and weights
27
+ _LOW_THRESHOLD = 0.15
28
+ _HIGH_THRESHOLD = 0.85
29
+ _WEIGHT_LOW = 0.2
30
+ _WEIGHT_MID = 1.0
31
+ _WEIGHT_HIGH = 0.1
32
+
33
+ def __init__(
34
+ self,
35
+ env_ids: list[str],
36
+ queue_size: int = 100,
37
+ preseed: bool = True,
38
+ ) -> None:
39
+ self._env_ids = list(env_ids)
40
+ self._queue_size = queue_size
41
+ self._queues: dict[str, deque[bool]] = {}
42
+ for eid in self._env_ids:
43
+ q: deque[bool] = deque(maxlen=queue_size)
44
+ if preseed:
45
+ # 50/50 prior for uniform early sampling
46
+ for _ in range(50):
47
+ q.append(True)
48
+ for _ in range(50):
49
+ q.append(False)
50
+ self._queues[eid] = q
51
+
52
+ def update(self, env_id: str, won: bool) -> None:
53
+ """Record an episode outcome.
54
+
55
+ Args:
56
+ env_id: Environment ID.
57
+ won: Whether the episode was won.
58
+ """
59
+ if env_id not in self._queues:
60
+ self._queues[env_id] = deque(maxlen=self._queue_size)
61
+ self._queues[env_id].append(won)
62
+
63
+ def win_rate(self, env_id: str) -> float:
64
+ """Rolling win rate for an environment.
65
+
66
+ Args:
67
+ env_id: Environment ID.
68
+
69
+ Returns:
70
+ Win rate in ``[0, 1]``. Default 0.5 if empty.
71
+ """
72
+ q = self._queues.get(env_id)
73
+ if q is None or len(q) == 0:
74
+ return 0.5
75
+ return sum(q) / len(q)
76
+
77
+ def sample_env(self) -> str:
78
+ """Sample an environment ID using bucket-weighted probabilities.
79
+
80
+ Returns:
81
+ Sampled environment ID.
82
+ """
83
+ weights: list[float] = []
84
+ for eid in self._env_ids:
85
+ w = self.win_rate(eid)
86
+ if w < self._LOW_THRESHOLD:
87
+ weights.append(self._WEIGHT_LOW)
88
+ elif w > self._HIGH_THRESHOLD:
89
+ weights.append(self._WEIGHT_HIGH)
90
+ else:
91
+ weights.append(self._WEIGHT_MID)
92
+ return random.choices(self._env_ids, weights=weights, k=1)[0]
93
+
94
+ def state_dict(self) -> dict:
95
+ """Serialise curriculum state.
96
+
97
+ Returns:
98
+ Dict with ``env_ids``, ``queue_size``, and per-env queues.
99
+ """
100
+ return {
101
+ "env_ids": self._env_ids,
102
+ "queue_size": self._queue_size,
103
+ "queues": {
104
+ eid: list(q) for eid, q in self._queues.items()
105
+ },
106
+ }
107
+
108
+ def load_state_dict(self, sd: dict) -> None:
109
+ """Restore curriculum state.
110
+
111
+ Args:
112
+ sd: State dict from ``state_dict()``.
113
+ """
114
+ self._queue_size = sd.get("queue_size", self._queue_size)
115
+ for eid, items in sd.get("queues", {}).items():
116
+ q: deque[bool] = deque(maxlen=self._queue_size)
117
+ q.extend(items)
118
+ self._queues[eid] = q
119
+
120
+
121
+ def efficiency_filter(
122
+ model_won: bool,
123
+ model_steps: int,
124
+ oracle_steps: int,
125
+ multiplier: float = 1.5,
126
+ ) -> bool:
127
+ """Decide whether to add oracle trajectory to the buffer.
128
+
129
+ Returns ``True`` (add oracle data) when the model either failed
130
+ or was substantially less efficient than the oracle.
131
+
132
+ Args:
133
+ model_won: Whether the model solved the episode.
134
+ model_steps: Steps the model took.
135
+ oracle_steps: Steps the oracle took.
136
+ multiplier: Efficiency threshold multiplier.
137
+
138
+ Returns:
139
+ ``True`` if oracle data should be added to the buffer.
140
+ """
141
+ if not model_won:
142
+ return True
143
+ return model_steps > multiplier * oracle_steps
src/diffusion/__init__.py ADDED
File without changes
src/diffusion/forward.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Forward masking process q(z_t | x_0).
2
+
3
+ Ported from the Craftax JAX implementation (src/diffusion/forward.py).
4
+ Each token is independently replaced with mask_token with probability
5
+ sigma_t = 1 - alpha_t. PAD positions are never masked.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from typing import Callable
11
+
12
+ import torch
13
+ from torch import Tensor
14
+
15
+
16
+ def q_sample(
17
+ x0: Tensor,
18
+ t: Tensor,
19
+ mask_token: int,
20
+ pad_token: int,
21
+ schedule_fn: Callable[[Tensor], Tensor],
22
+ ) -> Tensor:
23
+ """Sample z_t from the forward masking process.
24
+
25
+ Args:
26
+ x0: Clean action sequences. Shape ``[B, L]``, dtype int64.
27
+ t: Per-sample diffusion time in [0, 1]. Shape ``[B]``.
28
+ mask_token: Integer ID of the MASK token.
29
+ pad_token: Integer ID of the PAD token.
30
+ schedule_fn: Noise schedule returning alpha(t).
31
+
32
+ Returns:
33
+ Noisy sequence z_t. Shape ``[B, L]``, dtype int64.
34
+ PAD positions are preserved unchanged.
35
+ """
36
+ alpha_t = schedule_fn(t) # [B]
37
+ sigma_t = 1.0 - alpha_t # mask probability per sample
38
+ sigma_t = sigma_t.unsqueeze(-1) # [B, 1]
39
+
40
+ # Independent Bernoulli masking per position
41
+ mask_draws = torch.rand_like(x0, dtype=torch.float32) # [B, L]
42
+ do_mask = mask_draws < sigma_t # [B, L]
43
+
44
+ zt = torch.where(do_mask, mask_token, x0)
45
+
46
+ # Restore PAD positions — never mask padding
47
+ pad_mask = x0 == pad_token # [B, L]
48
+ zt = torch.where(pad_mask, pad_token, zt)
49
+
50
+ return zt
src/diffusion/loss.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MDLM ELBO loss with SUBS parameterisation.
2
+
3
+ Ported from the Craftax JAX implementation (src/diffusion/loss.py).
4
+ Computes continuous-time loss on masked positions only, with analytic
5
+ SUBS weighting clipped for numerical stability.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from typing import Callable
11
+
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from torch import Tensor
15
+
16
+ from src.diffusion.schedules import alpha_prime
17
+
18
+
19
+ _MAX_WEIGHT: float = 1000.0
20
+
21
+
22
+ def mdlm_loss(
23
+ logits: Tensor,
24
+ x0: Tensor,
25
+ zt: Tensor,
26
+ t: Tensor,
27
+ mask_token: int,
28
+ pad_token: int,
29
+ schedule_fn: Callable[[Tensor], Tensor],
30
+ weight_clip: float = _MAX_WEIGHT,
31
+ label_smoothing: float = 0.0,
32
+ use_importance_weighting: bool = False,
33
+ ) -> Tensor:
34
+ """Compute masked diffusion loss.
35
+
36
+ By default uses a simple masked cross-entropy average (matching the
37
+ reference implementation). When ``use_importance_weighting=True``,
38
+ applies SUBS weighting ``w(t) = -alpha'(t) / (1 - alpha_t)``.
39
+
40
+ Args:
41
+ logits: Model output. Shape ``[B, L, vocab]``.
42
+ x0: Clean action sequences. Shape ``[B, L]``, int64.
43
+ zt: Noisy sequences. Shape ``[B, L]``, int64.
44
+ t: Per-sample diffusion time in [0, 1]. Shape ``[B]``.
45
+ mask_token: MASK token ID.
46
+ pad_token: PAD token ID.
47
+ schedule_fn: Noise schedule returning alpha(t).
48
+ weight_clip: Upper clamp for SUBS weight (default 1000).
49
+ label_smoothing: Smoothing epsilon for cross-entropy.
50
+ use_importance_weighting: If ``True``, apply SUBS w(t) per sample.
51
+
52
+ Returns:
53
+ Scalar loss. Returns ``0.0`` when no masked positions exist.
54
+ """
55
+ B, L, V = logits.shape
56
+
57
+ # Mask: compute loss only on masked, non-PAD positions
58
+ is_masked = (zt == mask_token) & (x0 != pad_token) # [B, L]
59
+
60
+ if not is_masked.any():
61
+ return logits.new_tensor(0.0)
62
+
63
+ # Per-position cross-entropy
64
+ # Clamp targets to valid vocab range — out-of-range positions (PAD,
65
+ # MASK) will be zeroed out by is_masked anyway.
66
+ safe_targets = x0.clamp(0, V - 1) # [B, L]
67
+ ce = F.cross_entropy(
68
+ logits.reshape(-1, V),
69
+ safe_targets.reshape(-1),
70
+ reduction="none",
71
+ label_smoothing=label_smoothing,
72
+ ) # [B*L]
73
+ ce = ce.reshape(B, L) # [B, L]
74
+
75
+ # Zero out non-masked positions
76
+ ce = ce * is_masked.float() # [B, L]
77
+
78
+ # Global average over all masked positions (matches reference)
79
+ n_masked_total = is_masked.float().sum().clamp(min=1.0)
80
+ loss = ce.sum() / n_masked_total
81
+
82
+ if use_importance_weighting:
83
+ # SUBS weight: w_t = -alpha'(t) / (1 - alpha_t + eps)
84
+ alpha_t = schedule_fn(t) # [B]
85
+ d_alpha = alpha_prime(t, schedule_fn) # [B]
86
+ w_t = (-d_alpha) / (1.0 - alpha_t + 1e-8) # [B]
87
+ w_t = w_t.clamp(0.0, weight_clip) # [B]
88
+
89
+ # Per-sample weighted loss (needed for SUBS)
90
+ n_masked_per = is_masked.float().sum(dim=1).clamp(min=1.0) # [B]
91
+ per_sample = ce.sum(dim=1) / n_masked_per # [B]
92
+ loss = (per_sample * w_t).mean()
93
+
94
+ return loss
95
+
96
+
97
+ def auxiliary_goal_loss(
98
+ goal_pred: Tensor,
99
+ global_obs: Tensor,
100
+ pad_value: float = -1.0,
101
+ ) -> Tensor:
102
+ """MSE loss for auxiliary staircase-coordinate prediction.
103
+
104
+ Args:
105
+ goal_pred: Predicted normalised staircase coords. Shape ``[B, 2]``.
106
+ global_obs: Full map glyphs. Shape ``[B, 21, 79]``, int.
107
+ pad_value: Coordinate value used when staircase is not visible.
108
+
109
+ Returns:
110
+ Scalar MSE loss over samples where the staircase is visible.
111
+ Returns ``0.0`` when no staircase is visible in the batch.
112
+ """
113
+ targets = find_staircase_from_glyphs(global_obs) # [B, 2]
114
+ targets = targets.to(goal_pred.device, dtype=goal_pred.dtype)
115
+
116
+ # Only supervise where staircase is visible
117
+ valid = (targets[:, 0] != pad_value) # [B]
118
+ if not valid.any():
119
+ return goal_pred.new_tensor(0.0)
120
+
121
+ diff = (goal_pred[valid] - targets[valid]) ** 2 # [N, 2]
122
+ return diff.mean()
123
+
124
+
125
+ def find_staircase_from_glyphs(global_obs: Tensor) -> Tensor:
126
+ """Locate the staircase '>' in the global glyph map.
127
+
128
+ Searches for NLE staircase-down glyph (character code 62 = '>').
129
+ Returns normalised (row/H, col/W) coordinates per batch element,
130
+ or (-1, -1) when the staircase is not visible.
131
+
132
+ Args:
133
+ global_obs: Glyph map. Shape ``[B, H, W]`` or ``[H, W]``, int.
134
+
135
+ Returns:
136
+ Normalised coordinates. Shape ``[B, 2]`` (float32).
137
+ """
138
+ if global_obs.ndim == 2:
139
+ global_obs = global_obs.unsqueeze(0)
140
+
141
+ B, H, W = global_obs.shape
142
+ # NLE staircase-down glyphs: ord('>') = 62, plus NLE tile variants
143
+ # 2310 (S_dnstair), 2368 (S_dnstairs), 2383 (S_vodoor).
144
+ is_stair = (
145
+ (global_obs == 62)
146
+ | (global_obs == 2310)
147
+ | (global_obs == 2368)
148
+ | (global_obs == 2383)
149
+ )
150
+
151
+ coords = torch.full(
152
+ (B, 2), -1.0, dtype=torch.float32, device=global_obs.device
153
+ )
154
+ for b in range(B):
155
+ positions = is_stair[b].nonzero(as_tuple=False) # [N, 2]
156
+ if positions.shape[0] > 0:
157
+ row = positions[0, 0].float() / max(1, H - 1)
158
+ col = positions[0, 1].float() / max(1, W - 1)
159
+ coords[b, 0] = row
160
+ coords[b, 1] = col
161
+
162
+ return coords
src/diffusion/sampling.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ReMDM reverse denoising with remasking strategies.
2
+
3
+ Ported from the Craftax JAX implementation (src/diffusion/sampling.py).
4
+ Implements MaskGIT-style progressive unmasking with optional stochastic
5
+ remasking (ReMDM) using three strategy variants.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from types import SimpleNamespace
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn.functional as F
15
+ from torch import Tensor
16
+ from torch.distributions import Categorical
17
+
18
+ from src.diffusion.schedules import get_schedule
19
+
20
+ # NLE hazard glyph IDs and char codes (walls, locked doors, lava, water)
21
+ _HAZARD_GLYPHS: frozenset[int] = frozenset({2359, 2360, 2389, 2390})
22
+ _HAZARD_CHARS: frozenset[int] = frozenset(
23
+ {ord("|"), ord("-"), ord("+"), ord("L"), ord("W")}
24
+ )
25
+ # Cardinal action → (dy, dx) offsets
26
+ _CARDINAL_OFFSETS: dict[int, tuple[int, int]] = {
27
+ 0: (-1, 0), 1: (0, 1), 2: (1, 0), 3: (0, -1),
28
+ }
29
+ _N_PHYSICS_CHECK = 8 # only inspect the first N plan positions
30
+
31
+
32
+ def _check_hazard(local_crop: np.ndarray, action: int) -> bool:
33
+ """Return True if *action* from the agent's centre steps into a hazard.
34
+
35
+ Args:
36
+ local_crop: ``[crop_size, crop_size]`` glyph array.
37
+ action: Cardinal action index (0=N, 1=E, 2=S, 3=W).
38
+
39
+ Returns:
40
+ ``True`` when the target cell contains a hazard glyph.
41
+ """
42
+ if action not in _CARDINAL_OFFSETS:
43
+ return False
44
+ cs = local_crop.shape[0]
45
+ cy, cx = cs // 2, cs // 2
46
+ dy, dx = _CARDINAL_OFFSETS[action]
47
+ ny, nx = cy + dy, cx + dx
48
+ if not (0 <= ny < cs and 0 <= nx < cs):
49
+ return True
50
+ glyph = int(local_crop[ny, nx])
51
+ return glyph in _HAZARD_GLYPHS or glyph in _HAZARD_CHARS
52
+
53
+
54
+ def top_k_filter(logits: Tensor, k: int) -> Tensor:
55
+ """Zero out all but the top-k logits per position.
56
+
57
+ Args:
58
+ logits: Raw logits. Shape ``[..., V]``.
59
+ k: Number of top entries to keep.
60
+
61
+ Returns:
62
+ Filtered logits with non-top-k set to ``-inf``.
63
+ """
64
+ if k <= 0 or k >= logits.shape[-1]:
65
+ return logits
66
+ topk_vals, _ = logits.topk(k, dim=-1) # [..., k]
67
+ threshold = topk_vals[..., -1:] # [..., 1]
68
+ return logits.masked_fill(logits < threshold, float("-inf"))
69
+
70
+
71
+ def _compute_remask_prob(
72
+ strategy: str,
73
+ eta: float,
74
+ sigma_max: float,
75
+ confidence: Tensor | None,
76
+ ) -> Tensor | float:
77
+ """Compute per-token remasking probability.
78
+
79
+ Args:
80
+ strategy: One of ``"rescale"``, ``"cap"``, ``"conf"``.
81
+ eta: Base remasking strength hyperparameter.
82
+ sigma_max: ``1 - alpha_t(ratio)`` at current step.
83
+ confidence: Per-token confidence scores. Shape ``[B, L]``.
84
+ Required only for the ``"conf"`` strategy.
85
+
86
+ Returns:
87
+ Scalar or ``[B, L]`` tensor of remasking probabilities.
88
+ """
89
+ if strategy == "rescale":
90
+ return eta * sigma_max
91
+ if strategy == "cap":
92
+ return min(eta, sigma_max)
93
+ if strategy == "conf":
94
+ assert confidence is not None, "conf strategy requires confidence"
95
+ return eta * sigma_max * (1.0 - confidence)
96
+ raise ValueError(f"Unknown remask strategy: {strategy}")
97
+
98
+
99
+ @torch.no_grad()
100
+ def remdm_sample(
101
+ model: torch.nn.Module,
102
+ local_obs: Tensor,
103
+ global_obs: Tensor,
104
+ cfg: SimpleNamespace,
105
+ device: torch.device | str,
106
+ physics_aware: bool = True,
107
+ blind_global: bool = False,
108
+ return_analytics: bool = False,
109
+ num_steps: int | None = None,
110
+ ) -> Tensor | tuple[Tensor, list, list[float], list[int]]:
111
+ """Generate action sequences via iterative ReMDM denoising.
112
+
113
+ Args:
114
+ model: Denoising model with forward signature
115
+ ``(local_obs, global_obs, action_seq, t_discrete) -> dict``.
116
+ local_obs: Local crop observations. Shape ``[B, 9, 9]``.
117
+ global_obs: Global map observations. Shape ``[B, 21, 79]``.
118
+ cfg: Config namespace with ``seq_len``, ``mask_token``,
119
+ ``action_dim``, ``diffusion_steps_eval``, ``temperature``,
120
+ ``top_k``, ``eta``, ``remask_strategy``, ``noise_schedule``.
121
+ device: Torch device.
122
+ physics_aware: If ``True``, soft-penalise hazardous cardinal actions
123
+ by overriding their confidence to ``0.001`` before commitment
124
+ ranking. Only checks the first ``_N_PHYSICS_CHECK`` positions.
125
+ blind_global: If ``True``, zero out the global map observation
126
+ (local-only ablation).
127
+ return_analytics: If ``True``, also return per-step analytics as
128
+ ``(seq, path_per_step, tracking_confidence, tracking_masked)``.
129
+ num_steps: Override number of denoising steps (default uses
130
+ ``cfg.diffusion_steps_eval``).
131
+
132
+ Returns:
133
+ When ``return_analytics=False`` (default): fully committed action
134
+ sequence of shape ``[B, seq_len]``, int64, with no MASK tokens.
135
+
136
+ When ``return_analytics=True``: tuple
137
+ ``(seq, path_per_step, tracking_confidence, tracking_masked_count)``
138
+ where ``path_per_step`` is a list of ``[seq_len]`` numpy arrays,
139
+ ``tracking_confidence`` a list of per-step avg unmasked confidence
140
+ floats, and ``tracking_masked_count`` a list of masked-token counts.
141
+ """
142
+ B = local_obs.shape[0]
143
+ seq_len = cfg.seq_len
144
+ mask_token = cfg.mask_token
145
+ action_dim = cfg.action_dim
146
+ K = num_steps if num_steps is not None else cfg.diffusion_steps_eval
147
+ schedule_fn = get_schedule(cfg.noise_schedule)
148
+ min_keep = max(1, int(seq_len * 0.10)) # Safety Net: always unmask ≥10%
149
+
150
+ local_obs = local_obs.to(device)
151
+ global_obs = global_obs.to(device)
152
+
153
+ if blind_global:
154
+ global_obs = torch.zeros_like(global_obs)
155
+
156
+ # Pre-compute numpy local crops for physics checks (CPU, batch loop)
157
+ local_np: np.ndarray | None = None # [B, crop, crop]
158
+ if physics_aware:
159
+ local_np = local_obs.cpu().numpy()
160
+
161
+ # Analytics buffers (only populated when return_analytics=True)
162
+ path_per_step: list[np.ndarray] = []
163
+ tracking_confidence: list[float] = []
164
+ tracking_masked_count: list[int] = []
165
+
166
+ # Start fully masked
167
+ seq = torch.full(
168
+ (B, seq_len), mask_token, dtype=torch.long, device=device
169
+ )
170
+
171
+ for k in range(1, K + 1):
172
+ ratio = k / K
173
+ # Pass as tensor (not Python int) to avoid torch.compile recompilation
174
+ t_discrete = torch.full(
175
+ (B,), int(cfg.num_diffusion_steps * (1.0 - ratio)),
176
+ dtype=torch.long, device=device,
177
+ )
178
+
179
+ # Forward pass
180
+ out = model(local_obs, global_obs, seq, t_discrete)
181
+ logits = out["actions"] # [B, seq_len, vocab]
182
+
183
+ # Mask invalid action tokens (indices >= action_dim)
184
+ logits[:, :, action_dim:] = float("-inf")
185
+
186
+ # Temperature scaling
187
+ logits = logits / cfg.temperature
188
+
189
+ # Top-K filtering
190
+ logits = top_k_filter(logits, cfg.top_k)
191
+
192
+ # Sample predictions
193
+ probs = F.softmax(logits, dim=-1) # [B, seq_len, action_dim]
194
+ preds = Categorical(probs=probs).sample() # [B, seq_len]
195
+
196
+ # Confidence: probability of the sampled token
197
+ conf = probs.gather(
198
+ -1, preds.unsqueeze(-1)
199
+ ).squeeze(-1) # [B, seq_len]
200
+
201
+ # Physics softener: demote hazardous cardinal actions to conf=0.001
202
+ if physics_aware and local_np is not None:
203
+ preds_np = preds.cpu().numpy() # [B, seq_len]
204
+ conf_override = conf.clone()
205
+ for b in range(B):
206
+ crop_b = np.asarray(local_np[b]) # [crop, crop]
207
+ for pos in range(min(_N_PHYSICS_CHECK, seq_len)):
208
+ action = int(preds_np[b, pos])
209
+ if _check_hazard(crop_b, action):
210
+ conf_override[b, pos] = 0.001
211
+ conf = conf_override
212
+
213
+ is_masked = seq == mask_token # [B, seq_len]
214
+
215
+ if k < K:
216
+ # MaskGIT progressive unmasking with min-keep guarantee
217
+ n_unmask = max(min_keep, max(1, int(seq_len * ratio)))
218
+
219
+ # Set confidence of non-masked positions to -1 so they
220
+ # are not selected for unmasking
221
+ unmask_scores = conf.clone()
222
+ unmask_scores[~is_masked] = -1.0
223
+
224
+ # For each batch element, unmask top-confidence masked positions
225
+ _, topk_indices = unmask_scores.topk(
226
+ n_unmask, dim=-1
227
+ ) # [B, n_unmask]
228
+
229
+ # Build scatter mask for positions to unmask
230
+ unmask_mask = torch.zeros_like(seq, dtype=torch.bool)
231
+ unmask_mask.scatter_(1, topk_indices, True)
232
+ unmask_mask = unmask_mask & is_masked # only unmask masked pos
233
+
234
+ seq = torch.where(unmask_mask, preds, seq)
235
+
236
+ # ReMDM stochastic remasking of committed (non-masked) positions
237
+ is_committed = seq != mask_token # [B, seq_len]
238
+ alpha_t_ratio = schedule_fn(
239
+ torch.tensor(ratio, device=device)
240
+ )
241
+ sigma_max = (1.0 - alpha_t_ratio).item()
242
+
243
+ remask_prob = _compute_remask_prob(
244
+ cfg.remask_strategy, cfg.eta, sigma_max, conf
245
+ )
246
+ if isinstance(remask_prob, Tensor):
247
+ do_remask = (
248
+ torch.rand_like(conf) < remask_prob
249
+ ) & is_committed
250
+ else:
251
+ do_remask = (
252
+ torch.rand(B, seq_len, device=device) < remask_prob
253
+ ) & is_committed
254
+ seq = torch.where(do_remask, mask_token, seq)
255
+ else:
256
+ # Final step: commit all remaining MASK tokens
257
+ seq = torch.where(is_masked, preds, seq)
258
+
259
+ # Analytics tracking
260
+ if return_analytics:
261
+ path_per_step.append(seq[0].cpu().numpy().copy())
262
+ still_masked = (seq[0] == mask_token)
263
+ unmasked_conf = conf[0][~still_masked]
264
+ avg_conf = (
265
+ unmasked_conf.mean().item()
266
+ if unmasked_conf.numel() > 0 else 0.0
267
+ )
268
+ tracking_confidence.append(avg_conf)
269
+ tracking_masked_count.append(int(still_masked.sum().item()))
270
+
271
+ assert (seq != mask_token).all(), (
272
+ "remdm_sample produced MASK tokens in final output"
273
+ )
274
+ if return_analytics:
275
+ return seq, path_per_step, tracking_confidence, tracking_masked_count
276
+ return seq
277
+
278
+
279
+ @torch.no_grad()
280
+ def greedy_sample(
281
+ model: torch.nn.Module,
282
+ local_obs: Tensor,
283
+ global_obs: Tensor,
284
+ cfg: SimpleNamespace,
285
+ device: torch.device | str,
286
+ blind_global: bool = False,
287
+ num_steps: int | None = None,
288
+ ) -> Tensor:
289
+ """Greedy (argmax) MaskGIT sampling — no temperature, top-K, or remasking.
290
+
291
+ Used by ``DataCollector`` during DAgger for deterministic rollouts,
292
+ matching the reference ``run_model_episode`` behaviour.
293
+
294
+ Args:
295
+ model: Denoising model.
296
+ local_obs: Shape ``[B, 9, 9]``.
297
+ global_obs: Shape ``[B, 21, 79]``.
298
+ cfg: Config namespace.
299
+ device: Torch device.
300
+ blind_global: Zero out global map (local-only ablation).
301
+
302
+ Returns:
303
+ Fully committed action sequence ``[B, seq_len]``, int64.
304
+ """
305
+ B = local_obs.shape[0]
306
+ seq_len = cfg.seq_len
307
+ mask_token = cfg.mask_token
308
+ action_dim = cfg.action_dim
309
+ K = num_steps if num_steps is not None else cfg.diffusion_steps_eval
310
+
311
+ local_obs = local_obs.to(device)
312
+ global_obs = global_obs.to(device)
313
+ if blind_global:
314
+ global_obs = torch.zeros_like(global_obs)
315
+
316
+ seq = torch.full(
317
+ (B, seq_len), mask_token, dtype=torch.long, device=device,
318
+ )
319
+
320
+ for k in range(1, K + 1):
321
+ ratio = k / K
322
+ t_discrete = torch.full(
323
+ (B,), int(cfg.num_diffusion_steps * (1.0 - ratio)),
324
+ dtype=torch.long, device=device,
325
+ )
326
+
327
+ out = model(local_obs, global_obs, seq, t_discrete)
328
+ logits = out["actions"] # [B, seq_len, vocab]
329
+
330
+ # Mask invalid action tokens
331
+ logits[:, :, action_dim:] = float("-inf")
332
+
333
+ # Greedy: argmax over softmax (no temperature, no top-K)
334
+ probs = F.softmax(logits, dim=-1) # [B, seq_len, action_dim]
335
+ confidences, preds = probs.max(dim=-1) # [B, seq_len] each
336
+
337
+ # MaskGIT progressive unmasking by confidence
338
+ num_to_unmask = max(1, int(seq_len * ratio))
339
+ is_masked = seq == mask_token # [B, seq_len]
340
+
341
+ # Score only masked positions for unmasking
342
+ scores = confidences.clone()
343
+ scores[~is_masked] = -1.0
344
+ _, topk_idx = scores.topk(num_to_unmask, dim=-1)
345
+
346
+ unmask_mask = torch.zeros_like(seq, dtype=torch.bool)
347
+ unmask_mask.scatter_(1, topk_idx, True)
348
+ unmask_mask = unmask_mask & is_masked
349
+
350
+ seq = torch.where(unmask_mask, preds, seq)
351
+
352
+ # No remasking in greedy mode
353
+
354
+ # Force-commit any remaining masked tokens
355
+ still_masked = seq == mask_token
356
+ if still_masked.any():
357
+ t_zero = torch.zeros(B, dtype=torch.long, device=device)
358
+ out = model(local_obs, global_obs, seq, t_zero)
359
+ logits = out["actions"]
360
+ logits[:, :, action_dim:] = float("-inf")
361
+ preds = logits.argmax(dim=-1)
362
+ seq = torch.where(still_masked, preds, seq)
363
+
364
+ return seq
365
+
366
+
367
+ def select_action(
368
+ model: torch.nn.Module,
369
+ local_obs: Tensor,
370
+ global_obs: Tensor,
371
+ cfg: SimpleNamespace,
372
+ device: torch.device | str,
373
+ physics_aware: bool = True,
374
+ blind_global: bool = False,
375
+ ) -> int:
376
+ """Sample a single action from a length-1 batch.
377
+
378
+ Args:
379
+ model: Denoising model.
380
+ local_obs: Shape ``[9, 9]`` or ``[1, 9, 9]``.
381
+ global_obs: Shape ``[21, 79]`` or ``[1, 21, 79]``.
382
+ cfg: Config namespace.
383
+ device: Torch device.
384
+ physics_aware: Forward to ``remdm_sample``.
385
+ blind_global: Forward to ``remdm_sample``.
386
+
387
+ Returns:
388
+ The first action of the generated plan (int).
389
+ """
390
+ if local_obs.ndim == 2:
391
+ local_obs = local_obs.unsqueeze(0)
392
+ if global_obs.ndim == 2:
393
+ global_obs = global_obs.unsqueeze(0)
394
+ seq = remdm_sample(
395
+ model, local_obs, global_obs, cfg, device,
396
+ physics_aware=physics_aware, blind_global=blind_global,
397
+ )
398
+ return seq[0, 0].item()
src/diffusion/schedules.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Noise schedule functions for MDLM diffusion.
2
+
3
+ Ported from the Craftax JAX implementation (src/diffusion/schedules.py).
4
+ All functions operate on PyTorch tensors and are pure (no global state).
5
+
6
+ Convention: alpha(t) is the fraction of tokens that remain *unmasked*.
7
+ - alpha(0) = 1.0 (fully clean)
8
+ - alpha(1) = 0.0 (fully masked)
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import math
14
+ from typing import Callable
15
+
16
+ import torch
17
+ from torch import Tensor
18
+
19
+
20
+ def linear_schedule(t: Tensor) -> Tensor:
21
+ """Linear noise schedule: alpha(t) = 1 - t.
22
+
23
+ Args:
24
+ t: Diffusion time in [0, 1]. Any shape.
25
+
26
+ Returns:
27
+ Retention probability alpha_t, same shape as *t*.
28
+ """
29
+ return 1.0 - t
30
+
31
+
32
+ def cosine_schedule(t: Tensor) -> Tensor:
33
+ """Cosine noise schedule: alpha(t) = cos(pi/2 * t)^2.
34
+
35
+ Args:
36
+ t: Diffusion time in [0, 1]. Any shape.
37
+
38
+ Returns:
39
+ Retention probability alpha_t, same shape as *t*.
40
+ """
41
+ return torch.cos(t * (math.pi / 2.0)) ** 2
42
+
43
+
44
+ _SCHEDULE_MAP: dict[str, Callable[[Tensor], Tensor]] = {
45
+ "linear": linear_schedule,
46
+ "cosine": cosine_schedule,
47
+ }
48
+
49
+
50
+ def get_schedule(name: str) -> Callable[[Tensor], Tensor]:
51
+ """Look up a noise schedule by name.
52
+
53
+ Args:
54
+ name: One of ``"linear"`` or ``"cosine"``.
55
+
56
+ Returns:
57
+ The schedule function ``alpha(t)``.
58
+
59
+ Raises:
60
+ KeyError: If *name* is not registered.
61
+ """
62
+ if name not in _SCHEDULE_MAP:
63
+ raise KeyError(
64
+ f"Unknown schedule '{name}'. "
65
+ f"Available: {list(_SCHEDULE_MAP.keys())}"
66
+ )
67
+ return _SCHEDULE_MAP[name]
68
+
69
+
70
+ def alpha_prime(
71
+ t: Tensor,
72
+ schedule_fn: Callable[[Tensor], Tensor],
73
+ eps: float = 1e-5,
74
+ ) -> Tensor:
75
+ """Numerical derivative d(alpha)/dt via central difference.
76
+
77
+ Args:
78
+ t: Diffusion time in [0, 1]. Any shape.
79
+ schedule_fn: Noise schedule returning alpha(t).
80
+ eps: Half-width for finite-difference stencil.
81
+
82
+ Returns:
83
+ Approximate derivative, same shape as *t*.
84
+ """
85
+ t_clamped = t.clamp(eps, 1.0 - eps)
86
+ return (schedule_fn(t_clamped + eps) - schedule_fn(t_clamped - eps)) / (
87
+ 2.0 * eps
88
+ )
src/envs/__init__.py ADDED
File without changes
src/envs/discovery.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MiniHack environment discovery and diagnostic utilities.
2
+
3
+ Provides tools for scanning the gymnasium registry, validating action-space
4
+ consistency across environments, and benchmarking inference throughput.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import logging
10
+ import time
11
+ from types import SimpleNamespace
12
+
13
+ import torch
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ _NAV_KEYWORDS = ("Room", "Corridor", "Maze", "River")
18
+ _EXCLUDED_KEYWORDS = ("KeyRoom",)
19
+ _REFERENCE_ENV_ID = "MiniHack-Room-15x15-v0"
20
+
21
+
22
+ def list_working_minihack_tasks() -> list[str]:
23
+ """Scan the gymnasium registry for working MiniHack navigation tasks.
24
+
25
+ Filters to environments whose names contain at least one navigation
26
+ keyword and attempts to instantiate each. Returns the IDs of all
27
+ successfully created environments.
28
+
29
+ Returns:
30
+ Sorted list of working MiniHack navigation environment IDs.
31
+ """
32
+ import gymnasium as gym
33
+ import minihack # noqa: F401 — registers envs
34
+
35
+ all_ids = list(gym.envs.registry.keys())
36
+ candidates = [
37
+ e for e in all_ids
38
+ if "MiniHack" in e
39
+ and any(k in e for k in _NAV_KEYWORDS)
40
+ and not any(x in e for x in _EXCLUDED_KEYWORDS)
41
+ ]
42
+
43
+ working: list[str] = []
44
+ broken: list[str] = []
45
+ for env_id in sorted(candidates):
46
+ try:
47
+ env = gym.make(env_id)
48
+ working.append(env_id)
49
+ env.close()
50
+ except Exception:
51
+ broken.append(env_id)
52
+
53
+ logger.info(
54
+ f"MiniHack navigation tasks — working: {len(working)}, "
55
+ f"broken: {len(broken)}"
56
+ )
57
+ return working
58
+
59
+
60
+ def check_action_consistency_with_fixed_ref(
61
+ env_list: list[str],
62
+ ) -> list[tuple[str, str, int]]:
63
+ """Validate action-space ordering against a fixed reference environment.
64
+
65
+ Compares each environment's action list against
66
+ ``MiniHack-Room-15x15-v0`` and classifies the relationship as one of:
67
+ ``REFERENCE``, ``EXACT``, ``SUPERSET (+N)``, ``SUBSET (-N)``,
68
+ ``CONFLICT``, or ``CRASHED``.
69
+
70
+ Args:
71
+ env_list: MiniHack environment IDs to check.
72
+
73
+ Returns:
74
+ List of ``(env_id, status, action_space_size)`` tuples.
75
+ """
76
+ import gymnasium as gym
77
+ import minihack # noqa: F401
78
+
79
+ ref_env = gym.make(_REFERENCE_ENV_ID)
80
+ reference_actions = ref_env.unwrapped.actions # type: ignore[attr-defined]
81
+ ref_env.close()
82
+
83
+ results: list[tuple[str, str, int]] = []
84
+ for env_id in sorted(env_list):
85
+ if env_id == _REFERENCE_ENV_ID:
86
+ results.append((env_id, "REFERENCE", len(reference_actions)))
87
+ continue
88
+ try:
89
+ env = gym.make(env_id)
90
+ try:
91
+ env_actions = env.unwrapped.actions # type: ignore[attr-defined]
92
+ limit = min(len(reference_actions), len(env_actions))
93
+ is_match = all(
94
+ reference_actions[i] == env_actions[i]
95
+ for i in range(limit)
96
+ )
97
+ diff = len(env_actions) - len(reference_actions)
98
+ if is_match and diff == 0:
99
+ status = "EXACT"
100
+ elif diff > 0:
101
+ status = f"SUPERSET (+{diff})"
102
+ elif is_match:
103
+ status = f"SUBSET ({diff})"
104
+ else:
105
+ status = "CONFLICT"
106
+ results.append((env_id, status, len(env_actions)))
107
+ finally:
108
+ env.close()
109
+ except Exception:
110
+ results.append((env_id, "CRASHED", 0))
111
+
112
+ for name, status, size in results:
113
+ logger.info(f" {name:<40} | {status:<14} | n_actions={size}")
114
+ return results
115
+
116
+
117
+ def benchmark_inference(
118
+ model: torch.nn.Module,
119
+ cfg: SimpleNamespace,
120
+ device: torch.device | str,
121
+ n_actions: int = 100,
122
+ ) -> tuple[float, float]:
123
+ """Measure ReMDM inference throughput.
124
+
125
+ Runs ``n_actions`` planning calls with dummy observations and
126
+ measures wall-clock time.
127
+
128
+ Args:
129
+ model: Denoising model in eval mode.
130
+ cfg: Config namespace (used for ``seq_len``, ``mask_token``, etc.).
131
+ device: Torch device.
132
+ n_actions: Number of planning calls to benchmark.
133
+
134
+ Returns:
135
+ ``(diffusion_steps_per_sec, actions_per_sec)`` as floats.
136
+ """
137
+ from src.diffusion.sampling import remdm_sample
138
+
139
+ model.eval()
140
+ local_dummy = torch.zeros(
141
+ (1, cfg.crop_size, cfg.crop_size), dtype=torch.long, device=device,
142
+ )
143
+ global_dummy = torch.zeros(
144
+ (1, cfg.map_h, cfg.map_w), dtype=torch.long, device=device,
145
+ )
146
+
147
+ if torch.cuda.is_available():
148
+ torch.cuda.synchronize()
149
+
150
+ t0 = time.perf_counter()
151
+ for _ in range(n_actions):
152
+ remdm_sample(model, local_dummy, global_dummy, cfg, device)
153
+ if torch.cuda.is_available():
154
+ torch.cuda.synchronize()
155
+ elapsed = time.perf_counter() - t0
156
+
157
+ total_steps = n_actions * cfg.diffusion_steps_eval
158
+ steps_per_sec = total_steps / elapsed if elapsed > 0 else 0.0
159
+ actions_per_sec = n_actions / elapsed if elapsed > 0 else 0.0
160
+
161
+ logger.info(
162
+ f"Benchmark ({n_actions} actions): "
163
+ f"{steps_per_sec:.1f} diffusion-steps/s | "
164
+ f"{actions_per_sec:.1f} actions/s"
165
+ )
166
+ return steps_per_sec, actions_per_sec
src/envs/minihack_env.py ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MiniHack environment wrapper with BFS oracle and shaped rewards.
2
+
3
+ Ported from minihack_reference/src/env.py. Provides dual-stream
4
+ observations (9x9 local crop + 21x79 global map), a multi-tier BFS
5
+ oracle, and reward shaping (win bonus, BFS progress, exploration, step
6
+ penalty).
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import collections
12
+ import logging
13
+ from types import SimpleNamespace
14
+
15
+ import gymnasium as gym
16
+ import minihack # noqa: F401 — registers MiniHack envs
17
+ import numpy as np
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+ # Suppress noisy NLE INFO spam ("Not saving any NLE data." on every env create)
22
+ logging.getLogger("nle.env.base").setLevel(logging.WARNING)
23
+
24
+
25
+ # ── Staircase detection ──────────────────────────────────────────────
26
+
27
+
28
+ def find_staircase_from_glyphs(global_obs: np.ndarray) -> np.ndarray:
29
+ """Locate the staircase '>' in the global glyph map.
30
+
31
+ Args:
32
+ global_obs: Glyph map, shape ``[B, H, W]`` or ``[H, W]``.
33
+
34
+ Returns:
35
+ Normalised ``(row/H, col/W)`` coords, shape ``[B, 2]``
36
+ (float32). ``(-1, -1)`` when not visible.
37
+ """
38
+ squeeze = global_obs.ndim == 2
39
+ if squeeze:
40
+ global_obs = global_obs[np.newaxis]
41
+ B, H, W = global_obs.shape
42
+ coords = np.full((B, 2), -1.0, dtype=np.float32)
43
+ for b in range(B):
44
+ is_stair = (
45
+ (global_obs[b] == 62)
46
+ | (global_obs[b] == 2310)
47
+ | (global_obs[b] == 2368)
48
+ | (global_obs[b] == 2383)
49
+ )
50
+ positions = np.argwhere(is_stair)
51
+ if positions.shape[0] > 0:
52
+ coords[b, 0] = positions[0, 0] / max(1, H - 1)
53
+ coords[b, 1] = positions[0, 1] / max(1, W - 1)
54
+ return coords
55
+
56
+
57
+ # ── Environment wrapper ──────────────────────────────────────────────
58
+
59
+
60
+ class AdvancedObservationEnv(gym.Env):
61
+ """MiniHack wrapper with dual-stream obs, BFS oracle, shaped rewards.
62
+
63
+ Observations are ``(local_crop, global_map)`` where
64
+ ``local_crop`` is a ``[crop_size, crop_size]`` glyph window centred
65
+ on the agent and ``global_map`` is the full ``[21, 79]`` glyph grid.
66
+
67
+ Args:
68
+ env_id: MiniHack registry ID.
69
+ des_file: Optional ``.des`` file content (for custom levels).
70
+ cfg: Configuration namespace with ``crop_size``, ``action_dim``,
71
+ ``pad_token``, ``map_h``, ``map_w``.
72
+ """
73
+
74
+ _UNWALKABLE = frozenset({32, 45, 124, 125}) # space, -, |, }
75
+ _CLOSED_DOOR = 43 # '+'
76
+ _DIR_MAP = {(-1, 0): 0, (0, 1): 1, (1, 0): 2, (0, -1): 3}
77
+ _CARDINAL = [(-1, 0), (0, 1), (1, 0), (0, -1)]
78
+
79
+ def __init__(
80
+ self,
81
+ env_id: str,
82
+ des_file: str | None,
83
+ cfg: SimpleNamespace,
84
+ ) -> None:
85
+ super().__init__()
86
+ self.env_id = env_id
87
+ self._cfg = cfg
88
+ self._crop_half = cfg.crop_size // 2
89
+
90
+ obs_keys = ("glyphs", "chars", "pixel")
91
+ if des_file is not None:
92
+ self._inner = gym.make(
93
+ "MiniHack-Navigation-Custom-v0",
94
+ des_file=des_file,
95
+ observation_keys=obs_keys,
96
+ )
97
+ else:
98
+ self._inner = gym.make(
99
+ env_id, observation_keys=obs_keys,
100
+ )
101
+
102
+ self.observation_space = gym.spaces.Box(
103
+ low=0, high=6000,
104
+ shape=(cfg.crop_size, cfg.crop_size),
105
+ dtype=np.int16,
106
+ )
107
+ self.action_space: gym.spaces.Discrete = gym.spaces.Discrete(cfg.action_dim)
108
+
109
+ self._visited: set[tuple[int, int]] = set()
110
+ self._prev_bfs_dist: int | None = None
111
+ self.last_raw_obs: dict | None = None
112
+
113
+ # ── gym.Env interface ────────────────────────────────────────────
114
+
115
+ def reset(
116
+ self, seed: int | None = None, options: dict | None = None,
117
+ ) -> tuple[tuple[np.ndarray, np.ndarray], dict]:
118
+ """Reset environment and tracking state.
119
+
120
+ Args:
121
+ seed: Optional RNG seed.
122
+ options: Passed through to the inner env.
123
+
124
+ Returns:
125
+ ``((local_crop, global_map), info)``
126
+ """
127
+ obs, info = self._inner.reset(seed=seed, options=options)
128
+ self.last_raw_obs = obs
129
+ self._prev_bfs_dist = self._get_bfs_distance(obs)
130
+ self._visited = set()
131
+ agent_pos = self._get_agent_pos(obs)
132
+ if agent_pos is not None:
133
+ self._visited.add(agent_pos)
134
+ return self._get_obs(obs), info
135
+
136
+ def step(
137
+ self, action: int,
138
+ ) -> tuple[tuple[np.ndarray, np.ndarray], float, bool, bool, dict]:
139
+ """Execute one environment step with shaped reward.
140
+
141
+ Reward shaping:
142
+ - Win bonus: ``+20.0``
143
+ - BFS progress toward staircase: ``+0.5 * (prev - curr)``
144
+ - New-tile exploration: ``+0.05``
145
+ - Step penalty: ``-0.01``
146
+
147
+ Args:
148
+ action: Integer action in ``[0, action_dim)``.
149
+
150
+ Returns:
151
+ ``(obs, shaped_reward, terminated, truncated, info)``
152
+ """
153
+ inner_n = self._inner.action_space.n
154
+ if action >= inner_n:
155
+ action = action % inner_n
156
+
157
+ obs, raw_reward, terminated, truncated, info = self._inner.step(action)
158
+ self.last_raw_obs = obs
159
+ reward = float(raw_reward)
160
+
161
+ # Win bonus
162
+ if terminated and reward > 0:
163
+ info["won"] = True
164
+ reward += 20.0
165
+ else:
166
+ info["won"] = False
167
+
168
+ # BFS shaping
169
+ curr_dist = self._get_bfs_distance(obs)
170
+ if curr_dist is not None and self._prev_bfs_dist is not None:
171
+ reward += (self._prev_bfs_dist - curr_dist) * 0.5
172
+ self._prev_bfs_dist = curr_dist
173
+
174
+ # Exploration bonus
175
+ agent_pos = self._get_agent_pos(obs)
176
+ if agent_pos is not None and agent_pos not in self._visited:
177
+ reward += 0.05
178
+ self._visited.add(agent_pos)
179
+
180
+ # Step penalty
181
+ reward -= 0.01
182
+
183
+ return self._get_obs(obs), reward, terminated, truncated, info
184
+
185
+ @property
186
+ def unwrapped(self):
187
+ """Access the inner MiniHack env."""
188
+ return self._inner.unwrapped
189
+
190
+ def close(self) -> None:
191
+ """Close the inner environment."""
192
+ self._inner.close()
193
+
194
+ # ── Observation helpers ──────────────────────────────────────────
195
+
196
+ def _get_obs(
197
+ self, obs: dict,
198
+ ) -> tuple[np.ndarray, np.ndarray]:
199
+ """Extract dual-stream observation.
200
+
201
+ Args:
202
+ obs: Raw NLE observation dict.
203
+
204
+ Returns:
205
+ ``(local_crop [crop,crop], global_map [H,W])`` as int16.
206
+ """
207
+ return self._get_crop(obs), obs["glyphs"].copy().astype(np.int16)
208
+
209
+ def _get_crop(self, obs: dict) -> np.ndarray:
210
+ """Crop local glyph window centred on agent.
211
+
212
+ Args:
213
+ obs: Raw NLE observation dict.
214
+
215
+ Returns:
216
+ ``[crop_size, crop_size]`` int16 array.
217
+ """
218
+ glyphs = obs["glyphs"]
219
+ chars = obs["chars"]
220
+ agent_pos = np.argwhere(chars == ord("@"))
221
+ cs = self._cfg.crop_size
222
+ if len(agent_pos) == 0:
223
+ return np.full((cs, cs), self._cfg.pad_token, dtype=np.int16)
224
+ y, x = agent_pos[0]
225
+ h = self._crop_half
226
+ padded = np.pad(
227
+ glyphs, h, mode="constant",
228
+ constant_values=self._cfg.pad_token,
229
+ )
230
+ return padded[y:y + cs, x:x + cs].astype(np.int16)
231
+
232
+ def _get_agent_pos(self, obs: dict) -> tuple[int, int] | None:
233
+ """Find agent '@' position in the chars grid.
234
+
235
+ Args:
236
+ obs: Raw NLE observation dict.
237
+
238
+ Returns:
239
+ ``(row, col)`` or ``None``.
240
+ """
241
+ chars = obs["chars"]
242
+ pos = np.argwhere(chars == ord("@"))
243
+ return tuple(pos[0]) if len(pos) > 0 else None
244
+
245
+ def _get_bfs_distance(self, obs: dict) -> int | None:
246
+ """BFS shortest-path distance from agent to staircase.
247
+
248
+ Args:
249
+ obs: Raw NLE observation dict.
250
+
251
+ Returns:
252
+ Integer distance or ``None`` if unreachable / not visible.
253
+ """
254
+ chars = obs["chars"]
255
+ start = np.argwhere(chars == ord("@"))
256
+ target = np.argwhere(chars == ord(">"))
257
+ if len(start) == 0 or len(target) == 0:
258
+ return None
259
+ start = tuple(start[0])
260
+ target = tuple(target[0])
261
+ if start == target:
262
+ return 0
263
+ queue: collections.deque = collections.deque([(start, 0)])
264
+ visited = {start}
265
+ while queue:
266
+ (r, c), dist = queue.popleft()
267
+ if (r, c) == target:
268
+ return dist
269
+ for dr, dc in self._CARDINAL:
270
+ nr, nc = r + dr, c + dc
271
+ if (
272
+ 0 <= nr < self._cfg.map_h
273
+ and 0 <= nc < self._cfg.map_w
274
+ and (nr, nc) not in visited
275
+ and chars[nr, nc] not in self._UNWALKABLE
276
+ ):
277
+ visited.add((nr, nc))
278
+ queue.append(((nr, nc), dist + 1))
279
+ return None
280
+
281
+ # ── BFS Oracle ───────────────────────────────────────────────────
282
+
283
+ def get_oracle_action(self, obs: dict) -> int:
284
+ """5-tier BFS oracle action.
285
+
286
+ Priority:
287
+ 1. Kick adjacent closed door.
288
+ 2. BFS to staircase '>'.
289
+ 3. BFS to frontier (adjacent to unexplored space).
290
+ 4. BFS to farthest reachable tile.
291
+ 5. Random cardinal direction.
292
+
293
+ Args:
294
+ obs: Raw NLE observation dict (needs ``'chars'`` key).
295
+
296
+ Returns:
297
+ Action index in ``[0, action_dim)``.
298
+ """
299
+ if obs is None:
300
+ return 0
301
+ chars = obs["chars"]
302
+ start = np.argwhere(chars == ord("@"))
303
+ if len(start) == 0:
304
+ return np.random.randint(0, 4)
305
+ start = tuple(start[0])
306
+ target_list = np.argwhere(chars == ord(">"))
307
+
308
+ # 1. Adjacent closed door → kick
309
+ for dr, dc in self._CARDINAL:
310
+ nr, nc = start[0] + dr, start[1] + dc
311
+ if (
312
+ 0 <= nr < self._cfg.map_h
313
+ and 0 <= nc < self._cfg.map_w
314
+ and chars[nr, nc] == self._CLOSED_DOOR
315
+ ):
316
+ return 11 # KICK
317
+
318
+ # BFS to gather reachable tiles + check staircase
319
+ queue: collections.deque = collections.deque([(start, [])])
320
+ visited = {start}
321
+ reachable: list[tuple[tuple[int, int], list[tuple[int, int]]]] = []
322
+ target_path: list[tuple[int, int]] | None = None
323
+
324
+ while queue:
325
+ (r, c), path = queue.popleft()
326
+ reachable.append(((r, c), path))
327
+ for t_r, t_c in target_list:
328
+ if r == t_r and c == t_c:
329
+ target_path = path
330
+ break
331
+ if target_path is not None:
332
+ break
333
+ for dr, dc in self._CARDINAL:
334
+ nr, nc = r + dr, c + dc
335
+ if (
336
+ 0 <= nr < self._cfg.map_h
337
+ and 0 <= nc < self._cfg.map_w
338
+ and (nr, nc) not in visited
339
+ ):
340
+ ch = chars[nr, nc]
341
+ if ch not in self._UNWALKABLE and ch != self._CLOSED_DOOR:
342
+ visited.add((nr, nc))
343
+ queue.append(((nr, nc), path + [(dr, dc)]))
344
+
345
+ # 2. Path to staircase
346
+ if target_path:
347
+ return self._DIR_MAP.get(target_path[0], 0)
348
+
349
+ # 3. Frontier exploration — tiles adjacent to unexplored space
350
+ frontier: list[list[tuple[int, int]]] = []
351
+ for (r, c), path in reachable:
352
+ if not path:
353
+ continue
354
+ for dr, dc in self._CARDINAL:
355
+ nr, nc = r + dr, c + dc
356
+ if (
357
+ 0 <= nr < self._cfg.map_h
358
+ and 0 <= nc < self._cfg.map_w
359
+ and chars[nr, nc] == 32
360
+ ):
361
+ frontier.append(path)
362
+ break
363
+ if frontier:
364
+ frontier.sort(key=len)
365
+ return self._DIR_MAP.get(frontier[0][0], 0)
366
+
367
+ # 4. Farthest reachable tile
368
+ if reachable:
369
+ reachable.sort(key=lambda x: len(x[1]), reverse=True)
370
+ farthest = reachable[0][1]
371
+ if farthest:
372
+ return self._DIR_MAP.get(farthest[0], 0)
373
+
374
+ # 5. Random cardinal
375
+ return np.random.randint(0, 4)
376
+
377
+
378
+ # ── Factory ──────────────────────────────────────────────────────────
379
+
380
+
381
+ def make_env(
382
+ env_id: str,
383
+ des_file: str | None,
384
+ cfg: SimpleNamespace,
385
+ ) -> AdvancedObservationEnv:
386
+ """Create a wrapped MiniHack environment.
387
+
388
+ Args:
389
+ env_id: MiniHack registry ID.
390
+ des_file: Optional ``.des`` file content.
391
+ cfg: Configuration namespace.
392
+
393
+ Returns:
394
+ Wrapped environment.
395
+ """
396
+ return AdvancedObservationEnv(env_id, des_file, cfg)
397
+
398
+
399
+ def collect_oracle_trajectory(
400
+ env_id: str,
401
+ seed: int,
402
+ cfg: SimpleNamespace,
403
+ max_steps: int = 500,
404
+ ) -> dict | None:
405
+ """Roll out the BFS oracle on a single episode.
406
+
407
+ Args:
408
+ env_id: MiniHack registry ID.
409
+ seed: RNG seed for the episode.
410
+ cfg: Configuration namespace.
411
+ max_steps: Maximum episode length.
412
+
413
+ Returns:
414
+ ``{"local": [T,9,9], "global": [T,21,79],
415
+ "actions": [T], "env_id": str}`` on success,
416
+ or ``None`` on failure.
417
+ """
418
+ env = make_env(env_id, None, cfg)
419
+ try:
420
+ (local, glb), _info = env.reset(seed=seed)
421
+ locals_list = [local]
422
+ globals_list = [glb]
423
+ actions_list: list[int] = []
424
+
425
+ for _ in range(max_steps):
426
+ action = env.get_oracle_action(env.last_raw_obs)
427
+ actions_list.append(action)
428
+ (local, glb), _reward, terminated, truncated, _info = env.step(
429
+ action
430
+ )
431
+ locals_list.append(local)
432
+ globals_list.append(glb)
433
+ if terminated or truncated:
434
+ break
435
+
436
+ # Trim trailing obs (one more obs than actions)
437
+ locals_arr = np.stack(locals_list[:-1], axis=0).astype(np.int16)
438
+ globals_arr = np.stack(globals_list[:-1], axis=0).astype(np.int16)
439
+ actions_arr = np.array(actions_list, dtype=np.int64)
440
+
441
+ return {
442
+ "local": locals_arr,
443
+ "global": globals_arr,
444
+ "actions": actions_arr,
445
+ "env_id": env_id,
446
+ }
447
+ except Exception:
448
+ logger.error(
449
+ f"Oracle trajectory failed for {env_id} seed={seed}",
450
+ exc_info=True,
451
+ )
452
+ return None
453
+ finally:
454
+ env.close()
src/models/__init__.py ADDED
File without changes
src/models/denoiser.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Dual-stream denoising transformer for MiniHack.
2
+
3
+ Ported from minihack_reference/src/model.py. Architecture follows the
4
+ Craftax denoiser conventions (forward return format, obs-encoder pattern)
5
+ while using the MiniHack dual-stream design (local CNN + gated global
6
+ CNN + auxiliary goal head).
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import copy
12
+ import logging
13
+ import shutil
14
+ from types import SimpleNamespace
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ from torch import Tensor
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ class LocalDiffusionPlannerWithGlobal(nn.Module):
24
+ """Dual-stream transformer for masked diffusion action planning.
25
+
26
+ Combines a local 9x9 glyph crop with a gated global 21x79 map
27
+ context. Produces action logits and an auxiliary staircase-coordinate
28
+ prediction.
29
+
30
+ Architecture:
31
+ Local stream: Embedding(6000,64) -> CNN(64->32->64) -> Linear -> 1 token
32
+ Global stream: Embedding(6000,32) -> CNN(32->32->64) -> Pool(2,4)
33
+ -> Linear -> 8 tokens, gated by sigmoid(learnable scalar)
34
+ Goal head: mean(global_tokens) -> MLP -> [B,2] (before gate)
35
+ Action stream: Embedding(14, n_embd) + timestep + position
36
+ Transformer: concat all -> TransformerEncoder -> last 64 tokens -> head
37
+
38
+ Args:
39
+ cfg: Config namespace with ``action_dim``, ``n_embd``, ``n_head``,
40
+ ``n_layer``, ``n_global_tokens``, ``seq_len``,
41
+ ``global_gate_init``, ``num_diffusion_steps``.
42
+ """
43
+
44
+ def __init__(self, cfg: SimpleNamespace) -> None:
45
+ super().__init__()
46
+ action_dim = cfg.action_dim
47
+ n_embd = cfg.n_embd
48
+ n_head = cfg.n_head
49
+ n_layer = cfg.n_layer
50
+ n_global_tokens = cfg.n_global_tokens
51
+ seq_len = cfg.seq_len
52
+
53
+ assert n_embd % n_head == 0, (
54
+ f"n_embd ({n_embd}) must be divisible by n_head ({n_head})"
55
+ )
56
+
57
+ self.n_global_tokens = n_global_tokens
58
+
59
+ # ── Local stream: 9x9 crop -> 1 token ──────────────────────
60
+ self.embedding = nn.Embedding(6000, 64)
61
+ self.cnn = nn.Sequential(
62
+ nn.Conv2d(64, 32, 3, padding=1),
63
+ nn.GELU(),
64
+ nn.Conv2d(32, 64, 3, padding=1),
65
+ nn.GELU(),
66
+ nn.Flatten(),
67
+ nn.Linear(64 * 9 * 9, n_embd),
68
+ )
69
+
70
+ # ── Action stream ──────────────────────────────────────────
71
+ self.action_emb = nn.Embedding(action_dim + 2, n_embd)
72
+ self.timestep_emb = nn.Embedding(
73
+ cfg.num_diffusion_steps, n_embd,
74
+ )
75
+ self.pos_emb = nn.Embedding(seq_len, n_embd)
76
+
77
+ # ── Transformer ───────────────────────────────────────────
78
+ encoder_layer = nn.TransformerEncoderLayer(
79
+ d_model=n_embd,
80
+ nhead=n_head,
81
+ dim_feedforward=n_embd * 4,
82
+ dropout=getattr(cfg, "dropout", 0.0),
83
+ activation="gelu",
84
+ norm_first=True,
85
+ batch_first=True,
86
+ )
87
+ self.transformer = nn.TransformerEncoder(
88
+ encoder_layer, num_layers=n_layer, enable_nested_tensor=False,
89
+ )
90
+ self.head = nn.Linear(n_embd, action_dim)
91
+
92
+ # ── Global stream: 21x79 map -> 8 tokens ──────────────────
93
+ self.global_embedding = nn.Embedding(6000, 32)
94
+ self.global_cnn = nn.Sequential(
95
+ nn.Conv2d(32, 32, 5, stride=2, padding=2),
96
+ nn.GELU(),
97
+ nn.Conv2d(32, 64, 3, stride=2, padding=1),
98
+ nn.GELU(),
99
+ )
100
+ self.global_pool = nn.AdaptiveAvgPool2d((2, 4))
101
+ self.global_proj = nn.Linear(64, n_embd)
102
+ self.global_gate = nn.Parameter(
103
+ torch.tensor(cfg.global_gate_init)
104
+ )
105
+
106
+ # ── Auxiliary goal head (before gate) ──────────────────────
107
+ self.goal_head = nn.Sequential(
108
+ nn.Linear(n_embd, 128),
109
+ nn.GELU(),
110
+ nn.Linear(128, 2),
111
+ )
112
+
113
+ def forward(
114
+ self,
115
+ local_obs: Tensor,
116
+ global_obs: Tensor,
117
+ action_seq: Tensor,
118
+ t_discrete: int | Tensor,
119
+ ) -> dict[str, Tensor]:
120
+ """Forward pass producing action logits and goal prediction.
121
+
122
+ Args:
123
+ local_obs: Local glyph crop. Shape ``[B, 9, 9]``, int.
124
+ global_obs: Full glyph map. Shape ``[B, 21, 79]``, int.
125
+ action_seq: Noisy action sequence. Shape ``[B, seq_len]``, int.
126
+ t_discrete: Discrete timestep index (scalar int or ``[B]``).
127
+
128
+ Returns:
129
+ Dict with keys:
130
+ - ``"actions"``: ``[B, seq_len, action_dim]`` logits.
131
+ - ``"goal_pred"``: ``[B, 2]`` normalised staircase coords.
132
+ """
133
+ B, Seq = action_seq.shape
134
+ device = local_obs.device
135
+
136
+ # Local stream -> [B, 1, n_embd]
137
+ x_local = self.embedding(local_obs) # [B, 9, 9, 64]
138
+ x_local = x_local.permute(0, 3, 1, 2) # [B, 64, 9, 9]
139
+ local_token = self.cnn(x_local).unsqueeze(1) # [B, 1, n_embd]
140
+
141
+ # Global stream -> [B, 8, n_embd]
142
+ x_global = self.global_embedding(global_obs) # [B, 21, 79, 32]
143
+ x_global = x_global.permute(0, 3, 1, 2) # [B, 32, 21, 79]
144
+ gf = self.global_cnn(x_global) # [B, 64, H', W']
145
+ gf = self.global_pool(gf) # [B, 64, 2, 4]
146
+ global_tokens = gf.permute(0, 2, 3, 1) # [B, 2, 4, 64]
147
+ global_tokens = global_tokens.reshape(
148
+ B, self.n_global_tokens, -1
149
+ ) # [B, 8, 64]
150
+ global_tokens = self.global_proj(global_tokens) # [B, 8, n_embd]
151
+
152
+ # Aux goal head (before gate for direct gradient to CNN)
153
+ goal_pred = self.goal_head(
154
+ global_tokens.mean(dim=1)
155
+ ) # [B, 2]
156
+
157
+ # Apply gate
158
+ gate = torch.sigmoid(self.global_gate)
159
+ global_tokens = global_tokens * gate # [B, 8, n_embd]
160
+
161
+ # Action stream -> [B, seq_len, n_embd]
162
+ positions = torch.arange(
163
+ Seq, device=device,
164
+ ).unsqueeze(0).expand(B, -1) # [B, seq_len]
165
+
166
+ if isinstance(t_discrete, int):
167
+ t_tensor = torch.full(
168
+ (B,), t_discrete, dtype=torch.long, device=device,
169
+ )
170
+ else:
171
+ t_tensor = t_discrete.long().to(device)
172
+
173
+ seq_emb = (
174
+ self.action_emb(action_seq)
175
+ + self.timestep_emb(t_tensor).unsqueeze(1)
176
+ + self.pos_emb(positions)
177
+ ) # [B, seq_len, n_embd]
178
+
179
+ # Concatenate: [local(1), global(8), actions(seq_len)]
180
+ x = torch.cat(
181
+ [local_token, global_tokens, seq_emb], dim=1,
182
+ ) # [B, 1+8+seq_len, n_embd]
183
+
184
+ # Transformer
185
+ out = self.transformer(x) # [B, 1+8+seq_len, n_embd]
186
+
187
+ # Take last seq_len tokens for action predictions
188
+ n_prefix = 1 + self.n_global_tokens
189
+ action_logits = self.head(
190
+ out[:, n_prefix:, :]
191
+ ) # [B, seq_len, action_dim]
192
+
193
+ return {"actions": action_logits, "goal_pred": goal_pred}
194
+
195
+
196
+ class LocalDiffusionPlanner(nn.Module):
197
+ """Local-only ablation model (no global stream, no goal head).
198
+
199
+ Args:
200
+ cfg: Config namespace.
201
+ """
202
+
203
+ def __init__(self, cfg: SimpleNamespace) -> None:
204
+ super().__init__()
205
+ action_dim = cfg.action_dim
206
+ n_embd = cfg.n_embd
207
+ seq_len = cfg.seq_len
208
+
209
+ self.embedding = nn.Embedding(6000, 64)
210
+ self.cnn = nn.Sequential(
211
+ nn.Conv2d(64, 32, 3, padding=1),
212
+ nn.GELU(),
213
+ nn.Conv2d(32, 64, 3, padding=1),
214
+ nn.GELU(),
215
+ nn.Flatten(),
216
+ nn.Linear(64 * 9 * 9, n_embd),
217
+ )
218
+ self.action_emb = nn.Embedding(action_dim + 2, n_embd)
219
+ self.timestep_emb = nn.Embedding(cfg.num_diffusion_steps, n_embd)
220
+ self.pos_emb = nn.Embedding(seq_len, n_embd)
221
+ encoder_layer = nn.TransformerEncoderLayer(
222
+ d_model=n_embd,
223
+ nhead=cfg.n_head,
224
+ dim_feedforward=n_embd * 4,
225
+ dropout=getattr(cfg, "dropout", 0.0),
226
+ activation="gelu",
227
+ norm_first=True,
228
+ batch_first=True,
229
+ )
230
+ self.transformer = nn.TransformerEncoder(
231
+ encoder_layer, num_layers=cfg.n_layer,
232
+ )
233
+ self.head = nn.Linear(n_embd, action_dim)
234
+
235
+ def forward(
236
+ self,
237
+ local_obs: Tensor,
238
+ global_obs: Tensor,
239
+ action_seq: Tensor,
240
+ t_discrete: int | Tensor,
241
+ ) -> dict[str, Tensor]:
242
+ """Forward pass (ignores global_obs).
243
+
244
+ Args:
245
+ local_obs: ``[B, 9, 9]`` int.
246
+ global_obs: ``[B, 21, 79]`` int (ignored).
247
+ action_seq: ``[B, seq_len]`` int.
248
+ t_discrete: Timestep index.
249
+
250
+ Returns:
251
+ Dict with ``"actions"`` key only (no goal_pred).
252
+ """
253
+ B, Seq = action_seq.shape
254
+ device = local_obs.device
255
+
256
+ x_state = self.embedding(local_obs).permute(0, 3, 1, 2)
257
+ state_emb = self.cnn(x_state).unsqueeze(1) # [B, 1, n_embd]
258
+
259
+ positions = torch.arange(
260
+ Seq, device=device,
261
+ ).unsqueeze(0).expand(B, -1)
262
+
263
+ if isinstance(t_discrete, int):
264
+ t_tensor = torch.full(
265
+ (B,), t_discrete, dtype=torch.long, device=device,
266
+ )
267
+ else:
268
+ t_tensor = t_discrete.long().to(device)
269
+
270
+ seq_emb = (
271
+ self.action_emb(action_seq)
272
+ + self.timestep_emb(t_tensor).unsqueeze(1)
273
+ + self.pos_emb(positions)
274
+ )
275
+ x = torch.cat([state_emb, seq_emb], dim=1)
276
+ out = self.transformer(x)
277
+ return {"actions": self.head(out[:, 1:, :])}
278
+
279
+
280
+ # ── Factory ───────────────────────────────────���──────────────────────
281
+
282
+
283
+ def make_model(cfg: SimpleNamespace) -> nn.Module:
284
+ """Instantiate the default MiniHack denoising model.
285
+
286
+ Args:
287
+ cfg: Config namespace.
288
+
289
+ Returns:
290
+ ``LocalDiffusionPlannerWithGlobal`` instance.
291
+ """
292
+ return LocalDiffusionPlannerWithGlobal(cfg)
293
+
294
+
295
+ def _has_c_compiler() -> bool:
296
+ """Check whether a C compiler is reachable by Triton.
297
+
298
+ Checks the ``CC`` env var (set by conda activation scripts),
299
+ then falls back to ``cc`` and ``gcc`` on ``PATH``.
300
+ """
301
+ import os
302
+ cc_env = os.environ.get("CC")
303
+ if cc_env and shutil.which(cc_env):
304
+ return True
305
+ return shutil.which("cc") is not None or shutil.which("gcc") is not None
306
+
307
+
308
+ def try_compile(model: nn.Module, cfg: SimpleNamespace) -> nn.Module:
309
+ """Wrap *model* with ``torch.compile`` if enabled and a C compiler exists.
310
+
311
+ Falls back to the uncompiled model when ``torch.compile`` is
312
+ unavailable or Triton cannot find a C compiler (common on managed
313
+ GPU nodes that lack ``gcc``/``cc``).
314
+
315
+ Args:
316
+ model: The raw (uncompiled) model.
317
+ cfg: Config namespace; reads ``torch_compile`` bool.
318
+
319
+ Returns:
320
+ Compiled model, or *model* unchanged on fallback.
321
+ """
322
+ if not getattr(cfg, "torch_compile", False):
323
+ return model
324
+ if not hasattr(torch, "compile"):
325
+ return model
326
+ if not _has_c_compiler():
327
+ logger.warning(
328
+ "torch.compile requested but no C compiler found "
329
+ "(CC env var, cc, gcc); falling back to eager mode"
330
+ )
331
+ return model
332
+ logger.info("Compiling model with torch.compile")
333
+ return torch.compile(model, mode="default") # type: ignore[return-value]
334
+
335
+
336
+ # ── EMA ──────────────────────────────────────────────────────────────
337
+
338
+
339
+ class ModelEMA:
340
+ """Exponential moving average of model parameters.
341
+
342
+ Maintains a shadow copy of parameters updated as
343
+ ``theta_ema <- decay * theta_ema + (1 - decay) * theta``.
344
+
345
+ Args:
346
+ model: Source model.
347
+ decay: EMA decay factor (default 0.999).
348
+ """
349
+
350
+ def __init__(self, model: nn.Module, decay: float = 0.999) -> None:
351
+ self._decay = decay
352
+ self._shadow: dict[str, Tensor] = {}
353
+ for name, param in model.named_parameters():
354
+ self._shadow[name] = param.data.clone()
355
+
356
+ @torch.no_grad()
357
+ def update(self, model: nn.Module) -> None:
358
+ """Update shadow parameters from *model*.
359
+
360
+ Args:
361
+ model: Source model whose parameters are blended in.
362
+ """
363
+ for name, param in model.named_parameters():
364
+ self._shadow[name].mul_(self._decay).add_(
365
+ param.data, alpha=1.0 - self._decay,
366
+ )
367
+
368
+ def apply_to(self, model: nn.Module) -> None:
369
+ """Copy shadow parameters into *model* (for inference).
370
+
371
+ Args:
372
+ model: Target model to overwrite.
373
+ """
374
+ for name, param in model.named_parameters():
375
+ param.data.copy_(self._shadow[name])
376
+
377
+ def state_dict(self) -> dict[str, Tensor]:
378
+ """Return shadow parameter dict for serialisation.
379
+
380
+ Returns:
381
+ Dict mapping parameter names to EMA tensors.
382
+ """
383
+ return {k: v.clone() for k, v in self._shadow.items()}
384
+
385
+ def load_state_dict(self, sd: dict[str, Tensor]) -> None:
386
+ """Restore shadow parameters from *sd*.
387
+
388
+ Args:
389
+ sd: State dict from a prior ``state_dict()`` call.
390
+ """
391
+ for k, v in sd.items():
392
+ if k in self._shadow:
393
+ self._shadow[k].copy_(v)
394
+
395
+ def parameters(self):
396
+ """Iterate over shadow parameter tensors.
397
+
398
+ Yields:
399
+ EMA parameter tensors.
400
+ """
401
+ yield from self._shadow.values()
402
+
403
+ def make_eval_model(self, model: nn.Module) -> nn.Module:
404
+ """Return a deep copy of *model* with EMA weights applied.
405
+
406
+ Args:
407
+ model: Template model (architecture).
408
+
409
+ Returns:
410
+ New model with shadow parameters.
411
+ """
412
+ eval_model = copy.deepcopy(model)
413
+ self.apply_to(eval_model)
414
+ eval_model.eval()
415
+ return eval_model
src/planners/__init__.py ADDED
File without changes
src/planners/baselines.py ADDED
@@ -0,0 +1,1247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SB3 + Decision Transformer baselines for the ReMDM diffusion planner.
2
+
3
+ This module wraps standard discrete-action RL baselines (PPO, A2C, DQN,
4
+ recurrent PPO) plus two imitation baselines (Behavioural Cloning and
5
+ Decision Transformer) into the project's unified config + dispatch
6
+ surface so they can be compared head-to-head against the DAgger /
7
+ offline-BC diffusion planner on the same MiniHack environments.
8
+
9
+ Entry point: :func:`run_baselines`.
10
+
11
+ Hyperparameters live in ``configs/defaults.yaml`` under the
12
+ ``baselines_*`` namespace; the unified env-step training budget
13
+ (``cfg.total_timesteps``) is shared with DAgger and offline BC.
14
+
15
+ W&B logging routes through the project's :class:`Logger` (with the W&B
16
+ project temporarily swapped to ``cfg.baselines_wandb_project``); SB3's
17
+ standard ``WandbCallback`` piggybacks on the active run and syncs its
18
+ tensorboard scalars automatically. No file in this module calls
19
+ ``wandb.log(...)`` directly.
20
+ """
21
+
22
+ from __future__ import annotations
23
+
24
+ import logging
25
+ import os
26
+ import random
27
+ from pathlib import Path
28
+ from types import SimpleNamespace
29
+ from typing import Any
30
+
31
+ import gymnasium as gym
32
+ import numpy as np
33
+ import orjson
34
+ import torch
35
+ import torch.nn as nn
36
+ from sb3_contrib import RecurrentPPO
37
+ from stable_baselines3 import A2C, DQN, PPO
38
+ from stable_baselines3.common.callbacks import CallbackList, EvalCallback
39
+ from stable_baselines3.common.monitor import Monitor
40
+ from stable_baselines3.common.policies import ActorCriticPolicy
41
+ from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
42
+ from stable_baselines3.common.vec_env import SubprocVecEnv
43
+ from torch.utils.data import DataLoader, Dataset
44
+ from wandb.integration.sb3 import WandbCallback
45
+
46
+ from src.envs.minihack_env import (
47
+ AdvancedObservationEnv,
48
+ collect_oracle_trajectory,
49
+ )
50
+ from src.planners.logging import Logger
51
+
52
+ logger = logging.getLogger(__name__)
53
+
54
+
55
+ SB3_RL_ALGOS: tuple[str, ...] = ("ppo", "a2c", "dqn", "ppo-rnn")
56
+ IMITATION_ALGOS: tuple[str, ...] = ("bc", "dt")
57
+ ALL_BASELINE_ALGOS: tuple[str, ...] = SB3_RL_ALGOS + IMITATION_ALGOS
58
+
59
+
60
+ # =============================================================================
61
+ # Observation wrapper for SB3 dict-policies
62
+ # =============================================================================
63
+
64
+
65
+ class _SB3MiniHackWrapper(gym.Wrapper):
66
+ """Reshape ``AdvancedObservationEnv`` tuple obs into an SB3 dict obs.
67
+
68
+ The underlying env returns ``(local_crop, global_map)`` with shapes
69
+ ``(crop, crop)`` and ``(map_h, map_w)``; SB3's ``MultiInputPolicy``
70
+ needs a ``Dict`` space with explicit channel dims. Also remaps
71
+ ``info["won"]`` -> ``info["is_success"]`` so SB3's success tracking
72
+ reports our win rate.
73
+ """
74
+
75
+ def __init__(self, env: AdvancedObservationEnv) -> None:
76
+ super().__init__(env)
77
+ local_h, local_w = env.observation_space.shape
78
+ cfg = env._cfg # AdvancedObservationEnv stores cfg here
79
+ self.observation_space = gym.spaces.Dict(
80
+ {
81
+ "local": gym.spaces.Box(
82
+ low=0, high=6000, shape=(1, local_h, local_w), dtype=np.int16,
83
+ ),
84
+ "global": gym.spaces.Box(
85
+ low=0, high=6000, shape=(1, cfg.map_h, cfg.map_w), dtype=np.int16,
86
+ ),
87
+ }
88
+ )
89
+
90
+ def reset(self, **kwargs: Any) -> tuple[dict[str, np.ndarray], dict]:
91
+ (local, glob), info = self.env.reset(**kwargs)
92
+ return self._pack(local, glob), info
93
+
94
+ def step(
95
+ self, action: int,
96
+ ) -> tuple[dict[str, np.ndarray], float, bool, bool, dict]:
97
+ (local, glob), reward, terminated, truncated, info = self.env.step(action)
98
+ if "won" in info:
99
+ info["is_success"] = info["won"]
100
+ return self._pack(local, glob), reward, terminated, truncated, info
101
+
102
+ @staticmethod
103
+ def _pack(
104
+ local: np.ndarray, glob: np.ndarray,
105
+ ) -> dict[str, np.ndarray]:
106
+ return {
107
+ "local": np.expand_dims(local, axis=0), # [1, crop, crop]
108
+ "global": np.expand_dims(glob, axis=0), # [1, H, W]
109
+ }
110
+
111
+
112
+ # =============================================================================
113
+ # CNN feature extractor (shared by SB3 RL + BC)
114
+ # =============================================================================
115
+
116
+
117
+ class _MiniHackCNN(BaseFeaturesExtractor):
118
+ """Dual-stream CNN for the SB3 dict observation.
119
+
120
+ Local stream: ``Conv(1->16, 3) -> Conv(16->32, 3)``.
121
+ Global stream: ``Conv(1->16, 5, stride 2) -> Conv(16->32, 3, stride 2)``.
122
+ Both streams are flattened and concatenated, then projected to
123
+ ``features_dim`` via a single linear + ReLU.
124
+ """
125
+
126
+ def __init__(
127
+ self, observation_space: gym.spaces.Dict, features_dim: int = 256,
128
+ ) -> None:
129
+ super().__init__(observation_space, features_dim)
130
+ self.local_cnn = nn.Sequential(
131
+ nn.Conv2d(1, 16, kernel_size=3, padding=1),
132
+ nn.ReLU(),
133
+ nn.Conv2d(16, 32, kernel_size=3, padding=1),
134
+ nn.ReLU(),
135
+ nn.Flatten(),
136
+ )
137
+ self.global_cnn = nn.Sequential(
138
+ nn.Conv2d(1, 16, kernel_size=5, stride=2),
139
+ nn.ReLU(),
140
+ nn.Conv2d(16, 32, kernel_size=3, stride=2),
141
+ nn.ReLU(),
142
+ nn.Flatten(),
143
+ )
144
+ with torch.no_grad():
145
+ dummy_loc = torch.zeros(1, *observation_space["local"].shape)
146
+ dummy_glob = torch.zeros(1, *observation_space["global"].shape)
147
+ n_flatten = (
148
+ self.local_cnn(dummy_loc).shape[1]
149
+ + self.global_cnn(dummy_glob).shape[1]
150
+ )
151
+ self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU())
152
+
153
+ def forward(
154
+ self, observations: dict[str, torch.Tensor],
155
+ ) -> torch.Tensor:
156
+ loc = self.local_cnn(observations["local"].float()) # [B, F_l]
157
+ glob = self.global_cnn(observations["global"].float()) # [B, F_g]
158
+ return self.linear(torch.cat([loc, glob], dim=1))
159
+
160
+
161
+ # =============================================================================
162
+ # Decision Transformer
163
+ # =============================================================================
164
+
165
+
166
+ class _MiniHackStateEncoder(nn.Module):
167
+ """CNN encoder mapping a (local, global) obs pair to a state embedding."""
168
+
169
+ def __init__(
170
+ self,
171
+ embed_dim: int = 128,
172
+ crop_h: int = 9,
173
+ crop_w: int = 9,
174
+ map_h: int = 21,
175
+ map_w: int = 79,
176
+ ) -> None:
177
+ super().__init__()
178
+ self.local_cnn = nn.Sequential(
179
+ nn.Conv2d(1, 16, kernel_size=3, padding=1),
180
+ nn.ReLU(),
181
+ nn.Conv2d(16, 32, kernel_size=3, padding=1),
182
+ nn.ReLU(),
183
+ nn.Flatten(),
184
+ )
185
+ self.global_cnn = nn.Sequential(
186
+ nn.Conv2d(1, 16, kernel_size=5, stride=2),
187
+ nn.ReLU(),
188
+ nn.Conv2d(16, 32, kernel_size=3, stride=2),
189
+ nn.ReLU(),
190
+ nn.Flatten(),
191
+ )
192
+ with torch.no_grad():
193
+ dummy_loc = torch.zeros(1, 1, crop_h, crop_w)
194
+ dummy_glob = torch.zeros(1, 1, map_h, map_w)
195
+ local_flat = self.local_cnn(dummy_loc).shape[1]
196
+ global_flat = self.global_cnn(dummy_glob).shape[1]
197
+ self.proj = nn.Linear(local_flat + global_flat, embed_dim)
198
+
199
+ def forward(
200
+ self, local_obs: torch.Tensor, global_obs: torch.Tensor,
201
+ ) -> torch.Tensor:
202
+ # Accepts (B, T, 1, H, W) or (B, 1, H, W).
203
+ if local_obs.dim() == 5:
204
+ B, T = local_obs.shape[:2]
205
+ local_obs = local_obs.view(B * T, *local_obs.shape[2:])
206
+ global_obs = global_obs.view(B * T, *global_obs.shape[2:])
207
+ reshape = True
208
+ else:
209
+ B, T = local_obs.shape[0], 1
210
+ reshape = False
211
+
212
+ loc_feat = self.local_cnn(local_obs.float()) # [B*T, F_l]
213
+ glob_feat = self.global_cnn(global_obs.float()) # [B*T, F_g]
214
+ out = self.proj(torch.cat([loc_feat, glob_feat], dim=-1)) # [B*T, D]
215
+ if reshape:
216
+ out = out.view(B, T, -1)
217
+ return out
218
+
219
+
220
+ class _DecisionTransformer(nn.Module):
221
+ """Causal Decision Transformer over interleaved (R, s, a) tokens."""
222
+
223
+ def __init__(
224
+ self,
225
+ n_actions: int,
226
+ embed_dim: int = 128,
227
+ n_heads: int = 4,
228
+ n_layers: int = 3,
229
+ context_len: int = 30,
230
+ max_ep_len: int = 500,
231
+ dropout: float = 0.1,
232
+ crop_h: int = 9,
233
+ crop_w: int = 9,
234
+ map_h: int = 21,
235
+ map_w: int = 79,
236
+ ) -> None:
237
+ super().__init__()
238
+ self.embed_dim = embed_dim
239
+ self.context_len = context_len
240
+ self.n_actions = n_actions
241
+ self.max_ep_len = max_ep_len
242
+
243
+ self.state_encoder = _MiniHackStateEncoder(
244
+ embed_dim, crop_h, crop_w, map_h, map_w,
245
+ )
246
+ self.action_embed = nn.Embedding(n_actions + 1, embed_dim) # +1 for pad
247
+ self.return_embed = nn.Linear(1, embed_dim)
248
+ self.pos_embed = nn.Embedding(max_ep_len, embed_dim)
249
+ self.token_type_embed = nn.Embedding(3, embed_dim)
250
+ self.embed_ln = nn.LayerNorm(embed_dim)
251
+ self.dropout = nn.Dropout(dropout)
252
+
253
+ encoder_layer = nn.TransformerEncoderLayer(
254
+ d_model=embed_dim,
255
+ nhead=n_heads,
256
+ dim_feedforward=embed_dim * 4,
257
+ dropout=dropout,
258
+ activation="gelu",
259
+ batch_first=True,
260
+ )
261
+ self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
262
+ self.action_head = nn.Linear(embed_dim, n_actions)
263
+
264
+ self.apply(self._init_weights)
265
+
266
+ @staticmethod
267
+ def _init_weights(module: nn.Module) -> None:
268
+ if isinstance(module, nn.Linear):
269
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
270
+ if module.bias is not None:
271
+ nn.init.zeros_(module.bias)
272
+ elif isinstance(module, nn.Embedding):
273
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
274
+ elif isinstance(module, nn.LayerNorm):
275
+ nn.init.ones_(module.weight)
276
+ nn.init.zeros_(module.bias)
277
+
278
+ def forward(
279
+ self,
280
+ returns_to_go: torch.Tensor, # [B, T, 1]
281
+ local_obs: torch.Tensor, # [B, T, 1, H_l, W_l]
282
+ global_obs: torch.Tensor, # [B, T, 1, H_g, W_g]
283
+ actions: torch.Tensor, # [B, T]
284
+ timesteps: torch.Tensor, # [B, T]
285
+ attention_mask: torch.Tensor | None = None, # [B, T]
286
+ ) -> torch.Tensor:
287
+ B, T = returns_to_go.shape[:2]
288
+ device = returns_to_go.device
289
+
290
+ rtg_embed = self.return_embed(returns_to_go) # [B, T, D]
291
+ state_embed = self.state_encoder(local_obs, global_obs) # [B, T, D]
292
+ action_embed = self.action_embed(actions) # [B, T, D]
293
+
294
+ pos_embed = self.pos_embed(timesteps) # [B, T, D]
295
+ rtg_embed = rtg_embed + pos_embed + self.token_type_embed.weight[0]
296
+ state_embed = state_embed + pos_embed + self.token_type_embed.weight[1]
297
+ action_embed = action_embed + pos_embed + self.token_type_embed.weight[2]
298
+
299
+ # Interleave (R_0, s_0, a_0, R_1, s_1, a_1, ...) -> [B, 3T, D]
300
+ stacked = torch.stack([rtg_embed, state_embed, action_embed], dim=2)
301
+ stacked = stacked.view(B, 3 * T, self.embed_dim)
302
+ stacked = self.dropout(self.embed_ln(stacked))
303
+
304
+ seq_len = 3 * T
305
+ causal_mask = torch.triu(
306
+ torch.ones(seq_len, seq_len, device=device), diagonal=1,
307
+ ).bool()
308
+
309
+ key_padding_mask = None
310
+ if attention_mask is not None:
311
+ expanded = attention_mask.unsqueeze(-1).repeat(1, 1, 3).view(B, 3 * T)
312
+ key_padding_mask = expanded == 0
313
+
314
+ hidden = self.transformer(
315
+ stacked, mask=causal_mask, src_key_padding_mask=key_padding_mask,
316
+ )
317
+ # State token positions are 1, 4, 7, ... -> stride 3.
318
+ state_hidden = hidden[:, 1::3, :] # [B, T, D]
319
+ return self.action_head(state_hidden) # [B, T, A]
320
+
321
+ @torch.no_grad()
322
+ def get_action(
323
+ self,
324
+ returns_to_go: torch.Tensor,
325
+ local_obs: torch.Tensor,
326
+ global_obs: torch.Tensor,
327
+ actions: torch.Tensor,
328
+ timesteps: torch.Tensor,
329
+ ) -> torch.Tensor:
330
+ self.eval()
331
+ logits = self.forward(
332
+ returns_to_go, local_obs, global_obs, actions, timesteps,
333
+ )
334
+ return logits[:, -1, :].argmax(dim=-1)
335
+
336
+
337
+ class _DTDataset(Dataset):
338
+ """Sliding-window dataset over Decision Transformer trajectories."""
339
+
340
+ def __init__(
341
+ self,
342
+ trajectories: list[dict[str, np.ndarray]],
343
+ context_len: int,
344
+ max_ep_len: int,
345
+ n_actions: int,
346
+ ) -> None:
347
+ self.trajectories = trajectories
348
+ self.context_len = context_len
349
+ self.max_ep_len = max_ep_len
350
+ self.n_actions = n_actions
351
+ self.indices: list[tuple[int, int]] = [
352
+ (traj_idx, start)
353
+ for traj_idx, traj in enumerate(trajectories)
354
+ for start in range(len(traj["actions"]))
355
+ ]
356
+
357
+ def __len__(self) -> int:
358
+ return len(self.indices)
359
+
360
+ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
361
+ traj_idx, start = self.indices[idx]
362
+ traj = self.trajectories[traj_idx]
363
+ traj_len = len(traj["actions"])
364
+ end = min(start + self.context_len, traj_len)
365
+ actual_len = end - start
366
+
367
+ local = traj["local"][start:end].copy()
368
+ glob = traj["global"][start:end].copy()
369
+ actions = traj["actions"][start:end].copy()
370
+ rtg = traj["returns_to_go"][start:end].copy()
371
+ timesteps = np.arange(start, end)
372
+
373
+ # Clamp to valid embedding ranges.
374
+ timesteps = np.clip(timesteps, 0, self.max_ep_len - 1)
375
+ actions = np.clip(actions, 0, self.n_actions - 1)
376
+
377
+ pad_len = self.context_len - actual_len
378
+ if pad_len > 0:
379
+ local = np.pad(
380
+ local, ((0, pad_len), (0, 0), (0, 0), (0, 0)), mode="constant",
381
+ )
382
+ glob = np.pad(
383
+ glob, ((0, pad_len), (0, 0), (0, 0), (0, 0)), mode="constant",
384
+ )
385
+ actions = np.pad(actions, (0, pad_len), mode="constant")
386
+ rtg = np.pad(rtg, (0, pad_len), mode="constant")
387
+ timesteps = np.pad(timesteps, (0, pad_len), mode="constant")
388
+
389
+ attention_mask = np.zeros(self.context_len, dtype=np.float32)
390
+ attention_mask[:actual_len] = 1.0
391
+
392
+ return {
393
+ "local": torch.tensor(local, dtype=torch.float32),
394
+ "global": torch.tensor(glob, dtype=torch.float32),
395
+ "actions": torch.tensor(actions, dtype=torch.long),
396
+ "returns_to_go": torch.tensor(rtg, dtype=torch.float32).unsqueeze(-1),
397
+ "timesteps": torch.tensor(timesteps, dtype=torch.long),
398
+ "attention_mask": torch.tensor(attention_mask, dtype=torch.float32),
399
+ }
400
+
401
+
402
+ # =============================================================================
403
+ # SB3 callbacks + env factory
404
+ # =============================================================================
405
+
406
+
407
+ class _PrefixedEvalCallback(EvalCallback):
408
+ """``EvalCallback`` that records mean_reward / avg_steps / win_rate
409
+ under a unique per-environment prefix.
410
+
411
+ SB3 truncates metric names at 36 chars, which collides on long
412
+ MiniHack env IDs; the prefix lets us strip ``MiniHack-`` / ``-v0``
413
+ cleanly.
414
+ """
415
+
416
+ def __init__(
417
+ self, eval_env: SubprocVecEnv, prefix: str, **kwargs: Any,
418
+ ) -> None:
419
+ super().__init__(eval_env, **kwargs)
420
+ self.prefix = prefix
421
+
422
+ def _on_step(self) -> bool:
423
+ cont = super()._on_step()
424
+ if self.evaluations_results:
425
+ self.logger.record(
426
+ f"{self.prefix}/mean_reward", float(np.mean(self.evaluations_results[-1])),
427
+ )
428
+ self.logger.record(
429
+ f"{self.prefix}/avg_steps", float(np.mean(self.evaluations_length[-1])),
430
+ )
431
+ if self.evaluations_successes:
432
+ self.logger.record(
433
+ f"{self.prefix}/win_rate",
434
+ float(np.mean(self.evaluations_successes[-1])),
435
+ )
436
+ return cont
437
+
438
+
439
+ def _make_sb3_env_fn(env_id: str, cfg: SimpleNamespace, log_dir: str):
440
+ """Return a picklable thunk that builds one wrapped+monitored env."""
441
+
442
+ def _init() -> Monitor:
443
+ os.makedirs(log_dir, exist_ok=True)
444
+ env = AdvancedObservationEnv(env_id, des_file=None, cfg=cfg)
445
+ env = _SB3MiniHackWrapper(env)
446
+ return Monitor(env, log_dir)
447
+
448
+ return _init
449
+
450
+
451
+ # =============================================================================
452
+ # Helpers
453
+ # =============================================================================
454
+
455
+
456
+ def _short(env_id: str) -> str:
457
+ return env_id.replace("MiniHack-", "").replace("-v0", "")
458
+
459
+
460
+ def _eval_episodes_per_env(cfg: SimpleNamespace) -> int:
461
+ override = getattr(cfg, "baselines_eval_episodes_per_env", None)
462
+ if override is not None:
463
+ return int(override)
464
+ return int(cfg.eval_episodes_per_env)
465
+
466
+
467
+ def _seed_everything(seed: int) -> None:
468
+ random.seed(seed)
469
+ np.random.seed(seed)
470
+ torch.manual_seed(seed)
471
+ if torch.cuda.is_available():
472
+ torch.cuda.manual_seed(seed)
473
+ torch.cuda.manual_seed_all(seed)
474
+
475
+
476
+ def _resolve_output_dir(cfg: SimpleNamespace, override: str | None) -> Path:
477
+ if override:
478
+ out = Path(override)
479
+ else:
480
+ out = Path(cfg.baselines_output_dir)
481
+ out.mkdir(parents=True, exist_ok=True)
482
+ return out
483
+
484
+
485
+ def _init_baseline_logger(
486
+ cfg: SimpleNamespace, run_name: str,
487
+ ) -> Logger:
488
+ """Init the project Logger with W&B project swapped to baselines.
489
+
490
+ Mutates ``cfg.wandb_project`` / ``cfg.wandb_run_name`` /
491
+ ``cfg.wandb_resume_id`` for the duration of the call so the existing
492
+ Logger constructor picks them up. We deliberately do not restore the
493
+ originals — each baseline seed reuses this helper, and main.py exits
494
+ after ``run_baselines`` returns.
495
+ """
496
+
497
+ project_override = getattr(cfg, "baselines_wandb_project", None)
498
+ if project_override:
499
+ cfg.wandb_project = project_override
500
+ cfg.wandb_run_name = run_name
501
+ cfg.wandb_resume_id = None
502
+ return Logger(cfg)
503
+
504
+
505
+ # =============================================================================
506
+ # BC training
507
+ # =============================================================================
508
+
509
+
510
+ def _collect_bc_dataset(
511
+ cfg: SimpleNamespace,
512
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
513
+ """Roll out the BFS oracle on each ID env and stack flat (s, a) pairs."""
514
+
515
+ n_per_env = int(cfg.baselines_bc_oracle_episodes_per_env)
516
+ locals_, globals_, actions_ = [], [], []
517
+ for env_id in cfg.id_envs:
518
+ for traj_seed in range(n_per_env):
519
+ traj = collect_oracle_trajectory(env_id, traj_seed, cfg)
520
+ if traj is None:
521
+ continue
522
+ # (T, H, W) -> (T, 1, H, W)
523
+ locals_.append(np.expand_dims(traj["local"], axis=1))
524
+ globals_.append(np.expand_dims(traj["global"], axis=1))
525
+ actions_.append(traj["actions"])
526
+ if not actions_:
527
+ raise RuntimeError("BC oracle collection produced zero trajectories")
528
+ return (
529
+ np.concatenate(locals_, axis=0),
530
+ np.concatenate(globals_, axis=0),
531
+ np.concatenate(actions_, axis=0),
532
+ )
533
+
534
+
535
+ class _BCDataset(Dataset):
536
+ def __init__(
537
+ self, loc: np.ndarray, glob: np.ndarray, acts: np.ndarray,
538
+ ) -> None:
539
+ self.loc = torch.tensor(loc, dtype=torch.float32)
540
+ self.glob = torch.tensor(glob, dtype=torch.float32)
541
+ self.acts = torch.tensor(acts, dtype=torch.int64)
542
+
543
+ def __len__(self) -> int:
544
+ return len(self.acts)
545
+
546
+ def __getitem__(
547
+ self, idx: int,
548
+ ) -> dict[str, dict[str, torch.Tensor] | torch.Tensor]:
549
+ return {
550
+ "obs": {"local": self.loc[idx], "global": self.glob[idx]},
551
+ "acts": self.acts[idx],
552
+ }
553
+
554
+
555
+ def _eval_sb3_policy_manually(
556
+ policy: ActorCriticPolicy,
557
+ env_id: str,
558
+ cfg: SimpleNamespace,
559
+ log_dir: str,
560
+ n_episodes: int,
561
+ ) -> tuple[float, float]:
562
+ """Run ``policy.predict`` on a Monitor-wrapped vec env and return
563
+ (win_rate, avg_steps)."""
564
+
565
+ eval_env = SubprocVecEnv([_make_sb3_env_fn(env_id, cfg, log_dir)])
566
+ try:
567
+ obs = eval_env.reset()
568
+ wins = 0
569
+ total_steps = 0
570
+ completed = 0
571
+ while completed < n_episodes:
572
+ action, _ = policy.predict(obs, deterministic=True)
573
+ obs, _rewards, dones, infos = eval_env.step(action)
574
+ if dones[0]:
575
+ completed += 1
576
+ if infos[0].get("won", False):
577
+ wins += 1
578
+ total_steps += infos[0]["episode"]["l"]
579
+ finally:
580
+ eval_env.close()
581
+ return wins / n_episodes, total_steps / n_episodes
582
+
583
+
584
+ def _train_bc(
585
+ cfg: SimpleNamespace,
586
+ train_env: SubprocVecEnv,
587
+ log: Logger,
588
+ log_dir: str,
589
+ seed: int,
590
+ ) -> tuple[ActorCriticPolicy, dict[str, float]]:
591
+ """Train a Behavioural Cloning baseline. Returns (policy, seed_metrics)."""
592
+
593
+ device = torch.device(cfg.device)
594
+ n_eval = _eval_episodes_per_env(cfg)
595
+
596
+ logger.info("Collecting oracle demonstrations for BC...")
597
+ loc_arr, glob_arr, acts_arr = _collect_bc_dataset(cfg)
598
+ logger.info("BC dataset: %d transitions", len(acts_arr))
599
+
600
+ bc_loader = DataLoader(
601
+ _BCDataset(loc_arr, glob_arr, acts_arr),
602
+ batch_size=int(cfg.baselines_bc_batch_size),
603
+ shuffle=True,
604
+ num_workers=4,
605
+ pin_memory=torch.cuda.is_available(),
606
+ )
607
+
608
+ lr = float(cfg.baselines_bc_lr)
609
+ policy = ActorCriticPolicy(
610
+ observation_space=train_env.observation_space,
611
+ action_space=train_env.action_space,
612
+ lr_schedule=lambda _progress: lr,
613
+ features_extractor_class=_MiniHackCNN,
614
+ features_extractor_kwargs={"features_dim": 256},
615
+ ).to(device)
616
+
617
+ n_epochs = int(cfg.baselines_bc_epochs)
618
+ optimizer = torch.optim.AdamW(
619
+ policy.parameters(),
620
+ lr=lr,
621
+ weight_decay=float(cfg.weight_decay),
622
+ )
623
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
624
+ optimizer, T_max=n_epochs,
625
+ )
626
+ policy.train()
627
+ for epoch in range(n_epochs):
628
+ total_loss = 0.0
629
+ for batch in bc_loader:
630
+ obs = {k: v.to(policy.device) for k, v in batch["obs"].items()}
631
+ acts = batch["acts"].to(policy.device)
632
+ _values, log_prob, _entropy = policy.evaluate_actions(obs, acts)
633
+ loss = -log_prob.mean()
634
+ optimizer.zero_grad()
635
+ loss.backward()
636
+ torch.nn.utils.clip_grad_norm_(policy.parameters(), 1.0)
637
+ optimizer.step()
638
+ total_loss += loss.item()
639
+ scheduler.step()
640
+ avg_loss = total_loss / max(1, len(bc_loader))
641
+ current_lr = scheduler.get_last_lr()[0]
642
+ log.log(
643
+ {
644
+ "train/bc_loss": avg_loss,
645
+ "train/lr": current_lr,
646
+ "train/epoch": epoch + 1,
647
+ },
648
+ step=epoch + 1,
649
+ )
650
+ logger.info(
651
+ "BC epoch %02d/%02d | loss=%.4f | lr=%.2e",
652
+ epoch + 1, n_epochs, avg_loss, current_lr,
653
+ )
654
+
655
+ seed_metrics: dict[str, float] = {}
656
+ for split, env_list in (("ID", cfg.id_envs), ("OOD", cfg.ood_envs)):
657
+ logger.info("--- BC %s evaluation (seed=%d) ---", split, seed)
658
+ for env_id in env_list:
659
+ short = _short(env_id)
660
+ win_rate, avg_steps = _eval_sb3_policy_manually(
661
+ policy,
662
+ env_id,
663
+ cfg,
664
+ f"{log_dir}/eval_{split.lower()}/{env_id}",
665
+ n_eval,
666
+ )
667
+ seed_metrics[f"{split}/{short}/win_rate"] = win_rate * 100
668
+ seed_metrics[f"{split}/{short}/avg_steps"] = avg_steps
669
+ logger.info(
670
+ "%-30s | win_rate=%5.1f%% | avg_steps=%5.1f",
671
+ short, win_rate * 100, avg_steps,
672
+ )
673
+ log.log(seed_metrics, step=n_epochs + 1)
674
+ return policy, seed_metrics
675
+
676
+
677
+ # =============================================================================
678
+ # Decision Transformer training
679
+ # =============================================================================
680
+
681
+
682
+ def _collect_dt_trajectories(
683
+ cfg: SimpleNamespace,
684
+ ) -> list[dict[str, np.ndarray]]:
685
+ """Collect oracle trajectories with sparse reward + return-to-go labels."""
686
+
687
+ n_per_env = int(cfg.baselines_dt_oracle_episodes_per_env)
688
+ trajectories: list[dict[str, np.ndarray]] = []
689
+ for env_id in cfg.id_envs:
690
+ for traj_seed in range(n_per_env):
691
+ traj = collect_oracle_trajectory(env_id, traj_seed, cfg)
692
+ if traj is None:
693
+ continue
694
+ T = len(traj["actions"])
695
+ rewards = np.zeros(T, dtype=np.float32)
696
+ rewards[-1] = 1.0 # sparse goal reward
697
+ rtg = np.zeros(T, dtype=np.float32)
698
+ rtg[-1] = rewards[-1]
699
+ for t in range(T - 2, -1, -1):
700
+ rtg[t] = rewards[t] + rtg[t + 1]
701
+ trajectories.append(
702
+ {
703
+ "local": np.expand_dims(traj["local"], axis=1),
704
+ "global": np.expand_dims(traj["global"], axis=1),
705
+ "actions": traj["actions"],
706
+ "rewards": rewards,
707
+ "returns_to_go": rtg,
708
+ }
709
+ )
710
+ return trajectories
711
+
712
+
713
+ def _eval_dt(
714
+ model: _DecisionTransformer,
715
+ env_id: str,
716
+ cfg: SimpleNamespace,
717
+ target_return: float,
718
+ n_episodes: int,
719
+ max_ep_len: int,
720
+ eval_max_steps: int,
721
+ context_len: int,
722
+ ) -> tuple[float, float]:
723
+ """Roll out a trained Decision Transformer with target-return conditioning."""
724
+
725
+ device = torch.device(cfg.device)
726
+ env = AdvancedObservationEnv(env_id, des_file=None, cfg=cfg)
727
+ env = _SB3MiniHackWrapper(env)
728
+ model.eval()
729
+ wins = 0
730
+ total_steps = 0
731
+ try:
732
+ for _ep in range(n_episodes):
733
+ obs, _ = env.reset()
734
+ done = False
735
+
736
+ local_hist: list[np.ndarray] = []
737
+ global_hist: list[np.ndarray] = []
738
+ action_hist: list[int] = []
739
+ rtg_hist: list[float] = []
740
+ ts_hist: list[int] = []
741
+
742
+ current_rtg = float(target_return)
743
+ t = 0
744
+ info: dict = {}
745
+ while not done and t < eval_max_steps:
746
+ local_hist.append(obs["local"])
747
+ global_hist.append(obs["global"])
748
+ rtg_hist.append(current_rtg)
749
+ ts_hist.append(min(t, max_ep_len - 1))
750
+
751
+ ctx = min(len(local_hist), context_len)
752
+ local_in = np.stack(local_hist[-ctx:], axis=0)
753
+ global_in = np.stack(global_hist[-ctx:], axis=0)
754
+ rtg_in = np.array(rtg_hist[-ctx:], dtype=np.float32)
755
+ ts_in = np.array(ts_hist[-ctx:], dtype=np.int64)
756
+ if len(action_hist) < ctx:
757
+ act_in = np.zeros(ctx, dtype=np.int64)
758
+ if action_hist:
759
+ act_in[-len(action_hist):] = action_hist[-ctx:]
760
+ else:
761
+ act_in = np.array(action_hist[-ctx:], dtype=np.int64)
762
+
763
+ local_t = torch.tensor(local_in, dtype=torch.float32).unsqueeze(0).to(device)
764
+ global_t = torch.tensor(global_in, dtype=torch.float32).unsqueeze(0).to(device)
765
+ rtg_t = torch.tensor(rtg_in, dtype=torch.float32).unsqueeze(0).unsqueeze(-1).to(device)
766
+ act_t = torch.tensor(act_in, dtype=torch.long).unsqueeze(0).to(device)
767
+ ts_t = torch.tensor(ts_in, dtype=torch.long).unsqueeze(0).to(device)
768
+
769
+ with torch.no_grad():
770
+ action = int(
771
+ model.get_action(rtg_t, local_t, global_t, act_t, ts_t).item()
772
+ )
773
+ action = max(0, min(action, int(cfg.action_dim) - 1))
774
+ action_hist.append(action)
775
+
776
+ obs, reward, terminated, truncated, info = env.step(action)
777
+ done = terminated or truncated
778
+ current_rtg -= float(reward)
779
+ t += 1
780
+
781
+ if info.get("won", False):
782
+ wins += 1
783
+ total_steps += t
784
+ finally:
785
+ env.close()
786
+
787
+ return wins / n_episodes, total_steps / n_episodes
788
+
789
+
790
+ def _train_dt(
791
+ cfg: SimpleNamespace,
792
+ log: Logger,
793
+ log_dir: str,
794
+ seed: int,
795
+ ) -> tuple[_DecisionTransformer, dict[str, float]]:
796
+ """Train a Decision Transformer baseline. Returns (model, seed_metrics)."""
797
+
798
+ device = torch.device(cfg.device)
799
+ context_len = int(cfg.baselines_dt_context_len)
800
+ max_ep_len = int(cfg.baselines_dt_max_ep_len)
801
+ eval_max_steps = int(cfg.baselines_dt_eval_max_steps)
802
+ n_eval = _eval_episodes_per_env(cfg)
803
+ n_epochs = int(cfg.baselines_dt_epochs)
804
+
805
+ logger.info("Collecting oracle demonstrations for DT...")
806
+ trajectories = _collect_dt_trajectories(cfg)
807
+ if not trajectories:
808
+ raise RuntimeError("DT oracle collection produced zero trajectories")
809
+
810
+ traj_lengths = [len(t["actions"]) for t in trajectories]
811
+ logger.info(
812
+ "DT dataset: %d trajectories, %d transitions (len: min=%d max=%d mean=%.1f)",
813
+ len(trajectories),
814
+ sum(traj_lengths),
815
+ min(traj_lengths),
816
+ max(traj_lengths),
817
+ float(np.mean(traj_lengths)),
818
+ )
819
+ if max(traj_lengths) > max_ep_len:
820
+ logger.warning(
821
+ "Longest oracle trajectory (%d) exceeds baselines_dt_max_ep_len (%d); "
822
+ "positions will be clamped.",
823
+ max(traj_lengths),
824
+ max_ep_len,
825
+ )
826
+
827
+ target_return = float(max(t["returns_to_go"][0] for t in trajectories))
828
+
829
+ dataset = _DTDataset(
830
+ trajectories,
831
+ context_len=context_len,
832
+ max_ep_len=max_ep_len,
833
+ n_actions=int(cfg.action_dim),
834
+ )
835
+ loader = DataLoader(
836
+ dataset,
837
+ batch_size=int(cfg.baselines_dt_batch_size),
838
+ shuffle=True,
839
+ num_workers=4,
840
+ pin_memory=torch.cuda.is_available(),
841
+ )
842
+
843
+ model = _DecisionTransformer(
844
+ n_actions=int(cfg.action_dim),
845
+ embed_dim=int(cfg.baselines_dt_embed_dim),
846
+ n_heads=int(cfg.baselines_dt_n_heads),
847
+ n_layers=int(cfg.baselines_dt_n_layers),
848
+ context_len=context_len,
849
+ max_ep_len=max_ep_len,
850
+ crop_h=int(cfg.crop_size),
851
+ crop_w=int(cfg.crop_size),
852
+ map_h=int(cfg.map_h),
853
+ map_w=int(cfg.map_w),
854
+ ).to(device)
855
+ n_params = sum(p.numel() for p in model.parameters())
856
+ logger.info("DT parameters: %d", n_params)
857
+
858
+ optimizer = torch.optim.AdamW(
859
+ model.parameters(),
860
+ lr=float(cfg.baselines_dt_lr),
861
+ weight_decay=float(cfg.weight_decay),
862
+ )
863
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
864
+ optimizer, T_max=n_epochs,
865
+ )
866
+
867
+ for epoch in range(n_epochs):
868
+ model.train()
869
+ total_loss = 0.0
870
+ n_batches = 0
871
+ for batch in loader:
872
+ local = batch["local"].to(device)
873
+ glob = batch["global"].to(device)
874
+ actions = batch["actions"].to(device)
875
+ rtg = batch["returns_to_go"].to(device)
876
+ timesteps = batch["timesteps"].to(device)
877
+ attention_mask = batch["attention_mask"].to(device)
878
+
879
+ logits = model(rtg, local, glob, actions, timesteps, attention_mask)
880
+ logits_flat = logits.reshape(-1, int(cfg.action_dim))
881
+ targets_flat = actions.reshape(-1)
882
+ mask_flat = attention_mask.reshape(-1)
883
+ ce = nn.functional.cross_entropy(
884
+ logits_flat, targets_flat, reduction="none",
885
+ )
886
+ loss = (ce * mask_flat).sum() / mask_flat.sum().clamp(min=1.0)
887
+
888
+ optimizer.zero_grad()
889
+ loss.backward()
890
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
891
+ optimizer.step()
892
+
893
+ total_loss += loss.item()
894
+ n_batches += 1
895
+ scheduler.step()
896
+ avg_loss = total_loss / max(1, n_batches)
897
+ log.log(
898
+ {
899
+ "train/dt_loss": avg_loss,
900
+ "train/lr": float(scheduler.get_last_lr()[0]),
901
+ "train/epoch": epoch + 1,
902
+ },
903
+ step=epoch + 1,
904
+ )
905
+ logger.info(
906
+ "DT epoch %02d/%02d | loss=%.4f | lr=%.2e",
907
+ epoch + 1,
908
+ n_epochs,
909
+ avg_loss,
910
+ float(scheduler.get_last_lr()[0]),
911
+ )
912
+
913
+ seed_metrics: dict[str, float] = {}
914
+ logger.info("DT eval target return = %.2f", target_return)
915
+ for split, env_list in (("ID", cfg.id_envs), ("OOD", cfg.ood_envs)):
916
+ logger.info("--- DT %s evaluation (seed=%d) ---", split, seed)
917
+ for env_id in env_list:
918
+ short = _short(env_id)
919
+ win_rate, avg_steps = _eval_dt(
920
+ model,
921
+ env_id,
922
+ cfg,
923
+ target_return=target_return,
924
+ n_episodes=n_eval,
925
+ max_ep_len=max_ep_len,
926
+ eval_max_steps=eval_max_steps,
927
+ context_len=context_len,
928
+ )
929
+ seed_metrics[f"{split}/{short}/win_rate"] = win_rate * 100
930
+ seed_metrics[f"{split}/{short}/avg_steps"] = avg_steps
931
+ logger.info(
932
+ "%-30s | win_rate=%5.1f%% | avg_steps=%5.1f",
933
+ short, win_rate * 100, avg_steps,
934
+ )
935
+ log.log(seed_metrics, step=n_epochs + 1)
936
+ return model, seed_metrics
937
+
938
+
939
+ # =============================================================================
940
+ # SB3 RL training
941
+ # =============================================================================
942
+
943
+
944
+ def _build_sb3_model(
945
+ algo: str,
946
+ train_env: SubprocVecEnv,
947
+ cfg: SimpleNamespace,
948
+ seed: int,
949
+ tb_log_dir: str,
950
+ ):
951
+ """Construct one of {ppo, a2c, dqn, ppo-rnn} with the MiniHack CNN."""
952
+
953
+ policy_kwargs = {
954
+ "features_extractor_class": _MiniHackCNN,
955
+ "features_extractor_kwargs": {"features_dim": 256},
956
+ }
957
+ if algo == "ppo":
958
+ return PPO(
959
+ "MultiInputPolicy", train_env, policy_kwargs=policy_kwargs,
960
+ verbose=1, tensorboard_log=tb_log_dir, seed=seed,
961
+ )
962
+ if algo == "ppo-rnn":
963
+ return RecurrentPPO(
964
+ "MultiInputLstmPolicy", train_env, policy_kwargs=policy_kwargs,
965
+ verbose=1, tensorboard_log=tb_log_dir, seed=seed,
966
+ )
967
+ if algo == "a2c":
968
+ return A2C(
969
+ "MultiInputPolicy", train_env, policy_kwargs=policy_kwargs,
970
+ verbose=1, tensorboard_log=tb_log_dir, seed=seed,
971
+ )
972
+ if algo == "dqn":
973
+ return DQN(
974
+ "MultiInputPolicy", train_env, policy_kwargs=policy_kwargs,
975
+ verbose=1, tensorboard_log=tb_log_dir, seed=seed,
976
+ buffer_size=int(cfg.baselines_dqn_buffer_size),
977
+ )
978
+ raise ValueError(f"Unknown SB3 algo: {algo!r}")
979
+
980
+
981
+ def _build_sb3_callbacks(
982
+ cfg: SimpleNamespace,
983
+ train_env: SubprocVecEnv,
984
+ log_dir: str,
985
+ model_dir: str,
986
+ ) -> CallbackList:
987
+ callbacks: list = [WandbCallback(model_save_path=model_dir)]
988
+ n_eval = _eval_episodes_per_env(cfg)
989
+ eval_freq = max(
990
+ 1, int(cfg.baselines_eval_freq_env_steps) // train_env.num_envs,
991
+ )
992
+ for env_id in cfg.id_envs:
993
+ short = _short(env_id)
994
+ eval_env = SubprocVecEnv(
995
+ [_make_sb3_env_fn(env_id, cfg, f"{log_dir}/eval_id/{env_id}")]
996
+ )
997
+ callbacks.append(
998
+ _PrefixedEvalCallback(
999
+ eval_env,
1000
+ prefix=f"ID/{short}",
1001
+ best_model_save_path=f"{model_dir}/best_{env_id}/",
1002
+ log_path=f"{log_dir}/eval_id/{env_id}/",
1003
+ eval_freq=eval_freq,
1004
+ n_eval_episodes=n_eval,
1005
+ deterministic=True,
1006
+ )
1007
+ )
1008
+ for env_id in cfg.ood_envs:
1009
+ short = _short(env_id)
1010
+ eval_env = SubprocVecEnv(
1011
+ [_make_sb3_env_fn(env_id, cfg, f"{log_dir}/eval_ood/{env_id}")]
1012
+ )
1013
+ callbacks.append(
1014
+ _PrefixedEvalCallback(
1015
+ eval_env,
1016
+ prefix=f"OOD/{short}",
1017
+ best_model_save_path=None,
1018
+ log_path=f"{log_dir}/eval_ood/{env_id}/",
1019
+ eval_freq=eval_freq,
1020
+ n_eval_episodes=n_eval,
1021
+ deterministic=True,
1022
+ )
1023
+ )
1024
+ return CallbackList(callbacks)
1025
+
1026
+
1027
+ # =============================================================================
1028
+ # Aggregation
1029
+ # =============================================================================
1030
+
1031
+
1032
+ def _aggregate(
1033
+ all_seed_results: list[dict[str, Any]],
1034
+ ) -> dict[str, dict[str, float | list[float]]]:
1035
+ """Compute mean/std across seeds for every shared metric key."""
1036
+
1037
+ if not all_seed_results:
1038
+ return {}
1039
+ metric_keys = [k for k in all_seed_results[0].keys() if k != "seed"]
1040
+ agg: dict[str, dict[str, float | list[float]]] = {}
1041
+ for key in metric_keys:
1042
+ values = [r[key] for r in all_seed_results if key in r]
1043
+ if values:
1044
+ agg[key] = {
1045
+ "mean": float(np.mean(values)),
1046
+ "std": float(np.std(values)),
1047
+ "values": [float(v) for v in values],
1048
+ }
1049
+ return agg
1050
+
1051
+
1052
+ def _print_aggregated(seeds: list[int], agg: dict[str, dict[str, Any]]) -> None:
1053
+ if not agg:
1054
+ logger.info("No per-environment metrics to aggregate (RL eval is callback-driven)")
1055
+ return
1056
+ logger.info("Aggregated results across %d seeds: %s", len(seeds), seeds)
1057
+ for split in ("ID", "OOD"):
1058
+ env_metrics: dict[str, dict[str, dict[str, Any]]] = {}
1059
+ for key, stats in agg.items():
1060
+ if not key.startswith(f"{split}/"):
1061
+ continue
1062
+ _split, env_name, metric_name = key.split("/", 2)
1063
+ env_metrics.setdefault(env_name, {})[metric_name] = stats
1064
+ if not env_metrics:
1065
+ continue
1066
+ logger.info("--- %s environments ---", split)
1067
+ for env_name, metrics in sorted(env_metrics.items()):
1068
+ wr = metrics.get("win_rate", {})
1069
+ steps = metrics.get("avg_steps", {})
1070
+ logger.info(
1071
+ "%-30s | win_rate=%5.1f%% +/- %4.1f | avg_steps=%5.1f +/- %4.1f",
1072
+ env_name,
1073
+ wr.get("mean", 0.0),
1074
+ wr.get("std", 0.0),
1075
+ steps.get("mean", 0.0),
1076
+ steps.get("std", 0.0),
1077
+ )
1078
+
1079
+
1080
+ def _save_aggregated(
1081
+ out_path: Path,
1082
+ algo: str,
1083
+ seeds: list[int],
1084
+ all_seed_results: list[dict[str, Any]],
1085
+ agg: dict[str, dict[str, Any]],
1086
+ ) -> None:
1087
+ payload = {
1088
+ "algorithm": algo,
1089
+ "seeds": seeds,
1090
+ "n_seeds": len(seeds),
1091
+ "per_seed_results": all_seed_results,
1092
+ "aggregated": {
1093
+ k: {"mean": v["mean"], "std": v["std"]} for k, v in agg.items()
1094
+ },
1095
+ }
1096
+ out_path.write_bytes(orjson.dumps(payload, option=orjson.OPT_INDENT_2))
1097
+ logger.info("Aggregated results written to %s", out_path)
1098
+
1099
+
1100
+ # =============================================================================
1101
+ # Public entry point
1102
+ # =============================================================================
1103
+
1104
+
1105
+ def run_baselines(
1106
+ cfg: SimpleNamespace,
1107
+ algo: str,
1108
+ seeds: list[int] | None = None,
1109
+ output_path: str | None = None,
1110
+ ) -> None:
1111
+ """Train and evaluate one baseline algorithm across one or more seeds.
1112
+
1113
+ Args:
1114
+ cfg: Project config namespace (must contain ``baselines_*`` keys).
1115
+ algo: One of ``ppo``, ``a2c``, ``dqn``, ``ppo-rnn``, ``bc``, ``dt``.
1116
+ seeds: Optional list of seeds. ``None`` -> ``[cfg.seed]`` (or
1117
+ a single seed of ``0`` if ``cfg.seed`` is ``None``).
1118
+ output_path: Optional override for the aggregated-results JSON
1119
+ destination. When ``None``, results land under
1120
+ ``cfg.baselines_output_dir``.
1121
+ """
1122
+
1123
+ if algo not in ALL_BASELINE_ALGOS:
1124
+ raise ValueError(
1125
+ f"Unknown algo {algo!r}. Choose one of {ALL_BASELINE_ALGOS}."
1126
+ )
1127
+
1128
+ if seeds is None:
1129
+ seeds = [cfg.seed if cfg.seed is not None else 0]
1130
+ if not seeds:
1131
+ raise ValueError("seeds must be non-empty")
1132
+
1133
+ out_dir = _resolve_output_dir(cfg, None)
1134
+ if output_path is not None:
1135
+ agg_json_path = Path(output_path)
1136
+ agg_json_path.parent.mkdir(parents=True, exist_ok=True)
1137
+ else:
1138
+ agg_json_path = out_dir / f"results_{algo}_{len(seeds)}seeds.json"
1139
+
1140
+ logger.info(
1141
+ "Running baseline %s on %d seed(s): %s (output -> %s)",
1142
+ algo, len(seeds), seeds, agg_json_path,
1143
+ )
1144
+
1145
+ all_seed_results: list[dict[str, Any]] = []
1146
+ n_envs_per_id = int(cfg.baselines_n_envs_per_id)
1147
+
1148
+ for seed_idx, seed in enumerate(seeds):
1149
+ logger.info(
1150
+ "============================================================\n"
1151
+ " %s seed %d (%d/%d)\n"
1152
+ "============================================================",
1153
+ algo.upper(), seed, seed_idx + 1, len(seeds),
1154
+ )
1155
+ _seed_everything(seed)
1156
+
1157
+ run_name = f"{algo}-multitask-seed{seed}"
1158
+ log = _init_baseline_logger(cfg, run_name)
1159
+ run_id = (
1160
+ log._run.id # type: ignore[union-attr]
1161
+ if log._use_wandb and log._run is not None
1162
+ else f"local-{algo}-seed{seed}"
1163
+ )
1164
+ log_dir = str(out_dir / "logs" / run_id)
1165
+ model_dir = str(out_dir / "models" / run_id)
1166
+ os.makedirs(log_dir, exist_ok=True)
1167
+ os.makedirs(model_dir, exist_ok=True)
1168
+
1169
+ seed_results: dict[str, Any] = {"seed": seed}
1170
+ try:
1171
+ if algo == "dt":
1172
+ model, dt_metrics = _train_dt(cfg, log, log_dir, seed)
1173
+ seed_results.update(dt_metrics)
1174
+ torch.save(
1175
+ {
1176
+ "model_state_dict": model.state_dict(),
1177
+ "config": {
1178
+ "n_actions": int(cfg.action_dim),
1179
+ "embed_dim": int(cfg.baselines_dt_embed_dim),
1180
+ "n_heads": int(cfg.baselines_dt_n_heads),
1181
+ "n_layers": int(cfg.baselines_dt_n_layers),
1182
+ "context_len": int(cfg.baselines_dt_context_len),
1183
+ "max_ep_len": int(cfg.baselines_dt_max_ep_len),
1184
+ },
1185
+ },
1186
+ f"{model_dir}/dt_final_seed{seed}.pt",
1187
+ )
1188
+ else:
1189
+ # SB3 RL families and BC both need the parallel train env.
1190
+ train_env_fns = [
1191
+ _make_sb3_env_fn(env_id, cfg, log_dir)
1192
+ for env_id in list(cfg.id_envs) * n_envs_per_id
1193
+ ]
1194
+ train_env = SubprocVecEnv(train_env_fns)
1195
+ try:
1196
+ if algo == "bc":
1197
+ policy, bc_metrics = _train_bc(
1198
+ cfg, train_env, log, log_dir, seed,
1199
+ )
1200
+ seed_results.update(bc_metrics)
1201
+ policy.save(f"{model_dir}/bc_final_seed{seed}")
1202
+ else:
1203
+ sb3_model = _build_sb3_model(
1204
+ algo, train_env, cfg, seed,
1205
+ tb_log_dir=str(out_dir / "tb" / run_id),
1206
+ )
1207
+ callbacks = _build_sb3_callbacks(
1208
+ cfg, train_env, log_dir, model_dir,
1209
+ )
1210
+ logger.info(
1211
+ "Training %s for %d env-steps across %d ID maps "
1212
+ "(%d parallel envs)...",
1213
+ algo.upper(),
1214
+ int(cfg.total_timesteps),
1215
+ len(cfg.id_envs),
1216
+ train_env.num_envs,
1217
+ )
1218
+ sb3_model.learn(
1219
+ total_timesteps=int(cfg.total_timesteps),
1220
+ callback=callbacks,
1221
+ )
1222
+ sb3_model.save(f"{model_dir}/{algo}_final_seed{seed}")
1223
+ finally:
1224
+ train_env.close()
1225
+
1226
+ all_seed_results.append(seed_results)
1227
+ finally:
1228
+ log.finish()
1229
+ logger.info("%s seed %d complete.", algo.upper(), seed)
1230
+
1231
+ agg = _aggregate(all_seed_results)
1232
+ _print_aggregated(seeds, agg)
1233
+ if agg:
1234
+ _save_aggregated(agg_json_path, algo, seeds, all_seed_results, agg)
1235
+ # Final summary write to the project Logger so the aggregated
1236
+ # numbers land on a dedicated W&B run.
1237
+ summary_run_name = f"{algo}-multitask-summary"
1238
+ summary_log = _init_baseline_logger(cfg, summary_run_name)
1239
+ try:
1240
+ summary_payload: dict[str, float] = {}
1241
+ for key, stats in agg.items():
1242
+ summary_payload[f"summary/{key}/mean"] = stats["mean"]
1243
+ summary_payload[f"summary/{key}/std"] = stats["std"]
1244
+ summary_log.log_summary(summary_payload)
1245
+ finally:
1246
+ summary_log.finish()
1247
+ logger.info("All %d seed(s) complete.", len(seeds))
src/planners/collect.py ADDED
@@ -0,0 +1,588 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Data collection with DAgger and oracle replay.
2
+
3
+ Implements model episode rollout with replanning and DAgger-style
4
+ data collection using the BFS oracle and efficiency filter.
5
+ Supports parallel episode collection via ``ThreadPoolExecutor``.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import copy
11
+ import logging
12
+ import os
13
+ import random
14
+ import time
15
+ from concurrent.futures import ThreadPoolExecutor
16
+ from typing import TYPE_CHECKING
17
+ from types import SimpleNamespace
18
+
19
+ import numpy as np
20
+ import torch
21
+
22
+ from src.buffer import ReplayBuffer
23
+ from src.curriculum import DynamicCurriculum, efficiency_filter
24
+ from src.diffusion.sampling import greedy_sample, remdm_sample
25
+ from src.envs.minihack_env import collect_oracle_trajectory, make_env
26
+
27
+ if TYPE_CHECKING:
28
+ from src.models.denoiser import ModelEMA
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+
33
+ @torch.no_grad()
34
+ def run_model_episode(
35
+ model: torch.nn.Module,
36
+ env_id: str,
37
+ cfg: SimpleNamespace,
38
+ device: torch.device | str,
39
+ seed: int | None = None,
40
+ max_steps: int = 500,
41
+ des_file: str | None = None,
42
+ blind_global: bool = False,
43
+ stochastic: bool = False,
44
+ ) -> dict:
45
+ """Roll out the diffusion model on a single episode.
46
+
47
+ Maintains a ``seq_len``-length plan and replans every
48
+ ``cfg.replan_every`` steps.
49
+
50
+ Args:
51
+ model: Denoising model (eval mode).
52
+ env_id: MiniHack registry ID.
53
+ cfg: Config namespace.
54
+ device: Torch device.
55
+ seed: Optional RNG seed.
56
+ max_steps: Maximum episode length.
57
+ des_file: Optional ``.des`` file content for custom scenarios.
58
+ blind_global: If ``True``, zero out global map (local-only ablation).
59
+ stochastic: If ``True``, use stochastic ReMDM sampling (evaluation).
60
+ If ``False`` (default), use greedy argmax (DAgger collection).
61
+
62
+ Returns:
63
+ Dict with ``"local"`` ``[T,9,9]``, ``"global"`` ``[T,21,79]``,
64
+ ``"actions"`` ``[T]``, ``"won"`` bool, ``"steps"`` int,
65
+ ``"total_reward"`` float, ``"seed"`` int.
66
+ """
67
+ if seed is None:
68
+ seed = random.randint(0, 2**31 - 1)
69
+
70
+ _use_stochastic = stochastic
71
+
72
+ env = make_env(env_id, des_file, cfg)
73
+ try:
74
+ (local, glb), _info = env.reset(seed=seed)
75
+
76
+ locals_list = [local]
77
+ globals_list = [glb]
78
+ actions_list: list[int] = []
79
+ won = False
80
+ total_reward = 0.0
81
+ plan: torch.Tensor | None = None
82
+ step_in_plan = 0
83
+
84
+ model.eval()
85
+ for step_idx in range(max_steps):
86
+ # Replan when needed
87
+ if plan is None or step_in_plan >= cfg.replan_every:
88
+ local_t = torch.from_numpy(
89
+ local[np.newaxis]
90
+ ).long().to(device) # [1, 9, 9]
91
+ glb_t = torch.from_numpy(
92
+ glb[np.newaxis]
93
+ ).long().to(device) # [1, 21, 79]
94
+ if _use_stochastic:
95
+ plan = remdm_sample(
96
+ model, local_t, glb_t, cfg, device,
97
+ physics_aware=getattr(
98
+ cfg, "physics_aware_sampling", False,
99
+ ),
100
+ blind_global=blind_global,
101
+ )
102
+ else:
103
+ plan = greedy_sample(
104
+ model, local_t, glb_t, cfg, device,
105
+ blind_global=blind_global,
106
+ ) # [1, seq_len]
107
+ step_in_plan = 0
108
+
109
+ action = plan[0, step_in_plan].item()
110
+ action = max(0, min(action, cfg.action_dim - 1))
111
+ actions_list.append(action)
112
+ step_in_plan += 1
113
+
114
+ (local, glb), reward, terminated, truncated, info = env.step(
115
+ action,
116
+ )
117
+ total_reward += reward
118
+ locals_list.append(local)
119
+ globals_list.append(glb)
120
+
121
+ if info.get("won", False):
122
+ won = True
123
+ if terminated or truncated:
124
+ break
125
+ finally:
126
+ env.close()
127
+
128
+ # Trim trailing obs
129
+ locals_arr = np.stack(locals_list[:-1], axis=0).astype(np.int16)
130
+ globals_arr = np.stack(globals_list[:-1], axis=0).astype(np.int16)
131
+ actions_arr = np.array(actions_list, dtype=np.int64)
132
+
133
+ return {
134
+ "local": locals_arr,
135
+ "global": globals_arr,
136
+ "actions": actions_arr,
137
+ "won": won,
138
+ "steps": len(actions_list),
139
+ "total_reward": total_reward,
140
+ "seed": seed,
141
+ }
142
+
143
+
144
+ def _collect_episode_thread(
145
+ model: torch.nn.Module,
146
+ env_id: str,
147
+ seed: int,
148
+ cfg: SimpleNamespace,
149
+ ) -> dict | None:
150
+ """Thread worker: run one paired (model + oracle) episode.
151
+
152
+ Both NLE (C code) and PyTorch CPU inference release the GIL,
153
+ so true parallelism is achieved with threads. Each call uses
154
+ its own model copy and env instance.
155
+
156
+ Args:
157
+ model: CPU-resident eval-mode model (thread's own copy).
158
+ env_id: MiniHack environment ID.
159
+ seed: RNG seed for the episode.
160
+ cfg: Config namespace.
161
+
162
+ Returns:
163
+ Stats dict or ``None`` on failure.
164
+ """
165
+ try:
166
+ model_result = run_model_episode(
167
+ model, env_id, cfg, "cpu", seed,
168
+ )
169
+ oracle_result = collect_oracle_trajectory(env_id, seed, cfg)
170
+ oracle_steps = (
171
+ len(oracle_result["actions"]) if oracle_result else 999
172
+ )
173
+ return {
174
+ "env_id": env_id,
175
+ "seed": seed,
176
+ "model_won": model_result["won"],
177
+ "model_steps": model_result["steps"],
178
+ "oracle_steps": oracle_steps,
179
+ "oracle_result": oracle_result,
180
+ }
181
+ except Exception:
182
+ logger.error(
183
+ f"Thread worker failed for {env_id} seed={seed}", exc_info=True,
184
+ )
185
+ return None
186
+
187
+
188
+ class DataCollector:
189
+ """DAgger-style data collector.
190
+
191
+ Each iteration: sample an environment from the curriculum, run the
192
+ model, run the oracle on the same seed, apply efficiency filter, and
193
+ optionally add the oracle trajectory to the buffer.
194
+
195
+ Supports parallel episode collection via ``cfg.num_collection_workers``.
196
+
197
+ Uses a live reference to the ``ModelEMA`` object so the collector
198
+ always uses the latest EMA weights (synced before each rollout).
199
+
200
+ Args:
201
+ ema: EMA tracker holding shadow weights.
202
+ model: Training model (architecture template for EMA snapshot).
203
+ buffer: Replay buffer to populate.
204
+ curriculum: Dynamic environment curriculum.
205
+ cfg: Config namespace.
206
+ device: Torch device.
207
+ """
208
+
209
+ def __init__(
210
+ self,
211
+ ema: "ModelEMA",
212
+ model: torch.nn.Module,
213
+ buffer: ReplayBuffer,
214
+ curriculum: DynamicCurriculum,
215
+ cfg: SimpleNamespace,
216
+ device: torch.device | str,
217
+ ) -> None:
218
+ self._ema = ema
219
+ self._model_template = model
220
+ # Materialise an eval-mode copy; refreshed before each rollout
221
+ self.ema_model = ema.make_eval_model(model)
222
+ self.buffer = buffer
223
+ self.curriculum = curriculum
224
+ self.cfg = cfg
225
+ self.device = device
226
+ self._num_workers = getattr(cfg, "num_collection_workers", 0)
227
+ self._last_profile: dict[str, float] = {}
228
+ self._thread_pool: ThreadPoolExecutor | None = None
229
+ self._thread_models: list[torch.nn.Module] = []
230
+ if self._num_workers > 0:
231
+ n = min(self._num_workers, os.cpu_count() or 4)
232
+ self._thread_pool = ThreadPoolExecutor(max_workers=n)
233
+ # Create one CPU model copy per thread
234
+ for _ in range(n):
235
+ m = copy.deepcopy(model).cpu()
236
+ m.eval()
237
+ self._thread_models.append(m)
238
+
239
+ def _sync_ema(self) -> None:
240
+ """Copy latest EMA shadow weights into the eval model."""
241
+ self._ema.apply_to(self.ema_model)
242
+ self.ema_model.eval()
243
+
244
+ def collect_one_iteration(self) -> dict:
245
+ """Run one DAgger collection iteration (single episode).
246
+
247
+ Returns:
248
+ Stats dict with ``"env_id"``, ``"model_won"``,
249
+ ``"model_steps"``, ``"oracle_steps"``,
250
+ ``"added_to_buffer"`` keys.
251
+ """
252
+ self._sync_ema()
253
+ env_id = self.curriculum.sample_env()
254
+ seed = random.randint(0, 2**31 - 1)
255
+
256
+ # Model rollout
257
+ model_result = run_model_episode(
258
+ self.ema_model, env_id, self.cfg, self.device, seed,
259
+ )
260
+
261
+ # Oracle rollout (same seed)
262
+ oracle_result = collect_oracle_trajectory(
263
+ env_id, seed, self.cfg,
264
+ )
265
+ oracle_steps = (
266
+ len(oracle_result["actions"]) if oracle_result else 999
267
+ )
268
+
269
+ # Efficiency filter
270
+ add = efficiency_filter(
271
+ model_result["won"],
272
+ model_result["steps"],
273
+ oracle_steps,
274
+ self.cfg.efficiency_multiplier,
275
+ )
276
+
277
+ if add and oracle_result is not None:
278
+ self.buffer.add(oracle_result)
279
+
280
+ self.curriculum.update(env_id, model_result["won"])
281
+
282
+ return {
283
+ "env_id": env_id,
284
+ "model_won": model_result["won"],
285
+ "model_steps": model_result["steps"],
286
+ "oracle_steps": oracle_steps,
287
+ "added_to_buffer": add and oracle_result is not None,
288
+ }
289
+
290
+ def collect_batch_parallel(
291
+ self, n_episodes: int,
292
+ ) -> list[dict]:
293
+ """Collect multiple episodes in parallel using threads.
294
+
295
+ Both NLE env calls and PyTorch CPU inference release the GIL,
296
+ enabling true parallelism. Each thread uses a pre-allocated
297
+ CPU model copy. Weights are synced from EMA once per call.
298
+
299
+ Args:
300
+ n_episodes: Number of episodes to collect.
301
+
302
+ Returns:
303
+ List of per-episode stats dicts.
304
+ """
305
+ assert self._thread_pool is not None, (
306
+ "collect_batch_parallel requires num_collection_workers > 0"
307
+ )
308
+ self._sync_ema()
309
+
310
+ # Sync EMA weights to all thread-local CPU models
311
+ ema_sd = self.ema_model.state_dict()
312
+ cpu_sd = {k: v.cpu() for k, v in ema_sd.items()}
313
+ for tm in self._thread_models:
314
+ tm.load_state_dict(cpu_sd)
315
+ tm.eval()
316
+
317
+ # Build task list
318
+ tasks = []
319
+ for _ in range(n_episodes):
320
+ env_id = self.curriculum.sample_env()
321
+ seed = random.randint(0, 2**31 - 1)
322
+ tasks.append((env_id, seed))
323
+
324
+ # Round-robin assign models to tasks
325
+ n_models = len(self._thread_models)
326
+ futures = []
327
+ for i, (env_id, seed) in enumerate(tasks):
328
+ model = self._thread_models[i % n_models]
329
+ f = self._thread_pool.submit(
330
+ _collect_episode_thread, model, env_id, seed, self.cfg,
331
+ )
332
+ futures.append(f)
333
+
334
+ results = [f.result() for f in futures]
335
+
336
+ # Process results: efficiency filter + buffer add
337
+ stats_list = []
338
+ for res in results:
339
+ if res is None:
340
+ continue
341
+
342
+ add = efficiency_filter(
343
+ res["model_won"],
344
+ res["model_steps"],
345
+ res["oracle_steps"],
346
+ self.cfg.efficiency_multiplier,
347
+ )
348
+
349
+ oracle_result = res["oracle_result"]
350
+ if add and oracle_result is not None:
351
+ self.buffer.add(oracle_result)
352
+
353
+ self.curriculum.update(res["env_id"], res["model_won"])
354
+
355
+ stats_list.append({
356
+ "env_id": res["env_id"],
357
+ "model_won": res["model_won"],
358
+ "model_steps": res["model_steps"],
359
+ "oracle_steps": res["oracle_steps"],
360
+ "added_to_buffer": add and oracle_result is not None,
361
+ })
362
+
363
+ return stats_list
364
+
365
+ # ── GPU-batched collection ──────────────────────────────────
366
+
367
+ def collect_batch_gpu(self, n_episodes: int) -> list[dict]:
368
+ """Collect episodes with GPU-batched model inference.
369
+
370
+ Runs all model episodes with batched GPU forward passes
371
+ (B=n_episodes instead of B=1), then runs oracle rollouts
372
+ in parallel threads for efficiency filtering.
373
+
374
+ Args:
375
+ n_episodes: Number of episodes to collect.
376
+
377
+ Returns:
378
+ List of per-episode stats dicts.
379
+ """
380
+ self._sync_ema()
381
+ cfg = self.cfg
382
+ self._last_profile = {}
383
+
384
+ tasks = [
385
+ (self.curriculum.sample_env(), random.randint(0, 2**31 - 1))
386
+ for _ in range(n_episodes)
387
+ ]
388
+
389
+ # Phase 1: GPU-batched model rollouts
390
+ t0 = time.perf_counter()
391
+ model_results = self._run_model_episodes_batched(tasks)
392
+ model_time = time.perf_counter() - t0
393
+
394
+ # Phase 2: Oracle rollouts (threaded, CPU-only BFS)
395
+ t0 = time.perf_counter()
396
+ n_workers = min(n_episodes, os.cpu_count() or 4)
397
+ with ThreadPoolExecutor(max_workers=n_workers) as pool:
398
+ oracle_futures = [
399
+ pool.submit(
400
+ collect_oracle_trajectory, env_id, seed, cfg,
401
+ )
402
+ for env_id, seed in tasks
403
+ ]
404
+ oracle_results = [f.result() for f in oracle_futures]
405
+ oracle_time = time.perf_counter() - t0
406
+
407
+ # Phase 3: Efficiency filter + buffer add
408
+ stats_list: list[dict] = []
409
+ for (env_id, _seed), m_res, o_res in zip(
410
+ tasks, model_results, oracle_results,
411
+ ):
412
+ oracle_steps = (
413
+ len(o_res["actions"]) if o_res else 999
414
+ )
415
+ add = efficiency_filter(
416
+ m_res["won"],
417
+ m_res["steps"],
418
+ oracle_steps,
419
+ cfg.efficiency_multiplier,
420
+ )
421
+ if add and o_res is not None:
422
+ self.buffer.add(o_res)
423
+ self.curriculum.update(env_id, m_res["won"])
424
+ stats_list.append({
425
+ "env_id": env_id,
426
+ "model_won": m_res["won"],
427
+ "model_steps": m_res["steps"],
428
+ "oracle_steps": oracle_steps,
429
+ "added_to_buffer": add and o_res is not None,
430
+ })
431
+
432
+ self._last_profile["model_rollout_sec"] = model_time
433
+ self._last_profile["oracle_rollout_sec"] = oracle_time
434
+ return stats_list
435
+
436
+ @torch.no_grad()
437
+ def _run_model_episodes_batched(
438
+ self,
439
+ tasks: list[tuple[str, int]],
440
+ ) -> list[dict]:
441
+ """Run model episodes with batched GPU forward passes.
442
+
443
+ Creates one env per episode, steps them in lockstep, and
444
+ batches all replanning into single GPU forward passes
445
+ (B = number of active envs needing a replan).
446
+
447
+ Args:
448
+ tasks: List of ``(env_id, seed)`` pairs.
449
+
450
+ Returns:
451
+ List of trajectory dicts matching
452
+ ``run_model_episode`` output format.
453
+ """
454
+ cfg = self.cfg
455
+ device = self.device
456
+ model = self.ema_model
457
+ model.eval()
458
+ n = len(tasks)
459
+ max_steps = 500
460
+ K = getattr(
461
+ cfg, "diffusion_steps_collect", cfg.diffusion_steps_eval,
462
+ )
463
+ cs = cfg.crop_size
464
+
465
+ # Create and reset all envs
466
+ envs: list = []
467
+ cur_local = np.zeros((n, cs, cs), dtype=np.int16)
468
+ cur_global = np.zeros(
469
+ (n, cfg.map_h, cfg.map_w), dtype=np.int16,
470
+ )
471
+
472
+ t_reset = time.perf_counter()
473
+ for i, (env_id, seed) in enumerate(tasks):
474
+ env = make_env(env_id, None, cfg)
475
+ (local, glb), _ = env.reset(seed=seed)
476
+ envs.append(env)
477
+ cur_local[i] = local
478
+ cur_global[i] = glb
479
+ reset_time = time.perf_counter() - t_reset
480
+
481
+ # Pre-allocate history buffers
482
+ obs_local = np.zeros(
483
+ (n, max_steps + 1, cs, cs), dtype=np.int16,
484
+ )
485
+ obs_global = np.zeros(
486
+ (n, max_steps + 1, cfg.map_h, cfg.map_w),
487
+ dtype=np.int16,
488
+ )
489
+ act_buf = np.zeros((n, max_steps), dtype=np.int64)
490
+ obs_local[:, 0] = cur_local
491
+ obs_global[:, 0] = cur_global
492
+
493
+ # Per-episode state vectors
494
+ plans = np.zeros((n, cfg.seq_len), dtype=np.int64)
495
+ step_in_plan = np.zeros(n, dtype=np.int32)
496
+ need_replan = np.ones(n, dtype=bool)
497
+ done = np.zeros(n, dtype=bool)
498
+ won = np.zeros(n, dtype=bool)
499
+ total_reward = np.zeros(n, dtype=np.float64)
500
+ n_steps = np.zeros(n, dtype=np.int32)
501
+
502
+ inference_time = 0.0
503
+ env_step_time = 0.0
504
+
505
+ try:
506
+ for _ in range(max_steps):
507
+ # Batch replan on GPU
508
+ replan_idx = np.where(
509
+ need_replan & ~done,
510
+ )[0]
511
+ if len(replan_idx) > 0:
512
+ t0 = time.perf_counter()
513
+ local_t = torch.from_numpy(
514
+ cur_local[replan_idx],
515
+ ).long().to(device)
516
+ glb_t = torch.from_numpy(
517
+ cur_global[replan_idx],
518
+ ).long().to(device)
519
+ batch_plans = greedy_sample(
520
+ model, local_t, glb_t, cfg, device,
521
+ num_steps=K,
522
+ ).cpu().numpy()
523
+ plans[replan_idx] = batch_plans
524
+ step_in_plan[replan_idx] = 0
525
+ need_replan[replan_idx] = False
526
+ inference_time += time.perf_counter() - t0
527
+
528
+ # Step all active envs
529
+ t0 = time.perf_counter()
530
+ any_active = False
531
+ for i in range(n):
532
+ if done[i]:
533
+ continue
534
+ any_active = True
535
+
536
+ action = int(plans[i, step_in_plan[i]])
537
+ action = max(
538
+ 0, min(action, cfg.action_dim - 1),
539
+ )
540
+ act_buf[i, n_steps[i]] = action
541
+ step_in_plan[i] += 1
542
+ n_steps[i] += 1
543
+
544
+ if step_in_plan[i] >= cfg.replan_every:
545
+ need_replan[i] = True
546
+
547
+ obs, reward, term, trunc, info = (
548
+ envs[i].step(action)
549
+ )
550
+ local, glb = obs
551
+ total_reward[i] += reward
552
+ cur_local[i] = local
553
+ cur_global[i] = glb
554
+ obs_local[i, n_steps[i]] = local
555
+ obs_global[i, n_steps[i]] = glb
556
+
557
+ if info.get("won", False):
558
+ won[i] = True
559
+ if term or trunc:
560
+ done[i] = True
561
+ env_step_time += time.perf_counter() - t0
562
+
563
+ if not any_active:
564
+ break
565
+ finally:
566
+ for env in envs:
567
+ env.close()
568
+
569
+ # Build result dicts
570
+ results: list[dict] = []
571
+ for i in range(n):
572
+ T = int(n_steps[i])
573
+ results.append({
574
+ "local": obs_local[i, :T].copy(),
575
+ "global": obs_global[i, :T].copy(),
576
+ "actions": act_buf[i, :T].copy(),
577
+ "won": bool(won[i]),
578
+ "steps": T,
579
+ "total_reward": float(total_reward[i]),
580
+ "seed": tasks[i][1],
581
+ })
582
+
583
+ self._last_profile.update({
584
+ "env_reset_sec": reset_time,
585
+ "gpu_inference_sec": inference_time,
586
+ "env_step_sec": env_step_time,
587
+ })
588
+ return results
src/planners/collect_oracle.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Standalone BFS oracle data collection for offline training datasets.
2
+
3
+ Runs the BFS oracle across in-distribution MiniHack environments using
4
+ multiprocessing and saves the resulting trajectories in the dict format
5
+ expected by ``ReplayBuffer.load_offline_data()``.
6
+
7
+ Usage::
8
+
9
+ python main.py --mode collect
10
+ python main.py --mode collect collect_episodes_per_env=2000
11
+ python main.py --mode collect collect_output=data/small.pt
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import logging
17
+ import os
18
+ import time
19
+ from concurrent.futures import ProcessPoolExecutor, as_completed
20
+ from pathlib import Path
21
+ from types import SimpleNamespace
22
+
23
+ import torch
24
+
25
+ from src.envs.minihack_env import collect_oracle_trajectory
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ def _collect_single(
31
+ args: tuple[str, int, SimpleNamespace],
32
+ ) -> dict | None:
33
+ """Process-pool worker: collect one oracle trajectory.
34
+
35
+ Module-level function so ``ProcessPoolExecutor`` can pickle it.
36
+
37
+ Args:
38
+ args: ``(env_id, seed, cfg)`` tuple.
39
+
40
+ Returns:
41
+ Trajectory dict with ``"local"``, ``"global"``,
42
+ ``"actions"``, ``"env_id"`` keys, or ``None`` on failure.
43
+ """
44
+ env_id, seed, cfg = args
45
+ return collect_oracle_trajectory(env_id, seed, cfg)
46
+
47
+
48
+ def _format_eta(seconds: float) -> str:
49
+ """Format seconds into a human-readable ETA string.
50
+
51
+ Args:
52
+ seconds: Remaining time in seconds.
53
+
54
+ Returns:
55
+ Formatted string like ``"2m 30s"`` or ``"45s"``.
56
+ """
57
+ if seconds < 60:
58
+ return f"{seconds:.0f}s"
59
+ minutes = int(seconds // 60)
60
+ secs = int(seconds % 60)
61
+ return f"{minutes}m {secs:02d}s"
62
+
63
+
64
+ def run_collect(cfg: SimpleNamespace) -> None:
65
+ """Collect BFS oracle demonstrations and save as a .pt dataset.
66
+
67
+ Collects ``collect_episodes_per_env`` episodes per ID environment
68
+ using ``ProcessPoolExecutor`` for parallelism, then saves the
69
+ trajectories in the dict format consumed by
70
+ ``ReplayBuffer.load_offline_data()``.
71
+
72
+ The output file can be loaded directly by ``--mode offline``::
73
+
74
+ python main.py --mode collect
75
+ python main.py --mode offline --data data/dataset.pt
76
+
77
+ Args:
78
+ cfg: Config namespace. Reads ``collect_episodes_per_env``,
79
+ ``collect_num_workers``, ``collect_output``, ``id_envs``,
80
+ ``seed``.
81
+ """
82
+ eps_per_env: int = cfg.collect_episodes_per_env
83
+ max_workers: int = min(
84
+ cfg.collect_num_workers, os.cpu_count() or 4,
85
+ )
86
+ output_path: str = cfg.collect_output
87
+ id_envs: list[str] = cfg.id_envs
88
+ base_seed: int = cfg.seed if cfg.seed is not None else 0
89
+
90
+ total_episodes = eps_per_env * len(id_envs)
91
+ logger.info(
92
+ "Collecting %d oracle episodes "
93
+ "(%d per env, %d envs, %d workers)",
94
+ total_episodes, eps_per_env, len(id_envs), max_workers,
95
+ )
96
+
97
+ # Deterministic task list: (env_id, seed, cfg) per episode
98
+ tasks: list[tuple[str, int, SimpleNamespace]] = []
99
+ for env_idx, env_id in enumerate(id_envs):
100
+ for ep in range(eps_per_env):
101
+ seed = base_seed + env_idx * eps_per_env + ep
102
+ tasks.append((env_id, seed, cfg))
103
+
104
+ trajectories: list[dict] = []
105
+ per_env_count: dict[str, int] = {eid: 0 for eid in id_envs}
106
+ per_env_steps: dict[str, int] = {eid: 0 for eid in id_envs}
107
+ failures = 0
108
+ completed = 0
109
+ t_start = time.perf_counter()
110
+ log_interval = max(1, total_episodes // 50)
111
+
112
+ with ProcessPoolExecutor(max_workers=max_workers) as executor:
113
+ future_to_env: dict = {
114
+ executor.submit(_collect_single, task): task[0]
115
+ for task in tasks
116
+ }
117
+
118
+ for future in as_completed(future_to_env):
119
+ env_id = future_to_env[future]
120
+ completed += 1
121
+
122
+ try:
123
+ result = future.result()
124
+ except Exception:
125
+ logger.error(
126
+ "Worker crashed for %s", env_id, exc_info=True,
127
+ )
128
+ result = None
129
+
130
+ if result is not None:
131
+ trajectories.append(result)
132
+ per_env_count[env_id] += 1
133
+ per_env_steps[env_id] += len(result["actions"])
134
+ else:
135
+ failures += 1
136
+
137
+ if (
138
+ completed % log_interval == 0
139
+ or completed == total_episodes
140
+ ):
141
+ elapsed = time.perf_counter() - t_start
142
+ rate = completed / max(elapsed, 1e-6)
143
+ eta = (total_episodes - completed) / max(rate, 1e-6)
144
+ env_summary = " ".join(
145
+ f"{eid.split('-')[-2]}:{per_env_count[eid]}"
146
+ for eid in id_envs
147
+ )
148
+ logger.info(
149
+ " %d/%d (%.1f%%) %.1f eps/s ETA: %s | %s",
150
+ completed, total_episodes,
151
+ 100 * completed / total_episodes,
152
+ rate, _format_eta(eta), env_summary,
153
+ )
154
+
155
+ elapsed = time.perf_counter() - t_start
156
+
157
+ # Summary
158
+ total_steps = sum(per_env_steps.values())
159
+ logger.info("Collection complete in %.1fs", elapsed)
160
+ logger.info(
161
+ " Trajectories: %d (%d failures)",
162
+ len(trajectories), failures,
163
+ )
164
+ logger.info(" Total steps: %d", total_steps)
165
+ for env_id in id_envs:
166
+ n = per_env_count[env_id]
167
+ s = per_env_steps[env_id]
168
+ avg = s / max(n, 1)
169
+ logger.info(
170
+ " %s: %d eps, %d steps, avg %.1f steps/ep",
171
+ env_id, n, s, avg,
172
+ )
173
+
174
+ # Save in the dict format expected by ReplayBuffer.load_offline_data()
175
+ out = Path(output_path).resolve()
176
+ out.parent.mkdir(parents=True, exist_ok=True)
177
+
178
+ dataset: dict = {"trajectories": trajectories}
179
+ torch.save(dataset, str(out))
180
+
181
+ file_mb = out.stat().st_size / (1024 * 1024)
182
+ logger.info(
183
+ "Saved %d trajectories to %s (%.1f MB)",
184
+ len(trajectories), out, file_mb,
185
+ )
src/planners/inference.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Stateless evaluation runner.
2
+
3
+ Runs episodes using the diffusion model and collects per-environment
4
+ win rates, average rewards, and step counts. All episodes for a given
5
+ environment are rolled out in lockstep so that replanning calls are
6
+ batched into single GPU forward passes (B = n_episodes).
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import json
12
+ import logging
13
+ from datetime import datetime, timezone
14
+ from pathlib import Path
15
+ from types import SimpleNamespace
16
+
17
+ import numpy as np
18
+ import torch
19
+
20
+ from src.models.denoiser import ModelEMA, make_model
21
+ from src.planners.logging import Logger
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class Evaluator:
27
+ """Stateless evaluation runner.
28
+
29
+ Runs the model on a set of environments and returns aggregate
30
+ statistics per environment. Episodes within each environment are
31
+ executed in lockstep so replanning calls are GPU-batched.
32
+ """
33
+
34
+ @torch.no_grad()
35
+ def evaluate(
36
+ self,
37
+ env_ids: list[str],
38
+ model: torch.nn.Module,
39
+ n_episodes: int,
40
+ cfg: SimpleNamespace,
41
+ device: torch.device | str,
42
+ des_files: list[str] | None = None,
43
+ blind_global: bool = False,
44
+ ) -> dict[str, dict]:
45
+ """Evaluate *model* on each environment in *env_ids*.
46
+
47
+ All *n_episodes* for a given environment run in lockstep so
48
+ that replanning forward passes are batched (B = active envs
49
+ needing a replan).
50
+
51
+ Args:
52
+ env_ids: List of MiniHack environment IDs.
53
+ model: Denoising model (eval mode).
54
+ n_episodes: Episodes per environment.
55
+ cfg: Config namespace.
56
+ device: Torch device.
57
+ des_files: Optional list of ``.des`` file paths for custom
58
+ scenario evaluation. Each file yields one extra env entry
59
+ keyed by its filename stem.
60
+ blind_global: If ``True``, zero out global map observations
61
+ (local-only ablation mode).
62
+
63
+ Returns:
64
+ ``{env_id: {"win_rate", "wins", "avg_reward", "avg_steps",
65
+ "n_episodes"}}``
66
+ """
67
+ model.eval()
68
+ results: dict[str, dict] = {}
69
+
70
+ # Build list of (env_id, des_content) pairs
71
+ eval_targets: list[tuple[str, str | None]] = [
72
+ (eid, None) for eid in env_ids
73
+ ]
74
+ if des_files:
75
+ for des_path in des_files:
76
+ from pathlib import Path
77
+ stem = Path(des_path).stem
78
+ with open(des_path) as fh:
79
+ eval_targets.append((stem, fh.read()))
80
+
81
+ for env_id, des_content in eval_targets:
82
+ seeds = [
83
+ 42 + hash((env_id, ep)) % (2**31)
84
+ for ep in range(n_episodes)
85
+ ]
86
+ ep_results = self._run_episodes_batched(
87
+ model, env_id, n_episodes, cfg, device,
88
+ seeds=seeds,
89
+ des_content=des_content,
90
+ blind_global=blind_global,
91
+ )
92
+
93
+ wins = sum(1 for r in ep_results if r["won"])
94
+ total_reward = sum(r["total_reward"] for r in ep_results)
95
+ total_steps = sum(r["steps"] for r in ep_results)
96
+ n = max(len(ep_results), 1)
97
+ results[env_id] = {
98
+ "win_rate": wins / n,
99
+ "wins": wins,
100
+ "avg_reward": total_reward / n,
101
+ "avg_steps": total_steps / n,
102
+ "n_episodes": len(ep_results),
103
+ }
104
+
105
+ return results
106
+
107
+ @torch.no_grad()
108
+ def _run_episodes_batched(
109
+ self,
110
+ model: torch.nn.Module,
111
+ env_id: str,
112
+ n_episodes: int,
113
+ cfg: SimpleNamespace,
114
+ device: torch.device | str,
115
+ seeds: list[int],
116
+ des_content: str | None = None,
117
+ blind_global: bool = False,
118
+ ) -> list[dict]:
119
+ """Run episodes in lockstep with batched model inference.
120
+
121
+ Creates one environment per episode, steps them in lockstep,
122
+ and batches all replanning calls into single forward passes
123
+ (B = number of active envs needing a replan at each step).
124
+
125
+ Args:
126
+ model: Denoising model (eval mode).
127
+ env_id: MiniHack environment ID.
128
+ n_episodes: Number of episodes to run.
129
+ cfg: Config namespace.
130
+ device: Torch device.
131
+ seeds: Per-episode RNG seeds (length *n_episodes*).
132
+ des_content: Optional ``.des`` file content for custom
133
+ scenarios.
134
+ blind_global: If ``True``, zero out global map observations.
135
+
136
+ Returns:
137
+ List of per-episode dicts with ``"won"``, ``"steps"``,
138
+ ``"total_reward"`` keys. Failed episodes report
139
+ ``won=False``.
140
+ """
141
+ from src.diffusion.sampling import remdm_sample
142
+ from src.envs.minihack_env import make_env
143
+
144
+ n = n_episodes
145
+ max_steps = 500
146
+ cs = cfg.crop_size
147
+
148
+ # Create and reset all envs
149
+ envs: list = []
150
+ cur_local = np.zeros((n, cs, cs), dtype=np.int16)
151
+ cur_global = np.zeros(
152
+ (n, cfg.map_h, cfg.map_w), dtype=np.int16,
153
+ )
154
+ failed = np.zeros(n, dtype=bool)
155
+
156
+ for i in range(n):
157
+ try:
158
+ env = make_env(env_id, des_content, cfg)
159
+ (local, glb), _ = env.reset(seed=seeds[i])
160
+ envs.append(env)
161
+ cur_local[i] = local
162
+ cur_global[i] = glb
163
+ except Exception:
164
+ logger.warning(
165
+ "Failed to create env %s (ep %d)",
166
+ env_id, i, exc_info=True,
167
+ )
168
+ envs.append(None)
169
+ failed[i] = True
170
+
171
+ # Per-episode state vectors
172
+ plans = np.zeros((n, cfg.seq_len), dtype=np.int64)
173
+ step_in_plan = np.zeros(n, dtype=np.int32)
174
+ need_replan = np.ones(n, dtype=bool)
175
+ done = failed.copy()
176
+ won = np.zeros(n, dtype=bool)
177
+ total_reward = np.zeros(n, dtype=np.float64)
178
+ n_steps = np.zeros(n, dtype=np.int32)
179
+
180
+ try:
181
+ for _ in range(max_steps):
182
+ # Batch replan for active envs that need it
183
+ replan_idx = np.where(need_replan & ~done)[0]
184
+ if len(replan_idx) > 0:
185
+ local_t = torch.from_numpy(
186
+ cur_local[replan_idx],
187
+ ).long().to(device) # [B_r, cs, cs]
188
+ glb_t = torch.from_numpy(
189
+ cur_global[replan_idx],
190
+ ).long().to(device) # [B_r, map_h, map_w]
191
+ batch_plans = remdm_sample(
192
+ model, local_t, glb_t, cfg, device,
193
+ physics_aware=getattr(
194
+ cfg, "physics_aware_sampling", False,
195
+ ),
196
+ blind_global=blind_global,
197
+ ).cpu().numpy() # [B_r, seq_len]
198
+ plans[replan_idx] = batch_plans
199
+ step_in_plan[replan_idx] = 0
200
+ need_replan[replan_idx] = False
201
+
202
+ # Step all active envs
203
+ any_active = False
204
+ for i in range(n):
205
+ if done[i]:
206
+ continue
207
+ any_active = True
208
+
209
+ action = int(plans[i, step_in_plan[i]])
210
+ action = max(
211
+ 0, min(action, cfg.action_dim - 1),
212
+ )
213
+ step_in_plan[i] += 1
214
+ n_steps[i] += 1
215
+
216
+ if step_in_plan[i] >= cfg.replan_every:
217
+ need_replan[i] = True
218
+
219
+ try:
220
+ obs, reward, term, trunc, info = (
221
+ envs[i].step(action)
222
+ )
223
+ local, glb = obs
224
+ total_reward[i] += reward
225
+ cur_local[i] = local
226
+ cur_global[i] = glb
227
+
228
+ if info.get("won", False):
229
+ won[i] = True
230
+ if term or trunc:
231
+ done[i] = True
232
+ except Exception:
233
+ logger.warning(
234
+ "Episode %d step failed for %s",
235
+ i, env_id, exc_info=True,
236
+ )
237
+ done[i] = True
238
+
239
+ if not any_active:
240
+ break
241
+ finally:
242
+ for env in envs:
243
+ if env is not None:
244
+ env.close()
245
+
246
+ return [
247
+ {
248
+ "won": bool(won[i]),
249
+ "steps": int(n_steps[i]),
250
+ "total_reward": float(total_reward[i]),
251
+ }
252
+ for i in range(n)
253
+ ]
254
+
255
+
256
+ def format_eval_results(
257
+ results: dict[str, dict], label: str = "Eval",
258
+ ) -> str:
259
+ """Format evaluation results as an ASCII table.
260
+
261
+ Args:
262
+ results: Output of ``Evaluator.evaluate``.
263
+ label: Table header label.
264
+
265
+ Returns:
266
+ Formatted string.
267
+ """
268
+ lines = [f"{'=' * 60}", f" {label} Results", f"{'=' * 60}"]
269
+ lines.append(
270
+ f" {'Environment':<35} {'WinRate':>8} {'Steps':>8}"
271
+ )
272
+ lines.append(f" {'-' * 53}")
273
+ for env_id, stats in results.items():
274
+ wr = f"{stats['win_rate']:.2%}"
275
+ st = f"{stats['avg_steps']:.1f}"
276
+ lines.append(f" {env_id:<35} {wr:>8} {st:>8}")
277
+ lines.append(f"{'=' * 60}")
278
+ return "\n".join(lines)
279
+
280
+
281
+ def save_eval_json(
282
+ results: dict,
283
+ path: str,
284
+ metadata: dict | None = None,
285
+ ) -> None:
286
+ """Save evaluation results to a JSON file.
287
+
288
+ Args:
289
+ results: Evaluation results dict.
290
+ path: Output file path.
291
+ metadata: Optional extra metadata (e.g. iteration).
292
+ """
293
+ payload = {
294
+ "timestamp": datetime.now(timezone.utc).isoformat(),
295
+ "results": results,
296
+ }
297
+ if metadata:
298
+ payload["metadata"] = metadata
299
+ resolved = str(Path(path).resolve())
300
+ Path(resolved).parent.mkdir(parents=True, exist_ok=True)
301
+ try:
302
+ with open(resolved, "w") as f:
303
+ json.dump(payload, f, indent=2, default=str)
304
+ except Exception:
305
+ logger.error(f"Failed to save eval JSON to {resolved}", exc_info=True)
306
+
307
+
308
+ def run_inference(
309
+ cfg,
310
+ checkpoint_path: str,
311
+ env_ids: list[str] | None,
312
+ episodes: int,
313
+ output_path: str | None,
314
+ use_ema: bool,
315
+ log: Logger | None = None,
316
+ des_files: list[str] | None = None,
317
+ blind_global: bool = False,
318
+ ) -> None:
319
+ """Evaluate a checkpoint on specified environments."""
320
+
321
+ device = cfg.device
322
+ logger.info(f"Inference on {device}")
323
+
324
+ model = make_model(cfg).to(device)
325
+ ckpt = torch.load(
326
+ checkpoint_path, map_location=device, weights_only=False,
327
+ )
328
+
329
+ if "model_state_dict" in ckpt:
330
+ model.load_state_dict(ckpt["model_state_dict"])
331
+ if use_ema and "ema_state_dict" in ckpt:
332
+ ema = ModelEMA(model, decay=cfg.ema_decay)
333
+ ema.load_state_dict(ckpt["ema_state_dict"])
334
+ ema.apply_to(model)
335
+ else:
336
+ model.load_state_dict(ckpt)
337
+
338
+ model.eval()
339
+
340
+ if env_ids is None:
341
+ env_ids = cfg.id_envs + cfg.ood_envs
342
+
343
+ evaluator = Evaluator()
344
+ results = evaluator.evaluate(
345
+ env_ids, model, episodes, cfg, device,
346
+ des_files=des_files, blind_global=blind_global,
347
+ )
348
+
349
+ print(format_eval_results(results, label="Inference"))
350
+
351
+ if log is not None:
352
+ log.log_eval(results, step=0, prefix="inference")
353
+ log.log_summary(
354
+ {f"inference/{env_id}/win_rate": stats["win_rate"]
355
+ for env_id, stats in results.items()}
356
+ )
357
+
358
+ if output_path:
359
+ save_eval_json(results, output_path)
360
+ logger.info(f"Results saved to {output_path}")
src/planners/logging.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Centralised W&B and stdout logging.
2
+
3
+ Mirrors the Craftax logging conventions with metric namespaces:
4
+ ``diffusion/``, ``train/``, ``eval_id/``, ``eval_ood/``.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import logging
10
+ import torch
11
+ from typing import TYPE_CHECKING
12
+ from types import SimpleNamespace
13
+
14
+ if TYPE_CHECKING:
15
+ from wandb.sdk.wandb_run import Run as _WandbRun
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ def download_artifact(
21
+ artifact_ref: str, dst_dir: str = "artifacts",
22
+ ) -> str | None:
23
+ """Download a W&B artifact via the public API (no active run needed).
24
+
25
+ Args:
26
+ artifact_ref: Fully qualified artifact reference, e.g.
27
+ ``"entity/project/checkpoint-iter1000:latest"``.
28
+ dst_dir: Local directory to download into.
29
+
30
+ Returns:
31
+ Path to the ``.pth`` file inside the downloaded artifact
32
+ directory, or ``None`` on failure.
33
+ """
34
+ try:
35
+ import wandb
36
+ from pathlib import Path
37
+
38
+ api = wandb.Api()
39
+ artifact = api.artifact(artifact_ref)
40
+ artifact_dir = artifact.download(root=dst_dir)
41
+ pth_files = list(Path(artifact_dir).glob("*.pth"))
42
+ if not pth_files:
43
+ logger.error(
44
+ f"No .pth file found in artifact {artifact_ref}"
45
+ )
46
+ return None
47
+ path = str(pth_files[0])
48
+ logger.info(f"Downloaded artifact {artifact_ref} -> {path}")
49
+ return path
50
+ except Exception:
51
+ logger.error(
52
+ f"Failed to download artifact {artifact_ref}",
53
+ exc_info=True,
54
+ )
55
+ return None
56
+
57
+
58
+ def _auto_run_name(cfg: SimpleNamespace) -> str:
59
+ """Generate a descriptive W&B run name from key hyperparameters.
60
+
61
+ Format: ``seq{seq_len}_d{n_embd}_L{n_layer}_lr{dagger_lr}_bs{batch}_eta{eta}_{remask}``
62
+
63
+ Args:
64
+ cfg: Config namespace.
65
+
66
+ Returns:
67
+ A concise, human-readable run name.
68
+ """
69
+ parts = [
70
+ f"seq{cfg.seq_len}",
71
+ f"d{cfg.n_embd}",
72
+ f"L{cfg.n_layer}",
73
+ f"lr{cfg.dagger_lr:.0e}",
74
+ f"bs{cfg.dagger_batch_size}",
75
+ f"eta{cfg.eta}",
76
+ f"{cfg.remask_strategy}",
77
+ ]
78
+ if cfg.use_importance_weighting:
79
+ parts.append("subs")
80
+ if getattr(cfg, "physics_aware_sampling", False):
81
+ parts.append("phys")
82
+ if cfg.seed is not None:
83
+ parts.append(f"s{cfg.seed}")
84
+ return "_".join(parts)
85
+
86
+
87
+ class Logger:
88
+ """Centralised logger for W&B and stdout.
89
+
90
+ Args:
91
+ cfg: Config namespace with ``use_wandb``, ``wandb_project``,
92
+ ``wandb_entity``, ``seed``.
93
+ """
94
+
95
+ def __init__(self, cfg: SimpleNamespace) -> None:
96
+ self._use_wandb = cfg.use_wandb
97
+ self._run: _WandbRun | None = None
98
+ if self._use_wandb:
99
+ try:
100
+ import wandb
101
+ run_name = getattr(cfg, "wandb_run_name", None)
102
+ if not run_name:
103
+ run_name = _auto_run_name(cfg)
104
+ resume_id = getattr(cfg, "wandb_resume_id", None)
105
+ self._run = wandb.init(
106
+ project=cfg.wandb_project,
107
+ entity=cfg.wandb_entity or None,
108
+ name=run_name,
109
+ config=vars(cfg),
110
+ id=resume_id or None,
111
+ resume="must" if resume_id else "never",
112
+ )
113
+ # Define custom metric x-axes
114
+ wandb.define_metric("iteration")
115
+ for ns in (
116
+ "diffusion/*", "train/*", "perf/*", "speed/*",
117
+ "model/*",
118
+ "eval_id/*", "eval_ood/*",
119
+ "curriculum/*",
120
+ "ckpt_eval_id/*", "ckpt_eval_ood/*", "ckpt_eval/*",
121
+ "inference/*",
122
+ ):
123
+ wandb.define_metric(ns, step_metric="iteration")
124
+ except Exception:
125
+ logger.error("W&B init failed", exc_info=True)
126
+ self._use_wandb = False
127
+
128
+ def log_summary(self, metrics: dict) -> None:
129
+ """Write key/value pairs to the wandb run summary (final aggregates).
130
+
131
+ Args:
132
+ metrics: Flat ``{key: value}`` dict.
133
+ """
134
+ if self._use_wandb and self._run is not None:
135
+ try:
136
+ self._run.summary.update(metrics)
137
+ except Exception:
138
+ pass
139
+
140
+ def log(self, metrics: dict, step: int) -> None:
141
+ """Log a dict of metrics.
142
+
143
+ Args:
144
+ metrics: Flat ``{namespace/key: value}`` dict.
145
+ step: Global step index.
146
+ """
147
+ if self._use_wandb and self._run is not None:
148
+ try:
149
+ import wandb
150
+ # Include "iteration" so define_metric(step_metric="iteration") works
151
+ wandb.log({**metrics, "iteration": step}, step=step)
152
+ except Exception:
153
+ pass
154
+
155
+ # Stdout summary every 10 steps
156
+ if step % 10 == 0:
157
+ parts = [f"step={step}"]
158
+ for k, v in metrics.items():
159
+ if isinstance(v, float):
160
+ if abs(v) < 1e-3 and v != 0.0:
161
+ parts.append(f"{k}={v:.2e}")
162
+ else:
163
+ parts.append(f"{k}={v:.4f}")
164
+ else:
165
+ parts.append(f"{k}={v}")
166
+ logger.info(" ".join(parts))
167
+
168
+ def log_eval(
169
+ self, results: dict[str, dict], step: int, prefix: str,
170
+ ) -> None:
171
+ """Flatten evaluation results and log them.
172
+
173
+ Args:
174
+ results: ``{env_id: {"win_rate", ...}}``
175
+ step: Global step.
176
+ prefix: Metric namespace prefix (e.g. ``"eval_id"``).
177
+ """
178
+ flat: dict[str, float] = {}
179
+ for env_id, stats in results.items():
180
+ for key, val in stats.items():
181
+ if isinstance(val, (int, float)):
182
+ flat[f"{prefix}/{env_id}/{key}"] = val
183
+ self.log(flat, step=step)
184
+
185
+ def log_checkpoint_artifact(
186
+ self,
187
+ checkpoint_path: str,
188
+ config_path: str | None,
189
+ iteration: int,
190
+ metadata: dict | None = None,
191
+ artifact_name: str | None = None,
192
+ ) -> None:
193
+ """Upload a checkpoint as a W&B artifact with config attached.
194
+
195
+ Args:
196
+ checkpoint_path: Path to the ``.pth`` checkpoint file.
197
+ config_path: Path to the YAML config snapshot to attach.
198
+ If ``None``, only the checkpoint is uploaded.
199
+ iteration: Iteration number (used in the default artifact
200
+ name when ``artifact_name`` is not provided).
201
+ metadata: Optional metadata dict stored on the artifact.
202
+ artifact_name: Optional explicit artifact name. When
203
+ ``None``, defaults to ``f"checkpoint-iter{iteration}"``.
204
+ Offline BC passes a step-based name to avoid the
205
+ misleading "iter" prefix.
206
+ """
207
+ if not self._use_wandb or self._run is None:
208
+ return
209
+ try:
210
+ import wandb
211
+
212
+ name = artifact_name or f"checkpoint-iter{iteration}"
213
+ artifact = wandb.Artifact(
214
+ name=name,
215
+ type="model",
216
+ metadata=metadata or {},
217
+ )
218
+ artifact.add_file(checkpoint_path)
219
+ if config_path is not None:
220
+ artifact.add_file(config_path, name="config.yaml")
221
+ logged = self._run.log_artifact(artifact) # type: ignore[union-attr]
222
+ logged.wait() # block until upload completes
223
+ logger.info("W&B artifact uploaded: %s", name)
224
+ except Exception:
225
+ logger.error("W&B artifact upload failed", exc_info=True)
226
+
227
+ def finish(self) -> None:
228
+ """Close the W&B run if active."""
229
+ if self._use_wandb and self._run is not None:
230
+ try:
231
+ import wandb
232
+ wandb.finish()
233
+ except Exception:
234
+ pass
235
+
236
+
237
+ # ---------------------------------------------------------------------------
238
+ # Metric helper functions (used by both src/ and experiments/)
239
+ # ---------------------------------------------------------------------------
240
+
241
+
242
+ def gpu_memory_mb() -> float:
243
+ """Return peak GPU memory allocated in MB since last reset.
244
+
245
+ Returns:
246
+ Peak memory in MB, or 0.0 if CUDA is unavailable.
247
+ """
248
+ if torch.cuda.is_available():
249
+ return torch.cuda.max_memory_allocated() / (1024 * 1024)
250
+ return 0.0
251
+
252
+
253
+ def reset_gpu_memory_stats() -> None:
254
+ """Reset GPU peak memory stats for the current device."""
255
+ if torch.cuda.is_available():
256
+ torch.cuda.reset_peak_memory_stats()
257
+
258
+
259
+ def compute_param_norm(model: torch.nn.Module) -> float:
260
+ """Compute total L2 norm of all model parameters.
261
+
262
+ Args:
263
+ model: The model.
264
+
265
+ Returns:
266
+ Total L2 norm as a float.
267
+ """
268
+ total = 0.0
269
+ for p in model.parameters():
270
+ total += p.data.norm(2).item() ** 2
271
+ return total ** 0.5
272
+
273
+
274
+ def compute_param_drift(
275
+ model: torch.nn.Module,
276
+ ref_state: dict[str, torch.Tensor],
277
+ ) -> float:
278
+ """Compute L2 distance between current model params and a reference state.
279
+
280
+ Args:
281
+ model: Current model.
282
+ ref_state: Reference state_dict (e.g. pretrained weights).
283
+
284
+ Returns:
285
+ L2 distance as a float.
286
+ """
287
+ total = 0.0
288
+ for name, p in model.named_parameters():
289
+ if name in ref_state:
290
+ total += (p.data - ref_state[name]).norm(2).item() ** 2
291
+ return total ** 0.5
src/planners/offline.py ADDED
@@ -0,0 +1,727 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Offline behavioural cloning trainer.
2
+
3
+ Mirrors the Craftax ``make_train`` closure pattern. Trains the diffusion
4
+ model on pre-collected oracle demonstrations using the MDLM ELBO loss
5
+ with optional auxiliary goal loss.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import sys
11
+ import time
12
+ from pathlib import Path
13
+ import logging
14
+ from types import SimpleNamespace
15
+ from typing import Callable
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import yaml
20
+
21
+ from src.buffer import ReplayBuffer
22
+ from src.config import make_run_dir
23
+ from src.diffusion.forward import q_sample
24
+ from src.diffusion.loss import auxiliary_goal_loss, mdlm_loss
25
+ from src.diffusion.schedules import get_schedule
26
+ from src.models.denoiser import ModelEMA, make_model, try_compile
27
+ from src.planners.inference import Evaluator, save_eval_json
28
+ from src.planners.logging import (
29
+ Logger,
30
+ compute_param_drift,
31
+ compute_param_norm,
32
+ gpu_memory_mb,
33
+ reset_gpu_memory_stats,
34
+ )
35
+
36
+ logger = logging.getLogger(__name__)
37
+
38
+
39
+ def make_offline_trainer(cfg: SimpleNamespace) -> Callable:
40
+ """Build the offline BC training closure.
41
+
42
+ Args:
43
+ cfg: Config namespace.
44
+
45
+ Returns:
46
+ ``train_offline(model, ema_model, buffer, cfg, device) -> dict``
47
+ """
48
+ schedule_fn = get_schedule(cfg.noise_schedule)
49
+
50
+ def train_offline(
51
+ model: nn.Module,
52
+ ema_model: ModelEMA,
53
+ buffer: ReplayBuffer,
54
+ cfg: SimpleNamespace,
55
+ device: torch.device | str,
56
+ log: Logger | None = None,
57
+ raw_model: nn.Module | None = None,
58
+ resume_state: dict | None = None,
59
+ evaluator: Evaluator | None = None,
60
+ id_envs: list[str] | None = None,
61
+ ood_envs: list[str] | None = None,
62
+ ) -> dict:
63
+ """Run offline BC training.
64
+
65
+ Args:
66
+ model: Denoising model (may be torch.compiled).
67
+ ema_model: EMA tracker.
68
+ buffer: Replay buffer with offline data.
69
+ cfg: Config namespace.
70
+ device: Torch device.
71
+ log: Optional Logger for wandb and stdout metrics.
72
+ raw_model: Uncompiled model for EMA updates. If ``None``,
73
+ uses *model* directly.
74
+ resume_state: Checkpoint dict to resume from. If provided,
75
+ restores optimizer, scheduler, epoch, and step state.
76
+ evaluator: Optional ``Evaluator`` instance for periodic ID/OOD
77
+ evaluation. When ``None``, no eval is run during training.
78
+ id_envs: In-distribution environment IDs for periodic eval.
79
+ Required (non-empty) if ``evaluator`` is provided and
80
+ ``cfg.id_eval_every_timesteps > 0``.
81
+ ood_envs: Out-of-distribution environment IDs for periodic
82
+ eval. Required (non-empty) if ``evaluator`` is provided
83
+ and ``cfg.ood_eval_every_timesteps > 0``.
84
+
85
+ Returns:
86
+ Dict with ``"final_loss"`` and ``"loss_history"``.
87
+ """
88
+ _ema_source = raw_model if raw_model is not None else model
89
+ model.train()
90
+ optimizer = torch.optim.AdamW(
91
+ model.parameters(), lr=cfg.offline_lr,
92
+ weight_decay=cfg.weight_decay,
93
+ )
94
+
95
+ # Unified budget: `total_timesteps` counts env.step()-equivalent
96
+ # samples consumed during training. Each gradient step consumes
97
+ # `offline_batch_size` samples, so total grad steps derives
98
+ # directly from the budget and is independent of dataset size
99
+ # — this is what gives offline / DAgger / SB3 runs a common
100
+ # denominator when comparing curves.
101
+ total_grad_steps = max(
102
+ 1, cfg.total_timesteps // cfg.offline_batch_size,
103
+ )
104
+ # Optional override: pin offline gradient budget independently
105
+ # of `total_timesteps`. Used for paper-fair compute matching
106
+ # against a specific DAgger iteration count, e.g.
107
+ # `offline_total_grad_steps: 60000` to match 600 DAgger iters
108
+ # × `grad_steps_per_iteration: 100` AdamW updates regardless of
109
+ # what env-step budget DAgger consumed in those iters.
110
+ _grad_override = getattr(cfg, "offline_total_grad_steps", None)
111
+ if _grad_override is not None and _grad_override > 0:
112
+ total_grad_steps = int(_grad_override)
113
+ logger.info(
114
+ "Offline grad budget pinned via offline_total_grad_steps="
115
+ f"{total_grad_steps} (overrides total_timesteps)"
116
+ )
117
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
118
+ optimizer, T_max=total_grad_steps,
119
+ eta_min=cfg.offline_lr * 0.1,
120
+ )
121
+ # Checkpoint cadence — defaults to deriving from
122
+ # `checkpoint_every_timesteps` (env-step units → grad-step units
123
+ # via // batch_size). The optional `offline_checkpoint_every_grad_steps`
124
+ # override is used when an offline run is pinned via
125
+ # `offline_total_grad_steps` and needs an aligned cadence in
126
+ # grad-step units (env-step cadence diverges wildly from grad-step
127
+ # cadence between offline and DAgger because their sample-to-step
128
+ # ratios differ by ~50x).
129
+ _ckpt_grad_override = getattr(
130
+ cfg, "offline_checkpoint_every_grad_steps", None,
131
+ )
132
+ if _ckpt_grad_override is not None and _ckpt_grad_override > 0:
133
+ ckpt_every_step = int(_ckpt_grad_override)
134
+ else:
135
+ ckpt_every_step = (
136
+ cfg.checkpoint_every_timesteps // cfg.offline_batch_size
137
+ if cfg.checkpoint_every_timesteps > 0 else 0
138
+ )
139
+ # Eval cadence — same override pattern. Without this, an offline
140
+ # run pinned at e.g. 60k grad steps with the default
141
+ # `id_eval_every_timesteps=250000` would fire ~491 evals
142
+ # (250000 // 2048 = 122 grad steps per eval), which is
143
+ # impractically dense.
144
+ _eval_grad_override = getattr(
145
+ cfg, "offline_eval_every_grad_steps", None,
146
+ )
147
+ if _eval_grad_override is not None and _eval_grad_override > 0:
148
+ id_eval_every_env_steps = (
149
+ int(_eval_grad_override) * cfg.offline_batch_size
150
+ )
151
+ ood_eval_every_env_steps = id_eval_every_env_steps
152
+ else:
153
+ id_eval_every_env_steps = cfg.id_eval_every_timesteps
154
+ ood_eval_every_env_steps = cfg.ood_eval_every_timesteps
155
+ # Logging cadence. `offline_log_every` is the *minimum* cadence;
156
+ # the actual `log_every` is clamped on both ends so the number of
157
+ # log points stays in [~10, ~1000] regardless of run length:
158
+ #
159
+ # * Lower bound (`floor`): on very long runs, force `log_every`
160
+ # up so total log points cap at ~1000. Without this, a 600k
161
+ # grad-step run with the default `offline_log_every=10` would
162
+ # emit 60,000 W&B points — silent log spam.
163
+ #
164
+ # * Upper bound (`ceiling`): on very short runs (smoke, fast
165
+ # ablations) clamp `log_every` down so every run emits at
166
+ # least ~10 log points and curves stay comparable across
167
+ # budgets.
168
+ #
169
+ # When the configured value sits inside the [floor, ceiling]
170
+ # window (the common case), it is used unchanged.
171
+ _floor = max(1, total_grad_steps // 1000)
172
+ _ceiling = max(1, total_grad_steps // 10)
173
+ log_every = min(
174
+ _ceiling, max(_floor, cfg.offline_log_every),
175
+ )
176
+
177
+ # Restore optimizer/scheduler state if resuming
178
+ step = 0
179
+ if resume_state is not None:
180
+ if "optimizer_state_dict" in resume_state:
181
+ optimizer.load_state_dict(
182
+ resume_state["optimizer_state_dict"],
183
+ )
184
+ if "scheduler_state_dict" in resume_state:
185
+ scheduler.load_state_dict(
186
+ resume_state["scheduler_state_dict"],
187
+ )
188
+ step = resume_state.get("step", 0)
189
+ logger.info(
190
+ f"Resumed offline training from step {step}/"
191
+ f"{total_grad_steps}"
192
+ )
193
+
194
+ # AMP: enabled when use_amp=true and on CUDA
195
+ _use_amp = (
196
+ getattr(cfg, "use_amp", False)
197
+ and str(device).startswith("cuda")
198
+ )
199
+ scaler = torch.amp.GradScaler("cuda", enabled=_use_amp)
200
+
201
+ loss_history: list[float] = []
202
+ _batch_start = time.perf_counter()
203
+ last_ckpt_step = step
204
+ # Periodic eval anchors (env-step units, mirroring online.py).
205
+ # Snapping to current env_steps avoids accumulated drift across
206
+ # resumes; the next eval fires once another full interval has
207
+ # been processed since the resume point.
208
+ last_id_eval_env_steps = step * cfg.offline_batch_size
209
+ last_ood_eval_env_steps = step * cfg.offline_batch_size
210
+
211
+ # Snapshot of initial weights for `model/param_drift_from_init`.
212
+ # Mirrors online.py:Trainer.__init__.
213
+ _init_state = {
214
+ k: v.detach().clone()
215
+ for k, v in _ema_source.state_dict().items()
216
+ if v.is_floating_point()
217
+ }
218
+ # Counts logging emissions (not raw grad steps), used to gate
219
+ # the once-per-10-windows model health metrics analogously to
220
+ # online.py's `iteration % 10 == 0` cadence.
221
+ log_windows = 0
222
+ reset_gpu_memory_stats()
223
+
224
+ while step < total_grad_steps:
225
+ batch = buffer.sample(cfg.offline_batch_size)
226
+ if batch is None:
227
+ break
228
+ local_np, global_np, actions_np = batch
229
+ local_t = torch.from_numpy(local_np).long().to(device)
230
+ global_t = torch.from_numpy(global_np).long().to(device)
231
+ actions_t = torch.from_numpy(actions_np).long().to(device)
232
+
233
+ B = actions_t.shape[0]
234
+ t = torch.rand(B, device=device) # [B] in [0, 1)
235
+ t = t.clamp(1e-5, 1.0 - 1e-5)
236
+
237
+ zt = q_sample(
238
+ actions_t, t, cfg.mask_token, cfg.pad_token,
239
+ schedule_fn,
240
+ )
241
+ t_discrete = (
242
+ t * cfg.num_diffusion_steps
243
+ ).long().clamp(0, cfg.num_diffusion_steps - 1) # [B]
244
+
245
+ optimizer.zero_grad()
246
+ with torch.amp.autocast("cuda", enabled=_use_amp):
247
+ out = model(local_t, global_t, zt, t_discrete)
248
+
249
+ loss_diff = mdlm_loss(
250
+ out["actions"], actions_t, zt, t,
251
+ cfg.mask_token, cfg.pad_token, schedule_fn,
252
+ weight_clip=cfg.loss_weight_clip,
253
+ label_smoothing=cfg.label_smoothing,
254
+ use_importance_weighting=cfg.use_importance_weighting,
255
+ )
256
+
257
+ loss_aux = torch.tensor(0.0, device=device)
258
+ if "goal_pred" in out:
259
+ loss_aux = auxiliary_goal_loss(
260
+ out["goal_pred"], global_t,
261
+ )
262
+
263
+ loss = loss_diff + cfg.aux_loss_weight * loss_aux
264
+
265
+ scaler.scale(loss).backward()
266
+ scaler.unscale_(optimizer)
267
+ grad_norm = nn.utils.clip_grad_norm_(
268
+ model.parameters(), cfg.offline_grad_clip,
269
+ )
270
+ scaler.step(optimizer)
271
+ scaler.update()
272
+ scheduler.step()
273
+
274
+ ema_model.update(_ema_source)
275
+ loss_history.append(loss.item())
276
+ step += 1
277
+
278
+ # env-step equivalent: samples processed so far.
279
+ env_steps = step * cfg.offline_batch_size
280
+
281
+ if log is not None and step % log_every == 0:
282
+ step_time = time.perf_counter() - _batch_start
283
+ log_windows += 1
284
+
285
+ # Buffer state — for offline mode `offline_size` always
286
+ # equals `len(buffer)` (no online appends), so the
287
+ # online fraction is always 0.0. Logged anyway for
288
+ # symmetry with the DAgger curves.
289
+ buf_total = len(buffer)
290
+ buf_online_frac = (
291
+ (buf_total - buffer.offline_size) / max(buf_total, 1)
292
+ if hasattr(buffer, "offline_size")
293
+ else 0.0
294
+ )
295
+
296
+ # Throughput: samples processed in this logging window.
297
+ samples_window = log_every * cfg.offline_batch_size
298
+ samples_per_sec = samples_window / max(step_time, 1e-6)
299
+
300
+ _ema_source_ref = _ema_source
301
+ metrics = {
302
+ "diffusion/loss": loss.item(),
303
+ "diffusion/loss_diff": loss_diff.item(),
304
+ "diffusion/loss_aux": loss_aux.item(),
305
+ "train/buffer_size": buf_total,
306
+ "train/buffer_online_frac": buf_online_frac,
307
+ "train/lr": scheduler.get_last_lr()[0],
308
+ "train/env_steps": env_steps,
309
+ "train/progress": step / total_grad_steps,
310
+ "train/grad_norm": grad_norm.item(),
311
+ "speed/train_step_time_sec": step_time,
312
+ "speed/samples_per_sec": samples_per_sec,
313
+ "speed/gpu_memory_mb": gpu_memory_mb(),
314
+ # Legacy `perf/` mirror keys (kept for backward compat
315
+ # with existing dashboards / DAgger curves).
316
+ "perf/train_time_s": step_time,
317
+ "perf/grad_steps_per_sec": (
318
+ log_every / max(step_time, 1e-6)
319
+ ),
320
+ }
321
+ if hasattr(_ema_source_ref, "global_gate"):
322
+ gate_val = torch.sigmoid(
323
+ _ema_source_ref.global_gate,
324
+ ).item()
325
+ metrics["train/global_gate"] = gate_val
326
+ metrics["model/ema_gate_value"] = gate_val
327
+
328
+ # Model health (every 10 logging windows to keep overhead
329
+ # low — matches online.py's `iteration % 10 == 0`).
330
+ if log_windows % 10 == 1:
331
+ metrics["model/param_norm"] = compute_param_norm(
332
+ _ema_source_ref,
333
+ )
334
+ metrics["model/param_drift_from_init"] = (
335
+ compute_param_drift(
336
+ _ema_source_ref, _init_state,
337
+ )
338
+ )
339
+
340
+ log.log(metrics, step=step)
341
+ _batch_start = time.perf_counter()
342
+ reset_gpu_memory_stats()
343
+ logger.info(
344
+ f"step {step}/{total_grad_steps} "
345
+ f"(env_steps={env_steps}) loss={loss.item():.4f}"
346
+ )
347
+
348
+ # Periodic ID eval — env-step delta-check (mirrors
349
+ # online.py:277-305). Eval is opt-in: skipped entirely when
350
+ # no Evaluator was threaded through. The cadence variable
351
+ # already accounts for the optional
352
+ # `offline_eval_every_grad_steps` override.
353
+ if (
354
+ evaluator is not None
355
+ and id_envs
356
+ and id_eval_every_env_steps > 0
357
+ and env_steps - last_id_eval_env_steps
358
+ >= id_eval_every_env_steps
359
+ ):
360
+ eval_model = ema_model.make_eval_model(_ema_source)
361
+ results = evaluator.evaluate(
362
+ id_envs, eval_model, cfg.eval_episodes_per_env,
363
+ cfg, device,
364
+ )
365
+ if log is not None:
366
+ log.log_eval(results, step=step, prefix="eval_id")
367
+ mean_id_wr = (
368
+ sum(s["win_rate"] for s in results.values())
369
+ / len(results)
370
+ ) if results else 0.0
371
+ log.log(
372
+ {"eval_id/mean_win_rate": mean_id_wr},
373
+ step=step,
374
+ )
375
+ last_id_eval_env_steps = env_steps
376
+
377
+ # Periodic OOD eval — same delta-check pattern.
378
+ if (
379
+ evaluator is not None
380
+ and ood_envs
381
+ and ood_eval_every_env_steps > 0
382
+ and env_steps - last_ood_eval_env_steps
383
+ >= ood_eval_every_env_steps
384
+ ):
385
+ eval_model = ema_model.make_eval_model(_ema_source)
386
+ results = evaluator.evaluate(
387
+ ood_envs, eval_model, cfg.eval_episodes_per_env,
388
+ cfg, device,
389
+ )
390
+ if log is not None:
391
+ log.log_eval(results, step=step, prefix="eval_ood")
392
+ mean_ood_wr = (
393
+ sum(s["win_rate"] for s in results.values())
394
+ / len(results)
395
+ ) if results else 0.0
396
+ log.log(
397
+ {"eval_ood/mean_win_rate": mean_ood_wr},
398
+ step=step,
399
+ )
400
+ last_ood_eval_env_steps = env_steps
401
+
402
+ # Periodic step-level checkpoint (cadence derived from
403
+ # checkpoint_every_timesteps)
404
+ if (
405
+ ckpt_every_step > 0
406
+ and step - last_ckpt_step >= ckpt_every_step
407
+ ):
408
+ _save_offline_checkpoint(
409
+ _ema_source, ema_model, optimizer, scheduler,
410
+ step, cfg, log,
411
+ evaluator=evaluator,
412
+ id_envs=id_envs,
413
+ ood_envs=ood_envs,
414
+ device=device,
415
+ )
416
+ last_ckpt_step = step
417
+
418
+ if log is not None:
419
+ log.log_summary({
420
+ "offline/final_loss": loss_history[-1] if loss_history else 0.0,
421
+ "offline/total_steps": step,
422
+ "offline/total_timesteps": step * cfg.offline_batch_size,
423
+ })
424
+
425
+ return {
426
+ "final_loss": loss_history[-1] if loss_history else 0.0,
427
+ "loss_history": loss_history,
428
+ }
429
+
430
+ return train_offline
431
+
432
+
433
+ def _save_offline_checkpoint(
434
+ model: nn.Module,
435
+ ema_model: ModelEMA,
436
+ optimizer: torch.optim.Optimizer,
437
+ scheduler: torch.optim.lr_scheduler.LRScheduler,
438
+ step: int,
439
+ cfg: SimpleNamespace,
440
+ log: Logger | None,
441
+ evaluator: Evaluator | None = None,
442
+ id_envs: list[str] | None = None,
443
+ ood_envs: list[str] | None = None,
444
+ device: torch.device | str | None = None,
445
+ ) -> None:
446
+ """Save an offline training checkpoint, eval, and W&B artifact.
447
+
448
+ Mirrors the DAgger ``Trainer.save_checkpoint`` flow:
449
+ 1. Persist model + EMA + optimizer + scheduler state to disk.
450
+ 2. Save a YAML config snapshot alongside the checkpoint.
451
+ 3. Run an EMA-weight ID + OOD eval and emit ``ckpt_eval_*``
452
+ metrics + an eval JSON sidecar.
453
+ 4. Upload the checkpoint + config snapshot as a W&B artifact.
454
+
455
+ Steps 3 and 4 are skipped gracefully when ``evaluator`` / envs /
456
+ ``device`` are not provided, so callers that just want the bare
457
+ state dump still work.
458
+
459
+ Args:
460
+ model: Raw (uncompiled) model — used both for ``state_dict``
461
+ persistence and as the source argument to
462
+ ``ema_model.make_eval_model``.
463
+ ema_model: EMA tracker.
464
+ optimizer: Optimizer.
465
+ scheduler: LR scheduler.
466
+ step: Global gradient step count (used in filenames + metadata).
467
+ cfg: Config namespace.
468
+ log: Logger (used to extract W&B run ID, log eval metrics,
469
+ and upload artifact).
470
+ evaluator: Optional evaluator. When ``None``, the checkpoint
471
+ eval is skipped.
472
+ id_envs: ID env IDs for the checkpoint eval.
473
+ ood_envs: OOD env IDs for the checkpoint eval.
474
+ device: Torch device for the checkpoint eval.
475
+ """
476
+ wandb_run_id: str | None = None
477
+ if log is not None and log._use_wandb and log._run is not None:
478
+ wandb_run_id = log._run.id
479
+
480
+ ckpt_dir = Path(cfg.checkpoint_dir)
481
+ ckpt_dir.mkdir(parents=True, exist_ok=True)
482
+ path = ckpt_dir / f"offline_step{step}.pth"
483
+ torch.save(
484
+ {
485
+ "model_state_dict": model.state_dict(),
486
+ "ema_state_dict": ema_model.state_dict(),
487
+ "optimizer_state_dict": optimizer.state_dict(),
488
+ "scheduler_state_dict": scheduler.state_dict(),
489
+ "step": step,
490
+ "env_steps": step * cfg.offline_batch_size,
491
+ "wandb_run_id": wandb_run_id,
492
+ },
493
+ path,
494
+ )
495
+ logger.info(f"Offline checkpoint saved: {path}")
496
+
497
+ # Save config snapshot alongside checkpoint (mirrors DAgger).
498
+ config_path: Path | None = ckpt_dir / f"config_offline_step{step}.yaml"
499
+ try:
500
+ cfg_dict = {
501
+ k: v for k, v in vars(cfg).items() if not k.startswith("_")
502
+ }
503
+ with open(config_path, "w") as f:
504
+ yaml.dump(cfg_dict, f, default_flow_style=False)
505
+ except Exception:
506
+ logger.error("Failed to save config snapshot", exc_info=True)
507
+ config_path = None
508
+
509
+ # Checkpoint-time eval — mirrors Trainer.save_checkpoint in online.py.
510
+ # Skipped when the caller did not thread an evaluator through.
511
+ if (
512
+ evaluator is not None
513
+ and id_envs
514
+ and ood_envs
515
+ and device is not None
516
+ ):
517
+ try:
518
+ eval_model = ema_model.make_eval_model(model)
519
+ id_results = evaluator.evaluate(
520
+ id_envs, eval_model, cfg.checkpoint_eval_episodes,
521
+ cfg, device,
522
+ )
523
+ ood_results = evaluator.evaluate(
524
+ ood_envs, eval_model, cfg.checkpoint_eval_episodes,
525
+ cfg, device,
526
+ )
527
+
528
+ id_winrate = (
529
+ sum(s["win_rate"] for s in id_results.values())
530
+ / len(id_results)
531
+ ) if id_results else 0.0
532
+ ood_winrate = (
533
+ sum(s["win_rate"] for s in ood_results.values())
534
+ / len(ood_results)
535
+ ) if ood_results else 0.0
536
+
537
+ current_lr = scheduler.get_last_lr()[0]
538
+ training_meta = {
539
+ "step": step,
540
+ "env_steps": step * cfg.offline_batch_size,
541
+ "total_timesteps": cfg.total_timesteps,
542
+ "lr": current_lr,
543
+ "offline_batch_size": cfg.offline_batch_size,
544
+ "aux_loss_weight": cfg.aux_loss_weight,
545
+ "ema_decay": cfg.ema_decay,
546
+ "id_winrate": id_winrate,
547
+ "ood_winrate": ood_winrate,
548
+ "per_env_id": {
549
+ env_id: {
550
+ "win_rate": s["win_rate"],
551
+ "wins": s.get("wins", 0),
552
+ "avg_reward": s["avg_reward"],
553
+ "avg_steps": s["avg_steps"],
554
+ "n_episodes": s["n_episodes"],
555
+ }
556
+ for env_id, s in id_results.items()
557
+ },
558
+ "per_env_ood": {
559
+ env_id: {
560
+ "win_rate": s["win_rate"],
561
+ "wins": s.get("wins", 0),
562
+ "avg_reward": s["avg_reward"],
563
+ "avg_steps": s["avg_steps"],
564
+ "n_episodes": s["n_episodes"],
565
+ }
566
+ for env_id, s in ood_results.items()
567
+ },
568
+ }
569
+
570
+ json_path = ckpt_dir / f"eval_offline_step{step}.json"
571
+ save_eval_json(
572
+ {"id": id_results, "ood": ood_results},
573
+ str(json_path),
574
+ metadata=training_meta,
575
+ )
576
+
577
+ if log is not None:
578
+ log.log_eval(
579
+ id_results, step=step, prefix="ckpt_eval_id",
580
+ )
581
+ log.log_eval(
582
+ ood_results, step=step, prefix="ckpt_eval_ood",
583
+ )
584
+ log.log(
585
+ {
586
+ "ckpt_eval/id_winrate": id_winrate,
587
+ "ckpt_eval/ood_winrate": ood_winrate,
588
+ },
589
+ step=step,
590
+ )
591
+ log.log_summary({
592
+ f"ckpt_offline_step{step}/id_winrate": id_winrate,
593
+ f"ckpt_offline_step{step}/ood_winrate": ood_winrate,
594
+ })
595
+ except Exception:
596
+ logger.error(
597
+ "Offline checkpoint eval failed", exc_info=True,
598
+ )
599
+
600
+ # W&B artifact upload (no-op when wandb is not initialised).
601
+ if log is not None:
602
+ log.log_checkpoint_artifact(
603
+ checkpoint_path=str(path),
604
+ config_path=str(config_path) if config_path else None,
605
+ iteration=step,
606
+ metadata={"step": step, "mode": "offline"},
607
+ artifact_name=f"checkpoint-offline-step{step}",
608
+ )
609
+
610
+
611
+ def load_offline_dataset(
612
+ path: str | None, cfg: SimpleNamespace,
613
+ ) -> dict | None:
614
+ """Load an offline dataset from disk.
615
+
616
+ Args:
617
+ path: Path to a ``.pt`` file, or ``None``.
618
+ cfg: Config namespace (unused, reserved for future).
619
+
620
+ Returns:
621
+ Loaded dict or ``None``.
622
+ """
623
+ if path is None:
624
+ return None
625
+ try:
626
+ import torch as _torch
627
+ return _torch.load(path, map_location="cpu", weights_only=False)
628
+ except Exception:
629
+ logger.error(f"Failed to load dataset from {path}", exc_info=True)
630
+ return None
631
+
632
+
633
+ def run_offline(
634
+ cfg: SimpleNamespace,
635
+ data_path: str | None,
636
+ checkpoint_path: str | None = None,
637
+ ) -> None:
638
+ """Offline BC training on pre-collected data.
639
+
640
+ Args:
641
+ cfg: Config namespace.
642
+ data_path: Path to ``.pt`` dataset file.
643
+ checkpoint_path: Optional checkpoint to resume from. Restores
644
+ model, EMA, optimizer, scheduler, and W&B run for curve
645
+ continuity.
646
+ """
647
+ make_run_dir(cfg, tag="offline")
648
+
649
+ device = cfg.device
650
+ logger.info(f"Offline BC on {device}")
651
+
652
+ data = load_offline_dataset(data_path, cfg)
653
+ if data is None:
654
+ logger.error("No dataset provided or failed to load. Exiting.")
655
+ sys.exit(1)
656
+
657
+ # Offline buffer must hold the full pre-collected dataset. DAgger's
658
+ # `buffer_capacity` (typically 10k) would silently FIFO-evict 99% of
659
+ # the dataset, so honour the optional `offline_buffer_capacity`
660
+ # override when present.
661
+ _offline_buf_cap = (
662
+ getattr(cfg, "offline_buffer_capacity", None) or cfg.buffer_capacity
663
+ )
664
+ buffer = ReplayBuffer(_offline_buf_cap, cfg.seq_len, cfg.pad_token)
665
+ buffer.load_offline_data(data, cfg.id_envs)
666
+ logger.info(f"Loaded {len(buffer)} windows")
667
+
668
+ if len(buffer) == 0:
669
+ logger.error(
670
+ "Buffer is empty after loading dataset — no trajectories matched "
671
+ f"id_envs={cfg.id_envs}. Exiting."
672
+ )
673
+ sys.exit(1)
674
+
675
+ raw_model = make_model(cfg).to(device)
676
+
677
+ # torch.compile: wrap for training only; shares params with raw_model
678
+ model = try_compile(raw_model, cfg)
679
+
680
+ ema = ModelEMA(raw_model, decay=cfg.ema_decay)
681
+
682
+ # If resuming, extract W&B run ID from checkpoint before Logger init
683
+ resume_state: dict | None = None
684
+ if checkpoint_path:
685
+ resume_state = torch.load(
686
+ checkpoint_path, map_location=device, weights_only=False,
687
+ )
688
+ raw_model.load_state_dict(resume_state["model_state_dict"])
689
+ ema.load_state_dict(resume_state["ema_state_dict"])
690
+ resume_id = getattr(cfg, "wandb_resume_id", None)
691
+ if not resume_id:
692
+ saved_id = resume_state.get("wandb_run_id")
693
+ if saved_id:
694
+ cfg.wandb_resume_id = saved_id
695
+ logger.info(f"W&B run ID from checkpoint: {saved_id}")
696
+
697
+ log = Logger(cfg)
698
+ evaluator = Evaluator()
699
+ train_fn = make_offline_trainer(cfg)
700
+ result = train_fn(
701
+ model, ema, buffer, cfg, device, log=log,
702
+ raw_model=raw_model, resume_state=resume_state,
703
+ evaluator=evaluator,
704
+ id_envs=cfg.id_envs,
705
+ ood_envs=cfg.ood_envs,
706
+ )
707
+ logger.info(
708
+ f"Offline training done. Final loss: {result['final_loss']:.4f}"
709
+ )
710
+
711
+ # Save final checkpoint for downstream compatibility (DAgger, inference)
712
+ wandb_run_id: str | None = None
713
+ if log._use_wandb and log._run is not None:
714
+ wandb_run_id = log._run.id
715
+
716
+ ckpt_dir = Path(cfg.checkpoint_dir)
717
+ path = ckpt_dir / "offline_final.pth"
718
+ torch.save(
719
+ {
720
+ "model_state_dict": raw_model.state_dict(),
721
+ "ema_state_dict": ema.state_dict(),
722
+ "wandb_run_id": wandb_run_id,
723
+ },
724
+ path,
725
+ )
726
+ logger.info(f"Saved offline checkpoint: {path}")
727
+ log.finish()
src/planners/online.py ADDED
@@ -0,0 +1,721 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """DAgger online training loop.
2
+
3
+ Orchestrates the full DAgger pipeline: collect data via model + oracle,
4
+ train on buffer, evaluate periodically, and checkpoint.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import logging
10
+ import random
11
+ import time
12
+ from pathlib import Path
13
+ from types import SimpleNamespace
14
+
15
+ import numpy as np
16
+ import torch
17
+ import torch.nn as nn
18
+ import yaml
19
+
20
+ from src.buffer import ReplayBuffer
21
+ from src.config import make_run_dir
22
+ from src.diffusion.forward import q_sample
23
+ from src.diffusion.loss import auxiliary_goal_loss, mdlm_loss
24
+ from src.diffusion.schedules import get_schedule
25
+ from src.models.denoiser import ModelEMA, make_model, try_compile
26
+ from src.planners.collect import DataCollector
27
+ from src.planners.inference import Evaluator, save_eval_json
28
+ from src.planners.logging import (
29
+ Logger, gpu_memory_mb, reset_gpu_memory_stats,
30
+ compute_param_norm, compute_param_drift,
31
+ )
32
+ from src.curriculum import DynamicCurriculum
33
+ from src.envs.minihack_env import collect_oracle_trajectory
34
+
35
+ logger = logging.getLogger(__name__)
36
+
37
+
38
+ class Trainer:
39
+ """Full DAgger training loop.
40
+
41
+ Args:
42
+ model: Denoising model.
43
+ ema_model: EMA tracker.
44
+ optimizer: Torch optimizer.
45
+ scheduler: Optional LR scheduler.
46
+ buffer: Replay buffer.
47
+ collector: DAgger data collector.
48
+ evaluator: Evaluation runner.
49
+ log: Centralised logger.
50
+ cfg: Config namespace.
51
+ device: Torch device.
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ model: nn.Module,
57
+ ema_model: ModelEMA,
58
+ optimizer: torch.optim.Optimizer,
59
+ scheduler: torch.optim.lr_scheduler.LRScheduler | None,
60
+ buffer: ReplayBuffer,
61
+ collector: DataCollector,
62
+ evaluator: Evaluator,
63
+ log: Logger,
64
+ cfg: SimpleNamespace,
65
+ device: torch.device | str,
66
+ raw_model: nn.Module | None = None,
67
+ ) -> None:
68
+ self.model = model
69
+ # raw_model is the uncompiled model used for eval deep-copies.
70
+ # When torch.compile is off, raw_model is the same as model.
71
+ self._raw_model = raw_model if raw_model is not None else model
72
+ self.ema_model = ema_model
73
+ self.optimizer = optimizer
74
+ self.scheduler = scheduler
75
+ self.buffer = buffer
76
+ self.collector = collector
77
+ self.evaluator = evaluator
78
+ self.log = log
79
+ self.cfg = cfg
80
+ self.device = device
81
+ self._schedule_fn = get_schedule(cfg.noise_schedule)
82
+ # Snapshot of initial weights for param drift tracking
83
+ self._init_state = {
84
+ k: v.clone() for k, v in self._raw_model.state_dict().items()
85
+ if v.is_floating_point()
86
+ }
87
+ # AMP scaler: enabled only when use_amp=true and on CUDA
88
+ self._use_amp = (
89
+ getattr(cfg, "use_amp", False) and str(device).startswith("cuda")
90
+ )
91
+ self._scaler = torch.amp.GradScaler("cuda", enabled=self._use_amp)
92
+
93
+ # ── Main loop ────────────────────────────────────────────────
94
+
95
+ def train(
96
+ self, start_iter: int = 0, start_env_steps: int = 0,
97
+ ) -> None:
98
+ """Run the DAgger training loop.
99
+
100
+ The budget is ``cfg.total_timesteps`` — total env.step() calls
101
+ across model + oracle rollouts. Iteration count is derived; it
102
+ depends on how many env steps each iteration consumes (which in
103
+ turn depends on episode length and efficiency filter outcomes).
104
+
105
+ Args:
106
+ start_iter: Iteration index to resume from (for logging).
107
+ start_env_steps: Cumulative env steps already consumed.
108
+ """
109
+ cfg = self.cfg
110
+ env_steps_total = start_env_steps
111
+ iteration = start_iter
112
+ last_id_eval_step = start_env_steps
113
+ last_ood_eval_step = start_env_steps
114
+ last_ckpt_step = start_env_steps
115
+
116
+ while env_steps_total < cfg.total_timesteps:
117
+ reset_gpu_memory_stats()
118
+ iter_start = time.perf_counter()
119
+
120
+ # 1. Collect N episodes per iteration
121
+ n_eps = getattr(cfg, "episodes_per_iteration", 1)
122
+ num_workers = getattr(cfg, "num_collection_workers", 0)
123
+ model_wins = 0
124
+ added_total = 0
125
+ # Accumulators across all n_eps episodes — must be summed,
126
+ # NOT taken from a single (last) episode, otherwise the
127
+ # unified env-step budget undercounts by ~n_eps×.
128
+ model_steps_iter = 0
129
+ oracle_steps_iter = 0
130
+ last_env_id: str = ""
131
+
132
+ collect_start = time.perf_counter()
133
+ use_gpu_batch = (
134
+ str(self.device).startswith("cuda") and n_eps > 1
135
+ )
136
+ if use_gpu_batch:
137
+ # GPU-batched collection (all envs in lockstep)
138
+ batch_stats = self.collector.collect_batch_gpu(n_eps)
139
+ for s in batch_stats:
140
+ model_wins += int(s["model_won"])
141
+ added_total += int(s["added_to_buffer"])
142
+ model_steps_iter += int(s["model_steps"])
143
+ oracle_steps_iter += int(s["oracle_steps"])
144
+ last_env_id = s.get("env_id", last_env_id)
145
+ elif num_workers > 0 and n_eps > 1:
146
+ # Threaded CPU collection (fallback)
147
+ batch_stats = self.collector.collect_batch_parallel(
148
+ n_eps,
149
+ )
150
+ for s in batch_stats:
151
+ model_wins += int(s["model_won"])
152
+ added_total += int(s["added_to_buffer"])
153
+ model_steps_iter += int(s["model_steps"])
154
+ oracle_steps_iter += int(s["oracle_steps"])
155
+ last_env_id = s.get("env_id", last_env_id)
156
+ else:
157
+ # Sequential collection (reference behaviour)
158
+ for _ in range(n_eps):
159
+ s = self.collector.collect_one_iteration()
160
+ model_wins += int(s["model_won"])
161
+ added_total += int(s["added_to_buffer"])
162
+ model_steps_iter += int(s["model_steps"])
163
+ oracle_steps_iter += int(s["oracle_steps"])
164
+ last_env_id = s.get("env_id", last_env_id)
165
+ collect_time = time.perf_counter() - collect_start
166
+
167
+ collect_stats = {
168
+ "env_id": last_env_id,
169
+ "model_won": model_wins,
170
+ "added_to_buffer": added_total,
171
+ "model_steps": model_steps_iter,
172
+ "oracle_steps": oracle_steps_iter,
173
+ }
174
+
175
+ # Advance the unified env-step budget. Both model and oracle
176
+ # rollouts consume real env.step() calls (the oracle rollout
177
+ # runs in its own env instance in collect_oracle_trajectory),
178
+ # so both contribute to the budget.
179
+ iter_env_steps = model_steps_iter + oracle_steps_iter
180
+ env_steps_total += iter_env_steps
181
+
182
+ # 2. Gradient steps (EMA updated after each step)
183
+ self.model.train()
184
+ step_metrics: list[dict[str, float]] = []
185
+ train_start = time.perf_counter()
186
+ for _ in range(cfg.grad_steps_per_iteration):
187
+ m = self._train_step()
188
+ step_metrics.append(m)
189
+ self.ema_model.update(self._raw_model)
190
+ train_time = time.perf_counter() - train_start
191
+
192
+ iter_time = time.perf_counter() - iter_start
193
+
194
+ # 4. Log
195
+ n_steps = len(step_metrics) or 1
196
+ avg_loss = sum(m["loss"] for m in step_metrics) / n_steps
197
+ avg_loss_diff = sum(m["loss_diff"] for m in step_metrics) / n_steps
198
+ avg_loss_aux = sum(m["loss_aux"] for m in step_metrics) / n_steps
199
+ avg_grad_norm = sum(m["grad_norm"] for m in step_metrics) / n_steps
200
+ current_lr = (
201
+ self.scheduler.get_last_lr()[0]
202
+ if self.scheduler is not None
203
+ else self.cfg.dagger_lr
204
+ )
205
+
206
+ # Global gate value (how open is the global stream)
207
+ gate_val = None
208
+ if hasattr(self._raw_model, "global_gate"):
209
+ gate_val = torch.sigmoid(
210
+ self._raw_model.global_gate
211
+ ).item()
212
+
213
+ # Buffer online fraction
214
+ buf_total = len(self.buffer)
215
+ buf_online_frac = (
216
+ (buf_total - self.buffer.offline_size) / max(buf_total, 1)
217
+ if hasattr(self.buffer, "offline_size")
218
+ else 0.0
219
+ )
220
+
221
+ # Samples per second
222
+ total_samples = n_steps * cfg.dagger_batch_size
223
+ samples_per_sec = total_samples / max(train_time, 1e-6)
224
+
225
+ # Env steps per second (uses the iter-summed total, not a
226
+ # single episode — same bug class as the env-step budget).
227
+ env_steps_per_sec = iter_env_steps / max(collect_time, 1e-6)
228
+
229
+ metrics = {
230
+ "diffusion/loss": avg_loss,
231
+ "diffusion/loss_diff": avg_loss_diff,
232
+ "diffusion/loss_aux": avg_loss_aux,
233
+ "train/buffer_size": buf_total,
234
+ "train/buffer_online_frac": buf_online_frac,
235
+ "train/model_won": int(collect_stats["model_won"]),
236
+ "train/added_to_buffer": int(
237
+ collect_stats["added_to_buffer"]
238
+ ),
239
+ "train/episodes_collected": n_eps,
240
+ "train/model_steps": collect_stats["model_steps"],
241
+ "train/oracle_steps": collect_stats["oracle_steps"],
242
+ "train/efficiency_ratio": (
243
+ collect_stats["model_steps"]
244
+ / max(collect_stats["oracle_steps"], 1)
245
+ ),
246
+ "train/lr": current_lr,
247
+ "train/grad_norm": avg_grad_norm,
248
+ "train/env_steps": env_steps_total,
249
+ "train/progress": env_steps_total / cfg.total_timesteps,
250
+ "speed/iter_time_sec": iter_time,
251
+ "speed/collect_time_sec": collect_time,
252
+ "speed/train_step_time_sec": train_time,
253
+ "speed/samples_per_sec": samples_per_sec,
254
+ "speed/env_steps_per_sec": env_steps_per_sec,
255
+ "speed/gpu_memory_mb": gpu_memory_mb(),
256
+ # Keep old perf/ keys for backward compat
257
+ "perf/iter_time_s": iter_time,
258
+ "perf/collect_time_s": collect_time,
259
+ "perf/train_time_s": train_time,
260
+ "perf/grad_steps_per_sec": (
261
+ cfg.grad_steps_per_iteration / max(train_time, 1e-6)
262
+ ),
263
+ }
264
+ if gate_val is not None:
265
+ metrics["train/global_gate"] = gate_val
266
+ metrics["model/ema_gate_value"] = gate_val
267
+
268
+ # Model health (every 10 iters to avoid overhead)
269
+ if iteration % 10 == 0:
270
+ metrics["model/param_norm"] = compute_param_norm(
271
+ self._raw_model
272
+ )
273
+ metrics["model/param_drift_from_init"] = compute_param_drift(
274
+ self._raw_model, self._init_state
275
+ )
276
+
277
+ # Profile breakdown from GPU-batched collection
278
+ _profile = getattr(self.collector, "_last_profile", {})
279
+ for _pk, _pv in _profile.items():
280
+ metrics[f"profile/{_pk}"] = _pv
281
+
282
+ self.log.log(metrics, step=iteration)
283
+
284
+ # 5. ID eval — triggered when env-step delta crosses threshold
285
+ if (
286
+ cfg.id_eval_every_timesteps > 0
287
+ and env_steps_total - last_id_eval_step
288
+ >= cfg.id_eval_every_timesteps
289
+ ):
290
+ eval_model = self.ema_model.make_eval_model(self._raw_model)
291
+ results = self.evaluator.evaluate(
292
+ cfg.id_envs,
293
+ eval_model,
294
+ cfg.eval_episodes_per_env,
295
+ cfg,
296
+ self.device,
297
+ )
298
+ self.log.log_eval(results, step=iteration, prefix="eval_id")
299
+ mean_id_wr = float(np.mean(
300
+ [s["win_rate"] for s in results.values()]
301
+ )) if results else 0.0
302
+ self.log.log(
303
+ {
304
+ "eval_id/mean_win_rate": mean_id_wr,
305
+ **{
306
+ f"curriculum/{env_id}/win_rate":
307
+ self.collector.curriculum.win_rate(env_id)
308
+ for env_id in self.cfg.id_envs
309
+ },
310
+ },
311
+ step=iteration,
312
+ )
313
+ last_id_eval_step = env_steps_total
314
+
315
+ # 6. OOD eval — env-step-triggered
316
+ if (
317
+ cfg.ood_eval_every_timesteps > 0
318
+ and env_steps_total - last_ood_eval_step
319
+ >= cfg.ood_eval_every_timesteps
320
+ ):
321
+ eval_model = self.ema_model.make_eval_model(self._raw_model)
322
+ results = self.evaluator.evaluate(
323
+ cfg.ood_envs,
324
+ eval_model,
325
+ cfg.eval_episodes_per_env,
326
+ cfg,
327
+ self.device,
328
+ )
329
+ self.log.log_eval(results, step=iteration, prefix="eval_ood")
330
+ mean_ood_wr = float(np.mean(
331
+ [s["win_rate"] for s in results.values()]
332
+ )) if results else 0.0
333
+ self.log.log(
334
+ {"eval_ood/mean_win_rate": mean_ood_wr}, step=iteration,
335
+ )
336
+ last_ood_eval_step = env_steps_total
337
+
338
+ # 7. Checkpoint — env-step-triggered
339
+ if (
340
+ cfg.checkpoint_every_timesteps > 0
341
+ and env_steps_total - last_ckpt_step
342
+ >= cfg.checkpoint_every_timesteps
343
+ ):
344
+ self.save_checkpoint(iteration, env_steps_total)
345
+ last_ckpt_step = env_steps_total
346
+
347
+ iteration += 1
348
+
349
+ # Final checkpoint
350
+ if cfg.save_policy:
351
+ self.save_checkpoint(iteration, env_steps_total)
352
+
353
+ # ── Single gradient step ─────────────────────────────────────
354
+
355
+ def _train_step(self) -> dict[str, float]:
356
+ """One gradient step on a buffer sample.
357
+
358
+ Uses AMP (mixed precision) when ``cfg.use_amp`` is ``True``
359
+ and training on CUDA.
360
+
361
+ Returns:
362
+ Dict with ``"loss"``, ``"loss_diff"``, ``"loss_aux"``,
363
+ and ``"grad_norm"`` scalars.
364
+ """
365
+ cfg = self.cfg
366
+ batch = self.buffer.sample(cfg.dagger_batch_size)
367
+ if batch is None:
368
+ return {"loss": 0.0, "loss_diff": 0.0,
369
+ "loss_aux": 0.0, "grad_norm": 0.0}
370
+ local_np, global_np, actions_np = batch
371
+ local_t = torch.from_numpy(local_np).long().to(self.device)
372
+ global_t = torch.from_numpy(global_np).long().to(self.device)
373
+ actions_t = torch.from_numpy(actions_np).long().to(self.device)
374
+
375
+ B = actions_t.shape[0]
376
+ t = torch.rand(B, device=self.device).clamp(1e-5, 1.0 - 1e-5)
377
+
378
+ zt = q_sample(
379
+ actions_t, t, cfg.mask_token, cfg.pad_token,
380
+ self._schedule_fn,
381
+ )
382
+ t_discrete = (t * cfg.num_diffusion_steps).long().clamp(
383
+ 0, cfg.num_diffusion_steps - 1,
384
+ )
385
+
386
+ self.optimizer.zero_grad()
387
+ with torch.amp.autocast("cuda", enabled=self._use_amp):
388
+ out = self.model(local_t, global_t, zt, t_discrete)
389
+
390
+ loss_diff = mdlm_loss(
391
+ out["actions"], actions_t, zt, t,
392
+ cfg.mask_token, cfg.pad_token, self._schedule_fn,
393
+ weight_clip=cfg.loss_weight_clip,
394
+ label_smoothing=cfg.label_smoothing,
395
+ use_importance_weighting=cfg.use_importance_weighting,
396
+ )
397
+
398
+ loss_aux = torch.tensor(0.0, device=self.device)
399
+ if "goal_pred" in out:
400
+ loss_aux = auxiliary_goal_loss(out["goal_pred"], global_t)
401
+
402
+ loss = loss_diff + cfg.aux_loss_weight * loss_aux
403
+
404
+ self._scaler.scale(loss).backward()
405
+ self._scaler.unscale_(self.optimizer)
406
+ grad_norm = nn.utils.clip_grad_norm_(
407
+ self.model.parameters(), cfg.dagger_grad_clip,
408
+ )
409
+ self._scaler.step(self.optimizer)
410
+ self._scaler.update()
411
+ if self.scheduler is not None:
412
+ self.scheduler.step()
413
+
414
+ return {
415
+ "loss": loss.item(),
416
+ "loss_diff": loss_diff.item(),
417
+ "loss_aux": loss_aux.item(),
418
+ "grad_norm": grad_norm.item(),
419
+ }
420
+
421
+ # ── Checkpointing ────────────────────────────────────────────
422
+
423
+ def save_checkpoint(
424
+ self, iteration: int, env_steps: int,
425
+ ) -> None:
426
+ """Save a training checkpoint.
427
+
428
+ Args:
429
+ iteration: Current iteration number (for filename + metadata).
430
+ env_steps: Cumulative env.step() count consumed so far.
431
+ """
432
+ ckpt_dir = Path(self.cfg.checkpoint_dir)
433
+ ckpt_dir.mkdir(parents=True, exist_ok=True)
434
+ path = ckpt_dir / f"iter{iteration}.pth"
435
+
436
+ # Capture W&B run ID for seamless resumption
437
+ wandb_run_id: str | None = None
438
+ if self.log._use_wandb and self.log._run is not None:
439
+ wandb_run_id = self.log._run.id
440
+
441
+ state = {
442
+ "model_state_dict": self._raw_model.state_dict(),
443
+ "ema_state_dict": self.ema_model.state_dict(),
444
+ "optimizer_state_dict": self.optimizer.state_dict(),
445
+ "scheduler_state_dict": (
446
+ self.scheduler.state_dict()
447
+ if self.scheduler is not None
448
+ else None
449
+ ),
450
+ "curriculum_state": self.collector.curriculum.state_dict(),
451
+ "iteration": iteration,
452
+ "env_steps": env_steps,
453
+ "wandb_run_id": wandb_run_id,
454
+ "rng_states": {
455
+ "torch": torch.get_rng_state(),
456
+ "numpy": np.random.get_state(),
457
+ "python": random.getstate(),
458
+ },
459
+ }
460
+
461
+ try:
462
+ torch.save(state, path)
463
+ logger.info(f"Checkpoint saved: {path}")
464
+ except Exception:
465
+ logger.error(
466
+ f"Failed to save checkpoint to {path}", exc_info=True,
467
+ )
468
+
469
+ # Save config snapshot alongside checkpoint
470
+ config_path = ckpt_dir / f"config_iter{iteration}.yaml"
471
+ try:
472
+ cfg_dict = {
473
+ k: v for k, v in vars(self.cfg).items()
474
+ if not k.startswith("_")
475
+ }
476
+ with open(config_path, "w") as f:
477
+ yaml.dump(cfg_dict, f, default_flow_style=False)
478
+ except Exception:
479
+ logger.error("Failed to save config snapshot", exc_info=True)
480
+ config_path = None
481
+
482
+ # Run eval at checkpoint and save JSON
483
+ try:
484
+ eval_model = self.ema_model.make_eval_model(self._raw_model)
485
+ id_results = self.evaluator.evaluate(
486
+ self.cfg.id_envs, eval_model,
487
+ self.cfg.checkpoint_eval_episodes,
488
+ self.cfg, self.device,
489
+ )
490
+ ood_results = self.evaluator.evaluate(
491
+ self.cfg.ood_envs, eval_model,
492
+ self.cfg.checkpoint_eval_episodes,
493
+ self.cfg, self.device,
494
+ )
495
+
496
+ id_winrate = float(np.mean(
497
+ [s["win_rate"] for s in id_results.values()]
498
+ )) if id_results else 0.0
499
+ ood_winrate = float(np.mean(
500
+ [s["win_rate"] for s in ood_results.values()]
501
+ )) if ood_results else 0.0
502
+ current_lr = (
503
+ self.scheduler.get_last_lr()[0]
504
+ if self.scheduler is not None
505
+ else self.cfg.dagger_lr
506
+ )
507
+ training_meta = {
508
+ "iteration": iteration,
509
+ "env_steps": env_steps,
510
+ "total_timesteps": self.cfg.total_timesteps,
511
+ "lr": current_lr,
512
+ "dagger_batch_size": self.cfg.dagger_batch_size,
513
+ "aux_loss_weight": self.cfg.aux_loss_weight,
514
+ "buffer_size": len(self.buffer),
515
+ "buffer_capacity": self.cfg.buffer_capacity,
516
+ "ema_decay": self.cfg.ema_decay,
517
+ "grad_steps_per_iteration": self.cfg.grad_steps_per_iteration,
518
+ "episodes_per_iteration": getattr(
519
+ self.cfg, "episodes_per_iteration", 1
520
+ ),
521
+ "id_winrate": id_winrate,
522
+ "ood_winrate": ood_winrate,
523
+ "per_env_id": {
524
+ env_id: {
525
+ "win_rate": s["win_rate"],
526
+ "wins": s.get("wins", 0),
527
+ "avg_reward": s["avg_reward"],
528
+ "avg_steps": s["avg_steps"],
529
+ "n_episodes": s["n_episodes"],
530
+ }
531
+ for env_id, s in id_results.items()
532
+ },
533
+ "per_env_ood": {
534
+ env_id: {
535
+ "win_rate": s["win_rate"],
536
+ "wins": s.get("wins", 0),
537
+ "avg_reward": s["avg_reward"],
538
+ "avg_steps": s["avg_steps"],
539
+ "n_episodes": s["n_episodes"],
540
+ }
541
+ for env_id, s in ood_results.items()
542
+ },
543
+ }
544
+
545
+ json_path = ckpt_dir / f"eval_iter{iteration}.json"
546
+ save_eval_json(
547
+ {"id": id_results, "ood": ood_results},
548
+ str(json_path),
549
+ metadata=training_meta,
550
+ )
551
+
552
+ # W&B checkpoint log — per-env step metrics + aggregates
553
+ self.log.log_eval(
554
+ id_results, step=iteration, prefix="ckpt_eval_id",
555
+ )
556
+ self.log.log_eval(
557
+ ood_results, step=iteration, prefix="ckpt_eval_ood",
558
+ )
559
+ self.log.log(
560
+ {
561
+ "ckpt_eval/id_winrate": id_winrate,
562
+ "ckpt_eval/ood_winrate": ood_winrate,
563
+ },
564
+ step=iteration,
565
+ )
566
+ self.log.log_summary({
567
+ f"ckpt_{iteration}/id_winrate": id_winrate,
568
+ f"ckpt_{iteration}/ood_winrate": ood_winrate,
569
+ })
570
+ except Exception:
571
+ logger.error("Checkpoint eval failed", exc_info=True)
572
+
573
+ # HuggingFace Hub upload (no-op if HF_TOKEN or hub_run_id not set)
574
+ try:
575
+ from scripts.hf_upload import maybe_upload_checkpoint
576
+ maybe_upload_checkpoint(
577
+ str(ckpt_dir),
578
+ getattr(self.cfg, "hub_run_id", None),
579
+ getattr(self.cfg, "hub_repo_id", None),
580
+ )
581
+ except Exception:
582
+ logger.error("HF Hub upload failed", exc_info=True)
583
+
584
+ # W&B artifact upload
585
+ self.log.log_checkpoint_artifact(
586
+ checkpoint_path=str(path),
587
+ config_path=str(config_path) if config_path else None,
588
+ iteration=iteration,
589
+ metadata={
590
+ "iteration": iteration,
591
+ "buffer_size": len(self.buffer),
592
+ },
593
+ )
594
+
595
+ def load_checkpoint(self, path: str) -> tuple[int, int]:
596
+ """Load a training checkpoint.
597
+
598
+ Args:
599
+ path: Path to ``.pth`` checkpoint file.
600
+
601
+ Returns:
602
+ ``(start_iter, start_env_steps)`` — the iteration and
603
+ cumulative env-step count to resume from.
604
+ """
605
+ ckpt = torch.load(
606
+ path, map_location=self.device, weights_only=False,
607
+ )
608
+ self._raw_model.load_state_dict(ckpt["model_state_dict"])
609
+ self.ema_model.load_state_dict(ckpt["ema_state_dict"])
610
+ self.optimizer.load_state_dict(ckpt["optimizer_state_dict"])
611
+
612
+ if (
613
+ self.scheduler is not None
614
+ and ckpt.get("scheduler_state_dict") is not None
615
+ ):
616
+ self.scheduler.load_state_dict(ckpt["scheduler_state_dict"])
617
+
618
+ if "curriculum_state" in ckpt:
619
+ self.collector.curriculum.load_state_dict(
620
+ ckpt["curriculum_state"],
621
+ )
622
+
623
+ # Restore RNG states (best-effort)
624
+ rng = ckpt.get("rng_states", {})
625
+ try:
626
+ if "torch" in rng:
627
+ torch.set_rng_state(rng["torch"])
628
+ if "numpy" in rng:
629
+ np.random.set_state(rng["numpy"])
630
+ if "python" in rng:
631
+ random.setstate(rng["python"])
632
+ except Exception:
633
+ logger.warning(
634
+ "RNG state restore failed; continuing with fresh state",
635
+ )
636
+
637
+ iteration = ckpt.get("iteration", 0)
638
+ env_steps = ckpt.get("env_steps", 0)
639
+ resume_from = iteration + 1
640
+ logger.info(
641
+ f"Resumed from checkpoint: {path} (iter {iteration}, "
642
+ f"env_steps={env_steps}), starting at iter {resume_from}"
643
+ )
644
+ return resume_from, env_steps
645
+
646
+
647
+ def run_dagger(
648
+ cfg: SimpleNamespace,
649
+ checkpoint_path: str | None,
650
+ no_warm_start: bool,
651
+ ) -> None:
652
+ """DAgger online training loop."""
653
+ make_run_dir(cfg, tag="dagger")
654
+
655
+ device = cfg.device
656
+ logger.info(f"DAgger training on {device}")
657
+
658
+ raw_model = make_model(cfg).to(device)
659
+
660
+ # EMA and eval always use the raw (uncompiled) model — deep-copying
661
+ # a compiled model breaks FX tracing.
662
+ ema = ModelEMA(raw_model, decay=cfg.ema_decay)
663
+
664
+ # torch.compile: wrap for training only; shares parameters with raw_model
665
+ model = try_compile(raw_model, cfg)
666
+
667
+ optimizer = torch.optim.AdamW(
668
+ raw_model.parameters(), lr=cfg.dagger_lr,
669
+ weight_decay=cfg.weight_decay,
670
+ )
671
+
672
+ buffer = ReplayBuffer(cfg.buffer_capacity, cfg.seq_len, cfg.pad_token)
673
+ curriculum = DynamicCurriculum(
674
+ cfg.id_envs, cfg.curriculum_queue_size, cfg.curriculum_preseed,
675
+ )
676
+
677
+ # Seed buffer with some oracle data
678
+ for i, env_id in enumerate(cfg.id_envs):
679
+ for s in range(3):
680
+ traj = collect_oracle_trajectory(env_id, seed=i * 100 + s, cfg=cfg)
681
+ if traj is not None:
682
+ buffer.add(traj)
683
+ logger.info(f"Buffer seeded with {len(buffer)} windows")
684
+
685
+ # If resuming, extract W&B run ID from checkpoint before Logger init
686
+ # so the same W&B run is continued (curve continuity).
687
+ if checkpoint_path and not no_warm_start:
688
+ resume_id = getattr(cfg, "wandb_resume_id", None)
689
+ if not resume_id:
690
+ ckpt_peek = torch.load(
691
+ checkpoint_path, map_location="cpu", weights_only=False,
692
+ )
693
+ saved_id = ckpt_peek.get("wandb_run_id")
694
+ if saved_id:
695
+ cfg.wandb_resume_id = saved_id
696
+ logger.info(
697
+ f"W&B run ID from checkpoint: {saved_id}"
698
+ )
699
+ del ckpt_peek
700
+
701
+ # DataCollector uses raw_model for eval copies (not compiled)
702
+ collector = DataCollector(ema, raw_model, buffer, curriculum, cfg, device)
703
+ evaluator = Evaluator()
704
+ log = Logger(cfg)
705
+
706
+ trainer = Trainer(
707
+ model, ema, optimizer, None, buffer, collector,
708
+ evaluator, log, cfg, device, raw_model=raw_model,
709
+ )
710
+
711
+ start_iter = 0
712
+ start_env_steps = 0
713
+ if checkpoint_path and not no_warm_start:
714
+ start_iter, start_env_steps = trainer.load_checkpoint(
715
+ checkpoint_path,
716
+ )
717
+
718
+ trainer.train(
719
+ start_iter=start_iter, start_env_steps=start_env_steps,
720
+ )
721
+ log.finish()
src/planners/smoke.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ import torch
4
+
5
+ from src.buffer import ReplayBuffer
6
+ from src.curriculum import DynamicCurriculum
7
+ from src.envs.minihack_env import collect_oracle_trajectory
8
+ from src.models.denoiser import ModelEMA, make_model, try_compile
9
+ from src.planners.collect import DataCollector
10
+ from src.planners.inference import Evaluator, format_eval_results
11
+ from src.planners.logging import Logger
12
+ from src.planners.online import Trainer
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ def run_smoke(cfg) -> None:
18
+ """Smoke test: collect oracle data, train briefly, eval."""
19
+
20
+ device = cfg.device
21
+ logger.info(f"Smoke test on {device}")
22
+
23
+ # Collect a few oracle trajectories into the buffer
24
+ buffer = ReplayBuffer(cfg.buffer_capacity, cfg.seq_len, cfg.pad_token)
25
+ for i, env_id in enumerate(cfg.id_envs):
26
+ traj = collect_oracle_trajectory(env_id, seed=i, cfg=cfg)
27
+ if traj is not None:
28
+ buffer.add(traj)
29
+ logger.info(f"Buffer seeded with {len(buffer)} windows")
30
+
31
+ raw_model = make_model(cfg).to(device)
32
+
33
+ model = try_compile(raw_model, cfg)
34
+
35
+ ema = ModelEMA(raw_model, decay=cfg.ema_decay)
36
+ optimizer = torch.optim.AdamW(
37
+ raw_model.parameters(), lr=cfg.dagger_lr,
38
+ weight_decay=cfg.weight_decay,
39
+ )
40
+ curriculum = DynamicCurriculum(
41
+ cfg.id_envs, cfg.curriculum_queue_size, cfg.curriculum_preseed,
42
+ )
43
+
44
+ collector = DataCollector(ema, raw_model, buffer, curriculum, cfg, device)
45
+ evaluator = Evaluator()
46
+ log = Logger(cfg)
47
+
48
+ trainer = Trainer(
49
+ model, ema, optimizer, None, buffer, collector,
50
+ evaluator, log, cfg, device, raw_model=raw_model,
51
+ )
52
+ trainer.train(start_iter=0)
53
+
54
+ # Final eval
55
+ eval_model = ema.make_eval_model(raw_model)
56
+ results = evaluator.evaluate(
57
+ cfg.id_envs, eval_model, cfg.eval_episodes_per_env, cfg, device,
58
+ )
59
+ print(format_eval_results(results, label="Smoke"))
60
+ log.log_eval(results, step=0, prefix="smoke_eval")
61
+ mean_wr = float(sum(s["win_rate"] for s in results.values()) / len(results)) if results else 0.0
62
+ log.log({"smoke_eval/mean_win_rate": mean_wr}, step=0)
63
+ log.finish()