Demo notebook payload (source + checkpoint + assets)
Browse files- .gitattributes +3 -0
- README.md +739 -0
- ablation_assets/diagnosis_decision_tree.png +0 -0
- ablation_assets/grad_alignment.png +3 -0
- ablation_assets/gradient_conflict_map.png +0 -0
- ablation_assets/group_comparison.png +0 -0
- ablation_assets/group_summary.csv +5 -0
- ablation_assets/hypothesis_verdicts.csv +22 -0
- ablation_assets/main_results.csv +22 -0
- ablation_assets/per_env_delta.png +3 -0
- ablation_assets/per_env_win_rates.csv +22 -0
- ablation_assets/repr_drift.png +0 -0
- ablation_assets/results.json +0 -0
- ablation_assets/score_comparison.png +3 -0
- ablation_assets/score_delta.png +0 -0
- checkpoint_inference.pth +3 -0
- configs/defaults.yaml +242 -0
- configs/final_qmul_gpu.yaml +176 -0
- configs/final_ucl_gpu.yaml +158 -0
- configs/smoke.yaml +16 -0
- configs/ucl_gpu_bigger_model.yaml +103 -0
- configs/ucl_gpu_learning_behaviour.yaml +103 -0
- environments/.gitkeep +0 -0
- main.py +255 -0
- pyproject.toml +22 -0
- src/__init__.py +0 -0
- src/buffer.py +268 -0
- src/config.py +164 -0
- src/curriculum.py +143 -0
- src/diffusion/__init__.py +0 -0
- src/diffusion/forward.py +50 -0
- src/diffusion/loss.py +162 -0
- src/diffusion/sampling.py +398 -0
- src/diffusion/schedules.py +88 -0
- src/envs/__init__.py +0 -0
- src/envs/discovery.py +166 -0
- src/envs/minihack_env.py +454 -0
- src/models/__init__.py +0 -0
- src/models/denoiser.py +415 -0
- src/planners/__init__.py +0 -0
- src/planners/baselines.py +1247 -0
- src/planners/collect.py +588 -0
- src/planners/collect_oracle.py +185 -0
- src/planners/inference.py +360 -0
- src/planners/logging.py +291 -0
- src/planners/offline.py +727 -0
- src/planners/online.py +721 -0
- src/planners/smoke.py +63 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
ablation_assets/grad_alignment.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
ablation_assets/per_env_delta.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
ablation_assets/score_comparison.png filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,739 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ReMDM Planner for MiniHack
|
| 2 |
+
|
| 3 |
+
PyTorch implementation of **ReMDM** (Remasking Discrete Diffusion Model) for action-sequence planning in [MiniHack](https://github.com/facebookresearch/minihack) navigation environments. A dual-stream transformer generates 64-step action plans by iteratively denoising masked token sequences, conditioned on a 9x9 local crop and the full 21x79 dungeon map.
|
| 4 |
+
|
| 5 |
+
> The primary training method is **DAgger** with BFS oracle supervision: the model is trained from scratch, with the buffer seeded by pure expert trajectories on the first iteration. A standalone **offline BC** mode is also available as an independent baseline trained on pre-collected datasets. The paper compares both methods head-to-head; neither depends on the other. An offline BC checkpoint can optionally warm-start DAgger, but this is not used in the paper. Generalises **zero-shot** from 4 in-distribution environments to 3 out-of-distribution environments.
|
| 6 |
+
|
| 7 |
+
---
|
| 8 |
+
|
| 9 |
+
## Pipeline
|
| 10 |
+
|
| 11 |
+
```
|
| 12 |
+
[Primary] DAgger online training main.py --mode dagger
|
| 13 |
+
| (seed buffer with oracle demos on iter 0,
|
| 14 |
+
| collect with model, label with oracle,
|
| 15 |
+
| efficiency filter, curriculum sampling)
|
| 16 |
+
v checkpoint
|
| 17 |
+
[Evaluate] ID + OOD evaluation main.py --mode inference --checkpoint iter8000.pth
|
| 18 |
+
```
|
| 19 |
+
|
| 20 |
+
```
|
| 21 |
+
**Other modes:**
|
| 22 |
+
|
| 23 |
+
[Collect] Collect oracle demonstrations main.py --mode collect
|
| 24 |
+
[Offline BC] Train on pre-collected data main.py --mode offline --data dataset.pt
|
| 25 |
+
[Smoke test] Quick end-to-end check main.py --mode smoke
|
| 26 |
+
|
| 27 |
+
DAgger trains from scratch and is the recommended pipeline. Offline BC (`--mode collect` + `--mode offline`) is an independent training method compared against DAgger in the paper. An offline BC checkpoint can optionally warm-start DAgger via `--checkpoint`, but this was not used in the paper results.
|
| 28 |
+
```
|
| 29 |
+
|
| 30 |
+
---
|
| 31 |
+
|
| 32 |
+
## Environments
|
| 33 |
+
|
| 34 |
+
**In-distribution (training):**
|
| 35 |
+
|
| 36 |
+
| Environment | Description |
|
| 37 |
+
|---|---|
|
| 38 |
+
| `MiniHack-Room-Random-5x5-v0` | Small random room |
|
| 39 |
+
| `MiniHack-Room-Random-15x15-v0` | Large random room |
|
| 40 |
+
| `MiniHack-Corridor-R2-v0` | Two-room corridor |
|
| 41 |
+
| `MiniHack-MazeWalk-9x9-v0` | Small maze |
|
| 42 |
+
|
| 43 |
+
**Out-of-distribution (zero-shot evaluation):**
|
| 44 |
+
|
| 45 |
+
| Environment | Description |
|
| 46 |
+
|---|---|
|
| 47 |
+
| `MiniHack-Room-Dark-15x15-v0` | Dark room (limited visibility) |
|
| 48 |
+
| `MiniHack-Corridor-R5-v0` | Five-room corridor |
|
| 49 |
+
| `MiniHack-MazeWalk-45x19-v0` | Large maze |
|
| 50 |
+
|
| 51 |
+
---
|
| 52 |
+
|
| 53 |
+
## Installation
|
| 54 |
+
|
| 55 |
+
### Prerequisites
|
| 56 |
+
|
| 57 |
+
**Python 3.12+** is required.
|
| 58 |
+
|
| 59 |
+
**macOS (arm64):** Install cmake via Homebrew (needed to compile `nle` from source):
|
| 60 |
+
|
| 61 |
+
```bash
|
| 62 |
+
brew install cmake
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
**Linux (x86_64):** Pre-built wheels are available, but if building from source:
|
| 66 |
+
|
| 67 |
+
```bash
|
| 68 |
+
sudo apt-get install build-essential cmake bison flex libbz2-dev
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
### Setup
|
| 72 |
+
|
| 73 |
+
```bash
|
| 74 |
+
uv sync
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
This installs all dependencies from the lockfile, including `nle>=1.2.0` (from the maintained [NetHack-LE](https://github.com/NetHack-LE/nle) fork), `minihack`, `torch>=2.11.0`, `wandb`, `polars`, `orjson`, and `scipy`.
|
| 78 |
+
|
| 79 |
+
### GPU support (optional)
|
| 80 |
+
|
| 81 |
+
By default PyTorch runs on CPU. For NVIDIA CUDA 12:
|
| 82 |
+
|
| 83 |
+
```bash
|
| 84 |
+
uv pip install torch --index-url https://download.pytorch.org/whl/cu121
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
Verify GPU is detected:
|
| 88 |
+
|
| 89 |
+
```bash
|
| 90 |
+
uv run python -c "import torch; print(torch.cuda.is_available())"
|
| 91 |
+
```
|
| 92 |
+
|
| 93 |
+
---
|
| 94 |
+
|
| 95 |
+
## Usage
|
| 96 |
+
|
| 97 |
+
All modes share a single entry point. Defaults load from `configs/defaults.yaml`; any value can be overridden via `key=value` pairs.
|
| 98 |
+
|
| 99 |
+
```bash
|
| 100 |
+
python main.py --mode <MODE> [--config PATH] [key=value ...]
|
| 101 |
+
```
|
| 102 |
+
|
| 103 |
+
### Smoke test
|
| 104 |
+
|
| 105 |
+
Collects a few oracle trajectories, trains under a tiny 5k env-step budget, and prints ID evaluation results.
|
| 106 |
+
|
| 107 |
+
```bash
|
| 108 |
+
python main.py --mode smoke
|
| 109 |
+
```
|
| 110 |
+
|
| 111 |
+
### Collect oracle demonstrations
|
| 112 |
+
|
| 113 |
+
Run the BFS oracle across all 4 ID environments and save the trajectories as a `.pt` dataset for offline BC training. Uses multiprocessing for parallelism.
|
| 114 |
+
|
| 115 |
+
```bash
|
| 116 |
+
# Default: 5000 episodes per env, output to data/dataset.pt
|
| 117 |
+
python main.py --mode collect
|
| 118 |
+
|
| 119 |
+
# Custom episode count and output
|
| 120 |
+
python main.py --mode collect collect_episodes_per_env=2000 \
|
| 121 |
+
collect_output=data/small_dataset.pt
|
| 122 |
+
|
| 123 |
+
# Fewer workers (default: 8)
|
| 124 |
+
python main.py --mode collect collect_num_workers=4
|
| 125 |
+
|
| 126 |
+
# Reproducible with fixed seed
|
| 127 |
+
python main.py --mode collect seed=42
|
| 128 |
+
```
|
| 129 |
+
|
| 130 |
+
The output `.pt` file is directly consumable by `--mode offline`:
|
| 131 |
+
|
| 132 |
+
```bash
|
| 133 |
+
python main.py --mode collect
|
| 134 |
+
python main.py --mode offline --data data/dataset.pt
|
| 135 |
+
```
|
| 136 |
+
|
| 137 |
+
### Offline BC (optional)
|
| 138 |
+
|
| 139 |
+
Train the diffusion model on pre-collected oracle demonstrations. The run length
|
| 140 |
+
is controlled by `total_timesteps` — each env-step of the unified budget
|
| 141 |
+
corresponds to one dataset sample, so total gradient steps =
|
| 142 |
+
`total_timesteps // offline_batch_size`.
|
| 143 |
+
|
| 144 |
+
Periodic ID + OOD evaluation runs during training on the cadence defined by
|
| 145 |
+
`id_eval_every_timesteps` / `ood_eval_every_timesteps` (env-step units,
|
| 146 |
+
converted internally to grad-step deltas via `// offline_batch_size`),
|
| 147 |
+
mirroring the DAgger eval pattern. Results are logged to `eval_id/` and
|
| 148 |
+
`eval_ood/` W&B namespaces.
|
| 149 |
+
|
| 150 |
+
```bash
|
| 151 |
+
python main.py --mode offline --data path/to/dataset.pt
|
| 152 |
+
|
| 153 |
+
# Shorter / longer run (the same knob the DAgger and SB3 baselines use):
|
| 154 |
+
python main.py --mode offline --data dataset.pt total_timesteps=500000
|
| 155 |
+
|
| 156 |
+
# Resume from a step-level checkpoint (restores optimizer, scheduler,
|
| 157 |
+
# step counter, and W&B run)
|
| 158 |
+
python main.py --mode offline --data path/to/dataset.pt \
|
| 159 |
+
--checkpoint checkpoints/offline_step2000.pth
|
| 160 |
+
```
|
| 161 |
+
|
| 162 |
+
Step-level checkpoints are written every `checkpoint_every_timesteps` env-step
|
| 163 |
+
equivalents (converted internally to `/ offline_batch_size` grad steps).
|
| 164 |
+
Set to `0` to disable:
|
| 165 |
+
|
| 166 |
+
```bash
|
| 167 |
+
python main.py --mode offline --data dataset.pt checkpoint_every_timesteps=0
|
| 168 |
+
```
|
| 169 |
+
|
| 170 |
+
#### Compute-match overrides (paper-fair BC vs DAgger)
|
| 171 |
+
|
| 172 |
+
For research comparisons against a specific DAgger checkpoint, four optional
|
| 173 |
+
offline-only overrides bypass the env-step budget derivation. The
|
| 174 |
+
sample-to-grad-step ratio between the two modes (~50×) makes a single shared
|
| 175 |
+
`total_timesteps` budget unfair to one side; these knobs pin offline metrics
|
| 176 |
+
in grad-step units instead. All default to `null` (backwards compatible).
|
| 177 |
+
|
| 178 |
+
| Key | Purpose |
|
| 179 |
+
|---|---|
|
| 180 |
+
| `offline_total_grad_steps` | Pin gradient budget. Overrides `total_timesteps // offline_batch_size`. Use to match a DAgger iteration count (e.g. `60000` = 600 iters × 100 grad_steps_per_iter). |
|
| 181 |
+
| `offline_eval_every_grad_steps` | ID/OOD eval cadence in grad-step units. Without this, env-step cadence applied to BC's dense per-sample budget yields hundreds of evals. |
|
| 182 |
+
| `offline_checkpoint_every_grad_steps` | Checkpoint cadence in grad-step units. Same motivation. |
|
| 183 |
+
| `offline_buffer_capacity` | Distinct from `buffer_capacity` (sized for DAgger's small FIFO). The full BC dataset has ~500k–1M sliding windows; using DAgger's cap silently truncates. |
|
| 184 |
+
|
| 185 |
+
Example: train a fair offline BC baseline matched to DAgger@iter600
|
| 186 |
+
(60k AdamW updates × 2048 batch):
|
| 187 |
+
|
| 188 |
+
```bash
|
| 189 |
+
python main.py --mode offline --data data/oracle_bc_qmul.pt \
|
| 190 |
+
--config configs/final_qmul_gpu.yaml
|
| 191 |
+
```
|
| 192 |
+
|
| 193 |
+
The `final_qmul_gpu.yaml` and `final_ucl_gpu.yaml` configs both ship with
|
| 194 |
+
these overrides pre-set and with cross-cluster-identical training
|
| 195 |
+
hyperparameters (only collection-worker counts and output paths differ).
|
| 196 |
+
|
| 197 |
+
### DAgger online training
|
| 198 |
+
|
| 199 |
+
Full DAgger loop: seed buffer with oracle data, collect with model, label with BFS oracle, filter by efficiency, train on buffer.
|
| 200 |
+
|
| 201 |
+
```bash
|
| 202 |
+
# From scratch (seeds buffer with oracle data automatically)
|
| 203 |
+
python main.py --mode dagger
|
| 204 |
+
|
| 205 |
+
# Resume from local checkpoint
|
| 206 |
+
python main.py --mode dagger --checkpoint checkpoints/iter3000.pth
|
| 207 |
+
|
| 208 |
+
# Resume from a W&B artifact
|
| 209 |
+
python main.py --mode dagger \
|
| 210 |
+
--wandb-artifact entity/project/checkpoint-iter3000:latest
|
| 211 |
+
|
| 212 |
+
# Skip warm-start from checkpoint (reinitialise model, keep config)
|
| 213 |
+
python main.py --mode dagger --checkpoint checkpoints/iter3000.pth --no-warm-start
|
| 214 |
+
|
| 215 |
+
# Override hyperparameters (total_timesteps is the unified run-length knob)
|
| 216 |
+
python main.py --mode dagger total_timesteps=1000000 dagger_lr=0.0001
|
| 217 |
+
|
| 218 |
+
# Use a GPU-optimised config (paper run, QMUL H200)
|
| 219 |
+
python main.py --mode dagger --config configs/final_qmul_gpu.yaml
|
| 220 |
+
```
|
| 221 |
+
|
| 222 |
+
### Inference
|
| 223 |
+
|
| 224 |
+
Evaluate a checkpoint on specified environments. Accepts either `--checkpoint` (local path) or `--wandb-artifact` (W&B artifact reference).
|
| 225 |
+
|
| 226 |
+
```bash
|
| 227 |
+
# All ID + OOD environments
|
| 228 |
+
python main.py --mode inference --checkpoint checkpoints/iter8000.pth
|
| 229 |
+
|
| 230 |
+
# From a W&B artifact
|
| 231 |
+
python main.py --mode inference \
|
| 232 |
+
--wandb-artifact entity/project/checkpoint-iter8000:latest
|
| 233 |
+
|
| 234 |
+
# Specific environments, save JSON
|
| 235 |
+
python main.py --mode inference \
|
| 236 |
+
--checkpoint checkpoints/iter8000.pth \
|
| 237 |
+
--envs MiniHack-Room-Random-5x5-v0 MiniHack-MazeWalk-45x19-v0 \
|
| 238 |
+
--episodes 100 \
|
| 239 |
+
--output results.json
|
| 240 |
+
|
| 241 |
+
# Custom .des scenario files
|
| 242 |
+
python main.py --mode inference \
|
| 243 |
+
--checkpoint checkpoints/iter8000.pth \
|
| 244 |
+
--des environments/custom_level.des
|
| 245 |
+
|
| 246 |
+
# Local-only ablation (zero out global map)
|
| 247 |
+
python main.py --mode inference \
|
| 248 |
+
--checkpoint checkpoints/iter8000.pth --blind-global
|
| 249 |
+
|
| 250 |
+
# Use training weights instead of EMA
|
| 251 |
+
python main.py --mode inference --checkpoint iter8000.pth --no-ema
|
| 252 |
+
```
|
| 253 |
+
|
| 254 |
+
### Baselines (SB3 + Decision Transformer)
|
| 255 |
+
|
| 256 |
+
Train and evaluate the head-to-head baselines used in the paper comparison.
|
| 257 |
+
Six algorithms are wired in: standard discrete-action RL via Stable-Baselines3
|
| 258 |
+
(`ppo`, `a2c`, `dqn`, `ppo-rnn`), Behavioural Cloning (`bc`) on oracle
|
| 259 |
+
demonstrations, and a causal Decision Transformer (`dt`) with target-return
|
| 260 |
+
conditioning. All six share the unified `cfg.total_timesteps` budget so the
|
| 261 |
+
numbers are directly comparable to DAgger and offline BC.
|
| 262 |
+
|
| 263 |
+
Hyperparameters live under the `baselines_*` namespace in `configs/defaults.yaml`
|
| 264 |
+
(BC epochs / batch / LR, DT context length / depth / width, oracle episodes per
|
| 265 |
+
env, eval cadence, DQN replay buffer, parallel SubprocVecEnv count, etc.). The
|
| 266 |
+
runner writes per-seed checkpoints, SB3 logs, and an aggregated results JSON
|
| 267 |
+
under `cfg.baselines_output_dir` (default `outputs/baselines/`); W&B runs land
|
| 268 |
+
in a separate project (`cfg.baselines_wandb_project`, default `remdm-baselines`)
|
| 269 |
+
so they don't pollute the main training leaderboards.
|
| 270 |
+
|
| 271 |
+
```bash
|
| 272 |
+
# PPO on the 4 ID maps for the unified env-step budget, 1 seed
|
| 273 |
+
python main.py --mode baselines --algo ppo
|
| 274 |
+
|
| 275 |
+
# DQN with a custom budget and 3 seeds
|
| 276 |
+
python main.py --mode baselines --algo dqn \
|
| 277 |
+
--seeds 0 1 2 \
|
| 278 |
+
total_timesteps=1000000
|
| 279 |
+
|
| 280 |
+
# Behavioural Cloning baseline (oracle demos -> SB3 ActorCriticPolicy)
|
| 281 |
+
python main.py --mode baselines --algo bc --n-seeds 3
|
| 282 |
+
|
| 283 |
+
# Decision Transformer (causal R/s/a transformer with target-return)
|
| 284 |
+
python main.py --mode baselines --algo dt --seeds 0 1 2
|
| 285 |
+
|
| 286 |
+
# Override the aggregated-results JSON destination
|
| 287 |
+
python main.py --mode baselines --algo ppo --output results/ppo_smoke.json
|
| 288 |
+
|
| 289 |
+
# Paper-fair comparison against the ReMDM online budget (~5.65M env-steps)
|
| 290 |
+
python main.py --mode baselines --algo ppo total_timesteps=5650000
|
| 291 |
+
```
|
| 292 |
+
|
| 293 |
+
The BC and DT defaults (50 epochs, 5000 oracle trajectories per ID env, 64-token
|
| 294 |
+
DT context, 256-D DT embedding) are tuned to match the data and compute scale of
|
| 295 |
+
the offline BC and ReMDM runs reported in the paper.
|
| 296 |
+
|
| 297 |
+
### CLI flags
|
| 298 |
+
|
| 299 |
+
| Flag | Description |
|
| 300 |
+
|---|---|
|
| 301 |
+
| `--mode` | Required. One of `smoke`, `collect`, `offline`, `dagger`, `inference`, `baselines` |
|
| 302 |
+
| `--config PATH` | Config file (default: `configs/defaults.yaml`) |
|
| 303 |
+
| `--algo NAME` | Baseline algorithm (`ppo`, `a2c`, `dqn`, `ppo-rnn`, `bc`, `dt`); required with `--mode baselines` |
|
| 304 |
+
| `--seeds N [N ...]` | Explicit seed list for `--mode baselines` |
|
| 305 |
+
| `--n-seeds N` | Number of seeds starting from 0 (alternative to `--seeds`) |
|
| 306 |
+
| `--data PATH` | Dataset `.pt` file (offline mode) |
|
| 307 |
+
| `--checkpoint PATH` | Checkpoint `.pth` file |
|
| 308 |
+
| `--wandb-artifact REF` | W&B artifact reference (e.g. `entity/project/name:latest`) |
|
| 309 |
+
| `--no-warm-start` | Skip model warm-start from checkpoint (DAgger) |
|
| 310 |
+
| `--no-ema` | Use training weights instead of EMA for inference |
|
| 311 |
+
| `--envs ENV [ENV ...]` | Override evaluation environments |
|
| 312 |
+
| `--des PATH [PATH ...]` | Custom `.des` scenario files for evaluation |
|
| 313 |
+
| `--episodes N` | Episodes per environment (default: 50) |
|
| 314 |
+
| `--output PATH` | Save evaluation results / aggregated baselines JSON |
|
| 315 |
+
| `--blind-global` | Zero out global map observations (local-only ablation) |
|
| 316 |
+
|
| 317 |
+
---
|
| 318 |
+
|
| 319 |
+
## Architecture
|
| 320 |
+
|
| 321 |
+
**`LocalDiffusionPlannerWithGlobal`** (~5.2M parameters):
|
| 322 |
+
|
| 323 |
+
```
|
| 324 |
+
Local stream: 9x9 glyphs -> Embedding(6000,64) -> CNN(64->32->64) -> Linear -> 1 token
|
| 325 |
+
Global stream: 21x79 glyphs -> Embedding(6000,32) -> CNN(32->32->64) -> Pool(2,4) -> 8 tokens
|
| 326 |
+
Goal head: mean(global) -> MLP -> [B,2] staircase coords (aux loss)
|
| 327 |
+
Gate: sigmoid(learnable scalar, init=-3.0) * global_tokens
|
| 328 |
+
Action stream: Embedding(14, 256) + timestep_emb(100, 256) + position_emb(64, 256)
|
| 329 |
+
Transformer: concat [1 + 8 + 64 = 73 tokens] -> 4-layer encoder (256D, 4 heads, pre-norm)
|
| 330 |
+
Output head: last 64 tokens -> Linear(256, 12) -> action logits
|
| 331 |
+
```
|
| 332 |
+
|
| 333 |
+
The model takes `(local_obs, global_obs, noisy_action_seq, t_discrete)` and returns `{"actions": [B,64,12], "goal_pred": [B,2]}`.
|
| 334 |
+
|
| 335 |
+
A `LocalDiffusionPlanner` variant (no global stream, no goal head) is also available for ablation studies.
|
| 336 |
+
|
| 337 |
+
---
|
| 338 |
+
|
| 339 |
+
## Diffusion
|
| 340 |
+
|
| 341 |
+
**Forward process (MDLM):** Each action token is independently replaced with `MASK` (token 12) with probability `1 - alpha(t)`, where `alpha(t)` follows a linear or cosine schedule. PAD tokens (13) are never masked.
|
| 342 |
+
|
| 343 |
+
**Loss:** Cross-entropy on masked positions only, averaged globally across the batch. By default uses a flat average (matching the reference implementation). Optional SUBS importance weighting `w(t) = -alpha'(t) / (1 - alpha(t))`, clipped to `[0, 1000]`, can be enabled via `use_importance_weighting: true`. Optional label smoothing via `label_smoothing` (default 0.0).
|
| 344 |
+
|
| 345 |
+
**Reverse sampling (ReMDM):** Over `K` denoising steps (default 10):
|
| 346 |
+
1. Model predicts logits; apply temperature scaling and top-K filtering.
|
| 347 |
+
2. Sample predictions; compute per-token confidence.
|
| 348 |
+
3. **MaskGIT unmask:** commit the `n_unmask` highest-confidence masked positions.
|
| 349 |
+
4. **ReMDM remask:** stochastically re-mask committed positions to allow refinement.
|
| 350 |
+
5. Final step: commit all remaining positions.
|
| 351 |
+
|
| 352 |
+
**Greedy sampling:** Used during DAgger data collection for deterministic rollouts. Same MaskGIT progressive unmasking loop but with argmax decoding (no temperature, no top-K, no remasking). Uses fewer denoising steps (`diffusion_steps_collect: 5`) for faster collection.
|
| 353 |
+
|
| 354 |
+
### Remasking strategies
|
| 355 |
+
|
| 356 |
+
| Strategy | Formula | Description |
|
| 357 |
+
|---|---|---|
|
| 358 |
+
| `rescale` | `p = eta * sigma_max` | Proportional to noise level |
|
| 359 |
+
| `cap` | `p = min(eta, sigma_max)` | Fixed upper bound |
|
| 360 |
+
| `conf` | `p = eta * sigma_max * (1 - confidence)` | Low-confidence tokens remasked more |
|
| 361 |
+
|
| 362 |
+
---
|
| 363 |
+
|
| 364 |
+
## Configuration
|
| 365 |
+
|
| 366 |
+
### Key hyperparameters
|
| 367 |
+
|
| 368 |
+
**Model**
|
| 369 |
+
|
| 370 |
+
| Parameter | Default | Description |
|
| 371 |
+
|---|---|---|
|
| 372 |
+
| `n_embd` | 256 | Transformer hidden dimension |
|
| 373 |
+
| `n_head` | 4 | Attention heads |
|
| 374 |
+
| `n_layer` | 4 | Transformer blocks |
|
| 375 |
+
| `n_global_tokens` | 8 | Global stream context tokens |
|
| 376 |
+
| `seq_len` | 64 | Action plan length |
|
| 377 |
+
| `dropout` | 0.0 | Transformer dropout (0.0 -- forward masking regularises) |
|
| 378 |
+
| `ema_decay` | 0.999 | EMA smoothing for inference weights |
|
| 379 |
+
| `global_gate_init` | -3.0 | Initial value for global gate logit |
|
| 380 |
+
|
| 381 |
+
**Diffusion**
|
| 382 |
+
|
| 383 |
+
| Parameter | Default | Description |
|
| 384 |
+
|---|---|---|
|
| 385 |
+
| `noise_schedule` | `linear` | `linear` or `cosine` |
|
| 386 |
+
| `num_diffusion_steps` | 100 | Discrete timestep resolution |
|
| 387 |
+
| `diffusion_steps_eval` | 10 | Denoising iterations at inference |
|
| 388 |
+
| `diffusion_steps_collect` | 5 | Denoising iterations during DAgger collection |
|
| 389 |
+
| `remask_strategy` | `conf` | `rescale`, `cap`, or `conf` |
|
| 390 |
+
| `eta` | 0.15 | Remasking strength |
|
| 391 |
+
| `temperature` | 0.5 | Sampling temperature |
|
| 392 |
+
| `top_k` | 4 | Top-K filtering |
|
| 393 |
+
| `replan_every` | 16 | Env steps before replanning |
|
| 394 |
+
| `loss_weight_clip` | 1000.0 | SUBS importance weight clip bound |
|
| 395 |
+
| `label_smoothing` | 0.0 | Label smoothing for cross-entropy |
|
| 396 |
+
| `use_importance_weighting` | false | SUBS w(t) in loss (off = flat average) |
|
| 397 |
+
| `physics_aware_sampling` | false | Penalise hazardous actions at inference |
|
| 398 |
+
|
| 399 |
+
**Training budget (unified)**
|
| 400 |
+
|
| 401 |
+
Offline BC, DAgger, and the SB3 baselines all share a single env-step budget
|
| 402 |
+
expressed in `total_timesteps` (matching the SB3 convention). This is the only
|
| 403 |
+
knob that should change to scale a run up or down.
|
| 404 |
+
|
| 405 |
+
| Parameter | Default | Description |
|
| 406 |
+
|---|---|---|
|
| 407 |
+
| `total_timesteps` | 2,000,000 | Env-step budget shared across offline / DAgger / SB3 |
|
| 408 |
+
| `id_eval_every_timesteps` | 25,000 | ID eval cadence (env-steps) |
|
| 409 |
+
| `ood_eval_every_timesteps` | 25,000 | OOD eval cadence (env-steps) |
|
| 410 |
+
| `checkpoint_every_timesteps` | 125,000 | Checkpoint cadence (env-steps) |
|
| 411 |
+
|
| 412 |
+
- **Offline BC:** each dataset sample is one env.step() equivalent, so total
|
| 413 |
+
gradient steps = `total_timesteps // offline_batch_size`. The cosine LR
|
| 414 |
+
schedule's `T_max` derives from the same quantity, so runs of different
|
| 415 |
+
lengths still decay to the 10% floor at their end.
|
| 416 |
+
- **DAgger:** the training loop tracks cumulative `env.step()` calls (model +
|
| 417 |
+
oracle rollouts combined) and halts when the running total reaches
|
| 418 |
+
`total_timesteps`. `episodes_per_iteration` and `grad_steps_per_iteration`
|
| 419 |
+
control the collect/train ratio but **must not** scale with the budget.
|
| 420 |
+
- **Fairness caveat — `ema_decay`:** this is an absolute-update-count constant
|
| 421 |
+
(half-life ~ `1 / (1 − decay)` steps). If `total_timesteps` shifts by more
|
| 422 |
+
than ~2× from the default, the fraction of training covered by the EMA
|
| 423 |
+
window changes. For very short or very long runs, consider setting a
|
| 424 |
+
matching decay manually.
|
| 425 |
+
|
| 426 |
+
**Training**
|
| 427 |
+
|
| 428 |
+
| Parameter | Default | Description |
|
| 429 |
+
|---|---|---|
|
| 430 |
+
| `offline_lr` | 0.0003 | BC learning rate (cosine-decayed to 10% over `total_grad_steps`) |
|
| 431 |
+
| `dagger_lr` | 0.00003 | DAgger learning rate (constant) |
|
| 432 |
+
| `offline_batch_size` | 3584 | Offline BC batch size |
|
| 433 |
+
| `dagger_batch_size` | 3584 | DAgger batch size |
|
| 434 |
+
| `offline_grad_clip` | 1.0 | Gradient norm clip (offline) |
|
| 435 |
+
| `dagger_grad_clip` | 1.0 | Gradient norm clip (DAgger) |
|
| 436 |
+
| `weight_decay` | 0.0001 | AdamW weight decay (both optimizers) |
|
| 437 |
+
| `grad_steps_per_iteration` | 100 | Gradient steps per DAgger iteration |
|
| 438 |
+
| `episodes_per_iteration` | 30 | Episodes collected per DAgger iteration |
|
| 439 |
+
| `aux_loss_weight` | 0.5 | Weight for auxiliary goal loss |
|
| 440 |
+
| `buffer_capacity` | 10000 | Replay buffer size (windows) |
|
| 441 |
+
| `efficiency_multiplier` | 1.5 | DAgger efficiency filter threshold |
|
| 442 |
+
| `curriculum_preseed` | true | Pre-seed curriculum with 50/50 prior |
|
| 443 |
+
| `curriculum_queue_size` | 100 | Curriculum window size per environment |
|
| 444 |
+
|
| 445 |
+
**Data Collection**
|
| 446 |
+
|
| 447 |
+
| Parameter | Default | Description |
|
| 448 |
+
|---|---|---|
|
| 449 |
+
| `collect_episodes_per_env` | 5000 | Oracle episodes per ID environment |
|
| 450 |
+
| `collect_num_workers` | 8 | Parallel process workers for collection |
|
| 451 |
+
| `collect_output` | `data/dataset.pt` | Output path for collected dataset |
|
| 452 |
+
|
| 453 |
+
**Evaluation**
|
| 454 |
+
|
| 455 |
+
| Parameter | Default | Description |
|
| 456 |
+
|---|---|---|
|
| 457 |
+
| `eval_episodes_per_env` | 50 | Episodes per environment at eval time |
|
| 458 |
+
| `checkpoint_eval_episodes` | 50 | Episodes per env at checkpoint eval |
|
| 459 |
+
|
| 460 |
+
(Eval and checkpoint *cadences* are expressed in env-steps under
|
| 461 |
+
**Training budget (unified)** above.)
|
| 462 |
+
|
| 463 |
+
**Performance**
|
| 464 |
+
|
| 465 |
+
| Parameter | Default | Description |
|
| 466 |
+
|---|---|---|
|
| 467 |
+
| `use_amp` | false | Mixed-precision (FP16) training via `torch.amp` |
|
| 468 |
+
| `torch_compile` | false | `torch.compile` the model for fused kernels |
|
| 469 |
+
| `num_collection_workers` | 8 | Parallel workers for DAgger episode collection |
|
| 470 |
+
|
| 471 |
+
**Logging**
|
| 472 |
+
|
| 473 |
+
| Parameter | Default | Description |
|
| 474 |
+
|---|---|---|
|
| 475 |
+
| `use_wandb` | true | Enable W&B logging |
|
| 476 |
+
| `wandb_project` | `remdm-minihack` | W&B project name |
|
| 477 |
+
| `wandb_resume_id` | null | W&B run ID for resumption |
|
| 478 |
+
| `offline_log_every` | 10 | Stdout/W&B log frequency (offline steps) |
|
| 479 |
+
| `seed` | null | RNG seed (null = random) |
|
| 480 |
+
|
| 481 |
+
### Config presets
|
| 482 |
+
|
| 483 |
+
| File | Purpose |
|
| 484 |
+
|---|---|
|
| 485 |
+
| `configs/defaults.yaml` | Base defaults for all modes |
|
| 486 |
+
| `configs/smoke.yaml` | Fast smoke test (`total_timesteps=5000`, small buffer, W&B off) |
|
| 487 |
+
| `configs/ucl_gpu_bigger_model.yaml` | UCL GPU exploration with a larger model (384D, 6 heads) |
|
| 488 |
+
| `configs/ucl_gpu_learning_behaviour.yaml` | UCL GPU learning-behaviour study (eta=0.18, B=6144) |
|
| 489 |
+
| `configs/final_qmul_gpu.yaml` | **Paper run, QMUL H200.** Drives both `--mode dagger` (reproduces the iter600 checkpoint) and `--mode offline` (compute-matched fair BC baseline: 60k grad steps × B=2048). AMP + torch.compile + 32 collection workers. |
|
| 490 |
+
| `configs/final_ucl_gpu.yaml` | **Paper run, UCL 3090 Ti 24 GB.** Identical training hyperparams to the QMUL config for cross-cluster fairness; only `num_collection_workers` (8 instead of 32) and output paths differ. |
|
| 491 |
+
|
| 492 |
+
---
|
| 493 |
+
|
| 494 |
+
## DAgger Training Loop
|
| 495 |
+
|
| 496 |
+
Each DAgger iteration:
|
| 497 |
+
|
| 498 |
+
1. **Curriculum sampling:** Select an environment weighted by difficulty (low win-rate environments sampled more).
|
| 499 |
+
2. **Model rollout:** Generate plans with the EMA model using greedy sampling; execute with replanning every 16 steps. Collects `episodes_per_iteration` (default 30) episodes per iteration.
|
| 500 |
+
3. **Oracle rollout:** Run the BFS oracle on the **same seed** for comparison.
|
| 501 |
+
4. **Efficiency filter:** Add the oracle trajectory to the buffer if the model failed or took >1.5x the oracle's steps.
|
| 502 |
+
5. **Budget accounting:** Advance `env_steps_total += model_steps + oracle_steps`. The training loop halts when the running total reaches `total_timesteps`.
|
| 503 |
+
6. **Training:** Sample from the replay buffer; run `grad_steps_per_iteration` gradient steps, updating EMA weights after each gradient step.
|
| 504 |
+
|
| 505 |
+
Collection uses GPU-batched rollouts when on CUDA with `episodes_per_iteration > 1`, falling back to threaded CPU collection or sequential collection as appropriate.
|
| 506 |
+
|
| 507 |
+
The BFS oracle uses a 5-tier priority: (1) kick adjacent doors, (2) BFS to staircase, (3) BFS to frontier, (4) BFS to farthest tile, (5) random cardinal.
|
| 508 |
+
|
| 509 |
+
---
|
| 510 |
+
|
| 511 |
+
## Reward Shaping
|
| 512 |
+
|
| 513 |
+
The environment wrapper applies shaped rewards to guide learning:
|
| 514 |
+
|
| 515 |
+
| Component | Value | Condition |
|
| 516 |
+
|---|---|---|
|
| 517 |
+
| Win bonus | +20.0 | Episode won |
|
| 518 |
+
| BFS progress | +0.5 * (prev_dist - curr_dist) | Closer to staircase |
|
| 519 |
+
| Exploration | +0.05 | New tile visited |
|
| 520 |
+
| Step penalty | -0.01 | Every step |
|
| 521 |
+
|
| 522 |
+
---
|
| 523 |
+
|
| 524 |
+
## Project Structure
|
| 525 |
+
|
| 526 |
+
```
|
| 527 |
+
minihack-ReMDM-planner/
|
| 528 |
+
├── configs/
|
| 529 |
+
│ ├── defaults.yaml Base hyperparameters
|
| 530 |
+
│ ├── smoke.yaml Smoke test overrides
|
| 531 |
+
│ ├── ucl_gpu_bigger_model.yaml UCL GPU (larger model: 384D, 6 heads)
|
| 532 |
+
│ ├── ucl_gpu_learning_behaviour.yaml UCL GPU learning-behaviour study
|
| 533 |
+
│ ├── final_qmul_gpu.yaml Paper run: DAgger + fair offline BC (QMUL H200)
|
| 534 |
+
│ └── final_ucl_gpu.yaml Paper run: DAgger + fair offline BC (UCL 3090 Ti)
|
| 535 |
+
├── environments/ Custom .des scenario files
|
| 536 |
+
├── src/
|
| 537 |
+
│ ├── config.py YAML config loader with CLI overrides
|
| 538 |
+
│ ├── buffer.py ReplayBuffer with offline-protected FIFO
|
| 539 |
+
│ ├── curriculum.py DynamicCurriculum + efficiency_filter
|
| 540 |
+
│ ├── diffusion/
|
| 541 |
+
│ │ ├── schedules.py Linear and cosine noise schedules
|
| 542 |
+
│ │ ├── forward.py Forward masking process q(z_t | x_0)
|
| 543 |
+
│ │ ├── loss.py MDLM ELBO + auxiliary goal loss
|
| 544 |
+
│ │ └── sampling.py ReMDM reverse sampling with remasking
|
| 545 |
+
│ ├── models/
|
| 546 |
+
│ │ └── denoiser.py LocalDiffusionPlannerWithGlobal + ModelEMA
|
| 547 |
+
│ ├── envs/
|
| 548 |
+
│ │ ├── minihack_env.py AdvancedObservationEnv + BFS oracle
|
| 549 |
+
│ │ └── discovery.py Env registry scanner + inference benchmark
|
| 550 |
+
│ └── planners/
|
| 551 |
+
│ ├── collect.py run_model_episode + DataCollector
|
| 552 |
+
│ ├── collect_oracle.py Standalone oracle data collection
|
| 553 |
+
│ ├── offline.py Offline BC trainer
|
| 554 |
+
│ ├── online.py DAgger Trainer + checkpointing
|
| 555 |
+
│ ├── inference.py Evaluator + result formatting
|
| 556 |
+
│ ├── baselines.py SB3 + Decision Transformer baselines
|
| 557 |
+
│ ├── smoke.py Smoke-test runner
|
| 558 |
+
│ └── logging.py Centralised W&B + stdout logging
|
| 559 |
+
├── experiments/
|
| 560 |
+
│ └── rl_finetuning/ RL fine-tuning ablation suite
|
| 561 |
+
│ ├── run_ablations.py CLI entry point
|
| 562 |
+
│ ├── configs/ Ablation config files
|
| 563 |
+
│ ├── ablations/ Loss, optimizer, registry, training
|
| 564 |
+
│ ├── diagnostics/ Gradient, representation, timestep metrics
|
| 565 |
+
│ └── analysis/ Plots, tables, reports
|
| 566 |
+
├── scripts/
|
| 567 |
+
│ ├── hf_upload.py HuggingFace Hub upload utility
|
| 568 |
+
│ └── profile_dagger.py DAgger iteration profiler
|
| 569 |
+
├── main.py CLI entry point (smoke/collect/offline/dagger/inference/baselines)
|
| 570 |
+
├── pyproject.toml PEP 621 project metadata + dependencies
|
| 571 |
+
├── uv.lock Deterministic lockfile
|
| 572 |
+
└── README.md
|
| 573 |
+
```
|
| 574 |
+
|
| 575 |
+
---
|
| 576 |
+
|
| 577 |
+
## W&B Metric Namespaces
|
| 578 |
+
|
| 579 |
+
| Namespace | Contents |
|
| 580 |
+
|---|---|
|
| 581 |
+
| `diffusion/` | `loss`, `loss_diff`, `loss_aux` |
|
| 582 |
+
| `train/` | `buffer_size`, `buffer_online_frac`, `model_won`, `added_to_buffer`, `episodes_collected`, `model_steps`, `oracle_steps`, `efficiency_ratio`, `lr`, `grad_norm`, `global_gate`, `env_steps`, `progress` |
|
| 583 |
+
| `speed/` | `iter_time_sec`, `collect_time_sec`, `train_step_time_sec`, `samples_per_sec`, `env_steps_per_sec`, `gpu_memory_mb` |
|
| 584 |
+
| `perf/` | `iter_time_s`, `collect_time_s`, `train_time_s`, `grad_steps_per_sec` (legacy compat) |
|
| 585 |
+
| `model/` | `param_norm`, `param_drift_from_init`, `ema_gate_value` (every 10 iters) |
|
| 586 |
+
| `eval_id/{env}/` | Per-environment win rate, avg steps, avg reward (in-distribution) |
|
| 587 |
+
| `eval_ood/{env}/` | Per-environment win rate, avg steps, avg reward (out-of-distribution) |
|
| 588 |
+
| `eval_id/` | `mean_win_rate` |
|
| 589 |
+
| `eval_ood/` | `mean_win_rate` |
|
| 590 |
+
| `curriculum/{env}/` | `win_rate` per training environment |
|
| 591 |
+
| `ckpt_eval_id/`, `ckpt_eval_ood/` | Per-env metrics at checkpoint time |
|
| 592 |
+
| `ckpt_eval/` | `id_winrate`, `ood_winrate` |
|
| 593 |
+
| `offline/` | `final_loss`, `total_steps`, `total_timesteps` (summary only) |
|
| 594 |
+
|
| 595 |
+
Both DAgger and offline BC emit to `eval_id/` and `eval_ood/` namespaces.
|
| 596 |
+
Offline mode reuses the same `Evaluator` and EMA-weight evaluation path as
|
| 597 |
+
DAgger, so curves are directly comparable across modes.
|
| 598 |
+
|
| 599 |
+
---
|
| 600 |
+
|
| 601 |
+
## Checkpoint Format
|
| 602 |
+
|
| 603 |
+
**DAgger checkpoint:**
|
| 604 |
+
|
| 605 |
+
```python
|
| 606 |
+
{
|
| 607 |
+
"model_state_dict": ...,
|
| 608 |
+
"ema_state_dict": ...,
|
| 609 |
+
"optimizer_state_dict": ...,
|
| 610 |
+
"scheduler_state_dict": ...,
|
| 611 |
+
"curriculum_state": {...},
|
| 612 |
+
"iteration": int,
|
| 613 |
+
"env_steps": int, # cumulative env.step() calls so far
|
| 614 |
+
"wandb_run_id": str | None,
|
| 615 |
+
"rng_states": {"torch", "numpy", "python"},
|
| 616 |
+
}
|
| 617 |
+
```
|
| 618 |
+
|
| 619 |
+
**Offline BC checkpoint** (step-level, file `offline_step{N}.pth`, saved when
|
| 620 |
+
`checkpoint_every_timesteps > 0`):
|
| 621 |
+
|
| 622 |
+
```python
|
| 623 |
+
{
|
| 624 |
+
"model_state_dict": ...,
|
| 625 |
+
"ema_state_dict": ...,
|
| 626 |
+
"optimizer_state_dict": ...,
|
| 627 |
+
"scheduler_state_dict": ...,
|
| 628 |
+
"step": int,
|
| 629 |
+
"env_steps": int, # step * offline_batch_size
|
| 630 |
+
"wandb_run_id": str | None,
|
| 631 |
+
}
|
| 632 |
+
```
|
| 633 |
+
|
| 634 |
+
**Offline final checkpoint** (saved at the end of offline training):
|
| 635 |
+
|
| 636 |
+
```python
|
| 637 |
+
{
|
| 638 |
+
"model_state_dict": ...,
|
| 639 |
+
"ema_state_dict": ...,
|
| 640 |
+
"wandb_run_id": str | None,
|
| 641 |
+
}
|
| 642 |
+
```
|
| 643 |
+
|
| 644 |
+
Inference uses EMA weights by default. Pass `--no-ema` to use training weights.
|
| 645 |
+
|
| 646 |
+
### W&B Artifacts
|
| 647 |
+
|
| 648 |
+
Checkpoints are automatically uploaded as versioned W&B artifacts (type `"model"`) at each checkpoint save. Each artifact contains the `.pth` weights and a `config.yaml` snapshot of all hyperparameters used.
|
| 649 |
+
|
| 650 |
+
To resume from an artifact:
|
| 651 |
+
|
| 652 |
+
```bash
|
| 653 |
+
# DAgger resume
|
| 654 |
+
python main.py --mode dagger \
|
| 655 |
+
--wandb-artifact entity/project/checkpoint-iter3000:latest
|
| 656 |
+
|
| 657 |
+
# Inference
|
| 658 |
+
python main.py --mode inference \
|
| 659 |
+
--wandb-artifact entity/project/checkpoint-iter8000:v2
|
| 660 |
+
```
|
| 661 |
+
|
| 662 |
+
The artifact reference format is `entity/project/artifact-name:version` where version is `latest`, `v0`, `v1`, etc.
|
| 663 |
+
|
| 664 |
+
### W&B Run Resumption
|
| 665 |
+
|
| 666 |
+
All training loops save the W&B run ID in their checkpoints. When resuming from a checkpoint, the run ID is automatically extracted and passed to `wandb.init(resume="must")`, so metrics continue on the same W&B curves with no gaps.
|
| 667 |
+
|
| 668 |
+
```bash
|
| 669 |
+
# DAgger: automatic -- run ID is read from the checkpoint
|
| 670 |
+
python main.py --mode dagger --checkpoint checkpoints/iter2000.pth
|
| 671 |
+
|
| 672 |
+
# Offline BC: automatic
|
| 673 |
+
python main.py --mode offline --data dataset.pt \
|
| 674 |
+
--checkpoint checkpoints/offline_step2000.pth
|
| 675 |
+
|
| 676 |
+
# Manual override (e.g. checkpoint saved before this feature was added):
|
| 677 |
+
python main.py --mode dagger --checkpoint old_checkpoint.pth \
|
| 678 |
+
wandb_resume_id=abc123xyz
|
| 679 |
+
|
| 680 |
+
# Ablation suite:
|
| 681 |
+
python experiments/rl_finetuning/run_ablations.py \
|
| 682 |
+
--checkpoint path/to/ckpt.pth --all --use_wandb \
|
| 683 |
+
--wandb_resume_id abc123xyz
|
| 684 |
+
```
|
| 685 |
+
|
| 686 |
+
The run ID is visible in the W&B dashboard URL: `wandb.ai/.../runs/<run-id>`.
|
| 687 |
+
|
| 688 |
+
---
|
| 689 |
+
|
| 690 |
+
## Performance Tuning
|
| 691 |
+
|
| 692 |
+
Three config keys control performance optimisations. Defaults are set for GPU training; override for CPU or different hardware.
|
| 693 |
+
|
| 694 |
+
### Mixed precision (`use_amp: true`)
|
| 695 |
+
|
| 696 |
+
Wraps training forward/backward in `torch.amp.autocast("cuda")` with `GradScaler`. Active in both offline BC and DAgger training.
|
| 697 |
+
|
| 698 |
+
- **Measured speedup:** 2.2x on gradient steps, 1.7x on full smoke test wall-clock
|
| 699 |
+
- **Memory:** peak GPU stays ~16 GB at B=3584 (same as FP32 due to embedding-heavy model)
|
| 700 |
+
- **Correctness:** loss trajectory and win rates statistically equivalent to FP32
|
| 701 |
+
- **When to use:** always on GPU. No effect on CPU (autocast is a no-op)
|
| 702 |
+
- **Default:** `false` in `defaults.yaml`; enabled in GPU-specific configs
|
| 703 |
+
|
| 704 |
+
### torch.compile (`torch_compile: true`)
|
| 705 |
+
|
| 706 |
+
Applies `torch.compile(model, mode="default")` before training. Falls back gracefully if no C compiler is found (common on managed GPU nodes).
|
| 707 |
+
|
| 708 |
+
- **Measured speedup:** none beyond AMP alone. Not recommended for primary training.
|
| 709 |
+
- **Default:** `false` in `defaults.yaml`; opt in via the `final_*_gpu.yaml` configs.
|
| 710 |
+
- **When to use:** experimental only. May help on future PyTorch versions with better dynamic shape support.
|
| 711 |
+
|
| 712 |
+
### Parallel collection (`num_collection_workers: N`)
|
| 713 |
+
|
| 714 |
+
DAgger episode collection supports three strategies (auto-selected):
|
| 715 |
+
1. **GPU-batched** (default on CUDA with `episodes_per_iteration > 1`): all envs in lockstep
|
| 716 |
+
2. **Threaded CPU** (fallback when `num_collection_workers > 0`): `ThreadPoolExecutor` with CPU model copies
|
| 717 |
+
3. **Sequential** (reference behaviour): one episode at a time
|
| 718 |
+
|
| 719 |
+
- **Default:** `8` workers in `defaults.yaml`
|
| 720 |
+
- **When to use:** GPU-batched is preferred; workers primarily affect the CPU fallback path
|
| 721 |
+
|
| 722 |
+
### Profiling
|
| 723 |
+
|
| 724 |
+
Run `python scripts/profile_dagger.py [key=value ...]` to profile DAgger iteration components. Supports all config overrides (e.g., `use_amp=true`).
|
| 725 |
+
|
| 726 |
+
---
|
| 727 |
+
|
| 728 |
+
## Implementation Notes
|
| 729 |
+
|
| 730 |
+
- **MDLM loss** returns `0.0` (not NaN) when no masked positions exist in the batch. Uses global averaging by default; SUBS importance weighting is opt-in via `use_importance_weighting: true`.
|
| 731 |
+
- **PAD tokens** are never masked during the forward process and are excluded from the loss.
|
| 732 |
+
- **Sampling paths:** Evaluation uses stochastic ReMDM sampling (temperature, top-K, remasking) with `diffusion_steps_eval` (default 10) steps. DAgger collection uses greedy argmax sampling (deterministic, no remasking) with `diffusion_steps_collect` (default 5) steps for faster rollouts.
|
| 733 |
+
- **`remdm_sample`** guarantees a fully committed output (no MASK tokens) via a final-step commit and an assertion check. A min-keep 10% safety net prevents degenerate all-masked states.
|
| 734 |
+
- **EMA** shadow weights are updated after every gradient step (not per iteration). The `DataCollector` syncs the latest EMA weights before each rollout.
|
| 735 |
+
- **Curriculum** initialises with a 50/50 prior per environment (configurable via `curriculum_preseed`) and uses bucket-based weights over the rolling win-rate: low `[0, 0.15)` → 0.2, medium `[0.15, 0.85)` → 1.0, high `[0.85, 1.0]` → 0.1.
|
| 736 |
+
- **Replay buffer** pins offline data at the front; only online samples are FIFO-evicted. Returns `None` on empty buffer (callers handle gracefully).
|
| 737 |
+
- **Global gate** initialises at `sigmoid(-3.0) ~ 0.047`, starting nearly closed to prevent the global stream from destabilising early training.
|
| 738 |
+
- **Dropout** is set to 0.0 by default. The discrete diffusion forward masking already regularises; dropout on top is redundant.
|
| 739 |
+
- **DAgger warm-start:** On iteration 0, the buffer is seeded with 3 oracle trajectories per ID environment (12 total), giving the curriculum and training loop data to work with immediately.
|
ablation_assets/diagnosis_decision_tree.png
ADDED
|
ablation_assets/grad_alignment.png
ADDED
|
Git LFS Details
|
ablation_assets/gradient_conflict_map.png
ADDED
|
ablation_assets/group_comparison.png
ADDED
|
ablation_assets/group_summary.csv
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Group,N,Mean,Best,Worst,StdDev
|
| 2 |
+
Baseline,1,0.5625,0.5625,0.5625,0.0
|
| 3 |
+
A,6,0.6021,0.6667,0.5583,0.0358
|
| 4 |
+
B,7,0.4988,0.6542,0.0625,0.1834
|
| 5 |
+
C,7,0.6125,0.6458,0.5833,0.0184
|
ablation_assets/hypothesis_verdicts.csv
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Method,Group,Score,Delta_Baseline,Verdict,Hypothesis
|
| 2 |
+
advantage_clip,B,0.4958,-0.0667,NEUTRAL,If clipping helps: large advantage magnitudes destabilise training
|
| 3 |
+
attention_only,C,0.6167,0.0542,IMPROVEMENT,"If attention-only works: model needs routing updates, not feature updates"
|
| 4 |
+
baseline_rl,Baseline,0.5625,0.0,NEUTRAL,Diagnoses whether the RL signal alone causes collapse
|
| 5 |
+
bc_wins,B,0.5708,0.0083,NEUTRAL,If BC on wins helps: the return weighting is the specific cause
|
| 6 |
+
entropy_bonus,B,0.5708,0.0083,NEUTRAL,If entropy bonus helps: collapse is mode-collapse; not a gradient problem
|
| 7 |
+
ewc,A,0.6667,0.1042,IMPROVEMENT,If EWC helps: forgetting pretrained representations is the proximate cause
|
| 8 |
+
ffn_only,C,0.6083,0.0458,NEUTRAL,If FFN-only works: stored knowledge (FFN as memory) needs updating; not attention
|
| 9 |
+
frozen_backbone,C,0.6167,0.0542,IMPROVEMENT,If frozen backbone helps: deep gradient flow into backbone causes collapse
|
| 10 |
+
gradient_surgery,B,0.6542,0.0917,IMPROVEMENT,If PCGrad helps: gradients are conflicting and resolvable by projection
|
| 11 |
+
head_only,C,0.5958,0.0333,NEUTRAL,If head-only works: backbone representations are fine; only decision boundary needs updating
|
| 12 |
+
kl_penalty,A,0.5583,-0.0042,NEUTRAL,If this helps: catastrophic forgetting is the primary cause; soft regularisation suffices
|
| 13 |
+
layer_ablation_top1,C,0.6208,0.0583,IMPROVEMENT,Minimal unfrozen depth needed; collapse depth correlates with gradient flow depth
|
| 14 |
+
layer_ablation_top2,C,0.6458,0.0833,IMPROVEMENT,Minimal unfrozen depth needed; collapse depth correlates with gradient flow depth
|
| 15 |
+
layer_ablation_top3,C,0.5833,0.0208,NEUTRAL,Minimal unfrozen depth needed; collapse depth correlates with gradient flow depth
|
| 16 |
+
llrd,A,0.625,0.0625,IMPROVEMENT,If LLRD helps: deep gradient flow into early layers corrupts representations
|
| 17 |
+
lora,A,0.6042,0.0417,NEUTRAL,If LoRA works: too many unconstrained degrees of freedom cause collapse
|
| 18 |
+
low_t,B,0.55,-0.0125,NEUTRAL,If low-t helps: high-t (coarse-structure) gradients are biased
|
| 19 |
+
mixed_replay,A,0.5833,0.0208,NEUTRAL,If mixed replay helps: online data distribution alone is too corrupted
|
| 20 |
+
normalized_adv,B,0.0625,-0.5,COLLAPSE,If std normalisation helps: simple mean normalisation is too loose
|
| 21 |
+
t_curriculum,B,0.5875,0.025,NEUTRAL,If curriculum helps: ordering of learning signals matters
|
| 22 |
+
trust_region_kl,A,0.575,0.0125,NEUTRAL,If hard constraint helps: soft KL is insufficient -- a hard boundary is needed
|
ablation_assets/main_results.csv
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Method,Group,Score,Delta_Pretrained,Delta_Baseline,Verdict
|
| 2 |
+
ewc,A,0.6667,0.0792,0.1042,IMPROVEMENT
|
| 3 |
+
gradient_surgery,B,0.6542,0.0667,0.0917,IMPROVEMENT
|
| 4 |
+
layer_ablation_top2,C,0.6458,0.0583,0.0833,IMPROVEMENT
|
| 5 |
+
llrd,A,0.625,0.0375,0.0625,IMPROVEMENT
|
| 6 |
+
layer_ablation_top1,C,0.6208,0.0333,0.0583,IMPROVEMENT
|
| 7 |
+
frozen_backbone,C,0.6167,0.0292,0.0542,IMPROVEMENT
|
| 8 |
+
attention_only,C,0.6167,0.0292,0.0542,IMPROVEMENT
|
| 9 |
+
ffn_only,C,0.6083,0.0208,0.0458,NEUTRAL
|
| 10 |
+
lora,A,0.6042,0.0167,0.0417,NEUTRAL
|
| 11 |
+
head_only,C,0.5958,0.0083,0.0333,NEUTRAL
|
| 12 |
+
t_curriculum,B,0.5875,-0.0,0.025,NEUTRAL
|
| 13 |
+
mixed_replay,A,0.5833,-0.0042,0.0208,NEUTRAL
|
| 14 |
+
layer_ablation_top3,C,0.5833,-0.0042,0.0208,NEUTRAL
|
| 15 |
+
trust_region_kl,A,0.575,-0.0125,0.0125,NEUTRAL
|
| 16 |
+
entropy_bonus,B,0.5708,-0.0167,0.0083,NEUTRAL
|
| 17 |
+
bc_wins,B,0.5708,-0.0167,0.0083,NEUTRAL
|
| 18 |
+
baseline_rl,Baseline,0.5625,-0.025,0.0,NEUTRAL
|
| 19 |
+
kl_penalty,A,0.5583,-0.0292,-0.0042,NEUTRAL
|
| 20 |
+
low_t,B,0.55,-0.0375,-0.0125,NEUTRAL
|
| 21 |
+
advantage_clip,B,0.4958,-0.0917,-0.0667,NEUTRAL
|
| 22 |
+
normalized_adv,B,0.0625,-0.525,-0.5,COLLAPSE
|
ablation_assets/per_env_delta.png
ADDED
|
Git LFS Details
|
ablation_assets/per_env_win_rates.csv
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Method,MiniHack-Room-Random-5x5-v0,MiniHack-Room-Random-15x15-v0,MiniHack-Corridor-R2-v0,MiniHack-MazeWalk-9x9-v0
|
| 2 |
+
advantage_clip,0.9,0.95,0.25,0.15
|
| 3 |
+
attention_only,1.0,0.9,0.5,0.35
|
| 4 |
+
baseline_rl,1.0,0.7,0.3,0.1
|
| 5 |
+
bc_wins,0.9,0.7,0.4,0.1
|
| 6 |
+
entropy_bonus,0.9,0.4,0.45,0.15
|
| 7 |
+
ewc,1.0,0.85,0.6,0.3
|
| 8 |
+
ffn_only,1.0,1.0,0.35,0.3
|
| 9 |
+
frozen_backbone,0.95,0.9,0.5,0.4
|
| 10 |
+
gradient_surgery,1.0,0.9,0.45,0.25
|
| 11 |
+
head_only,1.0,0.8,0.3,0.3
|
| 12 |
+
kl_penalty,0.9,1.0,0.2,0.45
|
| 13 |
+
layer_ablation_top1,0.9,0.75,0.2,0.2
|
| 14 |
+
layer_ablation_top2,0.95,0.9,0.35,0.25
|
| 15 |
+
layer_ablation_top3,0.95,0.75,0.4,0.45
|
| 16 |
+
llrd,0.8,0.9,0.4,0.25
|
| 17 |
+
lora,1.0,0.75,0.2,0.2
|
| 18 |
+
low_t,1.0,0.6,0.45,0.15
|
| 19 |
+
mixed_replay,0.95,0.75,0.45,0.2
|
| 20 |
+
normalized_adv,0.1,0.0,0.1,0.1
|
| 21 |
+
t_curriculum,1.0,0.8,0.3,0.15
|
| 22 |
+
trust_region_kl,0.95,0.75,0.45,0.25
|
ablation_assets/repr_drift.png
ADDED
|
ablation_assets/results.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
ablation_assets/score_comparison.png
ADDED
|
Git LFS Details
|
ablation_assets/score_delta.png
ADDED
|
checkpoint_inference.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4619870dd5fcdb2f1575c4a458e128f3da31f9a75a73562d9d316f60f288df20
|
| 3 |
+
size 20991233
|
configs/defaults.yaml
ADDED
|
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ── Environments ──────────────────────────────────────────────────────
|
| 2 |
+
id_envs:
|
| 3 |
+
- MiniHack-Room-Random-5x5-v0
|
| 4 |
+
- MiniHack-Room-Random-15x15-v0
|
| 5 |
+
- MiniHack-Corridor-R2-v0
|
| 6 |
+
- MiniHack-MazeWalk-9x9-v0
|
| 7 |
+
ood_envs:
|
| 8 |
+
- MiniHack-Room-Dark-15x15-v0
|
| 9 |
+
- MiniHack-Corridor-R5-v0
|
| 10 |
+
- MiniHack-MazeWalk-45x19-v0
|
| 11 |
+
crop_size: 9
|
| 12 |
+
map_h: 21
|
| 13 |
+
map_w: 79
|
| 14 |
+
action_dim: 12
|
| 15 |
+
mask_token: 12
|
| 16 |
+
pad_token: 13
|
| 17 |
+
|
| 18 |
+
# ── Model ─────────────────────────────────────────────────────────────
|
| 19 |
+
n_embd: 256
|
| 20 |
+
n_head: 4
|
| 21 |
+
n_layer: 4
|
| 22 |
+
n_global_tokens: 8
|
| 23 |
+
seq_len: 64
|
| 24 |
+
global_gate_init: -3.0
|
| 25 |
+
# Transformer dropout. 0.0 is deliberate — discrete diffusion forward masking
|
| 26 |
+
# already regularises; dropout on top is redundant.
|
| 27 |
+
dropout: 0.0
|
| 28 |
+
ema_decay: 0.999
|
| 29 |
+
|
| 30 |
+
# ── Diffusion (MDLM) ─────────────────────────────────────────────────
|
| 31 |
+
noise_schedule: linear
|
| 32 |
+
num_diffusion_steps: 100
|
| 33 |
+
loss_weight_clip: 1000.0
|
| 34 |
+
label_smoothing: 0.0
|
| 35 |
+
# Use SUBS importance weighting w(t) in loss. Off by default (flat average
|
| 36 |
+
# matching reference). Enable for MDLM ELBO experiments.
|
| 37 |
+
use_importance_weighting: false
|
| 38 |
+
# ReMDM stochastic remask base fraction
|
| 39 |
+
eta: 0.15
|
| 40 |
+
# Remasking strategy: rescale | cap | conf
|
| 41 |
+
remask_strategy: conf
|
| 42 |
+
|
| 43 |
+
# ── Inference ─────────────────────────────────────────────────────────
|
| 44 |
+
# Number of reverse denoising steps at inference.
|
| 45 |
+
# Reference uses 5 (aggressive). Higher = better quality, slower.
|
| 46 |
+
diffusion_steps_eval: 10
|
| 47 |
+
# Denoising steps during DAgger collection. Fewer than eval since
|
| 48 |
+
# collection only needs "good enough" plans for efficiency comparison.
|
| 49 |
+
diffusion_steps_collect: 5
|
| 50 |
+
temperature: 0.5
|
| 51 |
+
top_k: 4
|
| 52 |
+
replan_every: 16
|
| 53 |
+
# Soft-penalise hazardous cardinal actions during stochastic sampling.
|
| 54 |
+
# Not active in the reference evaluation pipeline; off by default.
|
| 55 |
+
physics_aware_sampling: false
|
| 56 |
+
|
| 57 |
+
# ── Training budget (unified) ────────────────────────────────────────
|
| 58 |
+
# Total environment-step budget for training. Matches the SB3
|
| 59 |
+
# `total_timesteps` convention so runs can be compared apples-to-apples
|
| 60 |
+
# across offline BC, DAgger, and SB3 baselines.
|
| 61 |
+
#
|
| 62 |
+
# • DAgger: cumulative env.step() calls across model + oracle rollouts.
|
| 63 |
+
# Training stops once this budget is exhausted. `episodes_per_iteration`
|
| 64 |
+
# and `grad_steps_per_iteration` control the collect/train ratio;
|
| 65 |
+
# they do NOT change the total compute.
|
| 66 |
+
# • Offline BC: each dataset sample corresponds to one env.step() that
|
| 67 |
+
# collected it. Total gradient steps = total_timesteps // batch_size,
|
| 68 |
+
# i.e. the training consumes exactly `total_timesteps` samples.
|
| 69 |
+
#
|
| 70 |
+
# Fairness invariant — parameters that scale AUTOMATICALLY with this
|
| 71 |
+
# budget:
|
| 72 |
+
# * offline LR cosine T_max (= total_timesteps / offline_batch_size)
|
| 73 |
+
# * id_eval_every_timesteps / ood_eval_every_timesteps (env-step cadence)
|
| 74 |
+
# * checkpoint_every_timesteps (env-step cadence; offline converts via
|
| 75 |
+
# / offline_batch_size)
|
| 76 |
+
#
|
| 77 |
+
# Parameters held FIXED across different budgets (tuning knobs, not
|
| 78 |
+
# fairness knobs):
|
| 79 |
+
# * offline_batch_size, dagger_batch_size — per-step SNR
|
| 80 |
+
# * offline_lr, dagger_lr — peak learning rate
|
| 81 |
+
# * weight_decay, *_grad_clip, efficiency_multiplier, aux_loss_weight,
|
| 82 |
+
# loss_weight_clip, label_smoothing — optimisation regularisers
|
| 83 |
+
# * episodes_per_iteration, grad_steps_per_iteration — the collect/train
|
| 84 |
+
# ratio is itself a design choice; scaling these would confound
|
| 85 |
+
# collection coverage with update density
|
| 86 |
+
# * curriculum_queue_size, buffer_capacity — in absolute units by design
|
| 87 |
+
#
|
| 88 |
+
# Fairness caveat — `ema_decay` is an absolute-update-count constant
|
| 89 |
+
# (half-life ≈ 1 / (1 − decay) steps). If total_timesteps shifts by more
|
| 90 |
+
# than ~2x from the default, the fraction of training covered by the EMA
|
| 91 |
+
# window changes. For very short or very long runs, consider manually
|
| 92 |
+
# setting a matching decay (shorter run → lower decay, longer → higher).
|
| 93 |
+
total_timesteps: 2000000
|
| 94 |
+
|
| 95 |
+
# Evaluation + checkpoint cadence, in env-step units. These scale with
|
| 96 |
+
# total_timesteps so every run gets ~N eval points and ~M checkpoints
|
| 97 |
+
# regardless of budget. For offline BC, the cadence is converted to
|
| 98 |
+
# gradient-step intervals via `/ offline_batch_size`.
|
| 99 |
+
id_eval_every_timesteps: 25000
|
| 100 |
+
ood_eval_every_timesteps: 25000
|
| 101 |
+
checkpoint_every_timesteps: 125000
|
| 102 |
+
|
| 103 |
+
# ── Offline BC ────────────────────────────────────────────────────────
|
| 104 |
+
offline_lr: 0.0003
|
| 105 |
+
offline_batch_size: 3584
|
| 106 |
+
offline_grad_clip: 1.0
|
| 107 |
+
aux_loss_weight: 0.5
|
| 108 |
+
# ── Offline BC compute-match overrides (all opt-in, default null) ───
|
| 109 |
+
# These exist solely to support paper-fair comparisons against a
|
| 110 |
+
# specific DAgger iteration count, where the env-step / grad-step
|
| 111 |
+
# ratio between the two modes (~50x) makes a single shared
|
| 112 |
+
# `total_timesteps` budget unfair to one side. When null, offline
|
| 113 |
+
# falls back to the env-step-derived defaults.
|
| 114 |
+
#
|
| 115 |
+
# offline_total_grad_steps: pin gradient budget (e.g. 60000 to match
|
| 116 |
+
# 600 DAgger iters × 100 grad_steps_per_iter).
|
| 117 |
+
# offline_eval_every_grad_steps: ID/OOD eval cadence in grad-step
|
| 118 |
+
# units. Without this, dense env-step cadence yields ~500 evals.
|
| 119 |
+
# offline_checkpoint_every_grad_steps: checkpoint cadence in grad-step
|
| 120 |
+
# units. Same motivation as eval cadence.
|
| 121 |
+
# offline_buffer_capacity: distinct from `buffer_capacity` (which is
|
| 122 |
+
# sized for DAgger's small FIFO buffer). The full BC dataset has
|
| 123 |
+
# ~500k–1M sliding windows; using DAgger's cap silently truncates.
|
| 124 |
+
offline_total_grad_steps: null
|
| 125 |
+
offline_eval_every_grad_steps: null
|
| 126 |
+
offline_checkpoint_every_grad_steps: null
|
| 127 |
+
offline_buffer_capacity: null
|
| 128 |
+
|
| 129 |
+
# ── DAgger ────────────────────────────────────────────────────────────
|
| 130 |
+
dagger_lr: 0.00003
|
| 131 |
+
dagger_batch_size: 3584
|
| 132 |
+
dagger_grad_clip: 1.0
|
| 133 |
+
weight_decay: 0.0001
|
| 134 |
+
buffer_capacity: 10000
|
| 135 |
+
episodes_per_iteration: 30
|
| 136 |
+
grad_steps_per_iteration: 100
|
| 137 |
+
efficiency_multiplier: 1.5
|
| 138 |
+
curriculum_queue_size: 100
|
| 139 |
+
# Pre-seed curriculum queues with 50/50 prior for uniform early sampling.
|
| 140 |
+
curriculum_preseed: true
|
| 141 |
+
eval_episodes_per_env: 50
|
| 142 |
+
checkpoint_eval_episodes: 50
|
| 143 |
+
|
| 144 |
+
# ── Performance ──────────────────────────────────────────────────────
|
| 145 |
+
# Mixed-precision (FP16) training via torch.cuda.amp.
|
| 146 |
+
# Speeds up forward/backward ~1.5-2x on GPU. No effect on CPU.
|
| 147 |
+
use_amp: false
|
| 148 |
+
# torch.compile the model for fused kernels (experimental).
|
| 149 |
+
# May cause slow first iteration due to compilation. No effect on CPU.
|
| 150 |
+
torch_compile: false
|
| 151 |
+
# Number of parallel workers for DAgger episode collection.
|
| 152 |
+
# 0 = sequential (reference behaviour). Recommended: 4-8 on multi-core.
|
| 153 |
+
num_collection_workers: 8
|
| 154 |
+
|
| 155 |
+
# ── Data Collection ─────────────────────────────────────────────────
|
| 156 |
+
# Oracle episodes per ID environment for --mode collect.
|
| 157 |
+
collect_episodes_per_env: 5000
|
| 158 |
+
# Parallel environment workers for collection.
|
| 159 |
+
collect_num_workers: 8
|
| 160 |
+
# Output path for collected dataset.
|
| 161 |
+
collect_output: "data/dataset.pt"
|
| 162 |
+
|
| 163 |
+
# ── Checkpointing & Logging ──────────────────────────────────────────
|
| 164 |
+
checkpoint_dir: checkpoints
|
| 165 |
+
save_policy: true
|
| 166 |
+
hub_run_id: null
|
| 167 |
+
hub_repo_id: null
|
| 168 |
+
use_wandb: true
|
| 169 |
+
wandb_project: remdm-minihack
|
| 170 |
+
wandb_entity: "mathis-weil-university-college-london-ucl-"
|
| 171 |
+
wandb_run_name: null
|
| 172 |
+
wandb_resume_id: null
|
| 173 |
+
offline_log_every: 10
|
| 174 |
+
seed: null
|
| 175 |
+
|
| 176 |
+
# ── SB3 / DT baselines ───────────────────────────────────────────────
|
| 177 |
+
# Baselines compared head-to-head against the diffusion planner.
|
| 178 |
+
# Entry point:
|
| 179 |
+
# python main.py --mode baselines --algo {ppo,dqn,a2c,ppo-rnn,bc,dt}
|
| 180 |
+
#
|
| 181 |
+
# Algorithm families:
|
| 182 |
+
# * SB3 RL (ppo, a2c, dqn, ppo-rnn): consume `cfg.total_timesteps` as
|
| 183 |
+
# the env-step training budget — same convention as DAgger / offline
|
| 184 |
+
# BC. Use a custom MiniHack CNN feature extractor over the dict
|
| 185 |
+
# observation {"local": (1,9,9), "global": (1,21,79)}.
|
| 186 |
+
# * Behavioural Cloning (bc): collects oracle trajectories, trains an
|
| 187 |
+
# SB3 ActorCriticPolicy with a native PyTorch CE loop, evaluates on
|
| 188 |
+
# ID + OOD environments.
|
| 189 |
+
# * Decision Transformer (dt): collects oracle trajectories with
|
| 190 |
+
# return-to-go labels, trains a small causal transformer over
|
| 191 |
+
# interleaved (R, s, a) tokens, evaluates with target-return
|
| 192 |
+
# conditioning on ID + OOD environments.
|
| 193 |
+
#
|
| 194 |
+
# Number of parallel SB3 SubprocVecEnv workers per ID environment.
|
| 195 |
+
# Effective n_envs = baselines_n_envs_per_id * len(id_envs). Default = 2
|
| 196 |
+
# → 8 parallel envs over the 4 ID maps.
|
| 197 |
+
baselines_n_envs_per_id: 2
|
| 198 |
+
# DQN replay buffer capacity (transitions). Used only for --algo dqn.
|
| 199 |
+
baselines_dqn_buffer_size: 100000
|
| 200 |
+
# SB3 EvalCallback cadence in env-steps. Independent from
|
| 201 |
+
# id/ood_eval_every_timesteps because SB3's eval pipeline is per
|
| 202 |
+
# vector-env tick, not shared with the diffusion planner's evaluator.
|
| 203 |
+
baselines_eval_freq_env_steps: 10000
|
| 204 |
+
# Episodes per env at every eval trigger AND at the final BC / DT
|
| 205 |
+
# manual evaluation pass. Falls back to eval_episodes_per_env (50) when
|
| 206 |
+
# null so the comparison stays apples-to-apples with DAgger evals.
|
| 207 |
+
baselines_eval_episodes_per_env: null
|
| 208 |
+
|
| 209 |
+
# ── BC baseline ──────────────────────────────────────────────────────
|
| 210 |
+
# Oracle trajectories collected per ID environment (seeds 0..N-1).
|
| 211 |
+
# 5000 matches the offline BC dataset scale used by ReMDM.
|
| 212 |
+
baselines_bc_oracle_episodes_per_env: 5000
|
| 213 |
+
baselines_bc_epochs: 50
|
| 214 |
+
baselines_bc_batch_size: 256
|
| 215 |
+
baselines_bc_lr: 0.0003
|
| 216 |
+
|
| 217 |
+
# ── Decision Transformer baseline ────────────────────────────────────
|
| 218 |
+
# 5000 trajectories per ID env to match the BC / ReMDM data scale.
|
| 219 |
+
baselines_dt_oracle_episodes_per_env: 5000
|
| 220 |
+
baselines_dt_epochs: 50
|
| 221 |
+
baselines_dt_context_len: 64
|
| 222 |
+
baselines_dt_embed_dim: 256
|
| 223 |
+
baselines_dt_n_layers: 4
|
| 224 |
+
baselines_dt_n_heads: 4
|
| 225 |
+
baselines_dt_lr: 0.0003
|
| 226 |
+
baselines_dt_batch_size: 256
|
| 227 |
+
# Maximum episode length covered by DT positional embeddings. MUST be
|
| 228 |
+
# >= the longest oracle trajectory observed during data collection.
|
| 229 |
+
# Aligned with baselines_dt_eval_max_steps so positional embeddings
|
| 230 |
+
# cover the full eval-cap horizon.
|
| 231 |
+
baselines_dt_max_ep_len: 200
|
| 232 |
+
# DT eval rollout cap (steps before truncating an episode as a loss).
|
| 233 |
+
baselines_dt_eval_max_steps: 200
|
| 234 |
+
|
| 235 |
+
# ── Output / W&B ─────────────────────────────────────────────────────
|
| 236 |
+
# Separate W&B project for baselines (kept distinct from the main
|
| 237 |
+
# remdm-minihack project so baseline runs don't pollute training
|
| 238 |
+
# leaderboards). Set to null to fall back to wandb_project.
|
| 239 |
+
baselines_wandb_project: remdm-baselines
|
| 240 |
+
# Where per-seed checkpoints, SB3 logs, and aggregated results JSON
|
| 241 |
+
# are written. Resolved relative to the project root unless absolute.
|
| 242 |
+
baselines_output_dir: outputs/baselines
|
configs/final_qmul_gpu.yaml
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# =============================================================================
|
| 2 |
+
# QMUL H200 GPU — final paper run config
|
| 3 |
+
# =============================================================================
|
| 4 |
+
#
|
| 5 |
+
# This single config drives BOTH the final DAgger run that produced
|
| 6 |
+
# `checkpoint_final/online/final.pth` AND the compute-matched offline
|
| 7 |
+
# BC baseline used for the paper comparison.
|
| 8 |
+
#
|
| 9 |
+
# --mode dagger → reproduces the iter600 DAgger checkpoint recipe
|
| 10 |
+
# --mode offline → trains a fair offline BC baseline against it
|
| 11 |
+
#
|
| 12 |
+
# ── Fairness analysis ───────────────────────────────────────────────
|
| 13 |
+
#
|
| 14 |
+
# DAgger compute at iter600 (the checkpointed model):
|
| 15 |
+
# 600 iters × 100 grad_steps_per_iter × 2048 batch_size
|
| 16 |
+
# = 60,000 AdamW updates
|
| 17 |
+
# = 122,880,000 sample-equivalents
|
| 18 |
+
#
|
| 19 |
+
# The fair offline BC baseline matches this exactly:
|
| 20 |
+
# offline_total_grad_steps = 60,000 (override; pinned)
|
| 21 |
+
# offline_batch_size = 2048 (matches DAgger; same SNR)
|
| 22 |
+
# weight_decay, grad_clip, aux_loss_weight, model arch, diffusion
|
| 23 |
+
# params: all matched. Model is identical between modes.
|
| 24 |
+
#
|
| 25 |
+
# LR strategy follows "best-of-each-method" rather than identical
|
| 26 |
+
# optimisers — DAgger's 3e-5 constant is tuned for online refinement,
|
| 27 |
+
# offline's 3e-4 cosine→3e-5 is BC standard from-scratch. Both
|
| 28 |
+
# converge to the same effective late-training LR.
|
| 29 |
+
#
|
| 30 |
+
# Eval/checkpoint cadence is matched in *count* across modes (12 evals,
|
| 31 |
+
# 6 checkpoints per run) via the offline_*_every_grad_steps overrides,
|
| 32 |
+
# because the env-step→grad-step ratio differs by ~50× between modes.
|
| 33 |
+
#
|
| 34 |
+
# ── Hardware ─────────────────────────────────────────────────────────
|
| 35 |
+
#
|
| 36 |
+
# QMUL H200 (constrained VRAM allocation). The DAgger checkpoint was
|
| 37 |
+
# produced on this hardware, so batch_size and AMP settings must
|
| 38 |
+
# stay identical to the original run. AMP + torch.compile + 32-worker
|
| 39 |
+
# collection are the original perf settings.
|
| 40 |
+
|
| 41 |
+
# ── Environments ─────────────────────────────────────────────────────
|
| 42 |
+
id_envs:
|
| 43 |
+
- MiniHack-Room-Random-5x5-v0
|
| 44 |
+
- MiniHack-Room-Random-15x15-v0
|
| 45 |
+
- MiniHack-Corridor-R2-v0
|
| 46 |
+
- MiniHack-MazeWalk-9x9-v0
|
| 47 |
+
ood_envs:
|
| 48 |
+
- MiniHack-Room-Dark-15x15-v0
|
| 49 |
+
- MiniHack-Corridor-R5-v0
|
| 50 |
+
- MiniHack-MazeWalk-45x19-v0
|
| 51 |
+
crop_size: 9
|
| 52 |
+
map_h: 21
|
| 53 |
+
map_w: 79
|
| 54 |
+
action_dim: 12
|
| 55 |
+
mask_token: 12
|
| 56 |
+
pad_token: 13
|
| 57 |
+
|
| 58 |
+
# ── Model (matches checkpoint) ───────────────────────────────────────
|
| 59 |
+
n_embd: 256
|
| 60 |
+
n_head: 4
|
| 61 |
+
n_layer: 4
|
| 62 |
+
n_global_tokens: 8
|
| 63 |
+
seq_len: 64
|
| 64 |
+
global_gate_init: -3.0
|
| 65 |
+
dropout: 0.0
|
| 66 |
+
ema_decay: 0.999
|
| 67 |
+
|
| 68 |
+
# ── Diffusion (MDLM) — matches checkpoint ────────────────────────────
|
| 69 |
+
noise_schedule: linear
|
| 70 |
+
num_diffusion_steps: 100
|
| 71 |
+
loss_weight_clip: 1000.0
|
| 72 |
+
label_smoothing: 0.0
|
| 73 |
+
use_importance_weighting: false
|
| 74 |
+
eta: 0.15
|
| 75 |
+
remask_strategy: conf
|
| 76 |
+
|
| 77 |
+
# ── Inference / sampling — matches checkpoint ────────────────────────
|
| 78 |
+
diffusion_steps_eval: 10
|
| 79 |
+
diffusion_steps_collect: 5
|
| 80 |
+
temperature: 0.5
|
| 81 |
+
top_k: 4
|
| 82 |
+
replan_every: 16
|
| 83 |
+
physics_aware_sampling: false
|
| 84 |
+
|
| 85 |
+
# ── Shared training budget (DAgger only) ─────────────────────────────
|
| 86 |
+
# 5.65M env-steps reproduces the env-step budget consumed at iter600
|
| 87 |
+
# of the original DAgger run. This figure is calibrated against a real
|
| 88 |
+
# DAgger run with the same recipe (`p7wfp67q`, episodes_per_iteration=30,
|
| 89 |
+
# grad_steps_per_iteration=100): summing the per-iter env steps over
|
| 90 |
+
# the first 600 iterations gives 30 × 600 × mean(model_steps + oracle_steps)
|
| 91 |
+
# ≈ 30 × 600 × (198 + 116) ≈ 5.65 M real env.step() calls.
|
| 92 |
+
# (The earlier 3M figure was based on the buggy single-episode env-step
|
| 93 |
+
# accounting in `online.py:155-169` — fixed in the same commit as this
|
| 94 |
+
# config bump.) Used by `--mode dagger` only. Offline mode bypasses
|
| 95 |
+
# this via `offline_total_grad_steps` below — the unified env-step
|
| 96 |
+
# budget is fundamentally unfair when the sample-to-grad-step ratio
|
| 97 |
+
# differs by ~50× between modes.
|
| 98 |
+
total_timesteps: 5650000
|
| 99 |
+
|
| 100 |
+
# Eval/checkpoint cadence in env-step units (DAgger mode).
|
| 101 |
+
# Scaled with the corrected total_timesteps so the run still produces
|
| 102 |
+
# ~12 ID/OOD evals and ~6 checkpoints over its full duration.
|
| 103 |
+
# 470k → ~12 evals; 940k → ~6 checkpoints.
|
| 104 |
+
id_eval_every_timesteps: 470000
|
| 105 |
+
ood_eval_every_timesteps: 470000
|
| 106 |
+
checkpoint_every_timesteps: 940000
|
| 107 |
+
|
| 108 |
+
# Final-eval episode count (used by both ID/OOD eval triggers and
|
| 109 |
+
# checkpoint-time evals; matches the original DAgger run).
|
| 110 |
+
eval_episodes_per_env: 50
|
| 111 |
+
checkpoint_eval_episodes: 50
|
| 112 |
+
weight_decay: 0.0001
|
| 113 |
+
aux_loss_weight: 0.5
|
| 114 |
+
|
| 115 |
+
# ── DAgger (matches checkpoint_final/online/config_iter600.yaml) ─────
|
| 116 |
+
dagger_lr: 0.00003
|
| 117 |
+
dagger_batch_size: 2048
|
| 118 |
+
dagger_grad_clip: 1.0
|
| 119 |
+
buffer_capacity: 10000
|
| 120 |
+
episodes_per_iteration: 30
|
| 121 |
+
grad_steps_per_iteration: 100
|
| 122 |
+
efficiency_multiplier: 1.5
|
| 123 |
+
curriculum_queue_size: 100
|
| 124 |
+
curriculum_preseed: true
|
| 125 |
+
|
| 126 |
+
# ── Offline BC (compute-matched fair baseline) ───────────────────────
|
| 127 |
+
# Per the fairness analysis above:
|
| 128 |
+
# * Same gradient compute as DAgger (60k AdamW updates × 2048 batch)
|
| 129 |
+
# * Same model, diffusion, weight_decay, grad_clip, aux_loss
|
| 130 |
+
# * BC-tuned LR + cosine schedule (best practice from-scratch)
|
| 131 |
+
# * Eval/checkpoint counts matched to DAgger via grad-step overrides
|
| 132 |
+
offline_lr: 0.0003
|
| 133 |
+
offline_batch_size: 2048
|
| 134 |
+
offline_grad_clip: 1.0
|
| 135 |
+
|
| 136 |
+
# Compute pin: 60,000 AdamW updates = exactly DAgger@iter600.
|
| 137 |
+
offline_total_grad_steps: 60000
|
| 138 |
+
|
| 139 |
+
# Eval cadence: 5,000 grad steps → 12 evals (matches DAgger eval count).
|
| 140 |
+
offline_eval_every_grad_steps: 5000
|
| 141 |
+
|
| 142 |
+
# Checkpoint cadence: 10,000 grad steps → 6 checkpoints (matches DAgger).
|
| 143 |
+
offline_checkpoint_every_grad_steps: 10000
|
| 144 |
+
|
| 145 |
+
# Buffer cap for offline mode only — must hold the full pre-collected
|
| 146 |
+
# dataset (~1M sliding windows from 20k oracle trajectories). DAgger's
|
| 147 |
+
# `buffer_capacity: 10000` would silently FIFO-evict 99% of the data.
|
| 148 |
+
offline_buffer_capacity: 1500000
|
| 149 |
+
|
| 150 |
+
# ── Performance (cluster-tuned, matches original DAgger run) ─────────
|
| 151 |
+
use_amp: true
|
| 152 |
+
torch_compile: true
|
| 153 |
+
num_collection_workers: 32
|
| 154 |
+
|
| 155 |
+
# ── Data collection (for offline BC dataset) ─────────────────────────
|
| 156 |
+
# 5000 eps × 4 ID envs = 20k oracle trajectories. Strictly more than
|
| 157 |
+
# the ~7k unique trajectories DAgger had in its filtered buffer at
|
| 158 |
+
# iter600 — offline always gets a richer pre-collected pool, which is
|
| 159 |
+
# the standard fairness asymmetry in BC vs DAgger comparisons.
|
| 160 |
+
collect_episodes_per_env: 5000
|
| 161 |
+
collect_num_workers: 32
|
| 162 |
+
collect_output: data/oracle_bc_qmul.pt
|
| 163 |
+
|
| 164 |
+
# ── Checkpointing & Logging ──────────────────────────────────────────
|
| 165 |
+
checkpoint_dir: checkpoints_qmul
|
| 166 |
+
save_policy: true
|
| 167 |
+
hub_run_id: null
|
| 168 |
+
hub_repo_id: null
|
| 169 |
+
use_wandb: true
|
| 170 |
+
wandb_project: remdm-minihack
|
| 171 |
+
wandb_entity: "mathis-weil-university-college-london-ucl-"
|
| 172 |
+
wandb_run_name: null
|
| 173 |
+
# wandb_resume_id intentionally omitted — fresh runs by default.
|
| 174 |
+
# Override on the CLI (`wandb_resume_id=...`) to continue an existing run.
|
| 175 |
+
offline_log_every: 50
|
| 176 |
+
seed: null
|
configs/final_ucl_gpu.yaml
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# =============================================================================
|
| 2 |
+
# UCL 3090 Ti GPU — final paper run config
|
| 3 |
+
# =============================================================================
|
| 4 |
+
#
|
| 5 |
+
# This single config drives BOTH the final DAgger run and the
|
| 6 |
+
# compute-matched offline BC baseline used for the paper comparison.
|
| 7 |
+
#
|
| 8 |
+
# --mode dagger → reproduces the iter600 DAgger checkpoint recipe
|
| 9 |
+
# --mode offline → trains a fair offline BC baseline against it
|
| 10 |
+
#
|
| 11 |
+
# All training hyperparameters are IDENTICAL to `final_qmul_gpu.yaml`
|
| 12 |
+
# so cross-cluster runs produce directly comparable results. The only
|
| 13 |
+
# differences are hardware-specific perf knobs (collection workers).
|
| 14 |
+
# See the QMUL config header for the full fairness analysis.
|
| 15 |
+
#
|
| 16 |
+
# ── Hardware ─────────────────────────────────────────────────────────
|
| 17 |
+
#
|
| 18 |
+
# UCL 3090 Ti — 24 GB VRAM. The 4-layer × 256-dim model with
|
| 19 |
+
# batch=2048 and AMP fits with comfortable headroom (~6-8 GB peak).
|
| 20 |
+
# Lower core count than the QMUL cluster, so collection workers
|
| 21 |
+
# capped at 8.
|
| 22 |
+
|
| 23 |
+
# ── Environments ─────────────────────────────────────────────────────
|
| 24 |
+
id_envs:
|
| 25 |
+
- MiniHack-Room-Random-5x5-v0
|
| 26 |
+
- MiniHack-Room-Random-15x15-v0
|
| 27 |
+
- MiniHack-Corridor-R2-v0
|
| 28 |
+
- MiniHack-MazeWalk-9x9-v0
|
| 29 |
+
ood_envs:
|
| 30 |
+
- MiniHack-Room-Dark-15x15-v0
|
| 31 |
+
- MiniHack-Corridor-R5-v0
|
| 32 |
+
- MiniHack-MazeWalk-45x19-v0
|
| 33 |
+
crop_size: 9
|
| 34 |
+
map_h: 21
|
| 35 |
+
map_w: 79
|
| 36 |
+
action_dim: 12
|
| 37 |
+
mask_token: 12
|
| 38 |
+
pad_token: 13
|
| 39 |
+
|
| 40 |
+
# ── Model (matches checkpoint) ───────────────────────────────────────
|
| 41 |
+
n_embd: 256
|
| 42 |
+
n_head: 4
|
| 43 |
+
n_layer: 4
|
| 44 |
+
n_global_tokens: 8
|
| 45 |
+
seq_len: 64
|
| 46 |
+
global_gate_init: -3.0
|
| 47 |
+
dropout: 0.0
|
| 48 |
+
ema_decay: 0.999
|
| 49 |
+
|
| 50 |
+
# ── Diffusion (MDLM) — matches checkpoint ────────────────────────────
|
| 51 |
+
noise_schedule: linear
|
| 52 |
+
num_diffusion_steps: 100
|
| 53 |
+
loss_weight_clip: 1000.0
|
| 54 |
+
label_smoothing: 0.0
|
| 55 |
+
use_importance_weighting: false
|
| 56 |
+
eta: 0.15
|
| 57 |
+
remask_strategy: conf
|
| 58 |
+
|
| 59 |
+
# ── Inference / sampling — matches checkpoint ────────────────────────
|
| 60 |
+
diffusion_steps_eval: 10
|
| 61 |
+
diffusion_steps_collect: 5
|
| 62 |
+
temperature: 0.5
|
| 63 |
+
top_k: 4
|
| 64 |
+
replan_every: 16
|
| 65 |
+
physics_aware_sampling: false
|
| 66 |
+
|
| 67 |
+
# ── Shared training budget (DAgger only) ─────────────────────────────
|
| 68 |
+
# 5.65M env-steps reproduces the env-step budget consumed at iter600
|
| 69 |
+
# of the original DAgger run. Calibrated against a real DAgger run
|
| 70 |
+
# with the same recipe (see QMUL config header for the full derivation).
|
| 71 |
+
# The earlier 3M figure was based on the buggy single-episode env-step
|
| 72 |
+
# accounting in `online.py:155-169` — fixed in the same commit as this
|
| 73 |
+
# config bump. Used by `--mode dagger` only. Offline mode bypasses
|
| 74 |
+
# this via `offline_total_grad_steps` below.
|
| 75 |
+
total_timesteps: 5650000
|
| 76 |
+
|
| 77 |
+
# Eval/checkpoint cadence in env-step units (DAgger mode).
|
| 78 |
+
# Scaled with the corrected total_timesteps so the run still produces
|
| 79 |
+
# ~12 ID/OOD evals and ~6 checkpoints over its full duration.
|
| 80 |
+
id_eval_every_timesteps: 470000
|
| 81 |
+
ood_eval_every_timesteps: 470000
|
| 82 |
+
checkpoint_every_timesteps: 940000
|
| 83 |
+
|
| 84 |
+
# Final-eval episode count (used by both ID/OOD eval triggers and
|
| 85 |
+
# checkpoint-time evals; matches the original DAgger run).
|
| 86 |
+
eval_episodes_per_env: 50
|
| 87 |
+
checkpoint_eval_episodes: 50
|
| 88 |
+
weight_decay: 0.0001
|
| 89 |
+
aux_loss_weight: 0.5
|
| 90 |
+
|
| 91 |
+
# ── DAgger (matches checkpoint_final/online/config_iter600.yaml) ─────
|
| 92 |
+
dagger_lr: 0.00003
|
| 93 |
+
dagger_batch_size: 2048
|
| 94 |
+
dagger_grad_clip: 1.0
|
| 95 |
+
buffer_capacity: 10000
|
| 96 |
+
episodes_per_iteration: 30
|
| 97 |
+
grad_steps_per_iteration: 100
|
| 98 |
+
efficiency_multiplier: 1.5
|
| 99 |
+
curriculum_queue_size: 100
|
| 100 |
+
curriculum_preseed: true
|
| 101 |
+
|
| 102 |
+
# ── Offline BC (compute-matched fair baseline) ───────────────────────
|
| 103 |
+
# Per the fairness analysis (see QMUL config header):
|
| 104 |
+
# * Same gradient compute as DAgger (60k AdamW updates × 2048 batch)
|
| 105 |
+
# * Same model, diffusion, weight_decay, grad_clip, aux_loss
|
| 106 |
+
# * BC-tuned LR + cosine schedule (best practice from-scratch)
|
| 107 |
+
# * Eval/checkpoint counts matched to DAgger via grad-step overrides
|
| 108 |
+
#
|
| 109 |
+
# `offline_batch_size: 2048` is matched to DAgger (NOT the 4096 the
|
| 110 |
+
# previous UCL config used) so per-update SNR is identical between
|
| 111 |
+
# modes — this is the cleanest apples-to-apples optimisation
|
| 112 |
+
# comparison. The 24 GB VRAM can hold a larger batch but using one
|
| 113 |
+
# would confound the comparison.
|
| 114 |
+
offline_lr: 0.0003
|
| 115 |
+
offline_batch_size: 2048
|
| 116 |
+
offline_grad_clip: 1.0
|
| 117 |
+
|
| 118 |
+
# Compute pin: 60,000 AdamW updates = exactly DAgger@iter600.
|
| 119 |
+
offline_total_grad_steps: 60000
|
| 120 |
+
|
| 121 |
+
# Eval cadence: 5,000 grad steps → 12 evals (matches DAgger eval count).
|
| 122 |
+
offline_eval_every_grad_steps: 5000
|
| 123 |
+
|
| 124 |
+
# Checkpoint cadence: 10,000 grad steps → 6 checkpoints (matches DAgger).
|
| 125 |
+
offline_checkpoint_every_grad_steps: 10000
|
| 126 |
+
|
| 127 |
+
# Buffer cap for offline mode only — must hold the full pre-collected
|
| 128 |
+
# dataset (~1M sliding windows from 20k oracle trajectories). DAgger's
|
| 129 |
+
# `buffer_capacity: 10000` would silently FIFO-evict 99% of the data.
|
| 130 |
+
offline_buffer_capacity: 1500000
|
| 131 |
+
|
| 132 |
+
# ── Performance (cluster-tuned for 3090 Ti) ──────────────────────────
|
| 133 |
+
use_amp: true
|
| 134 |
+
torch_compile: true
|
| 135 |
+
num_collection_workers: 8
|
| 136 |
+
|
| 137 |
+
# ── Data collection (for offline BC dataset) ─────────────────────────
|
| 138 |
+
# 5000 eps × 4 ID envs = 20k oracle trajectories. Strictly more than
|
| 139 |
+
# the ~7k unique trajectories DAgger had in its filtered buffer at
|
| 140 |
+
# iter600 — offline always gets a richer pre-collected pool, which is
|
| 141 |
+
# the standard fairness asymmetry in BC vs DAgger comparisons.
|
| 142 |
+
collect_episodes_per_env: 5000
|
| 143 |
+
collect_num_workers: 8
|
| 144 |
+
collect_output: data/oracle_bc_ucl.pt
|
| 145 |
+
|
| 146 |
+
# ── Checkpointing & Logging ──────────────────────────────────────────
|
| 147 |
+
checkpoint_dir: checkpoints_ucl
|
| 148 |
+
save_policy: true
|
| 149 |
+
hub_run_id: null
|
| 150 |
+
hub_repo_id: null
|
| 151 |
+
use_wandb: true
|
| 152 |
+
wandb_project: remdm-minihack
|
| 153 |
+
wandb_entity: "mathis-weil-university-college-london-ucl-"
|
| 154 |
+
wandb_run_name: null
|
| 155 |
+
# wandb_resume_id intentionally omitted — fresh runs by default.
|
| 156 |
+
# Override on the CLI (`wandb_resume_id=...`) to continue an existing run.
|
| 157 |
+
offline_log_every: 50
|
| 158 |
+
seed: null
|
configs/smoke.yaml
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Smoke test overrides — fast end-to-end sanity check on CPU.
|
| 2 |
+
# With total_timesteps=5000 and ~2 eps × ~30 avg steps × 2 (model+oracle)
|
| 3 |
+
# = ~120 env steps/iter → ~40 iters → a few seconds per iter on CPU.
|
| 4 |
+
buffer_capacity: 50
|
| 5 |
+
dagger_batch_size: 256
|
| 6 |
+
offline_batch_size: 256
|
| 7 |
+
total_timesteps: 5000
|
| 8 |
+
id_eval_every_timesteps: 2500
|
| 9 |
+
ood_eval_every_timesteps: 2500
|
| 10 |
+
checkpoint_every_timesteps: 2500
|
| 11 |
+
episodes_per_iteration: 2
|
| 12 |
+
grad_steps_per_iteration: 5
|
| 13 |
+
eval_episodes_per_env: 2
|
| 14 |
+
checkpoint_eval_episodes: 2
|
| 15 |
+
num_collection_workers: 0
|
| 16 |
+
use_wandb: false
|
configs/ucl_gpu_bigger_model.yaml
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ── Environments ──────────────────────────────────────────────────────
|
| 2 |
+
id_envs:
|
| 3 |
+
- MiniHack-Room-Random-5x5-v0
|
| 4 |
+
- MiniHack-Room-Random-15x15-v0
|
| 5 |
+
- MiniHack-Corridor-R2-v0
|
| 6 |
+
- MiniHack-MazeWalk-9x9-v0
|
| 7 |
+
ood_envs:
|
| 8 |
+
- MiniHack-Room-Dark-15x15-v0
|
| 9 |
+
- MiniHack-Corridor-R5-v0
|
| 10 |
+
- MiniHack-MazeWalk-45x19-v0
|
| 11 |
+
crop_size: 9
|
| 12 |
+
map_h: 21
|
| 13 |
+
map_w: 79
|
| 14 |
+
action_dim: 12
|
| 15 |
+
mask_token: 12
|
| 16 |
+
pad_token: 13
|
| 17 |
+
|
| 18 |
+
# ── Model ─────────────────────────────────────────────────────────────
|
| 19 |
+
n_embd: 384
|
| 20 |
+
n_head: 6
|
| 21 |
+
n_layer: 4
|
| 22 |
+
n_global_tokens: 8
|
| 23 |
+
seq_len: 64
|
| 24 |
+
global_gate_init: -3.0
|
| 25 |
+
# Transformer dropout. 0.0 is deliberate — discrete diffusion forward masking
|
| 26 |
+
# already regularises; dropout on top is redundant.
|
| 27 |
+
dropout: 0.0
|
| 28 |
+
ema_decay: 0.999
|
| 29 |
+
|
| 30 |
+
# ── Diffusion (MDLM) ─────────────────────────────────────────────────
|
| 31 |
+
noise_schedule: linear
|
| 32 |
+
num_diffusion_steps: 100
|
| 33 |
+
loss_weight_clip: 1000.0
|
| 34 |
+
label_smoothing: 0.0
|
| 35 |
+
# Use SUBS importance weighting w(t) in loss. Off by default (flat average
|
| 36 |
+
# matching reference). Enable for MDLM ELBO experiments.
|
| 37 |
+
use_importance_weighting: false
|
| 38 |
+
# ReMDM stochastic remask base fraction
|
| 39 |
+
eta: 0.15
|
| 40 |
+
# Remasking strategy: rescale | cap | conf
|
| 41 |
+
remask_strategy: conf
|
| 42 |
+
|
| 43 |
+
# ── Inference ─────────────────────────────────────────────────────────
|
| 44 |
+
# Number of reverse denoising steps at inference.
|
| 45 |
+
# Reference uses 5 (aggressive). Higher = better quality, slower.
|
| 46 |
+
diffusion_steps_eval: 10
|
| 47 |
+
diffusion_steps_collect: 5
|
| 48 |
+
temperature: 0.5
|
| 49 |
+
top_k: 4
|
| 50 |
+
replan_every: 16
|
| 51 |
+
# Soft-penalise hazardous cardinal actions during stochastic sampling.
|
| 52 |
+
# Not active in the reference evaluation pipeline; off by default.
|
| 53 |
+
physics_aware_sampling: false
|
| 54 |
+
|
| 55 |
+
# ── Training budget (unified) ────────────────────────────────────────
|
| 56 |
+
total_timesteps: 20000000
|
| 57 |
+
id_eval_every_timesteps: 250000
|
| 58 |
+
ood_eval_every_timesteps: 250000
|
| 59 |
+
checkpoint_every_timesteps: 1250000
|
| 60 |
+
|
| 61 |
+
# ── Offline BC ────────────────────────────────────────────────────────
|
| 62 |
+
offline_lr: 0.0003
|
| 63 |
+
offline_batch_size: 4608
|
| 64 |
+
offline_grad_clip: 1.0
|
| 65 |
+
aux_loss_weight: 0.5
|
| 66 |
+
|
| 67 |
+
# ── DAgger ────────────────────────────────────────────────────────────
|
| 68 |
+
dagger_lr: 0.00003
|
| 69 |
+
dagger_batch_size: 4608
|
| 70 |
+
dagger_grad_clip: 1.0
|
| 71 |
+
weight_decay: 0.0001
|
| 72 |
+
buffer_capacity: 10000
|
| 73 |
+
episodes_per_iteration: 30
|
| 74 |
+
grad_steps_per_iteration: 100
|
| 75 |
+
efficiency_multiplier: 1.5
|
| 76 |
+
curriculum_queue_size: 100
|
| 77 |
+
# Pre-seed curriculum queues with 50/50 prior for uniform early sampling.
|
| 78 |
+
curriculum_preseed: true
|
| 79 |
+
eval_episodes_per_env: 50
|
| 80 |
+
checkpoint_eval_episodes: 50
|
| 81 |
+
|
| 82 |
+
# ── Performance ──────────────────────────────────────────────────────
|
| 83 |
+
# Mixed-precision (FP16) training via torch.cuda.amp.
|
| 84 |
+
# Speeds up forward/backward ~1.5-2x on GPU. No effect on CPU.
|
| 85 |
+
use_amp: true
|
| 86 |
+
# torch.compile the model for fused kernels (experimental).
|
| 87 |
+
# May cause slow first iteration due to compilation. No effect on CPU.
|
| 88 |
+
torch_compile: true
|
| 89 |
+
# Number of parallel workers for DAgger episode collection.
|
| 90 |
+
# 0 = sequential (reference behaviour). Recommended: 4-8 on multi-core.
|
| 91 |
+
num_collection_workers: 8
|
| 92 |
+
|
| 93 |
+
# ── Checkpointing & Logging ──────────────────────────────────────────
|
| 94 |
+
checkpoint_dir: checkpoints_ucl_bigger_model
|
| 95 |
+
save_policy: true
|
| 96 |
+
hub_run_id: null
|
| 97 |
+
hub_repo_id: null
|
| 98 |
+
use_wandb: true
|
| 99 |
+
wandb_project: remdm-minihack
|
| 100 |
+
wandb_entity: "mathis-weil-university-college-london-ucl-"
|
| 101 |
+
wandb_run_name: null
|
| 102 |
+
offline_log_every: 10
|
| 103 |
+
seed: null
|
configs/ucl_gpu_learning_behaviour.yaml
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ── Environments ──────────────────────────────────────────────────────
|
| 2 |
+
id_envs:
|
| 3 |
+
- MiniHack-Room-Random-5x5-v0
|
| 4 |
+
- MiniHack-Room-Random-15x15-v0
|
| 5 |
+
- MiniHack-Corridor-R2-v0
|
| 6 |
+
- MiniHack-MazeWalk-9x9-v0
|
| 7 |
+
ood_envs:
|
| 8 |
+
- MiniHack-Room-Dark-15x15-v0
|
| 9 |
+
- MiniHack-Corridor-R5-v0
|
| 10 |
+
- MiniHack-MazeWalk-45x19-v0
|
| 11 |
+
crop_size: 9
|
| 12 |
+
map_h: 21
|
| 13 |
+
map_w: 79
|
| 14 |
+
action_dim: 12
|
| 15 |
+
mask_token: 12
|
| 16 |
+
pad_token: 13
|
| 17 |
+
|
| 18 |
+
# ── Model ─────────────────────────────────────────────────────────────
|
| 19 |
+
n_embd: 256
|
| 20 |
+
n_head: 4
|
| 21 |
+
n_layer: 4
|
| 22 |
+
n_global_tokens: 8
|
| 23 |
+
seq_len: 64
|
| 24 |
+
global_gate_init: -3.0
|
| 25 |
+
# Transformer dropout. 0.0 is deliberate — discrete diffusion forward masking
|
| 26 |
+
# already regularises; dropout on top is redundant.
|
| 27 |
+
dropout: 0.0
|
| 28 |
+
ema_decay: 0.999
|
| 29 |
+
|
| 30 |
+
# ── Diffusion (MDLM) ─────────────────────────────────────────────────
|
| 31 |
+
noise_schedule: linear
|
| 32 |
+
num_diffusion_steps: 100
|
| 33 |
+
loss_weight_clip: 1000.0
|
| 34 |
+
label_smoothing: 0.0
|
| 35 |
+
# Use SUBS importance weighting w(t) in loss. Off by default (flat average
|
| 36 |
+
# matching reference). Enable for MDLM ELBO experiments.
|
| 37 |
+
use_importance_weighting: false
|
| 38 |
+
# ReMDM stochastic remask base fraction
|
| 39 |
+
eta: 0.18
|
| 40 |
+
# Remasking strategy: rescale | cap | conf
|
| 41 |
+
remask_strategy: conf
|
| 42 |
+
|
| 43 |
+
# ── Inference ─────────────────────────────────────────────────────────
|
| 44 |
+
# Number of reverse denoising steps at inference.
|
| 45 |
+
# Reference uses 5 (aggressive). Higher = better quality, slower.
|
| 46 |
+
diffusion_steps_eval: 10
|
| 47 |
+
diffusion_steps_collect: 5
|
| 48 |
+
temperature: 0.5
|
| 49 |
+
top_k: 4
|
| 50 |
+
replan_every: 16
|
| 51 |
+
# Soft-penalise hazardous cardinal actions during stochastic sampling.
|
| 52 |
+
# Not active in the reference evaluation pipeline; off by default.
|
| 53 |
+
physics_aware_sampling: false
|
| 54 |
+
|
| 55 |
+
# ── Training budget (unified) ────────────────────────────────────────
|
| 56 |
+
total_timesteps: 20000000
|
| 57 |
+
id_eval_every_timesteps: 250000
|
| 58 |
+
ood_eval_every_timesteps: 250000
|
| 59 |
+
checkpoint_every_timesteps: 1250000
|
| 60 |
+
|
| 61 |
+
# ── Offline BC ────────────────────────────────────────────────────────
|
| 62 |
+
offline_lr: 0.0003
|
| 63 |
+
offline_batch_size: 6144
|
| 64 |
+
offline_grad_clip: 1.0
|
| 65 |
+
aux_loss_weight: 0.5
|
| 66 |
+
|
| 67 |
+
# ── DAgger ────────────────────────────────────────────────────────────
|
| 68 |
+
dagger_lr: 0.00003
|
| 69 |
+
dagger_batch_size: 6144
|
| 70 |
+
dagger_grad_clip: 1.0
|
| 71 |
+
weight_decay: 0.0001
|
| 72 |
+
buffer_capacity: 10000
|
| 73 |
+
episodes_per_iteration: 30
|
| 74 |
+
grad_steps_per_iteration: 100
|
| 75 |
+
efficiency_multiplier: 1.5
|
| 76 |
+
curriculum_queue_size: 100
|
| 77 |
+
# Pre-seed curriculum queues with 50/50 prior for uniform early sampling.
|
| 78 |
+
curriculum_preseed: true
|
| 79 |
+
eval_episodes_per_env: 50
|
| 80 |
+
checkpoint_eval_episodes: 50
|
| 81 |
+
|
| 82 |
+
# ── Performance ──────────────────────────────────────────────────────
|
| 83 |
+
# Mixed-precision (FP16) training via torch.cuda.amp.
|
| 84 |
+
# Speeds up forward/backward ~1.5-2x on GPU. No effect on CPU.
|
| 85 |
+
use_amp: true
|
| 86 |
+
# torch.compile the model for fused kernels (experimental).
|
| 87 |
+
# May cause slow first iteration due to compilation. No effect on CPU.
|
| 88 |
+
torch_compile: true
|
| 89 |
+
# Number of parallel workers for DAgger episode collection.
|
| 90 |
+
# 0 = sequential (reference behaviour). Recommended: 4-8 on multi-core.
|
| 91 |
+
num_collection_workers: 8
|
| 92 |
+
|
| 93 |
+
# ── Checkpointing & Logging ──────────────────────────────────────────
|
| 94 |
+
checkpoint_dir: checkpoints_ucl_learning_behaviour
|
| 95 |
+
save_policy: true
|
| 96 |
+
hub_run_id: null
|
| 97 |
+
hub_repo_id: null
|
| 98 |
+
use_wandb: true
|
| 99 |
+
wandb_project: remdm-minihack
|
| 100 |
+
wandb_entity: "mathis-weil-university-college-london-ucl-"
|
| 101 |
+
wandb_run_name: null
|
| 102 |
+
offline_log_every: 10
|
| 103 |
+
seed: null
|
environments/.gitkeep
ADDED
|
File without changes
|
main.py
ADDED
|
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import logging
|
| 5 |
+
import random
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Any
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
from src.config import load_config
|
| 13 |
+
from src.planners.baselines import ALL_BASELINE_ALGOS, run_baselines
|
| 14 |
+
from src.planners.logging import Logger
|
| 15 |
+
from src.planners.offline import run_offline
|
| 16 |
+
from src.planners.online import run_dagger
|
| 17 |
+
from src.planners.inference import run_inference
|
| 18 |
+
from src.planners.collect_oracle import run_collect
|
| 19 |
+
from src.planners.smoke import run_smoke
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# =============================================================================
|
| 23 |
+
# Logging
|
| 24 |
+
# =============================================================================
|
| 25 |
+
|
| 26 |
+
logging.basicConfig(
|
| 27 |
+
level=logging.INFO,
|
| 28 |
+
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
| 29 |
+
)
|
| 30 |
+
logger = logging.getLogger(__name__)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# =============================================================================
|
| 34 |
+
# Utils
|
| 35 |
+
# =============================================================================
|
| 36 |
+
|
| 37 |
+
def _parse_overrides(extras: list[str]) -> dict[str, Any]:
|
| 38 |
+
return {
|
| 39 |
+
k.lstrip("-"): v
|
| 40 |
+
for item in extras if "=" in item
|
| 41 |
+
for k, v in [item.split("=", 1)]
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def _set_seed(seed: int | None) -> int:
|
| 46 |
+
if seed is None:
|
| 47 |
+
seed = random.randint(0, 2**31 - 1)
|
| 48 |
+
|
| 49 |
+
random.seed(seed)
|
| 50 |
+
np.random.seed(seed)
|
| 51 |
+
torch.manual_seed(seed)
|
| 52 |
+
if torch.cuda.is_available():
|
| 53 |
+
torch.cuda.manual_seed_all(seed)
|
| 54 |
+
|
| 55 |
+
return seed
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# =============================================================================
|
| 59 |
+
# CLI
|
| 60 |
+
# =============================================================================
|
| 61 |
+
|
| 62 |
+
def parse_args() -> tuple[argparse.Namespace, list[str]]:
|
| 63 |
+
parser = argparse.ArgumentParser(
|
| 64 |
+
description="ReMDM-MiniHack: Masked Diffusion Planner",
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
parser.add_argument(
|
| 68 |
+
"--mode",
|
| 69 |
+
required=True,
|
| 70 |
+
choices=[
|
| 71 |
+
"smoke", "offline", "dagger", "inference", "collect", "baselines",
|
| 72 |
+
],
|
| 73 |
+
)
|
| 74 |
+
parser.add_argument("--config", default="configs/defaults.yaml")
|
| 75 |
+
parser.add_argument(
|
| 76 |
+
"--algo", default=None, choices=list(ALL_BASELINE_ALGOS),
|
| 77 |
+
help="Baseline algorithm (required for --mode baselines)",
|
| 78 |
+
)
|
| 79 |
+
parser.add_argument(
|
| 80 |
+
"--seeds", type=int, nargs="+", default=None,
|
| 81 |
+
help=(
|
| 82 |
+
"Explicit list of seeds for --mode baselines "
|
| 83 |
+
"(e.g. --seeds 0 1 2)."
|
| 84 |
+
),
|
| 85 |
+
)
|
| 86 |
+
parser.add_argument(
|
| 87 |
+
"--n-seeds", type=int, default=None,
|
| 88 |
+
help=(
|
| 89 |
+
"Number of seeds starting from 0 (alternative to --seeds; "
|
| 90 |
+
"only used by --mode baselines)."
|
| 91 |
+
),
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
parser.add_argument("--data", default=None)
|
| 95 |
+
parser.add_argument("--checkpoint", default=None)
|
| 96 |
+
parser.add_argument(
|
| 97 |
+
"--wandb-artifact", default=None,
|
| 98 |
+
help=(
|
| 99 |
+
"W&B artifact reference to download as checkpoint, e.g. "
|
| 100 |
+
"'entity/project/checkpoint-iter1000:latest'"
|
| 101 |
+
),
|
| 102 |
+
)
|
| 103 |
+
parser.add_argument("--no-warm-start", action="store_true")
|
| 104 |
+
parser.add_argument("--no-ema", action="store_true")
|
| 105 |
+
|
| 106 |
+
parser.add_argument("--envs", nargs="+", default=None)
|
| 107 |
+
parser.add_argument(
|
| 108 |
+
"--des", nargs="+", default=None,
|
| 109 |
+
help="Paths to .des scenario files for custom environment evaluation",
|
| 110 |
+
)
|
| 111 |
+
parser.add_argument("--episodes", type=int, default=50)
|
| 112 |
+
parser.add_argument("--output", default=None)
|
| 113 |
+
parser.add_argument(
|
| 114 |
+
"--blind-global", action="store_true",
|
| 115 |
+
help="Zero out global map observations (local-only ablation)",
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
return parser.parse_known_args()
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
# =============================================================================
|
| 122 |
+
# Config
|
| 123 |
+
# =============================================================================
|
| 124 |
+
|
| 125 |
+
def build_config(args, extras):
|
| 126 |
+
config_path = args.config
|
| 127 |
+
if args.mode == "smoke" and config_path == "configs/defaults.yaml":
|
| 128 |
+
config_path = "configs/smoke.yaml"
|
| 129 |
+
|
| 130 |
+
cfg = load_config(config_path, _parse_overrides(extras))
|
| 131 |
+
|
| 132 |
+
seed = _set_seed(cfg.seed)
|
| 133 |
+
logger.info(f"Seed: {seed}")
|
| 134 |
+
|
| 135 |
+
return cfg
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
# =============================================================================
|
| 139 |
+
# Validation
|
| 140 |
+
# =============================================================================
|
| 141 |
+
|
| 142 |
+
def validate(args) -> None:
|
| 143 |
+
if args.mode == "inference" and not args.checkpoint and not args.wandb_artifact:
|
| 144 |
+
raise ValueError(
|
| 145 |
+
"--checkpoint or --wandb-artifact required for inference mode"
|
| 146 |
+
)
|
| 147 |
+
if args.mode == "baselines" and args.algo is None:
|
| 148 |
+
raise ValueError(
|
| 149 |
+
"--algo is required for --mode baselines "
|
| 150 |
+
f"(choose one of {list(ALL_BASELINE_ALGOS)})"
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def _resolve_seeds(args, cfg) -> list[int]:
|
| 155 |
+
"""Build the seed list for --mode baselines."""
|
| 156 |
+
if args.seeds is not None:
|
| 157 |
+
return list(args.seeds)
|
| 158 |
+
if args.n_seeds is not None:
|
| 159 |
+
return list(range(int(args.n_seeds)))
|
| 160 |
+
return [cfg.seed if cfg.seed is not None else 0]
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
# =============================================================================
|
| 164 |
+
# Dispatch (no lambdas, cleaner)
|
| 165 |
+
# =============================================================================
|
| 166 |
+
|
| 167 |
+
def _resolve_path(p: str | None) -> str | None:
|
| 168 |
+
"""Resolve a user-provided path to absolute, or return None."""
|
| 169 |
+
if p is None:
|
| 170 |
+
return None
|
| 171 |
+
return str(Path(p).resolve())
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def _resolve_checkpoint(args, cfg) -> str | None:
|
| 175 |
+
"""Return a local checkpoint path from --checkpoint or --wandb-artifact."""
|
| 176 |
+
if args.checkpoint:
|
| 177 |
+
return _resolve_path(args.checkpoint)
|
| 178 |
+
artifact_ref = args.wandb_artifact
|
| 179 |
+
if artifact_ref:
|
| 180 |
+
from src.planners.logging import download_artifact
|
| 181 |
+
path = download_artifact(artifact_ref)
|
| 182 |
+
if path is None:
|
| 183 |
+
raise RuntimeError(
|
| 184 |
+
f"Failed to download W&B artifact: {artifact_ref}"
|
| 185 |
+
)
|
| 186 |
+
return path
|
| 187 |
+
return None
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def run_mode(mode: str, cfg, args) -> None:
|
| 191 |
+
data_path = _resolve_path(args.data)
|
| 192 |
+
output_path = _resolve_path(args.output)
|
| 193 |
+
des_files = (
|
| 194 |
+
[str(Path(d).resolve()) for d in args.des]
|
| 195 |
+
if args.des else None
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
if mode == "smoke":
|
| 199 |
+
run_smoke(cfg)
|
| 200 |
+
|
| 201 |
+
elif mode == "offline":
|
| 202 |
+
ckpt = _resolve_checkpoint(args, cfg)
|
| 203 |
+
run_offline(cfg, data_path, checkpoint_path=ckpt)
|
| 204 |
+
|
| 205 |
+
elif mode == "dagger":
|
| 206 |
+
ckpt = _resolve_checkpoint(args, cfg)
|
| 207 |
+
run_dagger(cfg, ckpt, args.no_warm_start)
|
| 208 |
+
|
| 209 |
+
elif mode == "collect":
|
| 210 |
+
run_collect(cfg)
|
| 211 |
+
|
| 212 |
+
elif mode == "baselines":
|
| 213 |
+
run_baselines(
|
| 214 |
+
cfg,
|
| 215 |
+
algo=args.algo,
|
| 216 |
+
seeds=_resolve_seeds(args, cfg),
|
| 217 |
+
output_path=output_path,
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
elif mode == "inference":
|
| 221 |
+
ckpt = _resolve_checkpoint(args, cfg)
|
| 222 |
+
if ckpt is None:
|
| 223 |
+
raise ValueError(
|
| 224 |
+
"--checkpoint or --wandb-artifact required for inference"
|
| 225 |
+
)
|
| 226 |
+
log = Logger(cfg)
|
| 227 |
+
run_inference(
|
| 228 |
+
cfg,
|
| 229 |
+
ckpt,
|
| 230 |
+
args.envs,
|
| 231 |
+
args.episodes,
|
| 232 |
+
output_path,
|
| 233 |
+
not args.no_ema,
|
| 234 |
+
log=log,
|
| 235 |
+
des_files=des_files,
|
| 236 |
+
blind_global=args.blind_global,
|
| 237 |
+
)
|
| 238 |
+
log.finish()
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
# =============================================================================
|
| 242 |
+
# Entry point
|
| 243 |
+
# =============================================================================
|
| 244 |
+
|
| 245 |
+
def main() -> None:
|
| 246 |
+
args, extras = parse_args()
|
| 247 |
+
validate(args)
|
| 248 |
+
cfg = build_config(args, extras)
|
| 249 |
+
if torch.cuda.is_available():
|
| 250 |
+
torch.set_float32_matmul_precision("high")
|
| 251 |
+
run_mode(args.mode, cfg, args)
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
if __name__ == "__main__":
|
| 255 |
+
main()
|
pyproject.toml
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "minihack-remdm-planner"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "Add your description here"
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
requires-python = ">=3.12"
|
| 7 |
+
dependencies = [
|
| 8 |
+
"huggingface-hub>=1.8.0",
|
| 9 |
+
"ipython>=9.12.0",
|
| 10 |
+
"matplotlib>=3.10.8",
|
| 11 |
+
"minihack>=1.0.2",
|
| 12 |
+
"nle>=1.2.0",
|
| 13 |
+
"numpy>=2.4.4",
|
| 14 |
+
"orjson>=3.11.8",
|
| 15 |
+
"polars>=1.39.3",
|
| 16 |
+
"pyyaml>=6.0.3",
|
| 17 |
+
"sb3-contrib>=2.8.0",
|
| 18 |
+
"scipy>=1.17.1",
|
| 19 |
+
"stable-baselines3>=2.8.0",
|
| 20 |
+
"torch>=2.11.0",
|
| 21 |
+
"wandb>=0.25.1",
|
| 22 |
+
]
|
src/__init__.py
ADDED
|
File without changes
|
src/buffer.py
ADDED
|
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Replay buffer with offline-protected FIFO eviction.
|
| 2 |
+
|
| 3 |
+
Ported from minihack_reference/src/buffer.py. Stores observation-action
|
| 4 |
+
windows of fixed length ``seq_len``. Offline data is pinned at the front
|
| 5 |
+
and never evicted; online samples use FIFO.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class ReplayBuffer:
|
| 14 |
+
"""Fixed-capacity buffer with offline-protected FIFO eviction.
|
| 15 |
+
|
| 16 |
+
Offline samples (loaded once via ``load_offline_data``) are pinned
|
| 17 |
+
and never evicted. Online samples added via ``add`` are FIFO-evicted
|
| 18 |
+
when the total count exceeds ``capacity``.
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
capacity: Maximum total number of windows.
|
| 22 |
+
seq_len: Action-sequence window length.
|
| 23 |
+
pad_token: Token used to pad short sequences.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(
|
| 27 |
+
self, capacity: int, seq_len: int, pad_token: int,
|
| 28 |
+
) -> None:
|
| 29 |
+
self._capacity = capacity
|
| 30 |
+
self._seq_len = seq_len
|
| 31 |
+
self._pad_token = pad_token
|
| 32 |
+
|
| 33 |
+
# Each element: (local [9,9], global [21,79], actions [seq_len])
|
| 34 |
+
self._offline: list[tuple[np.ndarray, np.ndarray, np.ndarray]] = []
|
| 35 |
+
self._online: list[tuple[np.ndarray, np.ndarray, np.ndarray]] = []
|
| 36 |
+
|
| 37 |
+
# Stacked array cache for fast sampling
|
| 38 |
+
self._cache_valid = False
|
| 39 |
+
self._cached_local: np.ndarray | None = None
|
| 40 |
+
self._cached_global: np.ndarray | None = None
|
| 41 |
+
self._cached_actions: np.ndarray | None = None
|
| 42 |
+
|
| 43 |
+
# ── Offline data ─────────────────────────────────────────────
|
| 44 |
+
|
| 45 |
+
def load_offline_data(
|
| 46 |
+
self,
|
| 47 |
+
data: dict | list,
|
| 48 |
+
allowed_envs: list[str],
|
| 49 |
+
metadata: dict | None = None,
|
| 50 |
+
) -> None:
|
| 51 |
+
"""Load pre-collected trajectories and slice into windows.
|
| 52 |
+
|
| 53 |
+
Supports two dataset formats:
|
| 54 |
+
|
| 55 |
+
**New format** (dict): ``{"trajectories": [...]}`` where each entry
|
| 56 |
+
is a dict with ``"local"``, ``"global"``, ``"actions"``, ``"env_id"``.
|
| 57 |
+
|
| 58 |
+
**Legacy format** (list): Flat list of ``((local, global), action_seq)``
|
| 59 |
+
tuples produced by the reference pipeline (pre-windowed, already
|
| 60 |
+
``seq_len``-length). Env filtering uses an optional *metadata* dict
|
| 61 |
+
with a ``"samples_per_env"`` key mapping env IDs to sample counts.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
data: Dataset in new dict format or legacy list format.
|
| 65 |
+
allowed_envs: Only samples from these env IDs are kept.
|
| 66 |
+
metadata: Optional sidecar metadata for legacy format env
|
| 67 |
+
filtering. Ignored for the new format.
|
| 68 |
+
"""
|
| 69 |
+
if isinstance(data, list):
|
| 70 |
+
self._load_legacy_offline_data(data, allowed_envs, metadata)
|
| 71 |
+
return
|
| 72 |
+
|
| 73 |
+
trajectories = data.get("trajectories", [data])
|
| 74 |
+
for traj in trajectories:
|
| 75 |
+
if traj.get("env_id", "") not in allowed_envs:
|
| 76 |
+
continue
|
| 77 |
+
windows = self._slice_trajectory(traj)
|
| 78 |
+
self._offline.extend(windows)
|
| 79 |
+
# Truncate to capacity
|
| 80 |
+
if len(self._offline) > self._capacity:
|
| 81 |
+
self._offline = self._offline[: self._capacity]
|
| 82 |
+
self._invalidate_cache()
|
| 83 |
+
|
| 84 |
+
def _load_legacy_offline_data(
|
| 85 |
+
self,
|
| 86 |
+
data: list,
|
| 87 |
+
allowed_envs: list[str],
|
| 88 |
+
metadata: dict | None = None,
|
| 89 |
+
) -> None:
|
| 90 |
+
"""Load reference-format datasets (pre-windowed tuples).
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
data: List of ``((local_crop, global_map), action_seq)`` tuples.
|
| 94 |
+
``local_crop`` is ``[9, 9]``, ``global_map`` is ``[21, 79]``,
|
| 95 |
+
``action_seq`` is a sequence of length ``seq_len``.
|
| 96 |
+
allowed_envs: Env IDs to retain.
|
| 97 |
+
metadata: Optional dict with ``"samples_per_env"`` key mapping
|
| 98 |
+
env IDs to per-env sample counts for precise filtering.
|
| 99 |
+
"""
|
| 100 |
+
allowed = set(allowed_envs)
|
| 101 |
+
|
| 102 |
+
if metadata and "samples_per_env" in metadata:
|
| 103 |
+
# Build a per-sample env_id index from the metadata ordering
|
| 104 |
+
sample_to_env: list[str] = []
|
| 105 |
+
for env_id in sorted(metadata["samples_per_env"].keys()):
|
| 106 |
+
count = metadata["samples_per_env"][env_id]
|
| 107 |
+
sample_to_env.extend([env_id] * count)
|
| 108 |
+
|
| 109 |
+
for i, sample in enumerate(data):
|
| 110 |
+
env_id = (
|
| 111 |
+
sample_to_env[i] if i < len(sample_to_env) else None
|
| 112 |
+
)
|
| 113 |
+
if env_id is None or env_id in allowed:
|
| 114 |
+
self._offline.append(self._unpack_legacy_sample(sample))
|
| 115 |
+
else:
|
| 116 |
+
# No metadata — keep all samples (caller is responsible for
|
| 117 |
+
# pre-filtering)
|
| 118 |
+
for sample in data:
|
| 119 |
+
self._offline.append(self._unpack_legacy_sample(sample))
|
| 120 |
+
|
| 121 |
+
if len(self._offline) > self._capacity:
|
| 122 |
+
self._offline = self._offline[: self._capacity]
|
| 123 |
+
self._invalidate_cache()
|
| 124 |
+
|
| 125 |
+
@staticmethod
|
| 126 |
+
def _unpack_legacy_sample(
|
| 127 |
+
sample: tuple,
|
| 128 |
+
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 129 |
+
"""Convert a legacy ``((local, global), action_seq)`` sample.
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
sample: Tuple of ``(state, action_seq)`` where state is
|
| 133 |
+
``(local_crop, global_map)``.
|
| 134 |
+
|
| 135 |
+
Returns:
|
| 136 |
+
``(local [9,9], global [21,79], actions [seq_len])`` as
|
| 137 |
+
numpy int16/int64 arrays.
|
| 138 |
+
"""
|
| 139 |
+
(local, glb), action_seq = sample
|
| 140 |
+
return (
|
| 141 |
+
np.asarray(local, dtype=np.int16),
|
| 142 |
+
np.asarray(glb, dtype=np.int16),
|
| 143 |
+
np.asarray(action_seq, dtype=np.int64),
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
# ── Online data ──────────���───────────────────────────────────
|
| 147 |
+
|
| 148 |
+
def _invalidate_cache(self) -> None:
|
| 149 |
+
"""Mark the stacked array cache as stale."""
|
| 150 |
+
self._cache_valid = False
|
| 151 |
+
|
| 152 |
+
def _ensure_cache(self) -> None:
|
| 153 |
+
"""Rebuild stacked arrays from offline + online windows."""
|
| 154 |
+
if self._cache_valid:
|
| 155 |
+
return
|
| 156 |
+
combined = self._offline + self._online
|
| 157 |
+
if not combined:
|
| 158 |
+
return
|
| 159 |
+
n = len(combined)
|
| 160 |
+
l0, g0, a0 = combined[0]
|
| 161 |
+
self._cached_local = np.empty(
|
| 162 |
+
(n, *l0.shape), dtype=l0.dtype,
|
| 163 |
+
)
|
| 164 |
+
self._cached_global = np.empty(
|
| 165 |
+
(n, *g0.shape), dtype=g0.dtype,
|
| 166 |
+
)
|
| 167 |
+
self._cached_actions = np.empty(
|
| 168 |
+
(n, *a0.shape), dtype=a0.dtype,
|
| 169 |
+
)
|
| 170 |
+
for i, (l, g, a) in enumerate(combined):
|
| 171 |
+
self._cached_local[i] = l
|
| 172 |
+
self._cached_global[i] = g
|
| 173 |
+
self._cached_actions[i] = a
|
| 174 |
+
self._cache_valid = True
|
| 175 |
+
|
| 176 |
+
def add(self, trajectory: dict) -> None:
|
| 177 |
+
"""Add a trajectory, sliced into overlapping windows.
|
| 178 |
+
|
| 179 |
+
FIFO-evicts oldest online samples when over capacity.
|
| 180 |
+
|
| 181 |
+
Args:
|
| 182 |
+
trajectory: Dict with ``"local"`` ``[T,9,9]``,
|
| 183 |
+
``"global"`` ``[T,21,79]``, ``"actions"`` ``[T]``.
|
| 184 |
+
"""
|
| 185 |
+
windows = self._slice_trajectory(trajectory)
|
| 186 |
+
self._online.extend(windows)
|
| 187 |
+
max_online = self._capacity - len(self._offline)
|
| 188 |
+
if len(self._online) > max_online:
|
| 189 |
+
excess = len(self._online) - max_online
|
| 190 |
+
self._online = self._online[excess:]
|
| 191 |
+
self._invalidate_cache()
|
| 192 |
+
|
| 193 |
+
# ── Sampling ─────────────────────────────────────────────────
|
| 194 |
+
|
| 195 |
+
def sample(
|
| 196 |
+
self, batch_size: int,
|
| 197 |
+
) -> tuple[np.ndarray, np.ndarray, np.ndarray] | None:
|
| 198 |
+
"""Random sample from offline + online combined.
|
| 199 |
+
|
| 200 |
+
Args:
|
| 201 |
+
batch_size: Number of windows to sample.
|
| 202 |
+
|
| 203 |
+
Returns:
|
| 204 |
+
``(local [B,9,9], global [B,21,79], actions [B,seq_len])``
|
| 205 |
+
as numpy arrays, or ``None`` if the buffer is empty.
|
| 206 |
+
"""
|
| 207 |
+
if len(self) == 0:
|
| 208 |
+
return None
|
| 209 |
+
self._ensure_cache()
|
| 210 |
+
if self._cached_local is None:
|
| 211 |
+
return None
|
| 212 |
+
indices = np.random.randint(0, len(self), size=batch_size)
|
| 213 |
+
return (
|
| 214 |
+
self._cached_local[indices],
|
| 215 |
+
self._cached_global[indices],
|
| 216 |
+
self._cached_actions[indices],
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
# ── Properties ─────────��─────────────────────────────────────
|
| 220 |
+
|
| 221 |
+
def __len__(self) -> int:
|
| 222 |
+
"""Total number of windows (offline + online)."""
|
| 223 |
+
return len(self._offline) + len(self._online)
|
| 224 |
+
|
| 225 |
+
@property
|
| 226 |
+
def n_offline(self) -> int:
|
| 227 |
+
"""Number of pinned offline windows."""
|
| 228 |
+
return len(self._offline)
|
| 229 |
+
|
| 230 |
+
@property
|
| 231 |
+
def offline_size(self) -> int:
|
| 232 |
+
"""Number of pinned offline windows (alias)."""
|
| 233 |
+
return len(self._offline)
|
| 234 |
+
|
| 235 |
+
# ── Internals ───────────────────────────────────────────���────
|
| 236 |
+
|
| 237 |
+
def _slice_trajectory(
|
| 238 |
+
self, traj: dict,
|
| 239 |
+
) -> list[tuple[np.ndarray, np.ndarray, np.ndarray]]:
|
| 240 |
+
"""Slice a trajectory into overlapping seq_len windows.
|
| 241 |
+
|
| 242 |
+
Args:
|
| 243 |
+
traj: Trajectory dict with ``"local"``, ``"global"``,
|
| 244 |
+
``"actions"`` arrays.
|
| 245 |
+
|
| 246 |
+
Returns:
|
| 247 |
+
List of ``(local, global, actions)`` tuples.
|
| 248 |
+
"""
|
| 249 |
+
local_arr = np.asarray(traj["local"])
|
| 250 |
+
global_arr = np.asarray(traj["global"])
|
| 251 |
+
actions_arr = np.asarray(traj["actions"])
|
| 252 |
+
T = len(actions_arr)
|
| 253 |
+
windows: list[tuple[np.ndarray, np.ndarray, np.ndarray]] = []
|
| 254 |
+
|
| 255 |
+
for start in range(T):
|
| 256 |
+
end = start + self._seq_len
|
| 257 |
+
if end <= T:
|
| 258 |
+
a = actions_arr[start:end]
|
| 259 |
+
else:
|
| 260 |
+
a = np.full(self._seq_len, self._pad_token, dtype=np.int64)
|
| 261 |
+
a[: T - start] = actions_arr[start:]
|
| 262 |
+
|
| 263 |
+
# Use the observation at the window start
|
| 264 |
+
l = local_arr[min(start, len(local_arr) - 1)]
|
| 265 |
+
g = global_arr[min(start, len(global_arr) - 1)]
|
| 266 |
+
windows.append((l.copy(), g.copy(), a))
|
| 267 |
+
|
| 268 |
+
return windows
|
src/config.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Configuration loader for ReMDM-MiniHack.
|
| 2 |
+
|
| 3 |
+
Loads YAML configs with deep-merge and CLI override support,
|
| 4 |
+
following the Craftax config pattern.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import logging
|
| 10 |
+
import os
|
| 11 |
+
import secrets
|
| 12 |
+
from datetime import datetime, timezone
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from types import SimpleNamespace
|
| 15 |
+
|
| 16 |
+
import yaml
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
_PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _deep_merge(base: dict, override: dict) -> dict:
|
| 25 |
+
"""Recursively merge *override* into *base* (mutates *base*).
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
base: Base dictionary to merge into.
|
| 29 |
+
override: Dictionary whose values take precedence.
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
The merged dictionary (same object as *base*).
|
| 33 |
+
"""
|
| 34 |
+
for key, value in override.items():
|
| 35 |
+
if (
|
| 36 |
+
key in base
|
| 37 |
+
and isinstance(base[key], dict)
|
| 38 |
+
and isinstance(value, dict)
|
| 39 |
+
):
|
| 40 |
+
_deep_merge(base[key], value)
|
| 41 |
+
else:
|
| 42 |
+
base[key] = value
|
| 43 |
+
return base
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def _cast_value(value: str) -> int | float | bool | str | None:
|
| 47 |
+
"""Best-effort cast of a CLI string to a Python scalar.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
value: Raw string from the command line.
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
Parsed Python value (int, float, bool, str, or None).
|
| 54 |
+
"""
|
| 55 |
+
if value.lower() in ("true", "yes"):
|
| 56 |
+
return True
|
| 57 |
+
if value.lower() in ("false", "no"):
|
| 58 |
+
return False
|
| 59 |
+
if value.lower() == "null":
|
| 60 |
+
return None
|
| 61 |
+
try:
|
| 62 |
+
return int(value)
|
| 63 |
+
except ValueError:
|
| 64 |
+
pass
|
| 65 |
+
try:
|
| 66 |
+
return float(value)
|
| 67 |
+
except ValueError:
|
| 68 |
+
pass
|
| 69 |
+
return value
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def load_config(
|
| 73 |
+
config_path: str | None = None,
|
| 74 |
+
cli_overrides: dict | None = None,
|
| 75 |
+
) -> SimpleNamespace:
|
| 76 |
+
"""Load configuration from YAML with optional overrides.
|
| 77 |
+
|
| 78 |
+
1. Load ``configs/defaults.yaml``.
|
| 79 |
+
2. Deep-merge *config_path* on top (if provided and different from defaults).
|
| 80 |
+
3. Apply *cli_overrides* key=value pairs.
|
| 81 |
+
4. Auto-select device (``cuda`` if available, else ``cpu``; honour
|
| 82 |
+
``DEVICE`` env-var).
|
| 83 |
+
5. Validate invariants.
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
config_path: Path to a YAML file merged on top of defaults.
|
| 87 |
+
``None`` uses defaults only.
|
| 88 |
+
cli_overrides: ``{key: value}`` pairs applied last.
|
| 89 |
+
|
| 90 |
+
Returns:
|
| 91 |
+
A ``SimpleNamespace`` containing all hyperparameters.
|
| 92 |
+
|
| 93 |
+
Raises:
|
| 94 |
+
AssertionError: If ``mask_token != action_dim`` or
|
| 95 |
+
``pad_token != action_dim + 1``.
|
| 96 |
+
"""
|
| 97 |
+
if cli_overrides is None:
|
| 98 |
+
cli_overrides = {}
|
| 99 |
+
|
| 100 |
+
defaults_path = _PROJECT_ROOT / "configs" / "defaults.yaml"
|
| 101 |
+
with open(defaults_path, "r") as fh:
|
| 102 |
+
cfg = yaml.safe_load(fh)
|
| 103 |
+
|
| 104 |
+
if config_path is not None:
|
| 105 |
+
config_path_resolved = Path(config_path)
|
| 106 |
+
if not config_path_resolved.is_absolute():
|
| 107 |
+
config_path_resolved = _PROJECT_ROOT / config_path_resolved
|
| 108 |
+
if config_path_resolved.resolve() != defaults_path.resolve():
|
| 109 |
+
with open(config_path_resolved, "r") as fh:
|
| 110 |
+
overrides = yaml.safe_load(fh) or {}
|
| 111 |
+
_deep_merge(cfg, overrides)
|
| 112 |
+
|
| 113 |
+
for key, value in cli_overrides.items():
|
| 114 |
+
if isinstance(value, str):
|
| 115 |
+
value = _cast_value(value)
|
| 116 |
+
cfg[key] = value
|
| 117 |
+
|
| 118 |
+
# Device selection
|
| 119 |
+
env_device = os.environ.get("DEVICE")
|
| 120 |
+
if env_device:
|
| 121 |
+
cfg["device"] = env_device
|
| 122 |
+
elif "device" not in cfg:
|
| 123 |
+
try:
|
| 124 |
+
import torch
|
| 125 |
+
cfg["device"] = "cuda" if torch.cuda.is_available() else "cpu"
|
| 126 |
+
except ImportError:
|
| 127 |
+
cfg["device"] = "cpu"
|
| 128 |
+
|
| 129 |
+
ns = SimpleNamespace(**cfg)
|
| 130 |
+
|
| 131 |
+
# Validation
|
| 132 |
+
assert ns.mask_token == ns.action_dim, (
|
| 133 |
+
f"mask_token ({ns.mask_token}) must equal action_dim ({ns.action_dim})"
|
| 134 |
+
)
|
| 135 |
+
assert ns.pad_token == ns.action_dim + 1, (
|
| 136 |
+
f"pad_token ({ns.pad_token}) must equal action_dim + 1 "
|
| 137 |
+
f"({ns.action_dim + 1})"
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
return ns
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def make_run_dir(cfg: SimpleNamespace, tag: str = "run") -> Path:
|
| 144 |
+
"""Create a unique run subdirectory under ``cfg.checkpoint_dir``.
|
| 145 |
+
|
| 146 |
+
Generates a directory named ``{tag}_{YYYYMMDD}_{HHMMSS}_{hex4}``
|
| 147 |
+
to prevent concurrent runs from overwriting each other's
|
| 148 |
+
checkpoints. Updates ``cfg.checkpoint_dir`` in place.
|
| 149 |
+
|
| 150 |
+
Args:
|
| 151 |
+
cfg: Config namespace (``checkpoint_dir`` is mutated).
|
| 152 |
+
tag: Prefix for the directory name (e.g. ``"dagger"``,
|
| 153 |
+
``"offline"``).
|
| 154 |
+
|
| 155 |
+
Returns:
|
| 156 |
+
The created directory path.
|
| 157 |
+
"""
|
| 158 |
+
ts = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
|
| 159 |
+
suffix = secrets.token_hex(2)
|
| 160 |
+
run_dir = Path(cfg.checkpoint_dir).resolve() / f"{tag}_{ts}_{suffix}"
|
| 161 |
+
run_dir.mkdir(parents=True, exist_ok=True)
|
| 162 |
+
cfg.checkpoint_dir = str(run_dir)
|
| 163 |
+
logger.info("Checkpoint directory: %s", run_dir)
|
| 164 |
+
return run_dir
|
src/curriculum.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Dynamic environment curriculum and efficiency filter.
|
| 2 |
+
|
| 3 |
+
Ported from minihack_reference/src/curriculum.py. Tracks per-environment
|
| 4 |
+
win rates in a rolling window and uses bucket-based sampling weights to
|
| 5 |
+
focus training on environments where the model is struggling.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import random
|
| 11 |
+
from collections import deque
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class DynamicCurriculum:
|
| 15 |
+
"""Rolling-window curriculum with bucket-based sampling weights.
|
| 16 |
+
|
| 17 |
+
Each environment maintains a deque of recent win/loss outcomes.
|
| 18 |
+
Sampling probability is inversely proportional to performance:
|
| 19 |
+
environments with low win rates are sampled more often.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
env_ids: List of environment IDs to track.
|
| 23 |
+
queue_size: Rolling window size per environment.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
# Bucket thresholds and weights
|
| 27 |
+
_LOW_THRESHOLD = 0.15
|
| 28 |
+
_HIGH_THRESHOLD = 0.85
|
| 29 |
+
_WEIGHT_LOW = 0.2
|
| 30 |
+
_WEIGHT_MID = 1.0
|
| 31 |
+
_WEIGHT_HIGH = 0.1
|
| 32 |
+
|
| 33 |
+
def __init__(
|
| 34 |
+
self,
|
| 35 |
+
env_ids: list[str],
|
| 36 |
+
queue_size: int = 100,
|
| 37 |
+
preseed: bool = True,
|
| 38 |
+
) -> None:
|
| 39 |
+
self._env_ids = list(env_ids)
|
| 40 |
+
self._queue_size = queue_size
|
| 41 |
+
self._queues: dict[str, deque[bool]] = {}
|
| 42 |
+
for eid in self._env_ids:
|
| 43 |
+
q: deque[bool] = deque(maxlen=queue_size)
|
| 44 |
+
if preseed:
|
| 45 |
+
# 50/50 prior for uniform early sampling
|
| 46 |
+
for _ in range(50):
|
| 47 |
+
q.append(True)
|
| 48 |
+
for _ in range(50):
|
| 49 |
+
q.append(False)
|
| 50 |
+
self._queues[eid] = q
|
| 51 |
+
|
| 52 |
+
def update(self, env_id: str, won: bool) -> None:
|
| 53 |
+
"""Record an episode outcome.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
env_id: Environment ID.
|
| 57 |
+
won: Whether the episode was won.
|
| 58 |
+
"""
|
| 59 |
+
if env_id not in self._queues:
|
| 60 |
+
self._queues[env_id] = deque(maxlen=self._queue_size)
|
| 61 |
+
self._queues[env_id].append(won)
|
| 62 |
+
|
| 63 |
+
def win_rate(self, env_id: str) -> float:
|
| 64 |
+
"""Rolling win rate for an environment.
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
env_id: Environment ID.
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
Win rate in ``[0, 1]``. Default 0.5 if empty.
|
| 71 |
+
"""
|
| 72 |
+
q = self._queues.get(env_id)
|
| 73 |
+
if q is None or len(q) == 0:
|
| 74 |
+
return 0.5
|
| 75 |
+
return sum(q) / len(q)
|
| 76 |
+
|
| 77 |
+
def sample_env(self) -> str:
|
| 78 |
+
"""Sample an environment ID using bucket-weighted probabilities.
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
Sampled environment ID.
|
| 82 |
+
"""
|
| 83 |
+
weights: list[float] = []
|
| 84 |
+
for eid in self._env_ids:
|
| 85 |
+
w = self.win_rate(eid)
|
| 86 |
+
if w < self._LOW_THRESHOLD:
|
| 87 |
+
weights.append(self._WEIGHT_LOW)
|
| 88 |
+
elif w > self._HIGH_THRESHOLD:
|
| 89 |
+
weights.append(self._WEIGHT_HIGH)
|
| 90 |
+
else:
|
| 91 |
+
weights.append(self._WEIGHT_MID)
|
| 92 |
+
return random.choices(self._env_ids, weights=weights, k=1)[0]
|
| 93 |
+
|
| 94 |
+
def state_dict(self) -> dict:
|
| 95 |
+
"""Serialise curriculum state.
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
Dict with ``env_ids``, ``queue_size``, and per-env queues.
|
| 99 |
+
"""
|
| 100 |
+
return {
|
| 101 |
+
"env_ids": self._env_ids,
|
| 102 |
+
"queue_size": self._queue_size,
|
| 103 |
+
"queues": {
|
| 104 |
+
eid: list(q) for eid, q in self._queues.items()
|
| 105 |
+
},
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
def load_state_dict(self, sd: dict) -> None:
|
| 109 |
+
"""Restore curriculum state.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
sd: State dict from ``state_dict()``.
|
| 113 |
+
"""
|
| 114 |
+
self._queue_size = sd.get("queue_size", self._queue_size)
|
| 115 |
+
for eid, items in sd.get("queues", {}).items():
|
| 116 |
+
q: deque[bool] = deque(maxlen=self._queue_size)
|
| 117 |
+
q.extend(items)
|
| 118 |
+
self._queues[eid] = q
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def efficiency_filter(
|
| 122 |
+
model_won: bool,
|
| 123 |
+
model_steps: int,
|
| 124 |
+
oracle_steps: int,
|
| 125 |
+
multiplier: float = 1.5,
|
| 126 |
+
) -> bool:
|
| 127 |
+
"""Decide whether to add oracle trajectory to the buffer.
|
| 128 |
+
|
| 129 |
+
Returns ``True`` (add oracle data) when the model either failed
|
| 130 |
+
or was substantially less efficient than the oracle.
|
| 131 |
+
|
| 132 |
+
Args:
|
| 133 |
+
model_won: Whether the model solved the episode.
|
| 134 |
+
model_steps: Steps the model took.
|
| 135 |
+
oracle_steps: Steps the oracle took.
|
| 136 |
+
multiplier: Efficiency threshold multiplier.
|
| 137 |
+
|
| 138 |
+
Returns:
|
| 139 |
+
``True`` if oracle data should be added to the buffer.
|
| 140 |
+
"""
|
| 141 |
+
if not model_won:
|
| 142 |
+
return True
|
| 143 |
+
return model_steps > multiplier * oracle_steps
|
src/diffusion/__init__.py
ADDED
|
File without changes
|
src/diffusion/forward.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Forward masking process q(z_t | x_0).
|
| 2 |
+
|
| 3 |
+
Ported from the Craftax JAX implementation (src/diffusion/forward.py).
|
| 4 |
+
Each token is independently replaced with mask_token with probability
|
| 5 |
+
sigma_t = 1 - alpha_t. PAD positions are never masked.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
from typing import Callable
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from torch import Tensor
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def q_sample(
|
| 17 |
+
x0: Tensor,
|
| 18 |
+
t: Tensor,
|
| 19 |
+
mask_token: int,
|
| 20 |
+
pad_token: int,
|
| 21 |
+
schedule_fn: Callable[[Tensor], Tensor],
|
| 22 |
+
) -> Tensor:
|
| 23 |
+
"""Sample z_t from the forward masking process.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
x0: Clean action sequences. Shape ``[B, L]``, dtype int64.
|
| 27 |
+
t: Per-sample diffusion time in [0, 1]. Shape ``[B]``.
|
| 28 |
+
mask_token: Integer ID of the MASK token.
|
| 29 |
+
pad_token: Integer ID of the PAD token.
|
| 30 |
+
schedule_fn: Noise schedule returning alpha(t).
|
| 31 |
+
|
| 32 |
+
Returns:
|
| 33 |
+
Noisy sequence z_t. Shape ``[B, L]``, dtype int64.
|
| 34 |
+
PAD positions are preserved unchanged.
|
| 35 |
+
"""
|
| 36 |
+
alpha_t = schedule_fn(t) # [B]
|
| 37 |
+
sigma_t = 1.0 - alpha_t # mask probability per sample
|
| 38 |
+
sigma_t = sigma_t.unsqueeze(-1) # [B, 1]
|
| 39 |
+
|
| 40 |
+
# Independent Bernoulli masking per position
|
| 41 |
+
mask_draws = torch.rand_like(x0, dtype=torch.float32) # [B, L]
|
| 42 |
+
do_mask = mask_draws < sigma_t # [B, L]
|
| 43 |
+
|
| 44 |
+
zt = torch.where(do_mask, mask_token, x0)
|
| 45 |
+
|
| 46 |
+
# Restore PAD positions — never mask padding
|
| 47 |
+
pad_mask = x0 == pad_token # [B, L]
|
| 48 |
+
zt = torch.where(pad_mask, pad_token, zt)
|
| 49 |
+
|
| 50 |
+
return zt
|
src/diffusion/loss.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""MDLM ELBO loss with SUBS parameterisation.
|
| 2 |
+
|
| 3 |
+
Ported from the Craftax JAX implementation (src/diffusion/loss.py).
|
| 4 |
+
Computes continuous-time loss on masked positions only, with analytic
|
| 5 |
+
SUBS weighting clipped for numerical stability.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
from typing import Callable
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
from torch import Tensor
|
| 15 |
+
|
| 16 |
+
from src.diffusion.schedules import alpha_prime
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
_MAX_WEIGHT: float = 1000.0
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def mdlm_loss(
|
| 23 |
+
logits: Tensor,
|
| 24 |
+
x0: Tensor,
|
| 25 |
+
zt: Tensor,
|
| 26 |
+
t: Tensor,
|
| 27 |
+
mask_token: int,
|
| 28 |
+
pad_token: int,
|
| 29 |
+
schedule_fn: Callable[[Tensor], Tensor],
|
| 30 |
+
weight_clip: float = _MAX_WEIGHT,
|
| 31 |
+
label_smoothing: float = 0.0,
|
| 32 |
+
use_importance_weighting: bool = False,
|
| 33 |
+
) -> Tensor:
|
| 34 |
+
"""Compute masked diffusion loss.
|
| 35 |
+
|
| 36 |
+
By default uses a simple masked cross-entropy average (matching the
|
| 37 |
+
reference implementation). When ``use_importance_weighting=True``,
|
| 38 |
+
applies SUBS weighting ``w(t) = -alpha'(t) / (1 - alpha_t)``.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
logits: Model output. Shape ``[B, L, vocab]``.
|
| 42 |
+
x0: Clean action sequences. Shape ``[B, L]``, int64.
|
| 43 |
+
zt: Noisy sequences. Shape ``[B, L]``, int64.
|
| 44 |
+
t: Per-sample diffusion time in [0, 1]. Shape ``[B]``.
|
| 45 |
+
mask_token: MASK token ID.
|
| 46 |
+
pad_token: PAD token ID.
|
| 47 |
+
schedule_fn: Noise schedule returning alpha(t).
|
| 48 |
+
weight_clip: Upper clamp for SUBS weight (default 1000).
|
| 49 |
+
label_smoothing: Smoothing epsilon for cross-entropy.
|
| 50 |
+
use_importance_weighting: If ``True``, apply SUBS w(t) per sample.
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
Scalar loss. Returns ``0.0`` when no masked positions exist.
|
| 54 |
+
"""
|
| 55 |
+
B, L, V = logits.shape
|
| 56 |
+
|
| 57 |
+
# Mask: compute loss only on masked, non-PAD positions
|
| 58 |
+
is_masked = (zt == mask_token) & (x0 != pad_token) # [B, L]
|
| 59 |
+
|
| 60 |
+
if not is_masked.any():
|
| 61 |
+
return logits.new_tensor(0.0)
|
| 62 |
+
|
| 63 |
+
# Per-position cross-entropy
|
| 64 |
+
# Clamp targets to valid vocab range — out-of-range positions (PAD,
|
| 65 |
+
# MASK) will be zeroed out by is_masked anyway.
|
| 66 |
+
safe_targets = x0.clamp(0, V - 1) # [B, L]
|
| 67 |
+
ce = F.cross_entropy(
|
| 68 |
+
logits.reshape(-1, V),
|
| 69 |
+
safe_targets.reshape(-1),
|
| 70 |
+
reduction="none",
|
| 71 |
+
label_smoothing=label_smoothing,
|
| 72 |
+
) # [B*L]
|
| 73 |
+
ce = ce.reshape(B, L) # [B, L]
|
| 74 |
+
|
| 75 |
+
# Zero out non-masked positions
|
| 76 |
+
ce = ce * is_masked.float() # [B, L]
|
| 77 |
+
|
| 78 |
+
# Global average over all masked positions (matches reference)
|
| 79 |
+
n_masked_total = is_masked.float().sum().clamp(min=1.0)
|
| 80 |
+
loss = ce.sum() / n_masked_total
|
| 81 |
+
|
| 82 |
+
if use_importance_weighting:
|
| 83 |
+
# SUBS weight: w_t = -alpha'(t) / (1 - alpha_t + eps)
|
| 84 |
+
alpha_t = schedule_fn(t) # [B]
|
| 85 |
+
d_alpha = alpha_prime(t, schedule_fn) # [B]
|
| 86 |
+
w_t = (-d_alpha) / (1.0 - alpha_t + 1e-8) # [B]
|
| 87 |
+
w_t = w_t.clamp(0.0, weight_clip) # [B]
|
| 88 |
+
|
| 89 |
+
# Per-sample weighted loss (needed for SUBS)
|
| 90 |
+
n_masked_per = is_masked.float().sum(dim=1).clamp(min=1.0) # [B]
|
| 91 |
+
per_sample = ce.sum(dim=1) / n_masked_per # [B]
|
| 92 |
+
loss = (per_sample * w_t).mean()
|
| 93 |
+
|
| 94 |
+
return loss
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def auxiliary_goal_loss(
|
| 98 |
+
goal_pred: Tensor,
|
| 99 |
+
global_obs: Tensor,
|
| 100 |
+
pad_value: float = -1.0,
|
| 101 |
+
) -> Tensor:
|
| 102 |
+
"""MSE loss for auxiliary staircase-coordinate prediction.
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
goal_pred: Predicted normalised staircase coords. Shape ``[B, 2]``.
|
| 106 |
+
global_obs: Full map glyphs. Shape ``[B, 21, 79]``, int.
|
| 107 |
+
pad_value: Coordinate value used when staircase is not visible.
|
| 108 |
+
|
| 109 |
+
Returns:
|
| 110 |
+
Scalar MSE loss over samples where the staircase is visible.
|
| 111 |
+
Returns ``0.0`` when no staircase is visible in the batch.
|
| 112 |
+
"""
|
| 113 |
+
targets = find_staircase_from_glyphs(global_obs) # [B, 2]
|
| 114 |
+
targets = targets.to(goal_pred.device, dtype=goal_pred.dtype)
|
| 115 |
+
|
| 116 |
+
# Only supervise where staircase is visible
|
| 117 |
+
valid = (targets[:, 0] != pad_value) # [B]
|
| 118 |
+
if not valid.any():
|
| 119 |
+
return goal_pred.new_tensor(0.0)
|
| 120 |
+
|
| 121 |
+
diff = (goal_pred[valid] - targets[valid]) ** 2 # [N, 2]
|
| 122 |
+
return diff.mean()
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def find_staircase_from_glyphs(global_obs: Tensor) -> Tensor:
|
| 126 |
+
"""Locate the staircase '>' in the global glyph map.
|
| 127 |
+
|
| 128 |
+
Searches for NLE staircase-down glyph (character code 62 = '>').
|
| 129 |
+
Returns normalised (row/H, col/W) coordinates per batch element,
|
| 130 |
+
or (-1, -1) when the staircase is not visible.
|
| 131 |
+
|
| 132 |
+
Args:
|
| 133 |
+
global_obs: Glyph map. Shape ``[B, H, W]`` or ``[H, W]``, int.
|
| 134 |
+
|
| 135 |
+
Returns:
|
| 136 |
+
Normalised coordinates. Shape ``[B, 2]`` (float32).
|
| 137 |
+
"""
|
| 138 |
+
if global_obs.ndim == 2:
|
| 139 |
+
global_obs = global_obs.unsqueeze(0)
|
| 140 |
+
|
| 141 |
+
B, H, W = global_obs.shape
|
| 142 |
+
# NLE staircase-down glyphs: ord('>') = 62, plus NLE tile variants
|
| 143 |
+
# 2310 (S_dnstair), 2368 (S_dnstairs), 2383 (S_vodoor).
|
| 144 |
+
is_stair = (
|
| 145 |
+
(global_obs == 62)
|
| 146 |
+
| (global_obs == 2310)
|
| 147 |
+
| (global_obs == 2368)
|
| 148 |
+
| (global_obs == 2383)
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
coords = torch.full(
|
| 152 |
+
(B, 2), -1.0, dtype=torch.float32, device=global_obs.device
|
| 153 |
+
)
|
| 154 |
+
for b in range(B):
|
| 155 |
+
positions = is_stair[b].nonzero(as_tuple=False) # [N, 2]
|
| 156 |
+
if positions.shape[0] > 0:
|
| 157 |
+
row = positions[0, 0].float() / max(1, H - 1)
|
| 158 |
+
col = positions[0, 1].float() / max(1, W - 1)
|
| 159 |
+
coords[b, 0] = row
|
| 160 |
+
coords[b, 1] = col
|
| 161 |
+
|
| 162 |
+
return coords
|
src/diffusion/sampling.py
ADDED
|
@@ -0,0 +1,398 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ReMDM reverse denoising with remasking strategies.
|
| 2 |
+
|
| 3 |
+
Ported from the Craftax JAX implementation (src/diffusion/sampling.py).
|
| 4 |
+
Implements MaskGIT-style progressive unmasking with optional stochastic
|
| 5 |
+
remasking (ReMDM) using three strategy variants.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
from types import SimpleNamespace
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
from torch import Tensor
|
| 16 |
+
from torch.distributions import Categorical
|
| 17 |
+
|
| 18 |
+
from src.diffusion.schedules import get_schedule
|
| 19 |
+
|
| 20 |
+
# NLE hazard glyph IDs and char codes (walls, locked doors, lava, water)
|
| 21 |
+
_HAZARD_GLYPHS: frozenset[int] = frozenset({2359, 2360, 2389, 2390})
|
| 22 |
+
_HAZARD_CHARS: frozenset[int] = frozenset(
|
| 23 |
+
{ord("|"), ord("-"), ord("+"), ord("L"), ord("W")}
|
| 24 |
+
)
|
| 25 |
+
# Cardinal action → (dy, dx) offsets
|
| 26 |
+
_CARDINAL_OFFSETS: dict[int, tuple[int, int]] = {
|
| 27 |
+
0: (-1, 0), 1: (0, 1), 2: (1, 0), 3: (0, -1),
|
| 28 |
+
}
|
| 29 |
+
_N_PHYSICS_CHECK = 8 # only inspect the first N plan positions
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _check_hazard(local_crop: np.ndarray, action: int) -> bool:
|
| 33 |
+
"""Return True if *action* from the agent's centre steps into a hazard.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
local_crop: ``[crop_size, crop_size]`` glyph array.
|
| 37 |
+
action: Cardinal action index (0=N, 1=E, 2=S, 3=W).
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
``True`` when the target cell contains a hazard glyph.
|
| 41 |
+
"""
|
| 42 |
+
if action not in _CARDINAL_OFFSETS:
|
| 43 |
+
return False
|
| 44 |
+
cs = local_crop.shape[0]
|
| 45 |
+
cy, cx = cs // 2, cs // 2
|
| 46 |
+
dy, dx = _CARDINAL_OFFSETS[action]
|
| 47 |
+
ny, nx = cy + dy, cx + dx
|
| 48 |
+
if not (0 <= ny < cs and 0 <= nx < cs):
|
| 49 |
+
return True
|
| 50 |
+
glyph = int(local_crop[ny, nx])
|
| 51 |
+
return glyph in _HAZARD_GLYPHS or glyph in _HAZARD_CHARS
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def top_k_filter(logits: Tensor, k: int) -> Tensor:
|
| 55 |
+
"""Zero out all but the top-k logits per position.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
logits: Raw logits. Shape ``[..., V]``.
|
| 59 |
+
k: Number of top entries to keep.
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
Filtered logits with non-top-k set to ``-inf``.
|
| 63 |
+
"""
|
| 64 |
+
if k <= 0 or k >= logits.shape[-1]:
|
| 65 |
+
return logits
|
| 66 |
+
topk_vals, _ = logits.topk(k, dim=-1) # [..., k]
|
| 67 |
+
threshold = topk_vals[..., -1:] # [..., 1]
|
| 68 |
+
return logits.masked_fill(logits < threshold, float("-inf"))
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def _compute_remask_prob(
|
| 72 |
+
strategy: str,
|
| 73 |
+
eta: float,
|
| 74 |
+
sigma_max: float,
|
| 75 |
+
confidence: Tensor | None,
|
| 76 |
+
) -> Tensor | float:
|
| 77 |
+
"""Compute per-token remasking probability.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
strategy: One of ``"rescale"``, ``"cap"``, ``"conf"``.
|
| 81 |
+
eta: Base remasking strength hyperparameter.
|
| 82 |
+
sigma_max: ``1 - alpha_t(ratio)`` at current step.
|
| 83 |
+
confidence: Per-token confidence scores. Shape ``[B, L]``.
|
| 84 |
+
Required only for the ``"conf"`` strategy.
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
Scalar or ``[B, L]`` tensor of remasking probabilities.
|
| 88 |
+
"""
|
| 89 |
+
if strategy == "rescale":
|
| 90 |
+
return eta * sigma_max
|
| 91 |
+
if strategy == "cap":
|
| 92 |
+
return min(eta, sigma_max)
|
| 93 |
+
if strategy == "conf":
|
| 94 |
+
assert confidence is not None, "conf strategy requires confidence"
|
| 95 |
+
return eta * sigma_max * (1.0 - confidence)
|
| 96 |
+
raise ValueError(f"Unknown remask strategy: {strategy}")
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
@torch.no_grad()
|
| 100 |
+
def remdm_sample(
|
| 101 |
+
model: torch.nn.Module,
|
| 102 |
+
local_obs: Tensor,
|
| 103 |
+
global_obs: Tensor,
|
| 104 |
+
cfg: SimpleNamespace,
|
| 105 |
+
device: torch.device | str,
|
| 106 |
+
physics_aware: bool = True,
|
| 107 |
+
blind_global: bool = False,
|
| 108 |
+
return_analytics: bool = False,
|
| 109 |
+
num_steps: int | None = None,
|
| 110 |
+
) -> Tensor | tuple[Tensor, list, list[float], list[int]]:
|
| 111 |
+
"""Generate action sequences via iterative ReMDM denoising.
|
| 112 |
+
|
| 113 |
+
Args:
|
| 114 |
+
model: Denoising model with forward signature
|
| 115 |
+
``(local_obs, global_obs, action_seq, t_discrete) -> dict``.
|
| 116 |
+
local_obs: Local crop observations. Shape ``[B, 9, 9]``.
|
| 117 |
+
global_obs: Global map observations. Shape ``[B, 21, 79]``.
|
| 118 |
+
cfg: Config namespace with ``seq_len``, ``mask_token``,
|
| 119 |
+
``action_dim``, ``diffusion_steps_eval``, ``temperature``,
|
| 120 |
+
``top_k``, ``eta``, ``remask_strategy``, ``noise_schedule``.
|
| 121 |
+
device: Torch device.
|
| 122 |
+
physics_aware: If ``True``, soft-penalise hazardous cardinal actions
|
| 123 |
+
by overriding their confidence to ``0.001`` before commitment
|
| 124 |
+
ranking. Only checks the first ``_N_PHYSICS_CHECK`` positions.
|
| 125 |
+
blind_global: If ``True``, zero out the global map observation
|
| 126 |
+
(local-only ablation).
|
| 127 |
+
return_analytics: If ``True``, also return per-step analytics as
|
| 128 |
+
``(seq, path_per_step, tracking_confidence, tracking_masked)``.
|
| 129 |
+
num_steps: Override number of denoising steps (default uses
|
| 130 |
+
``cfg.diffusion_steps_eval``).
|
| 131 |
+
|
| 132 |
+
Returns:
|
| 133 |
+
When ``return_analytics=False`` (default): fully committed action
|
| 134 |
+
sequence of shape ``[B, seq_len]``, int64, with no MASK tokens.
|
| 135 |
+
|
| 136 |
+
When ``return_analytics=True``: tuple
|
| 137 |
+
``(seq, path_per_step, tracking_confidence, tracking_masked_count)``
|
| 138 |
+
where ``path_per_step`` is a list of ``[seq_len]`` numpy arrays,
|
| 139 |
+
``tracking_confidence`` a list of per-step avg unmasked confidence
|
| 140 |
+
floats, and ``tracking_masked_count`` a list of masked-token counts.
|
| 141 |
+
"""
|
| 142 |
+
B = local_obs.shape[0]
|
| 143 |
+
seq_len = cfg.seq_len
|
| 144 |
+
mask_token = cfg.mask_token
|
| 145 |
+
action_dim = cfg.action_dim
|
| 146 |
+
K = num_steps if num_steps is not None else cfg.diffusion_steps_eval
|
| 147 |
+
schedule_fn = get_schedule(cfg.noise_schedule)
|
| 148 |
+
min_keep = max(1, int(seq_len * 0.10)) # Safety Net: always unmask ≥10%
|
| 149 |
+
|
| 150 |
+
local_obs = local_obs.to(device)
|
| 151 |
+
global_obs = global_obs.to(device)
|
| 152 |
+
|
| 153 |
+
if blind_global:
|
| 154 |
+
global_obs = torch.zeros_like(global_obs)
|
| 155 |
+
|
| 156 |
+
# Pre-compute numpy local crops for physics checks (CPU, batch loop)
|
| 157 |
+
local_np: np.ndarray | None = None # [B, crop, crop]
|
| 158 |
+
if physics_aware:
|
| 159 |
+
local_np = local_obs.cpu().numpy()
|
| 160 |
+
|
| 161 |
+
# Analytics buffers (only populated when return_analytics=True)
|
| 162 |
+
path_per_step: list[np.ndarray] = []
|
| 163 |
+
tracking_confidence: list[float] = []
|
| 164 |
+
tracking_masked_count: list[int] = []
|
| 165 |
+
|
| 166 |
+
# Start fully masked
|
| 167 |
+
seq = torch.full(
|
| 168 |
+
(B, seq_len), mask_token, dtype=torch.long, device=device
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
for k in range(1, K + 1):
|
| 172 |
+
ratio = k / K
|
| 173 |
+
# Pass as tensor (not Python int) to avoid torch.compile recompilation
|
| 174 |
+
t_discrete = torch.full(
|
| 175 |
+
(B,), int(cfg.num_diffusion_steps * (1.0 - ratio)),
|
| 176 |
+
dtype=torch.long, device=device,
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
# Forward pass
|
| 180 |
+
out = model(local_obs, global_obs, seq, t_discrete)
|
| 181 |
+
logits = out["actions"] # [B, seq_len, vocab]
|
| 182 |
+
|
| 183 |
+
# Mask invalid action tokens (indices >= action_dim)
|
| 184 |
+
logits[:, :, action_dim:] = float("-inf")
|
| 185 |
+
|
| 186 |
+
# Temperature scaling
|
| 187 |
+
logits = logits / cfg.temperature
|
| 188 |
+
|
| 189 |
+
# Top-K filtering
|
| 190 |
+
logits = top_k_filter(logits, cfg.top_k)
|
| 191 |
+
|
| 192 |
+
# Sample predictions
|
| 193 |
+
probs = F.softmax(logits, dim=-1) # [B, seq_len, action_dim]
|
| 194 |
+
preds = Categorical(probs=probs).sample() # [B, seq_len]
|
| 195 |
+
|
| 196 |
+
# Confidence: probability of the sampled token
|
| 197 |
+
conf = probs.gather(
|
| 198 |
+
-1, preds.unsqueeze(-1)
|
| 199 |
+
).squeeze(-1) # [B, seq_len]
|
| 200 |
+
|
| 201 |
+
# Physics softener: demote hazardous cardinal actions to conf=0.001
|
| 202 |
+
if physics_aware and local_np is not None:
|
| 203 |
+
preds_np = preds.cpu().numpy() # [B, seq_len]
|
| 204 |
+
conf_override = conf.clone()
|
| 205 |
+
for b in range(B):
|
| 206 |
+
crop_b = np.asarray(local_np[b]) # [crop, crop]
|
| 207 |
+
for pos in range(min(_N_PHYSICS_CHECK, seq_len)):
|
| 208 |
+
action = int(preds_np[b, pos])
|
| 209 |
+
if _check_hazard(crop_b, action):
|
| 210 |
+
conf_override[b, pos] = 0.001
|
| 211 |
+
conf = conf_override
|
| 212 |
+
|
| 213 |
+
is_masked = seq == mask_token # [B, seq_len]
|
| 214 |
+
|
| 215 |
+
if k < K:
|
| 216 |
+
# MaskGIT progressive unmasking with min-keep guarantee
|
| 217 |
+
n_unmask = max(min_keep, max(1, int(seq_len * ratio)))
|
| 218 |
+
|
| 219 |
+
# Set confidence of non-masked positions to -1 so they
|
| 220 |
+
# are not selected for unmasking
|
| 221 |
+
unmask_scores = conf.clone()
|
| 222 |
+
unmask_scores[~is_masked] = -1.0
|
| 223 |
+
|
| 224 |
+
# For each batch element, unmask top-confidence masked positions
|
| 225 |
+
_, topk_indices = unmask_scores.topk(
|
| 226 |
+
n_unmask, dim=-1
|
| 227 |
+
) # [B, n_unmask]
|
| 228 |
+
|
| 229 |
+
# Build scatter mask for positions to unmask
|
| 230 |
+
unmask_mask = torch.zeros_like(seq, dtype=torch.bool)
|
| 231 |
+
unmask_mask.scatter_(1, topk_indices, True)
|
| 232 |
+
unmask_mask = unmask_mask & is_masked # only unmask masked pos
|
| 233 |
+
|
| 234 |
+
seq = torch.where(unmask_mask, preds, seq)
|
| 235 |
+
|
| 236 |
+
# ReMDM stochastic remasking of committed (non-masked) positions
|
| 237 |
+
is_committed = seq != mask_token # [B, seq_len]
|
| 238 |
+
alpha_t_ratio = schedule_fn(
|
| 239 |
+
torch.tensor(ratio, device=device)
|
| 240 |
+
)
|
| 241 |
+
sigma_max = (1.0 - alpha_t_ratio).item()
|
| 242 |
+
|
| 243 |
+
remask_prob = _compute_remask_prob(
|
| 244 |
+
cfg.remask_strategy, cfg.eta, sigma_max, conf
|
| 245 |
+
)
|
| 246 |
+
if isinstance(remask_prob, Tensor):
|
| 247 |
+
do_remask = (
|
| 248 |
+
torch.rand_like(conf) < remask_prob
|
| 249 |
+
) & is_committed
|
| 250 |
+
else:
|
| 251 |
+
do_remask = (
|
| 252 |
+
torch.rand(B, seq_len, device=device) < remask_prob
|
| 253 |
+
) & is_committed
|
| 254 |
+
seq = torch.where(do_remask, mask_token, seq)
|
| 255 |
+
else:
|
| 256 |
+
# Final step: commit all remaining MASK tokens
|
| 257 |
+
seq = torch.where(is_masked, preds, seq)
|
| 258 |
+
|
| 259 |
+
# Analytics tracking
|
| 260 |
+
if return_analytics:
|
| 261 |
+
path_per_step.append(seq[0].cpu().numpy().copy())
|
| 262 |
+
still_masked = (seq[0] == mask_token)
|
| 263 |
+
unmasked_conf = conf[0][~still_masked]
|
| 264 |
+
avg_conf = (
|
| 265 |
+
unmasked_conf.mean().item()
|
| 266 |
+
if unmasked_conf.numel() > 0 else 0.0
|
| 267 |
+
)
|
| 268 |
+
tracking_confidence.append(avg_conf)
|
| 269 |
+
tracking_masked_count.append(int(still_masked.sum().item()))
|
| 270 |
+
|
| 271 |
+
assert (seq != mask_token).all(), (
|
| 272 |
+
"remdm_sample produced MASK tokens in final output"
|
| 273 |
+
)
|
| 274 |
+
if return_analytics:
|
| 275 |
+
return seq, path_per_step, tracking_confidence, tracking_masked_count
|
| 276 |
+
return seq
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
@torch.no_grad()
|
| 280 |
+
def greedy_sample(
|
| 281 |
+
model: torch.nn.Module,
|
| 282 |
+
local_obs: Tensor,
|
| 283 |
+
global_obs: Tensor,
|
| 284 |
+
cfg: SimpleNamespace,
|
| 285 |
+
device: torch.device | str,
|
| 286 |
+
blind_global: bool = False,
|
| 287 |
+
num_steps: int | None = None,
|
| 288 |
+
) -> Tensor:
|
| 289 |
+
"""Greedy (argmax) MaskGIT sampling — no temperature, top-K, or remasking.
|
| 290 |
+
|
| 291 |
+
Used by ``DataCollector`` during DAgger for deterministic rollouts,
|
| 292 |
+
matching the reference ``run_model_episode`` behaviour.
|
| 293 |
+
|
| 294 |
+
Args:
|
| 295 |
+
model: Denoising model.
|
| 296 |
+
local_obs: Shape ``[B, 9, 9]``.
|
| 297 |
+
global_obs: Shape ``[B, 21, 79]``.
|
| 298 |
+
cfg: Config namespace.
|
| 299 |
+
device: Torch device.
|
| 300 |
+
blind_global: Zero out global map (local-only ablation).
|
| 301 |
+
|
| 302 |
+
Returns:
|
| 303 |
+
Fully committed action sequence ``[B, seq_len]``, int64.
|
| 304 |
+
"""
|
| 305 |
+
B = local_obs.shape[0]
|
| 306 |
+
seq_len = cfg.seq_len
|
| 307 |
+
mask_token = cfg.mask_token
|
| 308 |
+
action_dim = cfg.action_dim
|
| 309 |
+
K = num_steps if num_steps is not None else cfg.diffusion_steps_eval
|
| 310 |
+
|
| 311 |
+
local_obs = local_obs.to(device)
|
| 312 |
+
global_obs = global_obs.to(device)
|
| 313 |
+
if blind_global:
|
| 314 |
+
global_obs = torch.zeros_like(global_obs)
|
| 315 |
+
|
| 316 |
+
seq = torch.full(
|
| 317 |
+
(B, seq_len), mask_token, dtype=torch.long, device=device,
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
for k in range(1, K + 1):
|
| 321 |
+
ratio = k / K
|
| 322 |
+
t_discrete = torch.full(
|
| 323 |
+
(B,), int(cfg.num_diffusion_steps * (1.0 - ratio)),
|
| 324 |
+
dtype=torch.long, device=device,
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
out = model(local_obs, global_obs, seq, t_discrete)
|
| 328 |
+
logits = out["actions"] # [B, seq_len, vocab]
|
| 329 |
+
|
| 330 |
+
# Mask invalid action tokens
|
| 331 |
+
logits[:, :, action_dim:] = float("-inf")
|
| 332 |
+
|
| 333 |
+
# Greedy: argmax over softmax (no temperature, no top-K)
|
| 334 |
+
probs = F.softmax(logits, dim=-1) # [B, seq_len, action_dim]
|
| 335 |
+
confidences, preds = probs.max(dim=-1) # [B, seq_len] each
|
| 336 |
+
|
| 337 |
+
# MaskGIT progressive unmasking by confidence
|
| 338 |
+
num_to_unmask = max(1, int(seq_len * ratio))
|
| 339 |
+
is_masked = seq == mask_token # [B, seq_len]
|
| 340 |
+
|
| 341 |
+
# Score only masked positions for unmasking
|
| 342 |
+
scores = confidences.clone()
|
| 343 |
+
scores[~is_masked] = -1.0
|
| 344 |
+
_, topk_idx = scores.topk(num_to_unmask, dim=-1)
|
| 345 |
+
|
| 346 |
+
unmask_mask = torch.zeros_like(seq, dtype=torch.bool)
|
| 347 |
+
unmask_mask.scatter_(1, topk_idx, True)
|
| 348 |
+
unmask_mask = unmask_mask & is_masked
|
| 349 |
+
|
| 350 |
+
seq = torch.where(unmask_mask, preds, seq)
|
| 351 |
+
|
| 352 |
+
# No remasking in greedy mode
|
| 353 |
+
|
| 354 |
+
# Force-commit any remaining masked tokens
|
| 355 |
+
still_masked = seq == mask_token
|
| 356 |
+
if still_masked.any():
|
| 357 |
+
t_zero = torch.zeros(B, dtype=torch.long, device=device)
|
| 358 |
+
out = model(local_obs, global_obs, seq, t_zero)
|
| 359 |
+
logits = out["actions"]
|
| 360 |
+
logits[:, :, action_dim:] = float("-inf")
|
| 361 |
+
preds = logits.argmax(dim=-1)
|
| 362 |
+
seq = torch.where(still_masked, preds, seq)
|
| 363 |
+
|
| 364 |
+
return seq
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
def select_action(
|
| 368 |
+
model: torch.nn.Module,
|
| 369 |
+
local_obs: Tensor,
|
| 370 |
+
global_obs: Tensor,
|
| 371 |
+
cfg: SimpleNamespace,
|
| 372 |
+
device: torch.device | str,
|
| 373 |
+
physics_aware: bool = True,
|
| 374 |
+
blind_global: bool = False,
|
| 375 |
+
) -> int:
|
| 376 |
+
"""Sample a single action from a length-1 batch.
|
| 377 |
+
|
| 378 |
+
Args:
|
| 379 |
+
model: Denoising model.
|
| 380 |
+
local_obs: Shape ``[9, 9]`` or ``[1, 9, 9]``.
|
| 381 |
+
global_obs: Shape ``[21, 79]`` or ``[1, 21, 79]``.
|
| 382 |
+
cfg: Config namespace.
|
| 383 |
+
device: Torch device.
|
| 384 |
+
physics_aware: Forward to ``remdm_sample``.
|
| 385 |
+
blind_global: Forward to ``remdm_sample``.
|
| 386 |
+
|
| 387 |
+
Returns:
|
| 388 |
+
The first action of the generated plan (int).
|
| 389 |
+
"""
|
| 390 |
+
if local_obs.ndim == 2:
|
| 391 |
+
local_obs = local_obs.unsqueeze(0)
|
| 392 |
+
if global_obs.ndim == 2:
|
| 393 |
+
global_obs = global_obs.unsqueeze(0)
|
| 394 |
+
seq = remdm_sample(
|
| 395 |
+
model, local_obs, global_obs, cfg, device,
|
| 396 |
+
physics_aware=physics_aware, blind_global=blind_global,
|
| 397 |
+
)
|
| 398 |
+
return seq[0, 0].item()
|
src/diffusion/schedules.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Noise schedule functions for MDLM diffusion.
|
| 2 |
+
|
| 3 |
+
Ported from the Craftax JAX implementation (src/diffusion/schedules.py).
|
| 4 |
+
All functions operate on PyTorch tensors and are pure (no global state).
|
| 5 |
+
|
| 6 |
+
Convention: alpha(t) is the fraction of tokens that remain *unmasked*.
|
| 7 |
+
- alpha(0) = 1.0 (fully clean)
|
| 8 |
+
- alpha(1) = 0.0 (fully masked)
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import math
|
| 14 |
+
from typing import Callable
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
from torch import Tensor
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def linear_schedule(t: Tensor) -> Tensor:
|
| 21 |
+
"""Linear noise schedule: alpha(t) = 1 - t.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
t: Diffusion time in [0, 1]. Any shape.
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
Retention probability alpha_t, same shape as *t*.
|
| 28 |
+
"""
|
| 29 |
+
return 1.0 - t
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def cosine_schedule(t: Tensor) -> Tensor:
|
| 33 |
+
"""Cosine noise schedule: alpha(t) = cos(pi/2 * t)^2.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
t: Diffusion time in [0, 1]. Any shape.
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
Retention probability alpha_t, same shape as *t*.
|
| 40 |
+
"""
|
| 41 |
+
return torch.cos(t * (math.pi / 2.0)) ** 2
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
_SCHEDULE_MAP: dict[str, Callable[[Tensor], Tensor]] = {
|
| 45 |
+
"linear": linear_schedule,
|
| 46 |
+
"cosine": cosine_schedule,
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def get_schedule(name: str) -> Callable[[Tensor], Tensor]:
|
| 51 |
+
"""Look up a noise schedule by name.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
name: One of ``"linear"`` or ``"cosine"``.
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
The schedule function ``alpha(t)``.
|
| 58 |
+
|
| 59 |
+
Raises:
|
| 60 |
+
KeyError: If *name* is not registered.
|
| 61 |
+
"""
|
| 62 |
+
if name not in _SCHEDULE_MAP:
|
| 63 |
+
raise KeyError(
|
| 64 |
+
f"Unknown schedule '{name}'. "
|
| 65 |
+
f"Available: {list(_SCHEDULE_MAP.keys())}"
|
| 66 |
+
)
|
| 67 |
+
return _SCHEDULE_MAP[name]
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def alpha_prime(
|
| 71 |
+
t: Tensor,
|
| 72 |
+
schedule_fn: Callable[[Tensor], Tensor],
|
| 73 |
+
eps: float = 1e-5,
|
| 74 |
+
) -> Tensor:
|
| 75 |
+
"""Numerical derivative d(alpha)/dt via central difference.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
t: Diffusion time in [0, 1]. Any shape.
|
| 79 |
+
schedule_fn: Noise schedule returning alpha(t).
|
| 80 |
+
eps: Half-width for finite-difference stencil.
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
Approximate derivative, same shape as *t*.
|
| 84 |
+
"""
|
| 85 |
+
t_clamped = t.clamp(eps, 1.0 - eps)
|
| 86 |
+
return (schedule_fn(t_clamped + eps) - schedule_fn(t_clamped - eps)) / (
|
| 87 |
+
2.0 * eps
|
| 88 |
+
)
|
src/envs/__init__.py
ADDED
|
File without changes
|
src/envs/discovery.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""MiniHack environment discovery and diagnostic utilities.
|
| 2 |
+
|
| 3 |
+
Provides tools for scanning the gymnasium registry, validating action-space
|
| 4 |
+
consistency across environments, and benchmarking inference throughput.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import logging
|
| 10 |
+
import time
|
| 11 |
+
from types import SimpleNamespace
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
_NAV_KEYWORDS = ("Room", "Corridor", "Maze", "River")
|
| 18 |
+
_EXCLUDED_KEYWORDS = ("KeyRoom",)
|
| 19 |
+
_REFERENCE_ENV_ID = "MiniHack-Room-15x15-v0"
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def list_working_minihack_tasks() -> list[str]:
|
| 23 |
+
"""Scan the gymnasium registry for working MiniHack navigation tasks.
|
| 24 |
+
|
| 25 |
+
Filters to environments whose names contain at least one navigation
|
| 26 |
+
keyword and attempts to instantiate each. Returns the IDs of all
|
| 27 |
+
successfully created environments.
|
| 28 |
+
|
| 29 |
+
Returns:
|
| 30 |
+
Sorted list of working MiniHack navigation environment IDs.
|
| 31 |
+
"""
|
| 32 |
+
import gymnasium as gym
|
| 33 |
+
import minihack # noqa: F401 — registers envs
|
| 34 |
+
|
| 35 |
+
all_ids = list(gym.envs.registry.keys())
|
| 36 |
+
candidates = [
|
| 37 |
+
e for e in all_ids
|
| 38 |
+
if "MiniHack" in e
|
| 39 |
+
and any(k in e for k in _NAV_KEYWORDS)
|
| 40 |
+
and not any(x in e for x in _EXCLUDED_KEYWORDS)
|
| 41 |
+
]
|
| 42 |
+
|
| 43 |
+
working: list[str] = []
|
| 44 |
+
broken: list[str] = []
|
| 45 |
+
for env_id in sorted(candidates):
|
| 46 |
+
try:
|
| 47 |
+
env = gym.make(env_id)
|
| 48 |
+
working.append(env_id)
|
| 49 |
+
env.close()
|
| 50 |
+
except Exception:
|
| 51 |
+
broken.append(env_id)
|
| 52 |
+
|
| 53 |
+
logger.info(
|
| 54 |
+
f"MiniHack navigation tasks — working: {len(working)}, "
|
| 55 |
+
f"broken: {len(broken)}"
|
| 56 |
+
)
|
| 57 |
+
return working
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def check_action_consistency_with_fixed_ref(
|
| 61 |
+
env_list: list[str],
|
| 62 |
+
) -> list[tuple[str, str, int]]:
|
| 63 |
+
"""Validate action-space ordering against a fixed reference environment.
|
| 64 |
+
|
| 65 |
+
Compares each environment's action list against
|
| 66 |
+
``MiniHack-Room-15x15-v0`` and classifies the relationship as one of:
|
| 67 |
+
``REFERENCE``, ``EXACT``, ``SUPERSET (+N)``, ``SUBSET (-N)``,
|
| 68 |
+
``CONFLICT``, or ``CRASHED``.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
env_list: MiniHack environment IDs to check.
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
List of ``(env_id, status, action_space_size)`` tuples.
|
| 75 |
+
"""
|
| 76 |
+
import gymnasium as gym
|
| 77 |
+
import minihack # noqa: F401
|
| 78 |
+
|
| 79 |
+
ref_env = gym.make(_REFERENCE_ENV_ID)
|
| 80 |
+
reference_actions = ref_env.unwrapped.actions # type: ignore[attr-defined]
|
| 81 |
+
ref_env.close()
|
| 82 |
+
|
| 83 |
+
results: list[tuple[str, str, int]] = []
|
| 84 |
+
for env_id in sorted(env_list):
|
| 85 |
+
if env_id == _REFERENCE_ENV_ID:
|
| 86 |
+
results.append((env_id, "REFERENCE", len(reference_actions)))
|
| 87 |
+
continue
|
| 88 |
+
try:
|
| 89 |
+
env = gym.make(env_id)
|
| 90 |
+
try:
|
| 91 |
+
env_actions = env.unwrapped.actions # type: ignore[attr-defined]
|
| 92 |
+
limit = min(len(reference_actions), len(env_actions))
|
| 93 |
+
is_match = all(
|
| 94 |
+
reference_actions[i] == env_actions[i]
|
| 95 |
+
for i in range(limit)
|
| 96 |
+
)
|
| 97 |
+
diff = len(env_actions) - len(reference_actions)
|
| 98 |
+
if is_match and diff == 0:
|
| 99 |
+
status = "EXACT"
|
| 100 |
+
elif diff > 0:
|
| 101 |
+
status = f"SUPERSET (+{diff})"
|
| 102 |
+
elif is_match:
|
| 103 |
+
status = f"SUBSET ({diff})"
|
| 104 |
+
else:
|
| 105 |
+
status = "CONFLICT"
|
| 106 |
+
results.append((env_id, status, len(env_actions)))
|
| 107 |
+
finally:
|
| 108 |
+
env.close()
|
| 109 |
+
except Exception:
|
| 110 |
+
results.append((env_id, "CRASHED", 0))
|
| 111 |
+
|
| 112 |
+
for name, status, size in results:
|
| 113 |
+
logger.info(f" {name:<40} | {status:<14} | n_actions={size}")
|
| 114 |
+
return results
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def benchmark_inference(
|
| 118 |
+
model: torch.nn.Module,
|
| 119 |
+
cfg: SimpleNamespace,
|
| 120 |
+
device: torch.device | str,
|
| 121 |
+
n_actions: int = 100,
|
| 122 |
+
) -> tuple[float, float]:
|
| 123 |
+
"""Measure ReMDM inference throughput.
|
| 124 |
+
|
| 125 |
+
Runs ``n_actions`` planning calls with dummy observations and
|
| 126 |
+
measures wall-clock time.
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
model: Denoising model in eval mode.
|
| 130 |
+
cfg: Config namespace (used for ``seq_len``, ``mask_token``, etc.).
|
| 131 |
+
device: Torch device.
|
| 132 |
+
n_actions: Number of planning calls to benchmark.
|
| 133 |
+
|
| 134 |
+
Returns:
|
| 135 |
+
``(diffusion_steps_per_sec, actions_per_sec)`` as floats.
|
| 136 |
+
"""
|
| 137 |
+
from src.diffusion.sampling import remdm_sample
|
| 138 |
+
|
| 139 |
+
model.eval()
|
| 140 |
+
local_dummy = torch.zeros(
|
| 141 |
+
(1, cfg.crop_size, cfg.crop_size), dtype=torch.long, device=device,
|
| 142 |
+
)
|
| 143 |
+
global_dummy = torch.zeros(
|
| 144 |
+
(1, cfg.map_h, cfg.map_w), dtype=torch.long, device=device,
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
if torch.cuda.is_available():
|
| 148 |
+
torch.cuda.synchronize()
|
| 149 |
+
|
| 150 |
+
t0 = time.perf_counter()
|
| 151 |
+
for _ in range(n_actions):
|
| 152 |
+
remdm_sample(model, local_dummy, global_dummy, cfg, device)
|
| 153 |
+
if torch.cuda.is_available():
|
| 154 |
+
torch.cuda.synchronize()
|
| 155 |
+
elapsed = time.perf_counter() - t0
|
| 156 |
+
|
| 157 |
+
total_steps = n_actions * cfg.diffusion_steps_eval
|
| 158 |
+
steps_per_sec = total_steps / elapsed if elapsed > 0 else 0.0
|
| 159 |
+
actions_per_sec = n_actions / elapsed if elapsed > 0 else 0.0
|
| 160 |
+
|
| 161 |
+
logger.info(
|
| 162 |
+
f"Benchmark ({n_actions} actions): "
|
| 163 |
+
f"{steps_per_sec:.1f} diffusion-steps/s | "
|
| 164 |
+
f"{actions_per_sec:.1f} actions/s"
|
| 165 |
+
)
|
| 166 |
+
return steps_per_sec, actions_per_sec
|
src/envs/minihack_env.py
ADDED
|
@@ -0,0 +1,454 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""MiniHack environment wrapper with BFS oracle and shaped rewards.
|
| 2 |
+
|
| 3 |
+
Ported from minihack_reference/src/env.py. Provides dual-stream
|
| 4 |
+
observations (9x9 local crop + 21x79 global map), a multi-tier BFS
|
| 5 |
+
oracle, and reward shaping (win bonus, BFS progress, exploration, step
|
| 6 |
+
penalty).
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import collections
|
| 12 |
+
import logging
|
| 13 |
+
from types import SimpleNamespace
|
| 14 |
+
|
| 15 |
+
import gymnasium as gym
|
| 16 |
+
import minihack # noqa: F401 — registers MiniHack envs
|
| 17 |
+
import numpy as np
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
# Suppress noisy NLE INFO spam ("Not saving any NLE data." on every env create)
|
| 22 |
+
logging.getLogger("nle.env.base").setLevel(logging.WARNING)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# ── Staircase detection ──────────────────────────────────────────────
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def find_staircase_from_glyphs(global_obs: np.ndarray) -> np.ndarray:
|
| 29 |
+
"""Locate the staircase '>' in the global glyph map.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
global_obs: Glyph map, shape ``[B, H, W]`` or ``[H, W]``.
|
| 33 |
+
|
| 34 |
+
Returns:
|
| 35 |
+
Normalised ``(row/H, col/W)`` coords, shape ``[B, 2]``
|
| 36 |
+
(float32). ``(-1, -1)`` when not visible.
|
| 37 |
+
"""
|
| 38 |
+
squeeze = global_obs.ndim == 2
|
| 39 |
+
if squeeze:
|
| 40 |
+
global_obs = global_obs[np.newaxis]
|
| 41 |
+
B, H, W = global_obs.shape
|
| 42 |
+
coords = np.full((B, 2), -1.0, dtype=np.float32)
|
| 43 |
+
for b in range(B):
|
| 44 |
+
is_stair = (
|
| 45 |
+
(global_obs[b] == 62)
|
| 46 |
+
| (global_obs[b] == 2310)
|
| 47 |
+
| (global_obs[b] == 2368)
|
| 48 |
+
| (global_obs[b] == 2383)
|
| 49 |
+
)
|
| 50 |
+
positions = np.argwhere(is_stair)
|
| 51 |
+
if positions.shape[0] > 0:
|
| 52 |
+
coords[b, 0] = positions[0, 0] / max(1, H - 1)
|
| 53 |
+
coords[b, 1] = positions[0, 1] / max(1, W - 1)
|
| 54 |
+
return coords
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
# ── Environment wrapper ──────────────────────────────────────────────
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class AdvancedObservationEnv(gym.Env):
|
| 61 |
+
"""MiniHack wrapper with dual-stream obs, BFS oracle, shaped rewards.
|
| 62 |
+
|
| 63 |
+
Observations are ``(local_crop, global_map)`` where
|
| 64 |
+
``local_crop`` is a ``[crop_size, crop_size]`` glyph window centred
|
| 65 |
+
on the agent and ``global_map`` is the full ``[21, 79]`` glyph grid.
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
env_id: MiniHack registry ID.
|
| 69 |
+
des_file: Optional ``.des`` file content (for custom levels).
|
| 70 |
+
cfg: Configuration namespace with ``crop_size``, ``action_dim``,
|
| 71 |
+
``pad_token``, ``map_h``, ``map_w``.
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
_UNWALKABLE = frozenset({32, 45, 124, 125}) # space, -, |, }
|
| 75 |
+
_CLOSED_DOOR = 43 # '+'
|
| 76 |
+
_DIR_MAP = {(-1, 0): 0, (0, 1): 1, (1, 0): 2, (0, -1): 3}
|
| 77 |
+
_CARDINAL = [(-1, 0), (0, 1), (1, 0), (0, -1)]
|
| 78 |
+
|
| 79 |
+
def __init__(
|
| 80 |
+
self,
|
| 81 |
+
env_id: str,
|
| 82 |
+
des_file: str | None,
|
| 83 |
+
cfg: SimpleNamespace,
|
| 84 |
+
) -> None:
|
| 85 |
+
super().__init__()
|
| 86 |
+
self.env_id = env_id
|
| 87 |
+
self._cfg = cfg
|
| 88 |
+
self._crop_half = cfg.crop_size // 2
|
| 89 |
+
|
| 90 |
+
obs_keys = ("glyphs", "chars", "pixel")
|
| 91 |
+
if des_file is not None:
|
| 92 |
+
self._inner = gym.make(
|
| 93 |
+
"MiniHack-Navigation-Custom-v0",
|
| 94 |
+
des_file=des_file,
|
| 95 |
+
observation_keys=obs_keys,
|
| 96 |
+
)
|
| 97 |
+
else:
|
| 98 |
+
self._inner = gym.make(
|
| 99 |
+
env_id, observation_keys=obs_keys,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
self.observation_space = gym.spaces.Box(
|
| 103 |
+
low=0, high=6000,
|
| 104 |
+
shape=(cfg.crop_size, cfg.crop_size),
|
| 105 |
+
dtype=np.int16,
|
| 106 |
+
)
|
| 107 |
+
self.action_space: gym.spaces.Discrete = gym.spaces.Discrete(cfg.action_dim)
|
| 108 |
+
|
| 109 |
+
self._visited: set[tuple[int, int]] = set()
|
| 110 |
+
self._prev_bfs_dist: int | None = None
|
| 111 |
+
self.last_raw_obs: dict | None = None
|
| 112 |
+
|
| 113 |
+
# ── gym.Env interface ────────────────────────────────────────────
|
| 114 |
+
|
| 115 |
+
def reset(
|
| 116 |
+
self, seed: int | None = None, options: dict | None = None,
|
| 117 |
+
) -> tuple[tuple[np.ndarray, np.ndarray], dict]:
|
| 118 |
+
"""Reset environment and tracking state.
|
| 119 |
+
|
| 120 |
+
Args:
|
| 121 |
+
seed: Optional RNG seed.
|
| 122 |
+
options: Passed through to the inner env.
|
| 123 |
+
|
| 124 |
+
Returns:
|
| 125 |
+
``((local_crop, global_map), info)``
|
| 126 |
+
"""
|
| 127 |
+
obs, info = self._inner.reset(seed=seed, options=options)
|
| 128 |
+
self.last_raw_obs = obs
|
| 129 |
+
self._prev_bfs_dist = self._get_bfs_distance(obs)
|
| 130 |
+
self._visited = set()
|
| 131 |
+
agent_pos = self._get_agent_pos(obs)
|
| 132 |
+
if agent_pos is not None:
|
| 133 |
+
self._visited.add(agent_pos)
|
| 134 |
+
return self._get_obs(obs), info
|
| 135 |
+
|
| 136 |
+
def step(
|
| 137 |
+
self, action: int,
|
| 138 |
+
) -> tuple[tuple[np.ndarray, np.ndarray], float, bool, bool, dict]:
|
| 139 |
+
"""Execute one environment step with shaped reward.
|
| 140 |
+
|
| 141 |
+
Reward shaping:
|
| 142 |
+
- Win bonus: ``+20.0``
|
| 143 |
+
- BFS progress toward staircase: ``+0.5 * (prev - curr)``
|
| 144 |
+
- New-tile exploration: ``+0.05``
|
| 145 |
+
- Step penalty: ``-0.01``
|
| 146 |
+
|
| 147 |
+
Args:
|
| 148 |
+
action: Integer action in ``[0, action_dim)``.
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
``(obs, shaped_reward, terminated, truncated, info)``
|
| 152 |
+
"""
|
| 153 |
+
inner_n = self._inner.action_space.n
|
| 154 |
+
if action >= inner_n:
|
| 155 |
+
action = action % inner_n
|
| 156 |
+
|
| 157 |
+
obs, raw_reward, terminated, truncated, info = self._inner.step(action)
|
| 158 |
+
self.last_raw_obs = obs
|
| 159 |
+
reward = float(raw_reward)
|
| 160 |
+
|
| 161 |
+
# Win bonus
|
| 162 |
+
if terminated and reward > 0:
|
| 163 |
+
info["won"] = True
|
| 164 |
+
reward += 20.0
|
| 165 |
+
else:
|
| 166 |
+
info["won"] = False
|
| 167 |
+
|
| 168 |
+
# BFS shaping
|
| 169 |
+
curr_dist = self._get_bfs_distance(obs)
|
| 170 |
+
if curr_dist is not None and self._prev_bfs_dist is not None:
|
| 171 |
+
reward += (self._prev_bfs_dist - curr_dist) * 0.5
|
| 172 |
+
self._prev_bfs_dist = curr_dist
|
| 173 |
+
|
| 174 |
+
# Exploration bonus
|
| 175 |
+
agent_pos = self._get_agent_pos(obs)
|
| 176 |
+
if agent_pos is not None and agent_pos not in self._visited:
|
| 177 |
+
reward += 0.05
|
| 178 |
+
self._visited.add(agent_pos)
|
| 179 |
+
|
| 180 |
+
# Step penalty
|
| 181 |
+
reward -= 0.01
|
| 182 |
+
|
| 183 |
+
return self._get_obs(obs), reward, terminated, truncated, info
|
| 184 |
+
|
| 185 |
+
@property
|
| 186 |
+
def unwrapped(self):
|
| 187 |
+
"""Access the inner MiniHack env."""
|
| 188 |
+
return self._inner.unwrapped
|
| 189 |
+
|
| 190 |
+
def close(self) -> None:
|
| 191 |
+
"""Close the inner environment."""
|
| 192 |
+
self._inner.close()
|
| 193 |
+
|
| 194 |
+
# ── Observation helpers ──────────────────────────────────────────
|
| 195 |
+
|
| 196 |
+
def _get_obs(
|
| 197 |
+
self, obs: dict,
|
| 198 |
+
) -> tuple[np.ndarray, np.ndarray]:
|
| 199 |
+
"""Extract dual-stream observation.
|
| 200 |
+
|
| 201 |
+
Args:
|
| 202 |
+
obs: Raw NLE observation dict.
|
| 203 |
+
|
| 204 |
+
Returns:
|
| 205 |
+
``(local_crop [crop,crop], global_map [H,W])`` as int16.
|
| 206 |
+
"""
|
| 207 |
+
return self._get_crop(obs), obs["glyphs"].copy().astype(np.int16)
|
| 208 |
+
|
| 209 |
+
def _get_crop(self, obs: dict) -> np.ndarray:
|
| 210 |
+
"""Crop local glyph window centred on agent.
|
| 211 |
+
|
| 212 |
+
Args:
|
| 213 |
+
obs: Raw NLE observation dict.
|
| 214 |
+
|
| 215 |
+
Returns:
|
| 216 |
+
``[crop_size, crop_size]`` int16 array.
|
| 217 |
+
"""
|
| 218 |
+
glyphs = obs["glyphs"]
|
| 219 |
+
chars = obs["chars"]
|
| 220 |
+
agent_pos = np.argwhere(chars == ord("@"))
|
| 221 |
+
cs = self._cfg.crop_size
|
| 222 |
+
if len(agent_pos) == 0:
|
| 223 |
+
return np.full((cs, cs), self._cfg.pad_token, dtype=np.int16)
|
| 224 |
+
y, x = agent_pos[0]
|
| 225 |
+
h = self._crop_half
|
| 226 |
+
padded = np.pad(
|
| 227 |
+
glyphs, h, mode="constant",
|
| 228 |
+
constant_values=self._cfg.pad_token,
|
| 229 |
+
)
|
| 230 |
+
return padded[y:y + cs, x:x + cs].astype(np.int16)
|
| 231 |
+
|
| 232 |
+
def _get_agent_pos(self, obs: dict) -> tuple[int, int] | None:
|
| 233 |
+
"""Find agent '@' position in the chars grid.
|
| 234 |
+
|
| 235 |
+
Args:
|
| 236 |
+
obs: Raw NLE observation dict.
|
| 237 |
+
|
| 238 |
+
Returns:
|
| 239 |
+
``(row, col)`` or ``None``.
|
| 240 |
+
"""
|
| 241 |
+
chars = obs["chars"]
|
| 242 |
+
pos = np.argwhere(chars == ord("@"))
|
| 243 |
+
return tuple(pos[0]) if len(pos) > 0 else None
|
| 244 |
+
|
| 245 |
+
def _get_bfs_distance(self, obs: dict) -> int | None:
|
| 246 |
+
"""BFS shortest-path distance from agent to staircase.
|
| 247 |
+
|
| 248 |
+
Args:
|
| 249 |
+
obs: Raw NLE observation dict.
|
| 250 |
+
|
| 251 |
+
Returns:
|
| 252 |
+
Integer distance or ``None`` if unreachable / not visible.
|
| 253 |
+
"""
|
| 254 |
+
chars = obs["chars"]
|
| 255 |
+
start = np.argwhere(chars == ord("@"))
|
| 256 |
+
target = np.argwhere(chars == ord(">"))
|
| 257 |
+
if len(start) == 0 or len(target) == 0:
|
| 258 |
+
return None
|
| 259 |
+
start = tuple(start[0])
|
| 260 |
+
target = tuple(target[0])
|
| 261 |
+
if start == target:
|
| 262 |
+
return 0
|
| 263 |
+
queue: collections.deque = collections.deque([(start, 0)])
|
| 264 |
+
visited = {start}
|
| 265 |
+
while queue:
|
| 266 |
+
(r, c), dist = queue.popleft()
|
| 267 |
+
if (r, c) == target:
|
| 268 |
+
return dist
|
| 269 |
+
for dr, dc in self._CARDINAL:
|
| 270 |
+
nr, nc = r + dr, c + dc
|
| 271 |
+
if (
|
| 272 |
+
0 <= nr < self._cfg.map_h
|
| 273 |
+
and 0 <= nc < self._cfg.map_w
|
| 274 |
+
and (nr, nc) not in visited
|
| 275 |
+
and chars[nr, nc] not in self._UNWALKABLE
|
| 276 |
+
):
|
| 277 |
+
visited.add((nr, nc))
|
| 278 |
+
queue.append(((nr, nc), dist + 1))
|
| 279 |
+
return None
|
| 280 |
+
|
| 281 |
+
# ── BFS Oracle ───────────────────────────────────────────────────
|
| 282 |
+
|
| 283 |
+
def get_oracle_action(self, obs: dict) -> int:
|
| 284 |
+
"""5-tier BFS oracle action.
|
| 285 |
+
|
| 286 |
+
Priority:
|
| 287 |
+
1. Kick adjacent closed door.
|
| 288 |
+
2. BFS to staircase '>'.
|
| 289 |
+
3. BFS to frontier (adjacent to unexplored space).
|
| 290 |
+
4. BFS to farthest reachable tile.
|
| 291 |
+
5. Random cardinal direction.
|
| 292 |
+
|
| 293 |
+
Args:
|
| 294 |
+
obs: Raw NLE observation dict (needs ``'chars'`` key).
|
| 295 |
+
|
| 296 |
+
Returns:
|
| 297 |
+
Action index in ``[0, action_dim)``.
|
| 298 |
+
"""
|
| 299 |
+
if obs is None:
|
| 300 |
+
return 0
|
| 301 |
+
chars = obs["chars"]
|
| 302 |
+
start = np.argwhere(chars == ord("@"))
|
| 303 |
+
if len(start) == 0:
|
| 304 |
+
return np.random.randint(0, 4)
|
| 305 |
+
start = tuple(start[0])
|
| 306 |
+
target_list = np.argwhere(chars == ord(">"))
|
| 307 |
+
|
| 308 |
+
# 1. Adjacent closed door → kick
|
| 309 |
+
for dr, dc in self._CARDINAL:
|
| 310 |
+
nr, nc = start[0] + dr, start[1] + dc
|
| 311 |
+
if (
|
| 312 |
+
0 <= nr < self._cfg.map_h
|
| 313 |
+
and 0 <= nc < self._cfg.map_w
|
| 314 |
+
and chars[nr, nc] == self._CLOSED_DOOR
|
| 315 |
+
):
|
| 316 |
+
return 11 # KICK
|
| 317 |
+
|
| 318 |
+
# BFS to gather reachable tiles + check staircase
|
| 319 |
+
queue: collections.deque = collections.deque([(start, [])])
|
| 320 |
+
visited = {start}
|
| 321 |
+
reachable: list[tuple[tuple[int, int], list[tuple[int, int]]]] = []
|
| 322 |
+
target_path: list[tuple[int, int]] | None = None
|
| 323 |
+
|
| 324 |
+
while queue:
|
| 325 |
+
(r, c), path = queue.popleft()
|
| 326 |
+
reachable.append(((r, c), path))
|
| 327 |
+
for t_r, t_c in target_list:
|
| 328 |
+
if r == t_r and c == t_c:
|
| 329 |
+
target_path = path
|
| 330 |
+
break
|
| 331 |
+
if target_path is not None:
|
| 332 |
+
break
|
| 333 |
+
for dr, dc in self._CARDINAL:
|
| 334 |
+
nr, nc = r + dr, c + dc
|
| 335 |
+
if (
|
| 336 |
+
0 <= nr < self._cfg.map_h
|
| 337 |
+
and 0 <= nc < self._cfg.map_w
|
| 338 |
+
and (nr, nc) not in visited
|
| 339 |
+
):
|
| 340 |
+
ch = chars[nr, nc]
|
| 341 |
+
if ch not in self._UNWALKABLE and ch != self._CLOSED_DOOR:
|
| 342 |
+
visited.add((nr, nc))
|
| 343 |
+
queue.append(((nr, nc), path + [(dr, dc)]))
|
| 344 |
+
|
| 345 |
+
# 2. Path to staircase
|
| 346 |
+
if target_path:
|
| 347 |
+
return self._DIR_MAP.get(target_path[0], 0)
|
| 348 |
+
|
| 349 |
+
# 3. Frontier exploration — tiles adjacent to unexplored space
|
| 350 |
+
frontier: list[list[tuple[int, int]]] = []
|
| 351 |
+
for (r, c), path in reachable:
|
| 352 |
+
if not path:
|
| 353 |
+
continue
|
| 354 |
+
for dr, dc in self._CARDINAL:
|
| 355 |
+
nr, nc = r + dr, c + dc
|
| 356 |
+
if (
|
| 357 |
+
0 <= nr < self._cfg.map_h
|
| 358 |
+
and 0 <= nc < self._cfg.map_w
|
| 359 |
+
and chars[nr, nc] == 32
|
| 360 |
+
):
|
| 361 |
+
frontier.append(path)
|
| 362 |
+
break
|
| 363 |
+
if frontier:
|
| 364 |
+
frontier.sort(key=len)
|
| 365 |
+
return self._DIR_MAP.get(frontier[0][0], 0)
|
| 366 |
+
|
| 367 |
+
# 4. Farthest reachable tile
|
| 368 |
+
if reachable:
|
| 369 |
+
reachable.sort(key=lambda x: len(x[1]), reverse=True)
|
| 370 |
+
farthest = reachable[0][1]
|
| 371 |
+
if farthest:
|
| 372 |
+
return self._DIR_MAP.get(farthest[0], 0)
|
| 373 |
+
|
| 374 |
+
# 5. Random cardinal
|
| 375 |
+
return np.random.randint(0, 4)
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
# ── Factory ──────────────────────────────────────────────────────────
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
def make_env(
|
| 382 |
+
env_id: str,
|
| 383 |
+
des_file: str | None,
|
| 384 |
+
cfg: SimpleNamespace,
|
| 385 |
+
) -> AdvancedObservationEnv:
|
| 386 |
+
"""Create a wrapped MiniHack environment.
|
| 387 |
+
|
| 388 |
+
Args:
|
| 389 |
+
env_id: MiniHack registry ID.
|
| 390 |
+
des_file: Optional ``.des`` file content.
|
| 391 |
+
cfg: Configuration namespace.
|
| 392 |
+
|
| 393 |
+
Returns:
|
| 394 |
+
Wrapped environment.
|
| 395 |
+
"""
|
| 396 |
+
return AdvancedObservationEnv(env_id, des_file, cfg)
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
def collect_oracle_trajectory(
|
| 400 |
+
env_id: str,
|
| 401 |
+
seed: int,
|
| 402 |
+
cfg: SimpleNamespace,
|
| 403 |
+
max_steps: int = 500,
|
| 404 |
+
) -> dict | None:
|
| 405 |
+
"""Roll out the BFS oracle on a single episode.
|
| 406 |
+
|
| 407 |
+
Args:
|
| 408 |
+
env_id: MiniHack registry ID.
|
| 409 |
+
seed: RNG seed for the episode.
|
| 410 |
+
cfg: Configuration namespace.
|
| 411 |
+
max_steps: Maximum episode length.
|
| 412 |
+
|
| 413 |
+
Returns:
|
| 414 |
+
``{"local": [T,9,9], "global": [T,21,79],
|
| 415 |
+
"actions": [T], "env_id": str}`` on success,
|
| 416 |
+
or ``None`` on failure.
|
| 417 |
+
"""
|
| 418 |
+
env = make_env(env_id, None, cfg)
|
| 419 |
+
try:
|
| 420 |
+
(local, glb), _info = env.reset(seed=seed)
|
| 421 |
+
locals_list = [local]
|
| 422 |
+
globals_list = [glb]
|
| 423 |
+
actions_list: list[int] = []
|
| 424 |
+
|
| 425 |
+
for _ in range(max_steps):
|
| 426 |
+
action = env.get_oracle_action(env.last_raw_obs)
|
| 427 |
+
actions_list.append(action)
|
| 428 |
+
(local, glb), _reward, terminated, truncated, _info = env.step(
|
| 429 |
+
action
|
| 430 |
+
)
|
| 431 |
+
locals_list.append(local)
|
| 432 |
+
globals_list.append(glb)
|
| 433 |
+
if terminated or truncated:
|
| 434 |
+
break
|
| 435 |
+
|
| 436 |
+
# Trim trailing obs (one more obs than actions)
|
| 437 |
+
locals_arr = np.stack(locals_list[:-1], axis=0).astype(np.int16)
|
| 438 |
+
globals_arr = np.stack(globals_list[:-1], axis=0).astype(np.int16)
|
| 439 |
+
actions_arr = np.array(actions_list, dtype=np.int64)
|
| 440 |
+
|
| 441 |
+
return {
|
| 442 |
+
"local": locals_arr,
|
| 443 |
+
"global": globals_arr,
|
| 444 |
+
"actions": actions_arr,
|
| 445 |
+
"env_id": env_id,
|
| 446 |
+
}
|
| 447 |
+
except Exception:
|
| 448 |
+
logger.error(
|
| 449 |
+
f"Oracle trajectory failed for {env_id} seed={seed}",
|
| 450 |
+
exc_info=True,
|
| 451 |
+
)
|
| 452 |
+
return None
|
| 453 |
+
finally:
|
| 454 |
+
env.close()
|
src/models/__init__.py
ADDED
|
File without changes
|
src/models/denoiser.py
ADDED
|
@@ -0,0 +1,415 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Dual-stream denoising transformer for MiniHack.
|
| 2 |
+
|
| 3 |
+
Ported from minihack_reference/src/model.py. Architecture follows the
|
| 4 |
+
Craftax denoiser conventions (forward return format, obs-encoder pattern)
|
| 5 |
+
while using the MiniHack dual-stream design (local CNN + gated global
|
| 6 |
+
CNN + auxiliary goal head).
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import copy
|
| 12 |
+
import logging
|
| 13 |
+
import shutil
|
| 14 |
+
from types import SimpleNamespace
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn as nn
|
| 18 |
+
from torch import Tensor
|
| 19 |
+
|
| 20 |
+
logger = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class LocalDiffusionPlannerWithGlobal(nn.Module):
|
| 24 |
+
"""Dual-stream transformer for masked diffusion action planning.
|
| 25 |
+
|
| 26 |
+
Combines a local 9x9 glyph crop with a gated global 21x79 map
|
| 27 |
+
context. Produces action logits and an auxiliary staircase-coordinate
|
| 28 |
+
prediction.
|
| 29 |
+
|
| 30 |
+
Architecture:
|
| 31 |
+
Local stream: Embedding(6000,64) -> CNN(64->32->64) -> Linear -> 1 token
|
| 32 |
+
Global stream: Embedding(6000,32) -> CNN(32->32->64) -> Pool(2,4)
|
| 33 |
+
-> Linear -> 8 tokens, gated by sigmoid(learnable scalar)
|
| 34 |
+
Goal head: mean(global_tokens) -> MLP -> [B,2] (before gate)
|
| 35 |
+
Action stream: Embedding(14, n_embd) + timestep + position
|
| 36 |
+
Transformer: concat all -> TransformerEncoder -> last 64 tokens -> head
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
cfg: Config namespace with ``action_dim``, ``n_embd``, ``n_head``,
|
| 40 |
+
``n_layer``, ``n_global_tokens``, ``seq_len``,
|
| 41 |
+
``global_gate_init``, ``num_diffusion_steps``.
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
def __init__(self, cfg: SimpleNamespace) -> None:
|
| 45 |
+
super().__init__()
|
| 46 |
+
action_dim = cfg.action_dim
|
| 47 |
+
n_embd = cfg.n_embd
|
| 48 |
+
n_head = cfg.n_head
|
| 49 |
+
n_layer = cfg.n_layer
|
| 50 |
+
n_global_tokens = cfg.n_global_tokens
|
| 51 |
+
seq_len = cfg.seq_len
|
| 52 |
+
|
| 53 |
+
assert n_embd % n_head == 0, (
|
| 54 |
+
f"n_embd ({n_embd}) must be divisible by n_head ({n_head})"
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
self.n_global_tokens = n_global_tokens
|
| 58 |
+
|
| 59 |
+
# ── Local stream: 9x9 crop -> 1 token ──────────────────────
|
| 60 |
+
self.embedding = nn.Embedding(6000, 64)
|
| 61 |
+
self.cnn = nn.Sequential(
|
| 62 |
+
nn.Conv2d(64, 32, 3, padding=1),
|
| 63 |
+
nn.GELU(),
|
| 64 |
+
nn.Conv2d(32, 64, 3, padding=1),
|
| 65 |
+
nn.GELU(),
|
| 66 |
+
nn.Flatten(),
|
| 67 |
+
nn.Linear(64 * 9 * 9, n_embd),
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
# ── Action stream ──────────────────────────────────────────
|
| 71 |
+
self.action_emb = nn.Embedding(action_dim + 2, n_embd)
|
| 72 |
+
self.timestep_emb = nn.Embedding(
|
| 73 |
+
cfg.num_diffusion_steps, n_embd,
|
| 74 |
+
)
|
| 75 |
+
self.pos_emb = nn.Embedding(seq_len, n_embd)
|
| 76 |
+
|
| 77 |
+
# ── Transformer ───────────────────────────────────────────
|
| 78 |
+
encoder_layer = nn.TransformerEncoderLayer(
|
| 79 |
+
d_model=n_embd,
|
| 80 |
+
nhead=n_head,
|
| 81 |
+
dim_feedforward=n_embd * 4,
|
| 82 |
+
dropout=getattr(cfg, "dropout", 0.0),
|
| 83 |
+
activation="gelu",
|
| 84 |
+
norm_first=True,
|
| 85 |
+
batch_first=True,
|
| 86 |
+
)
|
| 87 |
+
self.transformer = nn.TransformerEncoder(
|
| 88 |
+
encoder_layer, num_layers=n_layer, enable_nested_tensor=False,
|
| 89 |
+
)
|
| 90 |
+
self.head = nn.Linear(n_embd, action_dim)
|
| 91 |
+
|
| 92 |
+
# ── Global stream: 21x79 map -> 8 tokens ──────────────────
|
| 93 |
+
self.global_embedding = nn.Embedding(6000, 32)
|
| 94 |
+
self.global_cnn = nn.Sequential(
|
| 95 |
+
nn.Conv2d(32, 32, 5, stride=2, padding=2),
|
| 96 |
+
nn.GELU(),
|
| 97 |
+
nn.Conv2d(32, 64, 3, stride=2, padding=1),
|
| 98 |
+
nn.GELU(),
|
| 99 |
+
)
|
| 100 |
+
self.global_pool = nn.AdaptiveAvgPool2d((2, 4))
|
| 101 |
+
self.global_proj = nn.Linear(64, n_embd)
|
| 102 |
+
self.global_gate = nn.Parameter(
|
| 103 |
+
torch.tensor(cfg.global_gate_init)
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
# ── Auxiliary goal head (before gate) ──────────────────────
|
| 107 |
+
self.goal_head = nn.Sequential(
|
| 108 |
+
nn.Linear(n_embd, 128),
|
| 109 |
+
nn.GELU(),
|
| 110 |
+
nn.Linear(128, 2),
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
def forward(
|
| 114 |
+
self,
|
| 115 |
+
local_obs: Tensor,
|
| 116 |
+
global_obs: Tensor,
|
| 117 |
+
action_seq: Tensor,
|
| 118 |
+
t_discrete: int | Tensor,
|
| 119 |
+
) -> dict[str, Tensor]:
|
| 120 |
+
"""Forward pass producing action logits and goal prediction.
|
| 121 |
+
|
| 122 |
+
Args:
|
| 123 |
+
local_obs: Local glyph crop. Shape ``[B, 9, 9]``, int.
|
| 124 |
+
global_obs: Full glyph map. Shape ``[B, 21, 79]``, int.
|
| 125 |
+
action_seq: Noisy action sequence. Shape ``[B, seq_len]``, int.
|
| 126 |
+
t_discrete: Discrete timestep index (scalar int or ``[B]``).
|
| 127 |
+
|
| 128 |
+
Returns:
|
| 129 |
+
Dict with keys:
|
| 130 |
+
- ``"actions"``: ``[B, seq_len, action_dim]`` logits.
|
| 131 |
+
- ``"goal_pred"``: ``[B, 2]`` normalised staircase coords.
|
| 132 |
+
"""
|
| 133 |
+
B, Seq = action_seq.shape
|
| 134 |
+
device = local_obs.device
|
| 135 |
+
|
| 136 |
+
# Local stream -> [B, 1, n_embd]
|
| 137 |
+
x_local = self.embedding(local_obs) # [B, 9, 9, 64]
|
| 138 |
+
x_local = x_local.permute(0, 3, 1, 2) # [B, 64, 9, 9]
|
| 139 |
+
local_token = self.cnn(x_local).unsqueeze(1) # [B, 1, n_embd]
|
| 140 |
+
|
| 141 |
+
# Global stream -> [B, 8, n_embd]
|
| 142 |
+
x_global = self.global_embedding(global_obs) # [B, 21, 79, 32]
|
| 143 |
+
x_global = x_global.permute(0, 3, 1, 2) # [B, 32, 21, 79]
|
| 144 |
+
gf = self.global_cnn(x_global) # [B, 64, H', W']
|
| 145 |
+
gf = self.global_pool(gf) # [B, 64, 2, 4]
|
| 146 |
+
global_tokens = gf.permute(0, 2, 3, 1) # [B, 2, 4, 64]
|
| 147 |
+
global_tokens = global_tokens.reshape(
|
| 148 |
+
B, self.n_global_tokens, -1
|
| 149 |
+
) # [B, 8, 64]
|
| 150 |
+
global_tokens = self.global_proj(global_tokens) # [B, 8, n_embd]
|
| 151 |
+
|
| 152 |
+
# Aux goal head (before gate for direct gradient to CNN)
|
| 153 |
+
goal_pred = self.goal_head(
|
| 154 |
+
global_tokens.mean(dim=1)
|
| 155 |
+
) # [B, 2]
|
| 156 |
+
|
| 157 |
+
# Apply gate
|
| 158 |
+
gate = torch.sigmoid(self.global_gate)
|
| 159 |
+
global_tokens = global_tokens * gate # [B, 8, n_embd]
|
| 160 |
+
|
| 161 |
+
# Action stream -> [B, seq_len, n_embd]
|
| 162 |
+
positions = torch.arange(
|
| 163 |
+
Seq, device=device,
|
| 164 |
+
).unsqueeze(0).expand(B, -1) # [B, seq_len]
|
| 165 |
+
|
| 166 |
+
if isinstance(t_discrete, int):
|
| 167 |
+
t_tensor = torch.full(
|
| 168 |
+
(B,), t_discrete, dtype=torch.long, device=device,
|
| 169 |
+
)
|
| 170 |
+
else:
|
| 171 |
+
t_tensor = t_discrete.long().to(device)
|
| 172 |
+
|
| 173 |
+
seq_emb = (
|
| 174 |
+
self.action_emb(action_seq)
|
| 175 |
+
+ self.timestep_emb(t_tensor).unsqueeze(1)
|
| 176 |
+
+ self.pos_emb(positions)
|
| 177 |
+
) # [B, seq_len, n_embd]
|
| 178 |
+
|
| 179 |
+
# Concatenate: [local(1), global(8), actions(seq_len)]
|
| 180 |
+
x = torch.cat(
|
| 181 |
+
[local_token, global_tokens, seq_emb], dim=1,
|
| 182 |
+
) # [B, 1+8+seq_len, n_embd]
|
| 183 |
+
|
| 184 |
+
# Transformer
|
| 185 |
+
out = self.transformer(x) # [B, 1+8+seq_len, n_embd]
|
| 186 |
+
|
| 187 |
+
# Take last seq_len tokens for action predictions
|
| 188 |
+
n_prefix = 1 + self.n_global_tokens
|
| 189 |
+
action_logits = self.head(
|
| 190 |
+
out[:, n_prefix:, :]
|
| 191 |
+
) # [B, seq_len, action_dim]
|
| 192 |
+
|
| 193 |
+
return {"actions": action_logits, "goal_pred": goal_pred}
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
class LocalDiffusionPlanner(nn.Module):
|
| 197 |
+
"""Local-only ablation model (no global stream, no goal head).
|
| 198 |
+
|
| 199 |
+
Args:
|
| 200 |
+
cfg: Config namespace.
|
| 201 |
+
"""
|
| 202 |
+
|
| 203 |
+
def __init__(self, cfg: SimpleNamespace) -> None:
|
| 204 |
+
super().__init__()
|
| 205 |
+
action_dim = cfg.action_dim
|
| 206 |
+
n_embd = cfg.n_embd
|
| 207 |
+
seq_len = cfg.seq_len
|
| 208 |
+
|
| 209 |
+
self.embedding = nn.Embedding(6000, 64)
|
| 210 |
+
self.cnn = nn.Sequential(
|
| 211 |
+
nn.Conv2d(64, 32, 3, padding=1),
|
| 212 |
+
nn.GELU(),
|
| 213 |
+
nn.Conv2d(32, 64, 3, padding=1),
|
| 214 |
+
nn.GELU(),
|
| 215 |
+
nn.Flatten(),
|
| 216 |
+
nn.Linear(64 * 9 * 9, n_embd),
|
| 217 |
+
)
|
| 218 |
+
self.action_emb = nn.Embedding(action_dim + 2, n_embd)
|
| 219 |
+
self.timestep_emb = nn.Embedding(cfg.num_diffusion_steps, n_embd)
|
| 220 |
+
self.pos_emb = nn.Embedding(seq_len, n_embd)
|
| 221 |
+
encoder_layer = nn.TransformerEncoderLayer(
|
| 222 |
+
d_model=n_embd,
|
| 223 |
+
nhead=cfg.n_head,
|
| 224 |
+
dim_feedforward=n_embd * 4,
|
| 225 |
+
dropout=getattr(cfg, "dropout", 0.0),
|
| 226 |
+
activation="gelu",
|
| 227 |
+
norm_first=True,
|
| 228 |
+
batch_first=True,
|
| 229 |
+
)
|
| 230 |
+
self.transformer = nn.TransformerEncoder(
|
| 231 |
+
encoder_layer, num_layers=cfg.n_layer,
|
| 232 |
+
)
|
| 233 |
+
self.head = nn.Linear(n_embd, action_dim)
|
| 234 |
+
|
| 235 |
+
def forward(
|
| 236 |
+
self,
|
| 237 |
+
local_obs: Tensor,
|
| 238 |
+
global_obs: Tensor,
|
| 239 |
+
action_seq: Tensor,
|
| 240 |
+
t_discrete: int | Tensor,
|
| 241 |
+
) -> dict[str, Tensor]:
|
| 242 |
+
"""Forward pass (ignores global_obs).
|
| 243 |
+
|
| 244 |
+
Args:
|
| 245 |
+
local_obs: ``[B, 9, 9]`` int.
|
| 246 |
+
global_obs: ``[B, 21, 79]`` int (ignored).
|
| 247 |
+
action_seq: ``[B, seq_len]`` int.
|
| 248 |
+
t_discrete: Timestep index.
|
| 249 |
+
|
| 250 |
+
Returns:
|
| 251 |
+
Dict with ``"actions"`` key only (no goal_pred).
|
| 252 |
+
"""
|
| 253 |
+
B, Seq = action_seq.shape
|
| 254 |
+
device = local_obs.device
|
| 255 |
+
|
| 256 |
+
x_state = self.embedding(local_obs).permute(0, 3, 1, 2)
|
| 257 |
+
state_emb = self.cnn(x_state).unsqueeze(1) # [B, 1, n_embd]
|
| 258 |
+
|
| 259 |
+
positions = torch.arange(
|
| 260 |
+
Seq, device=device,
|
| 261 |
+
).unsqueeze(0).expand(B, -1)
|
| 262 |
+
|
| 263 |
+
if isinstance(t_discrete, int):
|
| 264 |
+
t_tensor = torch.full(
|
| 265 |
+
(B,), t_discrete, dtype=torch.long, device=device,
|
| 266 |
+
)
|
| 267 |
+
else:
|
| 268 |
+
t_tensor = t_discrete.long().to(device)
|
| 269 |
+
|
| 270 |
+
seq_emb = (
|
| 271 |
+
self.action_emb(action_seq)
|
| 272 |
+
+ self.timestep_emb(t_tensor).unsqueeze(1)
|
| 273 |
+
+ self.pos_emb(positions)
|
| 274 |
+
)
|
| 275 |
+
x = torch.cat([state_emb, seq_emb], dim=1)
|
| 276 |
+
out = self.transformer(x)
|
| 277 |
+
return {"actions": self.head(out[:, 1:, :])}
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
# ── Factory ───────────────────────────────────���──────────────────────
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def make_model(cfg: SimpleNamespace) -> nn.Module:
|
| 284 |
+
"""Instantiate the default MiniHack denoising model.
|
| 285 |
+
|
| 286 |
+
Args:
|
| 287 |
+
cfg: Config namespace.
|
| 288 |
+
|
| 289 |
+
Returns:
|
| 290 |
+
``LocalDiffusionPlannerWithGlobal`` instance.
|
| 291 |
+
"""
|
| 292 |
+
return LocalDiffusionPlannerWithGlobal(cfg)
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def _has_c_compiler() -> bool:
|
| 296 |
+
"""Check whether a C compiler is reachable by Triton.
|
| 297 |
+
|
| 298 |
+
Checks the ``CC`` env var (set by conda activation scripts),
|
| 299 |
+
then falls back to ``cc`` and ``gcc`` on ``PATH``.
|
| 300 |
+
"""
|
| 301 |
+
import os
|
| 302 |
+
cc_env = os.environ.get("CC")
|
| 303 |
+
if cc_env and shutil.which(cc_env):
|
| 304 |
+
return True
|
| 305 |
+
return shutil.which("cc") is not None or shutil.which("gcc") is not None
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def try_compile(model: nn.Module, cfg: SimpleNamespace) -> nn.Module:
|
| 309 |
+
"""Wrap *model* with ``torch.compile`` if enabled and a C compiler exists.
|
| 310 |
+
|
| 311 |
+
Falls back to the uncompiled model when ``torch.compile`` is
|
| 312 |
+
unavailable or Triton cannot find a C compiler (common on managed
|
| 313 |
+
GPU nodes that lack ``gcc``/``cc``).
|
| 314 |
+
|
| 315 |
+
Args:
|
| 316 |
+
model: The raw (uncompiled) model.
|
| 317 |
+
cfg: Config namespace; reads ``torch_compile`` bool.
|
| 318 |
+
|
| 319 |
+
Returns:
|
| 320 |
+
Compiled model, or *model* unchanged on fallback.
|
| 321 |
+
"""
|
| 322 |
+
if not getattr(cfg, "torch_compile", False):
|
| 323 |
+
return model
|
| 324 |
+
if not hasattr(torch, "compile"):
|
| 325 |
+
return model
|
| 326 |
+
if not _has_c_compiler():
|
| 327 |
+
logger.warning(
|
| 328 |
+
"torch.compile requested but no C compiler found "
|
| 329 |
+
"(CC env var, cc, gcc); falling back to eager mode"
|
| 330 |
+
)
|
| 331 |
+
return model
|
| 332 |
+
logger.info("Compiling model with torch.compile")
|
| 333 |
+
return torch.compile(model, mode="default") # type: ignore[return-value]
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
# ── EMA ──────────────────────────────────────────────────────────────
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
class ModelEMA:
|
| 340 |
+
"""Exponential moving average of model parameters.
|
| 341 |
+
|
| 342 |
+
Maintains a shadow copy of parameters updated as
|
| 343 |
+
``theta_ema <- decay * theta_ema + (1 - decay) * theta``.
|
| 344 |
+
|
| 345 |
+
Args:
|
| 346 |
+
model: Source model.
|
| 347 |
+
decay: EMA decay factor (default 0.999).
|
| 348 |
+
"""
|
| 349 |
+
|
| 350 |
+
def __init__(self, model: nn.Module, decay: float = 0.999) -> None:
|
| 351 |
+
self._decay = decay
|
| 352 |
+
self._shadow: dict[str, Tensor] = {}
|
| 353 |
+
for name, param in model.named_parameters():
|
| 354 |
+
self._shadow[name] = param.data.clone()
|
| 355 |
+
|
| 356 |
+
@torch.no_grad()
|
| 357 |
+
def update(self, model: nn.Module) -> None:
|
| 358 |
+
"""Update shadow parameters from *model*.
|
| 359 |
+
|
| 360 |
+
Args:
|
| 361 |
+
model: Source model whose parameters are blended in.
|
| 362 |
+
"""
|
| 363 |
+
for name, param in model.named_parameters():
|
| 364 |
+
self._shadow[name].mul_(self._decay).add_(
|
| 365 |
+
param.data, alpha=1.0 - self._decay,
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
def apply_to(self, model: nn.Module) -> None:
|
| 369 |
+
"""Copy shadow parameters into *model* (for inference).
|
| 370 |
+
|
| 371 |
+
Args:
|
| 372 |
+
model: Target model to overwrite.
|
| 373 |
+
"""
|
| 374 |
+
for name, param in model.named_parameters():
|
| 375 |
+
param.data.copy_(self._shadow[name])
|
| 376 |
+
|
| 377 |
+
def state_dict(self) -> dict[str, Tensor]:
|
| 378 |
+
"""Return shadow parameter dict for serialisation.
|
| 379 |
+
|
| 380 |
+
Returns:
|
| 381 |
+
Dict mapping parameter names to EMA tensors.
|
| 382 |
+
"""
|
| 383 |
+
return {k: v.clone() for k, v in self._shadow.items()}
|
| 384 |
+
|
| 385 |
+
def load_state_dict(self, sd: dict[str, Tensor]) -> None:
|
| 386 |
+
"""Restore shadow parameters from *sd*.
|
| 387 |
+
|
| 388 |
+
Args:
|
| 389 |
+
sd: State dict from a prior ``state_dict()`` call.
|
| 390 |
+
"""
|
| 391 |
+
for k, v in sd.items():
|
| 392 |
+
if k in self._shadow:
|
| 393 |
+
self._shadow[k].copy_(v)
|
| 394 |
+
|
| 395 |
+
def parameters(self):
|
| 396 |
+
"""Iterate over shadow parameter tensors.
|
| 397 |
+
|
| 398 |
+
Yields:
|
| 399 |
+
EMA parameter tensors.
|
| 400 |
+
"""
|
| 401 |
+
yield from self._shadow.values()
|
| 402 |
+
|
| 403 |
+
def make_eval_model(self, model: nn.Module) -> nn.Module:
|
| 404 |
+
"""Return a deep copy of *model* with EMA weights applied.
|
| 405 |
+
|
| 406 |
+
Args:
|
| 407 |
+
model: Template model (architecture).
|
| 408 |
+
|
| 409 |
+
Returns:
|
| 410 |
+
New model with shadow parameters.
|
| 411 |
+
"""
|
| 412 |
+
eval_model = copy.deepcopy(model)
|
| 413 |
+
self.apply_to(eval_model)
|
| 414 |
+
eval_model.eval()
|
| 415 |
+
return eval_model
|
src/planners/__init__.py
ADDED
|
File without changes
|
src/planners/baselines.py
ADDED
|
@@ -0,0 +1,1247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SB3 + Decision Transformer baselines for the ReMDM diffusion planner.
|
| 2 |
+
|
| 3 |
+
This module wraps standard discrete-action RL baselines (PPO, A2C, DQN,
|
| 4 |
+
recurrent PPO) plus two imitation baselines (Behavioural Cloning and
|
| 5 |
+
Decision Transformer) into the project's unified config + dispatch
|
| 6 |
+
surface so they can be compared head-to-head against the DAgger /
|
| 7 |
+
offline-BC diffusion planner on the same MiniHack environments.
|
| 8 |
+
|
| 9 |
+
Entry point: :func:`run_baselines`.
|
| 10 |
+
|
| 11 |
+
Hyperparameters live in ``configs/defaults.yaml`` under the
|
| 12 |
+
``baselines_*`` namespace; the unified env-step training budget
|
| 13 |
+
(``cfg.total_timesteps``) is shared with DAgger and offline BC.
|
| 14 |
+
|
| 15 |
+
W&B logging routes through the project's :class:`Logger` (with the W&B
|
| 16 |
+
project temporarily swapped to ``cfg.baselines_wandb_project``); SB3's
|
| 17 |
+
standard ``WandbCallback`` piggybacks on the active run and syncs its
|
| 18 |
+
tensorboard scalars automatically. No file in this module calls
|
| 19 |
+
``wandb.log(...)`` directly.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
from __future__ import annotations
|
| 23 |
+
|
| 24 |
+
import logging
|
| 25 |
+
import os
|
| 26 |
+
import random
|
| 27 |
+
from pathlib import Path
|
| 28 |
+
from types import SimpleNamespace
|
| 29 |
+
from typing import Any
|
| 30 |
+
|
| 31 |
+
import gymnasium as gym
|
| 32 |
+
import numpy as np
|
| 33 |
+
import orjson
|
| 34 |
+
import torch
|
| 35 |
+
import torch.nn as nn
|
| 36 |
+
from sb3_contrib import RecurrentPPO
|
| 37 |
+
from stable_baselines3 import A2C, DQN, PPO
|
| 38 |
+
from stable_baselines3.common.callbacks import CallbackList, EvalCallback
|
| 39 |
+
from stable_baselines3.common.monitor import Monitor
|
| 40 |
+
from stable_baselines3.common.policies import ActorCriticPolicy
|
| 41 |
+
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
|
| 42 |
+
from stable_baselines3.common.vec_env import SubprocVecEnv
|
| 43 |
+
from torch.utils.data import DataLoader, Dataset
|
| 44 |
+
from wandb.integration.sb3 import WandbCallback
|
| 45 |
+
|
| 46 |
+
from src.envs.minihack_env import (
|
| 47 |
+
AdvancedObservationEnv,
|
| 48 |
+
collect_oracle_trajectory,
|
| 49 |
+
)
|
| 50 |
+
from src.planners.logging import Logger
|
| 51 |
+
|
| 52 |
+
logger = logging.getLogger(__name__)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
SB3_RL_ALGOS: tuple[str, ...] = ("ppo", "a2c", "dqn", "ppo-rnn")
|
| 56 |
+
IMITATION_ALGOS: tuple[str, ...] = ("bc", "dt")
|
| 57 |
+
ALL_BASELINE_ALGOS: tuple[str, ...] = SB3_RL_ALGOS + IMITATION_ALGOS
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# =============================================================================
|
| 61 |
+
# Observation wrapper for SB3 dict-policies
|
| 62 |
+
# =============================================================================
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class _SB3MiniHackWrapper(gym.Wrapper):
|
| 66 |
+
"""Reshape ``AdvancedObservationEnv`` tuple obs into an SB3 dict obs.
|
| 67 |
+
|
| 68 |
+
The underlying env returns ``(local_crop, global_map)`` with shapes
|
| 69 |
+
``(crop, crop)`` and ``(map_h, map_w)``; SB3's ``MultiInputPolicy``
|
| 70 |
+
needs a ``Dict`` space with explicit channel dims. Also remaps
|
| 71 |
+
``info["won"]`` -> ``info["is_success"]`` so SB3's success tracking
|
| 72 |
+
reports our win rate.
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
def __init__(self, env: AdvancedObservationEnv) -> None:
|
| 76 |
+
super().__init__(env)
|
| 77 |
+
local_h, local_w = env.observation_space.shape
|
| 78 |
+
cfg = env._cfg # AdvancedObservationEnv stores cfg here
|
| 79 |
+
self.observation_space = gym.spaces.Dict(
|
| 80 |
+
{
|
| 81 |
+
"local": gym.spaces.Box(
|
| 82 |
+
low=0, high=6000, shape=(1, local_h, local_w), dtype=np.int16,
|
| 83 |
+
),
|
| 84 |
+
"global": gym.spaces.Box(
|
| 85 |
+
low=0, high=6000, shape=(1, cfg.map_h, cfg.map_w), dtype=np.int16,
|
| 86 |
+
),
|
| 87 |
+
}
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
def reset(self, **kwargs: Any) -> tuple[dict[str, np.ndarray], dict]:
|
| 91 |
+
(local, glob), info = self.env.reset(**kwargs)
|
| 92 |
+
return self._pack(local, glob), info
|
| 93 |
+
|
| 94 |
+
def step(
|
| 95 |
+
self, action: int,
|
| 96 |
+
) -> tuple[dict[str, np.ndarray], float, bool, bool, dict]:
|
| 97 |
+
(local, glob), reward, terminated, truncated, info = self.env.step(action)
|
| 98 |
+
if "won" in info:
|
| 99 |
+
info["is_success"] = info["won"]
|
| 100 |
+
return self._pack(local, glob), reward, terminated, truncated, info
|
| 101 |
+
|
| 102 |
+
@staticmethod
|
| 103 |
+
def _pack(
|
| 104 |
+
local: np.ndarray, glob: np.ndarray,
|
| 105 |
+
) -> dict[str, np.ndarray]:
|
| 106 |
+
return {
|
| 107 |
+
"local": np.expand_dims(local, axis=0), # [1, crop, crop]
|
| 108 |
+
"global": np.expand_dims(glob, axis=0), # [1, H, W]
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
# =============================================================================
|
| 113 |
+
# CNN feature extractor (shared by SB3 RL + BC)
|
| 114 |
+
# =============================================================================
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class _MiniHackCNN(BaseFeaturesExtractor):
|
| 118 |
+
"""Dual-stream CNN for the SB3 dict observation.
|
| 119 |
+
|
| 120 |
+
Local stream: ``Conv(1->16, 3) -> Conv(16->32, 3)``.
|
| 121 |
+
Global stream: ``Conv(1->16, 5, stride 2) -> Conv(16->32, 3, stride 2)``.
|
| 122 |
+
Both streams are flattened and concatenated, then projected to
|
| 123 |
+
``features_dim`` via a single linear + ReLU.
|
| 124 |
+
"""
|
| 125 |
+
|
| 126 |
+
def __init__(
|
| 127 |
+
self, observation_space: gym.spaces.Dict, features_dim: int = 256,
|
| 128 |
+
) -> None:
|
| 129 |
+
super().__init__(observation_space, features_dim)
|
| 130 |
+
self.local_cnn = nn.Sequential(
|
| 131 |
+
nn.Conv2d(1, 16, kernel_size=3, padding=1),
|
| 132 |
+
nn.ReLU(),
|
| 133 |
+
nn.Conv2d(16, 32, kernel_size=3, padding=1),
|
| 134 |
+
nn.ReLU(),
|
| 135 |
+
nn.Flatten(),
|
| 136 |
+
)
|
| 137 |
+
self.global_cnn = nn.Sequential(
|
| 138 |
+
nn.Conv2d(1, 16, kernel_size=5, stride=2),
|
| 139 |
+
nn.ReLU(),
|
| 140 |
+
nn.Conv2d(16, 32, kernel_size=3, stride=2),
|
| 141 |
+
nn.ReLU(),
|
| 142 |
+
nn.Flatten(),
|
| 143 |
+
)
|
| 144 |
+
with torch.no_grad():
|
| 145 |
+
dummy_loc = torch.zeros(1, *observation_space["local"].shape)
|
| 146 |
+
dummy_glob = torch.zeros(1, *observation_space["global"].shape)
|
| 147 |
+
n_flatten = (
|
| 148 |
+
self.local_cnn(dummy_loc).shape[1]
|
| 149 |
+
+ self.global_cnn(dummy_glob).shape[1]
|
| 150 |
+
)
|
| 151 |
+
self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU())
|
| 152 |
+
|
| 153 |
+
def forward(
|
| 154 |
+
self, observations: dict[str, torch.Tensor],
|
| 155 |
+
) -> torch.Tensor:
|
| 156 |
+
loc = self.local_cnn(observations["local"].float()) # [B, F_l]
|
| 157 |
+
glob = self.global_cnn(observations["global"].float()) # [B, F_g]
|
| 158 |
+
return self.linear(torch.cat([loc, glob], dim=1))
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
# =============================================================================
|
| 162 |
+
# Decision Transformer
|
| 163 |
+
# =============================================================================
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class _MiniHackStateEncoder(nn.Module):
|
| 167 |
+
"""CNN encoder mapping a (local, global) obs pair to a state embedding."""
|
| 168 |
+
|
| 169 |
+
def __init__(
|
| 170 |
+
self,
|
| 171 |
+
embed_dim: int = 128,
|
| 172 |
+
crop_h: int = 9,
|
| 173 |
+
crop_w: int = 9,
|
| 174 |
+
map_h: int = 21,
|
| 175 |
+
map_w: int = 79,
|
| 176 |
+
) -> None:
|
| 177 |
+
super().__init__()
|
| 178 |
+
self.local_cnn = nn.Sequential(
|
| 179 |
+
nn.Conv2d(1, 16, kernel_size=3, padding=1),
|
| 180 |
+
nn.ReLU(),
|
| 181 |
+
nn.Conv2d(16, 32, kernel_size=3, padding=1),
|
| 182 |
+
nn.ReLU(),
|
| 183 |
+
nn.Flatten(),
|
| 184 |
+
)
|
| 185 |
+
self.global_cnn = nn.Sequential(
|
| 186 |
+
nn.Conv2d(1, 16, kernel_size=5, stride=2),
|
| 187 |
+
nn.ReLU(),
|
| 188 |
+
nn.Conv2d(16, 32, kernel_size=3, stride=2),
|
| 189 |
+
nn.ReLU(),
|
| 190 |
+
nn.Flatten(),
|
| 191 |
+
)
|
| 192 |
+
with torch.no_grad():
|
| 193 |
+
dummy_loc = torch.zeros(1, 1, crop_h, crop_w)
|
| 194 |
+
dummy_glob = torch.zeros(1, 1, map_h, map_w)
|
| 195 |
+
local_flat = self.local_cnn(dummy_loc).shape[1]
|
| 196 |
+
global_flat = self.global_cnn(dummy_glob).shape[1]
|
| 197 |
+
self.proj = nn.Linear(local_flat + global_flat, embed_dim)
|
| 198 |
+
|
| 199 |
+
def forward(
|
| 200 |
+
self, local_obs: torch.Tensor, global_obs: torch.Tensor,
|
| 201 |
+
) -> torch.Tensor:
|
| 202 |
+
# Accepts (B, T, 1, H, W) or (B, 1, H, W).
|
| 203 |
+
if local_obs.dim() == 5:
|
| 204 |
+
B, T = local_obs.shape[:2]
|
| 205 |
+
local_obs = local_obs.view(B * T, *local_obs.shape[2:])
|
| 206 |
+
global_obs = global_obs.view(B * T, *global_obs.shape[2:])
|
| 207 |
+
reshape = True
|
| 208 |
+
else:
|
| 209 |
+
B, T = local_obs.shape[0], 1
|
| 210 |
+
reshape = False
|
| 211 |
+
|
| 212 |
+
loc_feat = self.local_cnn(local_obs.float()) # [B*T, F_l]
|
| 213 |
+
glob_feat = self.global_cnn(global_obs.float()) # [B*T, F_g]
|
| 214 |
+
out = self.proj(torch.cat([loc_feat, glob_feat], dim=-1)) # [B*T, D]
|
| 215 |
+
if reshape:
|
| 216 |
+
out = out.view(B, T, -1)
|
| 217 |
+
return out
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
class _DecisionTransformer(nn.Module):
|
| 221 |
+
"""Causal Decision Transformer over interleaved (R, s, a) tokens."""
|
| 222 |
+
|
| 223 |
+
def __init__(
|
| 224 |
+
self,
|
| 225 |
+
n_actions: int,
|
| 226 |
+
embed_dim: int = 128,
|
| 227 |
+
n_heads: int = 4,
|
| 228 |
+
n_layers: int = 3,
|
| 229 |
+
context_len: int = 30,
|
| 230 |
+
max_ep_len: int = 500,
|
| 231 |
+
dropout: float = 0.1,
|
| 232 |
+
crop_h: int = 9,
|
| 233 |
+
crop_w: int = 9,
|
| 234 |
+
map_h: int = 21,
|
| 235 |
+
map_w: int = 79,
|
| 236 |
+
) -> None:
|
| 237 |
+
super().__init__()
|
| 238 |
+
self.embed_dim = embed_dim
|
| 239 |
+
self.context_len = context_len
|
| 240 |
+
self.n_actions = n_actions
|
| 241 |
+
self.max_ep_len = max_ep_len
|
| 242 |
+
|
| 243 |
+
self.state_encoder = _MiniHackStateEncoder(
|
| 244 |
+
embed_dim, crop_h, crop_w, map_h, map_w,
|
| 245 |
+
)
|
| 246 |
+
self.action_embed = nn.Embedding(n_actions + 1, embed_dim) # +1 for pad
|
| 247 |
+
self.return_embed = nn.Linear(1, embed_dim)
|
| 248 |
+
self.pos_embed = nn.Embedding(max_ep_len, embed_dim)
|
| 249 |
+
self.token_type_embed = nn.Embedding(3, embed_dim)
|
| 250 |
+
self.embed_ln = nn.LayerNorm(embed_dim)
|
| 251 |
+
self.dropout = nn.Dropout(dropout)
|
| 252 |
+
|
| 253 |
+
encoder_layer = nn.TransformerEncoderLayer(
|
| 254 |
+
d_model=embed_dim,
|
| 255 |
+
nhead=n_heads,
|
| 256 |
+
dim_feedforward=embed_dim * 4,
|
| 257 |
+
dropout=dropout,
|
| 258 |
+
activation="gelu",
|
| 259 |
+
batch_first=True,
|
| 260 |
+
)
|
| 261 |
+
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
|
| 262 |
+
self.action_head = nn.Linear(embed_dim, n_actions)
|
| 263 |
+
|
| 264 |
+
self.apply(self._init_weights)
|
| 265 |
+
|
| 266 |
+
@staticmethod
|
| 267 |
+
def _init_weights(module: nn.Module) -> None:
|
| 268 |
+
if isinstance(module, nn.Linear):
|
| 269 |
+
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 270 |
+
if module.bias is not None:
|
| 271 |
+
nn.init.zeros_(module.bias)
|
| 272 |
+
elif isinstance(module, nn.Embedding):
|
| 273 |
+
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 274 |
+
elif isinstance(module, nn.LayerNorm):
|
| 275 |
+
nn.init.ones_(module.weight)
|
| 276 |
+
nn.init.zeros_(module.bias)
|
| 277 |
+
|
| 278 |
+
def forward(
|
| 279 |
+
self,
|
| 280 |
+
returns_to_go: torch.Tensor, # [B, T, 1]
|
| 281 |
+
local_obs: torch.Tensor, # [B, T, 1, H_l, W_l]
|
| 282 |
+
global_obs: torch.Tensor, # [B, T, 1, H_g, W_g]
|
| 283 |
+
actions: torch.Tensor, # [B, T]
|
| 284 |
+
timesteps: torch.Tensor, # [B, T]
|
| 285 |
+
attention_mask: torch.Tensor | None = None, # [B, T]
|
| 286 |
+
) -> torch.Tensor:
|
| 287 |
+
B, T = returns_to_go.shape[:2]
|
| 288 |
+
device = returns_to_go.device
|
| 289 |
+
|
| 290 |
+
rtg_embed = self.return_embed(returns_to_go) # [B, T, D]
|
| 291 |
+
state_embed = self.state_encoder(local_obs, global_obs) # [B, T, D]
|
| 292 |
+
action_embed = self.action_embed(actions) # [B, T, D]
|
| 293 |
+
|
| 294 |
+
pos_embed = self.pos_embed(timesteps) # [B, T, D]
|
| 295 |
+
rtg_embed = rtg_embed + pos_embed + self.token_type_embed.weight[0]
|
| 296 |
+
state_embed = state_embed + pos_embed + self.token_type_embed.weight[1]
|
| 297 |
+
action_embed = action_embed + pos_embed + self.token_type_embed.weight[2]
|
| 298 |
+
|
| 299 |
+
# Interleave (R_0, s_0, a_0, R_1, s_1, a_1, ...) -> [B, 3T, D]
|
| 300 |
+
stacked = torch.stack([rtg_embed, state_embed, action_embed], dim=2)
|
| 301 |
+
stacked = stacked.view(B, 3 * T, self.embed_dim)
|
| 302 |
+
stacked = self.dropout(self.embed_ln(stacked))
|
| 303 |
+
|
| 304 |
+
seq_len = 3 * T
|
| 305 |
+
causal_mask = torch.triu(
|
| 306 |
+
torch.ones(seq_len, seq_len, device=device), diagonal=1,
|
| 307 |
+
).bool()
|
| 308 |
+
|
| 309 |
+
key_padding_mask = None
|
| 310 |
+
if attention_mask is not None:
|
| 311 |
+
expanded = attention_mask.unsqueeze(-1).repeat(1, 1, 3).view(B, 3 * T)
|
| 312 |
+
key_padding_mask = expanded == 0
|
| 313 |
+
|
| 314 |
+
hidden = self.transformer(
|
| 315 |
+
stacked, mask=causal_mask, src_key_padding_mask=key_padding_mask,
|
| 316 |
+
)
|
| 317 |
+
# State token positions are 1, 4, 7, ... -> stride 3.
|
| 318 |
+
state_hidden = hidden[:, 1::3, :] # [B, T, D]
|
| 319 |
+
return self.action_head(state_hidden) # [B, T, A]
|
| 320 |
+
|
| 321 |
+
@torch.no_grad()
|
| 322 |
+
def get_action(
|
| 323 |
+
self,
|
| 324 |
+
returns_to_go: torch.Tensor,
|
| 325 |
+
local_obs: torch.Tensor,
|
| 326 |
+
global_obs: torch.Tensor,
|
| 327 |
+
actions: torch.Tensor,
|
| 328 |
+
timesteps: torch.Tensor,
|
| 329 |
+
) -> torch.Tensor:
|
| 330 |
+
self.eval()
|
| 331 |
+
logits = self.forward(
|
| 332 |
+
returns_to_go, local_obs, global_obs, actions, timesteps,
|
| 333 |
+
)
|
| 334 |
+
return logits[:, -1, :].argmax(dim=-1)
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
class _DTDataset(Dataset):
|
| 338 |
+
"""Sliding-window dataset over Decision Transformer trajectories."""
|
| 339 |
+
|
| 340 |
+
def __init__(
|
| 341 |
+
self,
|
| 342 |
+
trajectories: list[dict[str, np.ndarray]],
|
| 343 |
+
context_len: int,
|
| 344 |
+
max_ep_len: int,
|
| 345 |
+
n_actions: int,
|
| 346 |
+
) -> None:
|
| 347 |
+
self.trajectories = trajectories
|
| 348 |
+
self.context_len = context_len
|
| 349 |
+
self.max_ep_len = max_ep_len
|
| 350 |
+
self.n_actions = n_actions
|
| 351 |
+
self.indices: list[tuple[int, int]] = [
|
| 352 |
+
(traj_idx, start)
|
| 353 |
+
for traj_idx, traj in enumerate(trajectories)
|
| 354 |
+
for start in range(len(traj["actions"]))
|
| 355 |
+
]
|
| 356 |
+
|
| 357 |
+
def __len__(self) -> int:
|
| 358 |
+
return len(self.indices)
|
| 359 |
+
|
| 360 |
+
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
| 361 |
+
traj_idx, start = self.indices[idx]
|
| 362 |
+
traj = self.trajectories[traj_idx]
|
| 363 |
+
traj_len = len(traj["actions"])
|
| 364 |
+
end = min(start + self.context_len, traj_len)
|
| 365 |
+
actual_len = end - start
|
| 366 |
+
|
| 367 |
+
local = traj["local"][start:end].copy()
|
| 368 |
+
glob = traj["global"][start:end].copy()
|
| 369 |
+
actions = traj["actions"][start:end].copy()
|
| 370 |
+
rtg = traj["returns_to_go"][start:end].copy()
|
| 371 |
+
timesteps = np.arange(start, end)
|
| 372 |
+
|
| 373 |
+
# Clamp to valid embedding ranges.
|
| 374 |
+
timesteps = np.clip(timesteps, 0, self.max_ep_len - 1)
|
| 375 |
+
actions = np.clip(actions, 0, self.n_actions - 1)
|
| 376 |
+
|
| 377 |
+
pad_len = self.context_len - actual_len
|
| 378 |
+
if pad_len > 0:
|
| 379 |
+
local = np.pad(
|
| 380 |
+
local, ((0, pad_len), (0, 0), (0, 0), (0, 0)), mode="constant",
|
| 381 |
+
)
|
| 382 |
+
glob = np.pad(
|
| 383 |
+
glob, ((0, pad_len), (0, 0), (0, 0), (0, 0)), mode="constant",
|
| 384 |
+
)
|
| 385 |
+
actions = np.pad(actions, (0, pad_len), mode="constant")
|
| 386 |
+
rtg = np.pad(rtg, (0, pad_len), mode="constant")
|
| 387 |
+
timesteps = np.pad(timesteps, (0, pad_len), mode="constant")
|
| 388 |
+
|
| 389 |
+
attention_mask = np.zeros(self.context_len, dtype=np.float32)
|
| 390 |
+
attention_mask[:actual_len] = 1.0
|
| 391 |
+
|
| 392 |
+
return {
|
| 393 |
+
"local": torch.tensor(local, dtype=torch.float32),
|
| 394 |
+
"global": torch.tensor(glob, dtype=torch.float32),
|
| 395 |
+
"actions": torch.tensor(actions, dtype=torch.long),
|
| 396 |
+
"returns_to_go": torch.tensor(rtg, dtype=torch.float32).unsqueeze(-1),
|
| 397 |
+
"timesteps": torch.tensor(timesteps, dtype=torch.long),
|
| 398 |
+
"attention_mask": torch.tensor(attention_mask, dtype=torch.float32),
|
| 399 |
+
}
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
# =============================================================================
|
| 403 |
+
# SB3 callbacks + env factory
|
| 404 |
+
# =============================================================================
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
class _PrefixedEvalCallback(EvalCallback):
|
| 408 |
+
"""``EvalCallback`` that records mean_reward / avg_steps / win_rate
|
| 409 |
+
under a unique per-environment prefix.
|
| 410 |
+
|
| 411 |
+
SB3 truncates metric names at 36 chars, which collides on long
|
| 412 |
+
MiniHack env IDs; the prefix lets us strip ``MiniHack-`` / ``-v0``
|
| 413 |
+
cleanly.
|
| 414 |
+
"""
|
| 415 |
+
|
| 416 |
+
def __init__(
|
| 417 |
+
self, eval_env: SubprocVecEnv, prefix: str, **kwargs: Any,
|
| 418 |
+
) -> None:
|
| 419 |
+
super().__init__(eval_env, **kwargs)
|
| 420 |
+
self.prefix = prefix
|
| 421 |
+
|
| 422 |
+
def _on_step(self) -> bool:
|
| 423 |
+
cont = super()._on_step()
|
| 424 |
+
if self.evaluations_results:
|
| 425 |
+
self.logger.record(
|
| 426 |
+
f"{self.prefix}/mean_reward", float(np.mean(self.evaluations_results[-1])),
|
| 427 |
+
)
|
| 428 |
+
self.logger.record(
|
| 429 |
+
f"{self.prefix}/avg_steps", float(np.mean(self.evaluations_length[-1])),
|
| 430 |
+
)
|
| 431 |
+
if self.evaluations_successes:
|
| 432 |
+
self.logger.record(
|
| 433 |
+
f"{self.prefix}/win_rate",
|
| 434 |
+
float(np.mean(self.evaluations_successes[-1])),
|
| 435 |
+
)
|
| 436 |
+
return cont
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
def _make_sb3_env_fn(env_id: str, cfg: SimpleNamespace, log_dir: str):
|
| 440 |
+
"""Return a picklable thunk that builds one wrapped+monitored env."""
|
| 441 |
+
|
| 442 |
+
def _init() -> Monitor:
|
| 443 |
+
os.makedirs(log_dir, exist_ok=True)
|
| 444 |
+
env = AdvancedObservationEnv(env_id, des_file=None, cfg=cfg)
|
| 445 |
+
env = _SB3MiniHackWrapper(env)
|
| 446 |
+
return Monitor(env, log_dir)
|
| 447 |
+
|
| 448 |
+
return _init
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
# =============================================================================
|
| 452 |
+
# Helpers
|
| 453 |
+
# =============================================================================
|
| 454 |
+
|
| 455 |
+
|
| 456 |
+
def _short(env_id: str) -> str:
|
| 457 |
+
return env_id.replace("MiniHack-", "").replace("-v0", "")
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
def _eval_episodes_per_env(cfg: SimpleNamespace) -> int:
|
| 461 |
+
override = getattr(cfg, "baselines_eval_episodes_per_env", None)
|
| 462 |
+
if override is not None:
|
| 463 |
+
return int(override)
|
| 464 |
+
return int(cfg.eval_episodes_per_env)
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
def _seed_everything(seed: int) -> None:
|
| 468 |
+
random.seed(seed)
|
| 469 |
+
np.random.seed(seed)
|
| 470 |
+
torch.manual_seed(seed)
|
| 471 |
+
if torch.cuda.is_available():
|
| 472 |
+
torch.cuda.manual_seed(seed)
|
| 473 |
+
torch.cuda.manual_seed_all(seed)
|
| 474 |
+
|
| 475 |
+
|
| 476 |
+
def _resolve_output_dir(cfg: SimpleNamespace, override: str | None) -> Path:
|
| 477 |
+
if override:
|
| 478 |
+
out = Path(override)
|
| 479 |
+
else:
|
| 480 |
+
out = Path(cfg.baselines_output_dir)
|
| 481 |
+
out.mkdir(parents=True, exist_ok=True)
|
| 482 |
+
return out
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
def _init_baseline_logger(
|
| 486 |
+
cfg: SimpleNamespace, run_name: str,
|
| 487 |
+
) -> Logger:
|
| 488 |
+
"""Init the project Logger with W&B project swapped to baselines.
|
| 489 |
+
|
| 490 |
+
Mutates ``cfg.wandb_project`` / ``cfg.wandb_run_name`` /
|
| 491 |
+
``cfg.wandb_resume_id`` for the duration of the call so the existing
|
| 492 |
+
Logger constructor picks them up. We deliberately do not restore the
|
| 493 |
+
originals — each baseline seed reuses this helper, and main.py exits
|
| 494 |
+
after ``run_baselines`` returns.
|
| 495 |
+
"""
|
| 496 |
+
|
| 497 |
+
project_override = getattr(cfg, "baselines_wandb_project", None)
|
| 498 |
+
if project_override:
|
| 499 |
+
cfg.wandb_project = project_override
|
| 500 |
+
cfg.wandb_run_name = run_name
|
| 501 |
+
cfg.wandb_resume_id = None
|
| 502 |
+
return Logger(cfg)
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
# =============================================================================
|
| 506 |
+
# BC training
|
| 507 |
+
# =============================================================================
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
def _collect_bc_dataset(
|
| 511 |
+
cfg: SimpleNamespace,
|
| 512 |
+
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 513 |
+
"""Roll out the BFS oracle on each ID env and stack flat (s, a) pairs."""
|
| 514 |
+
|
| 515 |
+
n_per_env = int(cfg.baselines_bc_oracle_episodes_per_env)
|
| 516 |
+
locals_, globals_, actions_ = [], [], []
|
| 517 |
+
for env_id in cfg.id_envs:
|
| 518 |
+
for traj_seed in range(n_per_env):
|
| 519 |
+
traj = collect_oracle_trajectory(env_id, traj_seed, cfg)
|
| 520 |
+
if traj is None:
|
| 521 |
+
continue
|
| 522 |
+
# (T, H, W) -> (T, 1, H, W)
|
| 523 |
+
locals_.append(np.expand_dims(traj["local"], axis=1))
|
| 524 |
+
globals_.append(np.expand_dims(traj["global"], axis=1))
|
| 525 |
+
actions_.append(traj["actions"])
|
| 526 |
+
if not actions_:
|
| 527 |
+
raise RuntimeError("BC oracle collection produced zero trajectories")
|
| 528 |
+
return (
|
| 529 |
+
np.concatenate(locals_, axis=0),
|
| 530 |
+
np.concatenate(globals_, axis=0),
|
| 531 |
+
np.concatenate(actions_, axis=0),
|
| 532 |
+
)
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
class _BCDataset(Dataset):
|
| 536 |
+
def __init__(
|
| 537 |
+
self, loc: np.ndarray, glob: np.ndarray, acts: np.ndarray,
|
| 538 |
+
) -> None:
|
| 539 |
+
self.loc = torch.tensor(loc, dtype=torch.float32)
|
| 540 |
+
self.glob = torch.tensor(glob, dtype=torch.float32)
|
| 541 |
+
self.acts = torch.tensor(acts, dtype=torch.int64)
|
| 542 |
+
|
| 543 |
+
def __len__(self) -> int:
|
| 544 |
+
return len(self.acts)
|
| 545 |
+
|
| 546 |
+
def __getitem__(
|
| 547 |
+
self, idx: int,
|
| 548 |
+
) -> dict[str, dict[str, torch.Tensor] | torch.Tensor]:
|
| 549 |
+
return {
|
| 550 |
+
"obs": {"local": self.loc[idx], "global": self.glob[idx]},
|
| 551 |
+
"acts": self.acts[idx],
|
| 552 |
+
}
|
| 553 |
+
|
| 554 |
+
|
| 555 |
+
def _eval_sb3_policy_manually(
|
| 556 |
+
policy: ActorCriticPolicy,
|
| 557 |
+
env_id: str,
|
| 558 |
+
cfg: SimpleNamespace,
|
| 559 |
+
log_dir: str,
|
| 560 |
+
n_episodes: int,
|
| 561 |
+
) -> tuple[float, float]:
|
| 562 |
+
"""Run ``policy.predict`` on a Monitor-wrapped vec env and return
|
| 563 |
+
(win_rate, avg_steps)."""
|
| 564 |
+
|
| 565 |
+
eval_env = SubprocVecEnv([_make_sb3_env_fn(env_id, cfg, log_dir)])
|
| 566 |
+
try:
|
| 567 |
+
obs = eval_env.reset()
|
| 568 |
+
wins = 0
|
| 569 |
+
total_steps = 0
|
| 570 |
+
completed = 0
|
| 571 |
+
while completed < n_episodes:
|
| 572 |
+
action, _ = policy.predict(obs, deterministic=True)
|
| 573 |
+
obs, _rewards, dones, infos = eval_env.step(action)
|
| 574 |
+
if dones[0]:
|
| 575 |
+
completed += 1
|
| 576 |
+
if infos[0].get("won", False):
|
| 577 |
+
wins += 1
|
| 578 |
+
total_steps += infos[0]["episode"]["l"]
|
| 579 |
+
finally:
|
| 580 |
+
eval_env.close()
|
| 581 |
+
return wins / n_episodes, total_steps / n_episodes
|
| 582 |
+
|
| 583 |
+
|
| 584 |
+
def _train_bc(
|
| 585 |
+
cfg: SimpleNamespace,
|
| 586 |
+
train_env: SubprocVecEnv,
|
| 587 |
+
log: Logger,
|
| 588 |
+
log_dir: str,
|
| 589 |
+
seed: int,
|
| 590 |
+
) -> tuple[ActorCriticPolicy, dict[str, float]]:
|
| 591 |
+
"""Train a Behavioural Cloning baseline. Returns (policy, seed_metrics)."""
|
| 592 |
+
|
| 593 |
+
device = torch.device(cfg.device)
|
| 594 |
+
n_eval = _eval_episodes_per_env(cfg)
|
| 595 |
+
|
| 596 |
+
logger.info("Collecting oracle demonstrations for BC...")
|
| 597 |
+
loc_arr, glob_arr, acts_arr = _collect_bc_dataset(cfg)
|
| 598 |
+
logger.info("BC dataset: %d transitions", len(acts_arr))
|
| 599 |
+
|
| 600 |
+
bc_loader = DataLoader(
|
| 601 |
+
_BCDataset(loc_arr, glob_arr, acts_arr),
|
| 602 |
+
batch_size=int(cfg.baselines_bc_batch_size),
|
| 603 |
+
shuffle=True,
|
| 604 |
+
num_workers=4,
|
| 605 |
+
pin_memory=torch.cuda.is_available(),
|
| 606 |
+
)
|
| 607 |
+
|
| 608 |
+
lr = float(cfg.baselines_bc_lr)
|
| 609 |
+
policy = ActorCriticPolicy(
|
| 610 |
+
observation_space=train_env.observation_space,
|
| 611 |
+
action_space=train_env.action_space,
|
| 612 |
+
lr_schedule=lambda _progress: lr,
|
| 613 |
+
features_extractor_class=_MiniHackCNN,
|
| 614 |
+
features_extractor_kwargs={"features_dim": 256},
|
| 615 |
+
).to(device)
|
| 616 |
+
|
| 617 |
+
n_epochs = int(cfg.baselines_bc_epochs)
|
| 618 |
+
optimizer = torch.optim.AdamW(
|
| 619 |
+
policy.parameters(),
|
| 620 |
+
lr=lr,
|
| 621 |
+
weight_decay=float(cfg.weight_decay),
|
| 622 |
+
)
|
| 623 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
| 624 |
+
optimizer, T_max=n_epochs,
|
| 625 |
+
)
|
| 626 |
+
policy.train()
|
| 627 |
+
for epoch in range(n_epochs):
|
| 628 |
+
total_loss = 0.0
|
| 629 |
+
for batch in bc_loader:
|
| 630 |
+
obs = {k: v.to(policy.device) for k, v in batch["obs"].items()}
|
| 631 |
+
acts = batch["acts"].to(policy.device)
|
| 632 |
+
_values, log_prob, _entropy = policy.evaluate_actions(obs, acts)
|
| 633 |
+
loss = -log_prob.mean()
|
| 634 |
+
optimizer.zero_grad()
|
| 635 |
+
loss.backward()
|
| 636 |
+
torch.nn.utils.clip_grad_norm_(policy.parameters(), 1.0)
|
| 637 |
+
optimizer.step()
|
| 638 |
+
total_loss += loss.item()
|
| 639 |
+
scheduler.step()
|
| 640 |
+
avg_loss = total_loss / max(1, len(bc_loader))
|
| 641 |
+
current_lr = scheduler.get_last_lr()[0]
|
| 642 |
+
log.log(
|
| 643 |
+
{
|
| 644 |
+
"train/bc_loss": avg_loss,
|
| 645 |
+
"train/lr": current_lr,
|
| 646 |
+
"train/epoch": epoch + 1,
|
| 647 |
+
},
|
| 648 |
+
step=epoch + 1,
|
| 649 |
+
)
|
| 650 |
+
logger.info(
|
| 651 |
+
"BC epoch %02d/%02d | loss=%.4f | lr=%.2e",
|
| 652 |
+
epoch + 1, n_epochs, avg_loss, current_lr,
|
| 653 |
+
)
|
| 654 |
+
|
| 655 |
+
seed_metrics: dict[str, float] = {}
|
| 656 |
+
for split, env_list in (("ID", cfg.id_envs), ("OOD", cfg.ood_envs)):
|
| 657 |
+
logger.info("--- BC %s evaluation (seed=%d) ---", split, seed)
|
| 658 |
+
for env_id in env_list:
|
| 659 |
+
short = _short(env_id)
|
| 660 |
+
win_rate, avg_steps = _eval_sb3_policy_manually(
|
| 661 |
+
policy,
|
| 662 |
+
env_id,
|
| 663 |
+
cfg,
|
| 664 |
+
f"{log_dir}/eval_{split.lower()}/{env_id}",
|
| 665 |
+
n_eval,
|
| 666 |
+
)
|
| 667 |
+
seed_metrics[f"{split}/{short}/win_rate"] = win_rate * 100
|
| 668 |
+
seed_metrics[f"{split}/{short}/avg_steps"] = avg_steps
|
| 669 |
+
logger.info(
|
| 670 |
+
"%-30s | win_rate=%5.1f%% | avg_steps=%5.1f",
|
| 671 |
+
short, win_rate * 100, avg_steps,
|
| 672 |
+
)
|
| 673 |
+
log.log(seed_metrics, step=n_epochs + 1)
|
| 674 |
+
return policy, seed_metrics
|
| 675 |
+
|
| 676 |
+
|
| 677 |
+
# =============================================================================
|
| 678 |
+
# Decision Transformer training
|
| 679 |
+
# =============================================================================
|
| 680 |
+
|
| 681 |
+
|
| 682 |
+
def _collect_dt_trajectories(
|
| 683 |
+
cfg: SimpleNamespace,
|
| 684 |
+
) -> list[dict[str, np.ndarray]]:
|
| 685 |
+
"""Collect oracle trajectories with sparse reward + return-to-go labels."""
|
| 686 |
+
|
| 687 |
+
n_per_env = int(cfg.baselines_dt_oracle_episodes_per_env)
|
| 688 |
+
trajectories: list[dict[str, np.ndarray]] = []
|
| 689 |
+
for env_id in cfg.id_envs:
|
| 690 |
+
for traj_seed in range(n_per_env):
|
| 691 |
+
traj = collect_oracle_trajectory(env_id, traj_seed, cfg)
|
| 692 |
+
if traj is None:
|
| 693 |
+
continue
|
| 694 |
+
T = len(traj["actions"])
|
| 695 |
+
rewards = np.zeros(T, dtype=np.float32)
|
| 696 |
+
rewards[-1] = 1.0 # sparse goal reward
|
| 697 |
+
rtg = np.zeros(T, dtype=np.float32)
|
| 698 |
+
rtg[-1] = rewards[-1]
|
| 699 |
+
for t in range(T - 2, -1, -1):
|
| 700 |
+
rtg[t] = rewards[t] + rtg[t + 1]
|
| 701 |
+
trajectories.append(
|
| 702 |
+
{
|
| 703 |
+
"local": np.expand_dims(traj["local"], axis=1),
|
| 704 |
+
"global": np.expand_dims(traj["global"], axis=1),
|
| 705 |
+
"actions": traj["actions"],
|
| 706 |
+
"rewards": rewards,
|
| 707 |
+
"returns_to_go": rtg,
|
| 708 |
+
}
|
| 709 |
+
)
|
| 710 |
+
return trajectories
|
| 711 |
+
|
| 712 |
+
|
| 713 |
+
def _eval_dt(
|
| 714 |
+
model: _DecisionTransformer,
|
| 715 |
+
env_id: str,
|
| 716 |
+
cfg: SimpleNamespace,
|
| 717 |
+
target_return: float,
|
| 718 |
+
n_episodes: int,
|
| 719 |
+
max_ep_len: int,
|
| 720 |
+
eval_max_steps: int,
|
| 721 |
+
context_len: int,
|
| 722 |
+
) -> tuple[float, float]:
|
| 723 |
+
"""Roll out a trained Decision Transformer with target-return conditioning."""
|
| 724 |
+
|
| 725 |
+
device = torch.device(cfg.device)
|
| 726 |
+
env = AdvancedObservationEnv(env_id, des_file=None, cfg=cfg)
|
| 727 |
+
env = _SB3MiniHackWrapper(env)
|
| 728 |
+
model.eval()
|
| 729 |
+
wins = 0
|
| 730 |
+
total_steps = 0
|
| 731 |
+
try:
|
| 732 |
+
for _ep in range(n_episodes):
|
| 733 |
+
obs, _ = env.reset()
|
| 734 |
+
done = False
|
| 735 |
+
|
| 736 |
+
local_hist: list[np.ndarray] = []
|
| 737 |
+
global_hist: list[np.ndarray] = []
|
| 738 |
+
action_hist: list[int] = []
|
| 739 |
+
rtg_hist: list[float] = []
|
| 740 |
+
ts_hist: list[int] = []
|
| 741 |
+
|
| 742 |
+
current_rtg = float(target_return)
|
| 743 |
+
t = 0
|
| 744 |
+
info: dict = {}
|
| 745 |
+
while not done and t < eval_max_steps:
|
| 746 |
+
local_hist.append(obs["local"])
|
| 747 |
+
global_hist.append(obs["global"])
|
| 748 |
+
rtg_hist.append(current_rtg)
|
| 749 |
+
ts_hist.append(min(t, max_ep_len - 1))
|
| 750 |
+
|
| 751 |
+
ctx = min(len(local_hist), context_len)
|
| 752 |
+
local_in = np.stack(local_hist[-ctx:], axis=0)
|
| 753 |
+
global_in = np.stack(global_hist[-ctx:], axis=0)
|
| 754 |
+
rtg_in = np.array(rtg_hist[-ctx:], dtype=np.float32)
|
| 755 |
+
ts_in = np.array(ts_hist[-ctx:], dtype=np.int64)
|
| 756 |
+
if len(action_hist) < ctx:
|
| 757 |
+
act_in = np.zeros(ctx, dtype=np.int64)
|
| 758 |
+
if action_hist:
|
| 759 |
+
act_in[-len(action_hist):] = action_hist[-ctx:]
|
| 760 |
+
else:
|
| 761 |
+
act_in = np.array(action_hist[-ctx:], dtype=np.int64)
|
| 762 |
+
|
| 763 |
+
local_t = torch.tensor(local_in, dtype=torch.float32).unsqueeze(0).to(device)
|
| 764 |
+
global_t = torch.tensor(global_in, dtype=torch.float32).unsqueeze(0).to(device)
|
| 765 |
+
rtg_t = torch.tensor(rtg_in, dtype=torch.float32).unsqueeze(0).unsqueeze(-1).to(device)
|
| 766 |
+
act_t = torch.tensor(act_in, dtype=torch.long).unsqueeze(0).to(device)
|
| 767 |
+
ts_t = torch.tensor(ts_in, dtype=torch.long).unsqueeze(0).to(device)
|
| 768 |
+
|
| 769 |
+
with torch.no_grad():
|
| 770 |
+
action = int(
|
| 771 |
+
model.get_action(rtg_t, local_t, global_t, act_t, ts_t).item()
|
| 772 |
+
)
|
| 773 |
+
action = max(0, min(action, int(cfg.action_dim) - 1))
|
| 774 |
+
action_hist.append(action)
|
| 775 |
+
|
| 776 |
+
obs, reward, terminated, truncated, info = env.step(action)
|
| 777 |
+
done = terminated or truncated
|
| 778 |
+
current_rtg -= float(reward)
|
| 779 |
+
t += 1
|
| 780 |
+
|
| 781 |
+
if info.get("won", False):
|
| 782 |
+
wins += 1
|
| 783 |
+
total_steps += t
|
| 784 |
+
finally:
|
| 785 |
+
env.close()
|
| 786 |
+
|
| 787 |
+
return wins / n_episodes, total_steps / n_episodes
|
| 788 |
+
|
| 789 |
+
|
| 790 |
+
def _train_dt(
|
| 791 |
+
cfg: SimpleNamespace,
|
| 792 |
+
log: Logger,
|
| 793 |
+
log_dir: str,
|
| 794 |
+
seed: int,
|
| 795 |
+
) -> tuple[_DecisionTransformer, dict[str, float]]:
|
| 796 |
+
"""Train a Decision Transformer baseline. Returns (model, seed_metrics)."""
|
| 797 |
+
|
| 798 |
+
device = torch.device(cfg.device)
|
| 799 |
+
context_len = int(cfg.baselines_dt_context_len)
|
| 800 |
+
max_ep_len = int(cfg.baselines_dt_max_ep_len)
|
| 801 |
+
eval_max_steps = int(cfg.baselines_dt_eval_max_steps)
|
| 802 |
+
n_eval = _eval_episodes_per_env(cfg)
|
| 803 |
+
n_epochs = int(cfg.baselines_dt_epochs)
|
| 804 |
+
|
| 805 |
+
logger.info("Collecting oracle demonstrations for DT...")
|
| 806 |
+
trajectories = _collect_dt_trajectories(cfg)
|
| 807 |
+
if not trajectories:
|
| 808 |
+
raise RuntimeError("DT oracle collection produced zero trajectories")
|
| 809 |
+
|
| 810 |
+
traj_lengths = [len(t["actions"]) for t in trajectories]
|
| 811 |
+
logger.info(
|
| 812 |
+
"DT dataset: %d trajectories, %d transitions (len: min=%d max=%d mean=%.1f)",
|
| 813 |
+
len(trajectories),
|
| 814 |
+
sum(traj_lengths),
|
| 815 |
+
min(traj_lengths),
|
| 816 |
+
max(traj_lengths),
|
| 817 |
+
float(np.mean(traj_lengths)),
|
| 818 |
+
)
|
| 819 |
+
if max(traj_lengths) > max_ep_len:
|
| 820 |
+
logger.warning(
|
| 821 |
+
"Longest oracle trajectory (%d) exceeds baselines_dt_max_ep_len (%d); "
|
| 822 |
+
"positions will be clamped.",
|
| 823 |
+
max(traj_lengths),
|
| 824 |
+
max_ep_len,
|
| 825 |
+
)
|
| 826 |
+
|
| 827 |
+
target_return = float(max(t["returns_to_go"][0] for t in trajectories))
|
| 828 |
+
|
| 829 |
+
dataset = _DTDataset(
|
| 830 |
+
trajectories,
|
| 831 |
+
context_len=context_len,
|
| 832 |
+
max_ep_len=max_ep_len,
|
| 833 |
+
n_actions=int(cfg.action_dim),
|
| 834 |
+
)
|
| 835 |
+
loader = DataLoader(
|
| 836 |
+
dataset,
|
| 837 |
+
batch_size=int(cfg.baselines_dt_batch_size),
|
| 838 |
+
shuffle=True,
|
| 839 |
+
num_workers=4,
|
| 840 |
+
pin_memory=torch.cuda.is_available(),
|
| 841 |
+
)
|
| 842 |
+
|
| 843 |
+
model = _DecisionTransformer(
|
| 844 |
+
n_actions=int(cfg.action_dim),
|
| 845 |
+
embed_dim=int(cfg.baselines_dt_embed_dim),
|
| 846 |
+
n_heads=int(cfg.baselines_dt_n_heads),
|
| 847 |
+
n_layers=int(cfg.baselines_dt_n_layers),
|
| 848 |
+
context_len=context_len,
|
| 849 |
+
max_ep_len=max_ep_len,
|
| 850 |
+
crop_h=int(cfg.crop_size),
|
| 851 |
+
crop_w=int(cfg.crop_size),
|
| 852 |
+
map_h=int(cfg.map_h),
|
| 853 |
+
map_w=int(cfg.map_w),
|
| 854 |
+
).to(device)
|
| 855 |
+
n_params = sum(p.numel() for p in model.parameters())
|
| 856 |
+
logger.info("DT parameters: %d", n_params)
|
| 857 |
+
|
| 858 |
+
optimizer = torch.optim.AdamW(
|
| 859 |
+
model.parameters(),
|
| 860 |
+
lr=float(cfg.baselines_dt_lr),
|
| 861 |
+
weight_decay=float(cfg.weight_decay),
|
| 862 |
+
)
|
| 863 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
| 864 |
+
optimizer, T_max=n_epochs,
|
| 865 |
+
)
|
| 866 |
+
|
| 867 |
+
for epoch in range(n_epochs):
|
| 868 |
+
model.train()
|
| 869 |
+
total_loss = 0.0
|
| 870 |
+
n_batches = 0
|
| 871 |
+
for batch in loader:
|
| 872 |
+
local = batch["local"].to(device)
|
| 873 |
+
glob = batch["global"].to(device)
|
| 874 |
+
actions = batch["actions"].to(device)
|
| 875 |
+
rtg = batch["returns_to_go"].to(device)
|
| 876 |
+
timesteps = batch["timesteps"].to(device)
|
| 877 |
+
attention_mask = batch["attention_mask"].to(device)
|
| 878 |
+
|
| 879 |
+
logits = model(rtg, local, glob, actions, timesteps, attention_mask)
|
| 880 |
+
logits_flat = logits.reshape(-1, int(cfg.action_dim))
|
| 881 |
+
targets_flat = actions.reshape(-1)
|
| 882 |
+
mask_flat = attention_mask.reshape(-1)
|
| 883 |
+
ce = nn.functional.cross_entropy(
|
| 884 |
+
logits_flat, targets_flat, reduction="none",
|
| 885 |
+
)
|
| 886 |
+
loss = (ce * mask_flat).sum() / mask_flat.sum().clamp(min=1.0)
|
| 887 |
+
|
| 888 |
+
optimizer.zero_grad()
|
| 889 |
+
loss.backward()
|
| 890 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 891 |
+
optimizer.step()
|
| 892 |
+
|
| 893 |
+
total_loss += loss.item()
|
| 894 |
+
n_batches += 1
|
| 895 |
+
scheduler.step()
|
| 896 |
+
avg_loss = total_loss / max(1, n_batches)
|
| 897 |
+
log.log(
|
| 898 |
+
{
|
| 899 |
+
"train/dt_loss": avg_loss,
|
| 900 |
+
"train/lr": float(scheduler.get_last_lr()[0]),
|
| 901 |
+
"train/epoch": epoch + 1,
|
| 902 |
+
},
|
| 903 |
+
step=epoch + 1,
|
| 904 |
+
)
|
| 905 |
+
logger.info(
|
| 906 |
+
"DT epoch %02d/%02d | loss=%.4f | lr=%.2e",
|
| 907 |
+
epoch + 1,
|
| 908 |
+
n_epochs,
|
| 909 |
+
avg_loss,
|
| 910 |
+
float(scheduler.get_last_lr()[0]),
|
| 911 |
+
)
|
| 912 |
+
|
| 913 |
+
seed_metrics: dict[str, float] = {}
|
| 914 |
+
logger.info("DT eval target return = %.2f", target_return)
|
| 915 |
+
for split, env_list in (("ID", cfg.id_envs), ("OOD", cfg.ood_envs)):
|
| 916 |
+
logger.info("--- DT %s evaluation (seed=%d) ---", split, seed)
|
| 917 |
+
for env_id in env_list:
|
| 918 |
+
short = _short(env_id)
|
| 919 |
+
win_rate, avg_steps = _eval_dt(
|
| 920 |
+
model,
|
| 921 |
+
env_id,
|
| 922 |
+
cfg,
|
| 923 |
+
target_return=target_return,
|
| 924 |
+
n_episodes=n_eval,
|
| 925 |
+
max_ep_len=max_ep_len,
|
| 926 |
+
eval_max_steps=eval_max_steps,
|
| 927 |
+
context_len=context_len,
|
| 928 |
+
)
|
| 929 |
+
seed_metrics[f"{split}/{short}/win_rate"] = win_rate * 100
|
| 930 |
+
seed_metrics[f"{split}/{short}/avg_steps"] = avg_steps
|
| 931 |
+
logger.info(
|
| 932 |
+
"%-30s | win_rate=%5.1f%% | avg_steps=%5.1f",
|
| 933 |
+
short, win_rate * 100, avg_steps,
|
| 934 |
+
)
|
| 935 |
+
log.log(seed_metrics, step=n_epochs + 1)
|
| 936 |
+
return model, seed_metrics
|
| 937 |
+
|
| 938 |
+
|
| 939 |
+
# =============================================================================
|
| 940 |
+
# SB3 RL training
|
| 941 |
+
# =============================================================================
|
| 942 |
+
|
| 943 |
+
|
| 944 |
+
def _build_sb3_model(
|
| 945 |
+
algo: str,
|
| 946 |
+
train_env: SubprocVecEnv,
|
| 947 |
+
cfg: SimpleNamespace,
|
| 948 |
+
seed: int,
|
| 949 |
+
tb_log_dir: str,
|
| 950 |
+
):
|
| 951 |
+
"""Construct one of {ppo, a2c, dqn, ppo-rnn} with the MiniHack CNN."""
|
| 952 |
+
|
| 953 |
+
policy_kwargs = {
|
| 954 |
+
"features_extractor_class": _MiniHackCNN,
|
| 955 |
+
"features_extractor_kwargs": {"features_dim": 256},
|
| 956 |
+
}
|
| 957 |
+
if algo == "ppo":
|
| 958 |
+
return PPO(
|
| 959 |
+
"MultiInputPolicy", train_env, policy_kwargs=policy_kwargs,
|
| 960 |
+
verbose=1, tensorboard_log=tb_log_dir, seed=seed,
|
| 961 |
+
)
|
| 962 |
+
if algo == "ppo-rnn":
|
| 963 |
+
return RecurrentPPO(
|
| 964 |
+
"MultiInputLstmPolicy", train_env, policy_kwargs=policy_kwargs,
|
| 965 |
+
verbose=1, tensorboard_log=tb_log_dir, seed=seed,
|
| 966 |
+
)
|
| 967 |
+
if algo == "a2c":
|
| 968 |
+
return A2C(
|
| 969 |
+
"MultiInputPolicy", train_env, policy_kwargs=policy_kwargs,
|
| 970 |
+
verbose=1, tensorboard_log=tb_log_dir, seed=seed,
|
| 971 |
+
)
|
| 972 |
+
if algo == "dqn":
|
| 973 |
+
return DQN(
|
| 974 |
+
"MultiInputPolicy", train_env, policy_kwargs=policy_kwargs,
|
| 975 |
+
verbose=1, tensorboard_log=tb_log_dir, seed=seed,
|
| 976 |
+
buffer_size=int(cfg.baselines_dqn_buffer_size),
|
| 977 |
+
)
|
| 978 |
+
raise ValueError(f"Unknown SB3 algo: {algo!r}")
|
| 979 |
+
|
| 980 |
+
|
| 981 |
+
def _build_sb3_callbacks(
|
| 982 |
+
cfg: SimpleNamespace,
|
| 983 |
+
train_env: SubprocVecEnv,
|
| 984 |
+
log_dir: str,
|
| 985 |
+
model_dir: str,
|
| 986 |
+
) -> CallbackList:
|
| 987 |
+
callbacks: list = [WandbCallback(model_save_path=model_dir)]
|
| 988 |
+
n_eval = _eval_episodes_per_env(cfg)
|
| 989 |
+
eval_freq = max(
|
| 990 |
+
1, int(cfg.baselines_eval_freq_env_steps) // train_env.num_envs,
|
| 991 |
+
)
|
| 992 |
+
for env_id in cfg.id_envs:
|
| 993 |
+
short = _short(env_id)
|
| 994 |
+
eval_env = SubprocVecEnv(
|
| 995 |
+
[_make_sb3_env_fn(env_id, cfg, f"{log_dir}/eval_id/{env_id}")]
|
| 996 |
+
)
|
| 997 |
+
callbacks.append(
|
| 998 |
+
_PrefixedEvalCallback(
|
| 999 |
+
eval_env,
|
| 1000 |
+
prefix=f"ID/{short}",
|
| 1001 |
+
best_model_save_path=f"{model_dir}/best_{env_id}/",
|
| 1002 |
+
log_path=f"{log_dir}/eval_id/{env_id}/",
|
| 1003 |
+
eval_freq=eval_freq,
|
| 1004 |
+
n_eval_episodes=n_eval,
|
| 1005 |
+
deterministic=True,
|
| 1006 |
+
)
|
| 1007 |
+
)
|
| 1008 |
+
for env_id in cfg.ood_envs:
|
| 1009 |
+
short = _short(env_id)
|
| 1010 |
+
eval_env = SubprocVecEnv(
|
| 1011 |
+
[_make_sb3_env_fn(env_id, cfg, f"{log_dir}/eval_ood/{env_id}")]
|
| 1012 |
+
)
|
| 1013 |
+
callbacks.append(
|
| 1014 |
+
_PrefixedEvalCallback(
|
| 1015 |
+
eval_env,
|
| 1016 |
+
prefix=f"OOD/{short}",
|
| 1017 |
+
best_model_save_path=None,
|
| 1018 |
+
log_path=f"{log_dir}/eval_ood/{env_id}/",
|
| 1019 |
+
eval_freq=eval_freq,
|
| 1020 |
+
n_eval_episodes=n_eval,
|
| 1021 |
+
deterministic=True,
|
| 1022 |
+
)
|
| 1023 |
+
)
|
| 1024 |
+
return CallbackList(callbacks)
|
| 1025 |
+
|
| 1026 |
+
|
| 1027 |
+
# =============================================================================
|
| 1028 |
+
# Aggregation
|
| 1029 |
+
# =============================================================================
|
| 1030 |
+
|
| 1031 |
+
|
| 1032 |
+
def _aggregate(
|
| 1033 |
+
all_seed_results: list[dict[str, Any]],
|
| 1034 |
+
) -> dict[str, dict[str, float | list[float]]]:
|
| 1035 |
+
"""Compute mean/std across seeds for every shared metric key."""
|
| 1036 |
+
|
| 1037 |
+
if not all_seed_results:
|
| 1038 |
+
return {}
|
| 1039 |
+
metric_keys = [k for k in all_seed_results[0].keys() if k != "seed"]
|
| 1040 |
+
agg: dict[str, dict[str, float | list[float]]] = {}
|
| 1041 |
+
for key in metric_keys:
|
| 1042 |
+
values = [r[key] for r in all_seed_results if key in r]
|
| 1043 |
+
if values:
|
| 1044 |
+
agg[key] = {
|
| 1045 |
+
"mean": float(np.mean(values)),
|
| 1046 |
+
"std": float(np.std(values)),
|
| 1047 |
+
"values": [float(v) for v in values],
|
| 1048 |
+
}
|
| 1049 |
+
return agg
|
| 1050 |
+
|
| 1051 |
+
|
| 1052 |
+
def _print_aggregated(seeds: list[int], agg: dict[str, dict[str, Any]]) -> None:
|
| 1053 |
+
if not agg:
|
| 1054 |
+
logger.info("No per-environment metrics to aggregate (RL eval is callback-driven)")
|
| 1055 |
+
return
|
| 1056 |
+
logger.info("Aggregated results across %d seeds: %s", len(seeds), seeds)
|
| 1057 |
+
for split in ("ID", "OOD"):
|
| 1058 |
+
env_metrics: dict[str, dict[str, dict[str, Any]]] = {}
|
| 1059 |
+
for key, stats in agg.items():
|
| 1060 |
+
if not key.startswith(f"{split}/"):
|
| 1061 |
+
continue
|
| 1062 |
+
_split, env_name, metric_name = key.split("/", 2)
|
| 1063 |
+
env_metrics.setdefault(env_name, {})[metric_name] = stats
|
| 1064 |
+
if not env_metrics:
|
| 1065 |
+
continue
|
| 1066 |
+
logger.info("--- %s environments ---", split)
|
| 1067 |
+
for env_name, metrics in sorted(env_metrics.items()):
|
| 1068 |
+
wr = metrics.get("win_rate", {})
|
| 1069 |
+
steps = metrics.get("avg_steps", {})
|
| 1070 |
+
logger.info(
|
| 1071 |
+
"%-30s | win_rate=%5.1f%% +/- %4.1f | avg_steps=%5.1f +/- %4.1f",
|
| 1072 |
+
env_name,
|
| 1073 |
+
wr.get("mean", 0.0),
|
| 1074 |
+
wr.get("std", 0.0),
|
| 1075 |
+
steps.get("mean", 0.0),
|
| 1076 |
+
steps.get("std", 0.0),
|
| 1077 |
+
)
|
| 1078 |
+
|
| 1079 |
+
|
| 1080 |
+
def _save_aggregated(
|
| 1081 |
+
out_path: Path,
|
| 1082 |
+
algo: str,
|
| 1083 |
+
seeds: list[int],
|
| 1084 |
+
all_seed_results: list[dict[str, Any]],
|
| 1085 |
+
agg: dict[str, dict[str, Any]],
|
| 1086 |
+
) -> None:
|
| 1087 |
+
payload = {
|
| 1088 |
+
"algorithm": algo,
|
| 1089 |
+
"seeds": seeds,
|
| 1090 |
+
"n_seeds": len(seeds),
|
| 1091 |
+
"per_seed_results": all_seed_results,
|
| 1092 |
+
"aggregated": {
|
| 1093 |
+
k: {"mean": v["mean"], "std": v["std"]} for k, v in agg.items()
|
| 1094 |
+
},
|
| 1095 |
+
}
|
| 1096 |
+
out_path.write_bytes(orjson.dumps(payload, option=orjson.OPT_INDENT_2))
|
| 1097 |
+
logger.info("Aggregated results written to %s", out_path)
|
| 1098 |
+
|
| 1099 |
+
|
| 1100 |
+
# =============================================================================
|
| 1101 |
+
# Public entry point
|
| 1102 |
+
# =============================================================================
|
| 1103 |
+
|
| 1104 |
+
|
| 1105 |
+
def run_baselines(
|
| 1106 |
+
cfg: SimpleNamespace,
|
| 1107 |
+
algo: str,
|
| 1108 |
+
seeds: list[int] | None = None,
|
| 1109 |
+
output_path: str | None = None,
|
| 1110 |
+
) -> None:
|
| 1111 |
+
"""Train and evaluate one baseline algorithm across one or more seeds.
|
| 1112 |
+
|
| 1113 |
+
Args:
|
| 1114 |
+
cfg: Project config namespace (must contain ``baselines_*`` keys).
|
| 1115 |
+
algo: One of ``ppo``, ``a2c``, ``dqn``, ``ppo-rnn``, ``bc``, ``dt``.
|
| 1116 |
+
seeds: Optional list of seeds. ``None`` -> ``[cfg.seed]`` (or
|
| 1117 |
+
a single seed of ``0`` if ``cfg.seed`` is ``None``).
|
| 1118 |
+
output_path: Optional override for the aggregated-results JSON
|
| 1119 |
+
destination. When ``None``, results land under
|
| 1120 |
+
``cfg.baselines_output_dir``.
|
| 1121 |
+
"""
|
| 1122 |
+
|
| 1123 |
+
if algo not in ALL_BASELINE_ALGOS:
|
| 1124 |
+
raise ValueError(
|
| 1125 |
+
f"Unknown algo {algo!r}. Choose one of {ALL_BASELINE_ALGOS}."
|
| 1126 |
+
)
|
| 1127 |
+
|
| 1128 |
+
if seeds is None:
|
| 1129 |
+
seeds = [cfg.seed if cfg.seed is not None else 0]
|
| 1130 |
+
if not seeds:
|
| 1131 |
+
raise ValueError("seeds must be non-empty")
|
| 1132 |
+
|
| 1133 |
+
out_dir = _resolve_output_dir(cfg, None)
|
| 1134 |
+
if output_path is not None:
|
| 1135 |
+
agg_json_path = Path(output_path)
|
| 1136 |
+
agg_json_path.parent.mkdir(parents=True, exist_ok=True)
|
| 1137 |
+
else:
|
| 1138 |
+
agg_json_path = out_dir / f"results_{algo}_{len(seeds)}seeds.json"
|
| 1139 |
+
|
| 1140 |
+
logger.info(
|
| 1141 |
+
"Running baseline %s on %d seed(s): %s (output -> %s)",
|
| 1142 |
+
algo, len(seeds), seeds, agg_json_path,
|
| 1143 |
+
)
|
| 1144 |
+
|
| 1145 |
+
all_seed_results: list[dict[str, Any]] = []
|
| 1146 |
+
n_envs_per_id = int(cfg.baselines_n_envs_per_id)
|
| 1147 |
+
|
| 1148 |
+
for seed_idx, seed in enumerate(seeds):
|
| 1149 |
+
logger.info(
|
| 1150 |
+
"============================================================\n"
|
| 1151 |
+
" %s seed %d (%d/%d)\n"
|
| 1152 |
+
"============================================================",
|
| 1153 |
+
algo.upper(), seed, seed_idx + 1, len(seeds),
|
| 1154 |
+
)
|
| 1155 |
+
_seed_everything(seed)
|
| 1156 |
+
|
| 1157 |
+
run_name = f"{algo}-multitask-seed{seed}"
|
| 1158 |
+
log = _init_baseline_logger(cfg, run_name)
|
| 1159 |
+
run_id = (
|
| 1160 |
+
log._run.id # type: ignore[union-attr]
|
| 1161 |
+
if log._use_wandb and log._run is not None
|
| 1162 |
+
else f"local-{algo}-seed{seed}"
|
| 1163 |
+
)
|
| 1164 |
+
log_dir = str(out_dir / "logs" / run_id)
|
| 1165 |
+
model_dir = str(out_dir / "models" / run_id)
|
| 1166 |
+
os.makedirs(log_dir, exist_ok=True)
|
| 1167 |
+
os.makedirs(model_dir, exist_ok=True)
|
| 1168 |
+
|
| 1169 |
+
seed_results: dict[str, Any] = {"seed": seed}
|
| 1170 |
+
try:
|
| 1171 |
+
if algo == "dt":
|
| 1172 |
+
model, dt_metrics = _train_dt(cfg, log, log_dir, seed)
|
| 1173 |
+
seed_results.update(dt_metrics)
|
| 1174 |
+
torch.save(
|
| 1175 |
+
{
|
| 1176 |
+
"model_state_dict": model.state_dict(),
|
| 1177 |
+
"config": {
|
| 1178 |
+
"n_actions": int(cfg.action_dim),
|
| 1179 |
+
"embed_dim": int(cfg.baselines_dt_embed_dim),
|
| 1180 |
+
"n_heads": int(cfg.baselines_dt_n_heads),
|
| 1181 |
+
"n_layers": int(cfg.baselines_dt_n_layers),
|
| 1182 |
+
"context_len": int(cfg.baselines_dt_context_len),
|
| 1183 |
+
"max_ep_len": int(cfg.baselines_dt_max_ep_len),
|
| 1184 |
+
},
|
| 1185 |
+
},
|
| 1186 |
+
f"{model_dir}/dt_final_seed{seed}.pt",
|
| 1187 |
+
)
|
| 1188 |
+
else:
|
| 1189 |
+
# SB3 RL families and BC both need the parallel train env.
|
| 1190 |
+
train_env_fns = [
|
| 1191 |
+
_make_sb3_env_fn(env_id, cfg, log_dir)
|
| 1192 |
+
for env_id in list(cfg.id_envs) * n_envs_per_id
|
| 1193 |
+
]
|
| 1194 |
+
train_env = SubprocVecEnv(train_env_fns)
|
| 1195 |
+
try:
|
| 1196 |
+
if algo == "bc":
|
| 1197 |
+
policy, bc_metrics = _train_bc(
|
| 1198 |
+
cfg, train_env, log, log_dir, seed,
|
| 1199 |
+
)
|
| 1200 |
+
seed_results.update(bc_metrics)
|
| 1201 |
+
policy.save(f"{model_dir}/bc_final_seed{seed}")
|
| 1202 |
+
else:
|
| 1203 |
+
sb3_model = _build_sb3_model(
|
| 1204 |
+
algo, train_env, cfg, seed,
|
| 1205 |
+
tb_log_dir=str(out_dir / "tb" / run_id),
|
| 1206 |
+
)
|
| 1207 |
+
callbacks = _build_sb3_callbacks(
|
| 1208 |
+
cfg, train_env, log_dir, model_dir,
|
| 1209 |
+
)
|
| 1210 |
+
logger.info(
|
| 1211 |
+
"Training %s for %d env-steps across %d ID maps "
|
| 1212 |
+
"(%d parallel envs)...",
|
| 1213 |
+
algo.upper(),
|
| 1214 |
+
int(cfg.total_timesteps),
|
| 1215 |
+
len(cfg.id_envs),
|
| 1216 |
+
train_env.num_envs,
|
| 1217 |
+
)
|
| 1218 |
+
sb3_model.learn(
|
| 1219 |
+
total_timesteps=int(cfg.total_timesteps),
|
| 1220 |
+
callback=callbacks,
|
| 1221 |
+
)
|
| 1222 |
+
sb3_model.save(f"{model_dir}/{algo}_final_seed{seed}")
|
| 1223 |
+
finally:
|
| 1224 |
+
train_env.close()
|
| 1225 |
+
|
| 1226 |
+
all_seed_results.append(seed_results)
|
| 1227 |
+
finally:
|
| 1228 |
+
log.finish()
|
| 1229 |
+
logger.info("%s seed %d complete.", algo.upper(), seed)
|
| 1230 |
+
|
| 1231 |
+
agg = _aggregate(all_seed_results)
|
| 1232 |
+
_print_aggregated(seeds, agg)
|
| 1233 |
+
if agg:
|
| 1234 |
+
_save_aggregated(agg_json_path, algo, seeds, all_seed_results, agg)
|
| 1235 |
+
# Final summary write to the project Logger so the aggregated
|
| 1236 |
+
# numbers land on a dedicated W&B run.
|
| 1237 |
+
summary_run_name = f"{algo}-multitask-summary"
|
| 1238 |
+
summary_log = _init_baseline_logger(cfg, summary_run_name)
|
| 1239 |
+
try:
|
| 1240 |
+
summary_payload: dict[str, float] = {}
|
| 1241 |
+
for key, stats in agg.items():
|
| 1242 |
+
summary_payload[f"summary/{key}/mean"] = stats["mean"]
|
| 1243 |
+
summary_payload[f"summary/{key}/std"] = stats["std"]
|
| 1244 |
+
summary_log.log_summary(summary_payload)
|
| 1245 |
+
finally:
|
| 1246 |
+
summary_log.finish()
|
| 1247 |
+
logger.info("All %d seed(s) complete.", len(seeds))
|
src/planners/collect.py
ADDED
|
@@ -0,0 +1,588 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Data collection with DAgger and oracle replay.
|
| 2 |
+
|
| 3 |
+
Implements model episode rollout with replanning and DAgger-style
|
| 4 |
+
data collection using the BFS oracle and efficiency filter.
|
| 5 |
+
Supports parallel episode collection via ``ThreadPoolExecutor``.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import copy
|
| 11 |
+
import logging
|
| 12 |
+
import os
|
| 13 |
+
import random
|
| 14 |
+
import time
|
| 15 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 16 |
+
from typing import TYPE_CHECKING
|
| 17 |
+
from types import SimpleNamespace
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
import torch
|
| 21 |
+
|
| 22 |
+
from src.buffer import ReplayBuffer
|
| 23 |
+
from src.curriculum import DynamicCurriculum, efficiency_filter
|
| 24 |
+
from src.diffusion.sampling import greedy_sample, remdm_sample
|
| 25 |
+
from src.envs.minihack_env import collect_oracle_trajectory, make_env
|
| 26 |
+
|
| 27 |
+
if TYPE_CHECKING:
|
| 28 |
+
from src.models.denoiser import ModelEMA
|
| 29 |
+
|
| 30 |
+
logger = logging.getLogger(__name__)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@torch.no_grad()
|
| 34 |
+
def run_model_episode(
|
| 35 |
+
model: torch.nn.Module,
|
| 36 |
+
env_id: str,
|
| 37 |
+
cfg: SimpleNamespace,
|
| 38 |
+
device: torch.device | str,
|
| 39 |
+
seed: int | None = None,
|
| 40 |
+
max_steps: int = 500,
|
| 41 |
+
des_file: str | None = None,
|
| 42 |
+
blind_global: bool = False,
|
| 43 |
+
stochastic: bool = False,
|
| 44 |
+
) -> dict:
|
| 45 |
+
"""Roll out the diffusion model on a single episode.
|
| 46 |
+
|
| 47 |
+
Maintains a ``seq_len``-length plan and replans every
|
| 48 |
+
``cfg.replan_every`` steps.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
model: Denoising model (eval mode).
|
| 52 |
+
env_id: MiniHack registry ID.
|
| 53 |
+
cfg: Config namespace.
|
| 54 |
+
device: Torch device.
|
| 55 |
+
seed: Optional RNG seed.
|
| 56 |
+
max_steps: Maximum episode length.
|
| 57 |
+
des_file: Optional ``.des`` file content for custom scenarios.
|
| 58 |
+
blind_global: If ``True``, zero out global map (local-only ablation).
|
| 59 |
+
stochastic: If ``True``, use stochastic ReMDM sampling (evaluation).
|
| 60 |
+
If ``False`` (default), use greedy argmax (DAgger collection).
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
Dict with ``"local"`` ``[T,9,9]``, ``"global"`` ``[T,21,79]``,
|
| 64 |
+
``"actions"`` ``[T]``, ``"won"`` bool, ``"steps"`` int,
|
| 65 |
+
``"total_reward"`` float, ``"seed"`` int.
|
| 66 |
+
"""
|
| 67 |
+
if seed is None:
|
| 68 |
+
seed = random.randint(0, 2**31 - 1)
|
| 69 |
+
|
| 70 |
+
_use_stochastic = stochastic
|
| 71 |
+
|
| 72 |
+
env = make_env(env_id, des_file, cfg)
|
| 73 |
+
try:
|
| 74 |
+
(local, glb), _info = env.reset(seed=seed)
|
| 75 |
+
|
| 76 |
+
locals_list = [local]
|
| 77 |
+
globals_list = [glb]
|
| 78 |
+
actions_list: list[int] = []
|
| 79 |
+
won = False
|
| 80 |
+
total_reward = 0.0
|
| 81 |
+
plan: torch.Tensor | None = None
|
| 82 |
+
step_in_plan = 0
|
| 83 |
+
|
| 84 |
+
model.eval()
|
| 85 |
+
for step_idx in range(max_steps):
|
| 86 |
+
# Replan when needed
|
| 87 |
+
if plan is None or step_in_plan >= cfg.replan_every:
|
| 88 |
+
local_t = torch.from_numpy(
|
| 89 |
+
local[np.newaxis]
|
| 90 |
+
).long().to(device) # [1, 9, 9]
|
| 91 |
+
glb_t = torch.from_numpy(
|
| 92 |
+
glb[np.newaxis]
|
| 93 |
+
).long().to(device) # [1, 21, 79]
|
| 94 |
+
if _use_stochastic:
|
| 95 |
+
plan = remdm_sample(
|
| 96 |
+
model, local_t, glb_t, cfg, device,
|
| 97 |
+
physics_aware=getattr(
|
| 98 |
+
cfg, "physics_aware_sampling", False,
|
| 99 |
+
),
|
| 100 |
+
blind_global=blind_global,
|
| 101 |
+
)
|
| 102 |
+
else:
|
| 103 |
+
plan = greedy_sample(
|
| 104 |
+
model, local_t, glb_t, cfg, device,
|
| 105 |
+
blind_global=blind_global,
|
| 106 |
+
) # [1, seq_len]
|
| 107 |
+
step_in_plan = 0
|
| 108 |
+
|
| 109 |
+
action = plan[0, step_in_plan].item()
|
| 110 |
+
action = max(0, min(action, cfg.action_dim - 1))
|
| 111 |
+
actions_list.append(action)
|
| 112 |
+
step_in_plan += 1
|
| 113 |
+
|
| 114 |
+
(local, glb), reward, terminated, truncated, info = env.step(
|
| 115 |
+
action,
|
| 116 |
+
)
|
| 117 |
+
total_reward += reward
|
| 118 |
+
locals_list.append(local)
|
| 119 |
+
globals_list.append(glb)
|
| 120 |
+
|
| 121 |
+
if info.get("won", False):
|
| 122 |
+
won = True
|
| 123 |
+
if terminated or truncated:
|
| 124 |
+
break
|
| 125 |
+
finally:
|
| 126 |
+
env.close()
|
| 127 |
+
|
| 128 |
+
# Trim trailing obs
|
| 129 |
+
locals_arr = np.stack(locals_list[:-1], axis=0).astype(np.int16)
|
| 130 |
+
globals_arr = np.stack(globals_list[:-1], axis=0).astype(np.int16)
|
| 131 |
+
actions_arr = np.array(actions_list, dtype=np.int64)
|
| 132 |
+
|
| 133 |
+
return {
|
| 134 |
+
"local": locals_arr,
|
| 135 |
+
"global": globals_arr,
|
| 136 |
+
"actions": actions_arr,
|
| 137 |
+
"won": won,
|
| 138 |
+
"steps": len(actions_list),
|
| 139 |
+
"total_reward": total_reward,
|
| 140 |
+
"seed": seed,
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def _collect_episode_thread(
|
| 145 |
+
model: torch.nn.Module,
|
| 146 |
+
env_id: str,
|
| 147 |
+
seed: int,
|
| 148 |
+
cfg: SimpleNamespace,
|
| 149 |
+
) -> dict | None:
|
| 150 |
+
"""Thread worker: run one paired (model + oracle) episode.
|
| 151 |
+
|
| 152 |
+
Both NLE (C code) and PyTorch CPU inference release the GIL,
|
| 153 |
+
so true parallelism is achieved with threads. Each call uses
|
| 154 |
+
its own model copy and env instance.
|
| 155 |
+
|
| 156 |
+
Args:
|
| 157 |
+
model: CPU-resident eval-mode model (thread's own copy).
|
| 158 |
+
env_id: MiniHack environment ID.
|
| 159 |
+
seed: RNG seed for the episode.
|
| 160 |
+
cfg: Config namespace.
|
| 161 |
+
|
| 162 |
+
Returns:
|
| 163 |
+
Stats dict or ``None`` on failure.
|
| 164 |
+
"""
|
| 165 |
+
try:
|
| 166 |
+
model_result = run_model_episode(
|
| 167 |
+
model, env_id, cfg, "cpu", seed,
|
| 168 |
+
)
|
| 169 |
+
oracle_result = collect_oracle_trajectory(env_id, seed, cfg)
|
| 170 |
+
oracle_steps = (
|
| 171 |
+
len(oracle_result["actions"]) if oracle_result else 999
|
| 172 |
+
)
|
| 173 |
+
return {
|
| 174 |
+
"env_id": env_id,
|
| 175 |
+
"seed": seed,
|
| 176 |
+
"model_won": model_result["won"],
|
| 177 |
+
"model_steps": model_result["steps"],
|
| 178 |
+
"oracle_steps": oracle_steps,
|
| 179 |
+
"oracle_result": oracle_result,
|
| 180 |
+
}
|
| 181 |
+
except Exception:
|
| 182 |
+
logger.error(
|
| 183 |
+
f"Thread worker failed for {env_id} seed={seed}", exc_info=True,
|
| 184 |
+
)
|
| 185 |
+
return None
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
class DataCollector:
|
| 189 |
+
"""DAgger-style data collector.
|
| 190 |
+
|
| 191 |
+
Each iteration: sample an environment from the curriculum, run the
|
| 192 |
+
model, run the oracle on the same seed, apply efficiency filter, and
|
| 193 |
+
optionally add the oracle trajectory to the buffer.
|
| 194 |
+
|
| 195 |
+
Supports parallel episode collection via ``cfg.num_collection_workers``.
|
| 196 |
+
|
| 197 |
+
Uses a live reference to the ``ModelEMA`` object so the collector
|
| 198 |
+
always uses the latest EMA weights (synced before each rollout).
|
| 199 |
+
|
| 200 |
+
Args:
|
| 201 |
+
ema: EMA tracker holding shadow weights.
|
| 202 |
+
model: Training model (architecture template for EMA snapshot).
|
| 203 |
+
buffer: Replay buffer to populate.
|
| 204 |
+
curriculum: Dynamic environment curriculum.
|
| 205 |
+
cfg: Config namespace.
|
| 206 |
+
device: Torch device.
|
| 207 |
+
"""
|
| 208 |
+
|
| 209 |
+
def __init__(
|
| 210 |
+
self,
|
| 211 |
+
ema: "ModelEMA",
|
| 212 |
+
model: torch.nn.Module,
|
| 213 |
+
buffer: ReplayBuffer,
|
| 214 |
+
curriculum: DynamicCurriculum,
|
| 215 |
+
cfg: SimpleNamespace,
|
| 216 |
+
device: torch.device | str,
|
| 217 |
+
) -> None:
|
| 218 |
+
self._ema = ema
|
| 219 |
+
self._model_template = model
|
| 220 |
+
# Materialise an eval-mode copy; refreshed before each rollout
|
| 221 |
+
self.ema_model = ema.make_eval_model(model)
|
| 222 |
+
self.buffer = buffer
|
| 223 |
+
self.curriculum = curriculum
|
| 224 |
+
self.cfg = cfg
|
| 225 |
+
self.device = device
|
| 226 |
+
self._num_workers = getattr(cfg, "num_collection_workers", 0)
|
| 227 |
+
self._last_profile: dict[str, float] = {}
|
| 228 |
+
self._thread_pool: ThreadPoolExecutor | None = None
|
| 229 |
+
self._thread_models: list[torch.nn.Module] = []
|
| 230 |
+
if self._num_workers > 0:
|
| 231 |
+
n = min(self._num_workers, os.cpu_count() or 4)
|
| 232 |
+
self._thread_pool = ThreadPoolExecutor(max_workers=n)
|
| 233 |
+
# Create one CPU model copy per thread
|
| 234 |
+
for _ in range(n):
|
| 235 |
+
m = copy.deepcopy(model).cpu()
|
| 236 |
+
m.eval()
|
| 237 |
+
self._thread_models.append(m)
|
| 238 |
+
|
| 239 |
+
def _sync_ema(self) -> None:
|
| 240 |
+
"""Copy latest EMA shadow weights into the eval model."""
|
| 241 |
+
self._ema.apply_to(self.ema_model)
|
| 242 |
+
self.ema_model.eval()
|
| 243 |
+
|
| 244 |
+
def collect_one_iteration(self) -> dict:
|
| 245 |
+
"""Run one DAgger collection iteration (single episode).
|
| 246 |
+
|
| 247 |
+
Returns:
|
| 248 |
+
Stats dict with ``"env_id"``, ``"model_won"``,
|
| 249 |
+
``"model_steps"``, ``"oracle_steps"``,
|
| 250 |
+
``"added_to_buffer"`` keys.
|
| 251 |
+
"""
|
| 252 |
+
self._sync_ema()
|
| 253 |
+
env_id = self.curriculum.sample_env()
|
| 254 |
+
seed = random.randint(0, 2**31 - 1)
|
| 255 |
+
|
| 256 |
+
# Model rollout
|
| 257 |
+
model_result = run_model_episode(
|
| 258 |
+
self.ema_model, env_id, self.cfg, self.device, seed,
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
# Oracle rollout (same seed)
|
| 262 |
+
oracle_result = collect_oracle_trajectory(
|
| 263 |
+
env_id, seed, self.cfg,
|
| 264 |
+
)
|
| 265 |
+
oracle_steps = (
|
| 266 |
+
len(oracle_result["actions"]) if oracle_result else 999
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
# Efficiency filter
|
| 270 |
+
add = efficiency_filter(
|
| 271 |
+
model_result["won"],
|
| 272 |
+
model_result["steps"],
|
| 273 |
+
oracle_steps,
|
| 274 |
+
self.cfg.efficiency_multiplier,
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
if add and oracle_result is not None:
|
| 278 |
+
self.buffer.add(oracle_result)
|
| 279 |
+
|
| 280 |
+
self.curriculum.update(env_id, model_result["won"])
|
| 281 |
+
|
| 282 |
+
return {
|
| 283 |
+
"env_id": env_id,
|
| 284 |
+
"model_won": model_result["won"],
|
| 285 |
+
"model_steps": model_result["steps"],
|
| 286 |
+
"oracle_steps": oracle_steps,
|
| 287 |
+
"added_to_buffer": add and oracle_result is not None,
|
| 288 |
+
}
|
| 289 |
+
|
| 290 |
+
def collect_batch_parallel(
|
| 291 |
+
self, n_episodes: int,
|
| 292 |
+
) -> list[dict]:
|
| 293 |
+
"""Collect multiple episodes in parallel using threads.
|
| 294 |
+
|
| 295 |
+
Both NLE env calls and PyTorch CPU inference release the GIL,
|
| 296 |
+
enabling true parallelism. Each thread uses a pre-allocated
|
| 297 |
+
CPU model copy. Weights are synced from EMA once per call.
|
| 298 |
+
|
| 299 |
+
Args:
|
| 300 |
+
n_episodes: Number of episodes to collect.
|
| 301 |
+
|
| 302 |
+
Returns:
|
| 303 |
+
List of per-episode stats dicts.
|
| 304 |
+
"""
|
| 305 |
+
assert self._thread_pool is not None, (
|
| 306 |
+
"collect_batch_parallel requires num_collection_workers > 0"
|
| 307 |
+
)
|
| 308 |
+
self._sync_ema()
|
| 309 |
+
|
| 310 |
+
# Sync EMA weights to all thread-local CPU models
|
| 311 |
+
ema_sd = self.ema_model.state_dict()
|
| 312 |
+
cpu_sd = {k: v.cpu() for k, v in ema_sd.items()}
|
| 313 |
+
for tm in self._thread_models:
|
| 314 |
+
tm.load_state_dict(cpu_sd)
|
| 315 |
+
tm.eval()
|
| 316 |
+
|
| 317 |
+
# Build task list
|
| 318 |
+
tasks = []
|
| 319 |
+
for _ in range(n_episodes):
|
| 320 |
+
env_id = self.curriculum.sample_env()
|
| 321 |
+
seed = random.randint(0, 2**31 - 1)
|
| 322 |
+
tasks.append((env_id, seed))
|
| 323 |
+
|
| 324 |
+
# Round-robin assign models to tasks
|
| 325 |
+
n_models = len(self._thread_models)
|
| 326 |
+
futures = []
|
| 327 |
+
for i, (env_id, seed) in enumerate(tasks):
|
| 328 |
+
model = self._thread_models[i % n_models]
|
| 329 |
+
f = self._thread_pool.submit(
|
| 330 |
+
_collect_episode_thread, model, env_id, seed, self.cfg,
|
| 331 |
+
)
|
| 332 |
+
futures.append(f)
|
| 333 |
+
|
| 334 |
+
results = [f.result() for f in futures]
|
| 335 |
+
|
| 336 |
+
# Process results: efficiency filter + buffer add
|
| 337 |
+
stats_list = []
|
| 338 |
+
for res in results:
|
| 339 |
+
if res is None:
|
| 340 |
+
continue
|
| 341 |
+
|
| 342 |
+
add = efficiency_filter(
|
| 343 |
+
res["model_won"],
|
| 344 |
+
res["model_steps"],
|
| 345 |
+
res["oracle_steps"],
|
| 346 |
+
self.cfg.efficiency_multiplier,
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
oracle_result = res["oracle_result"]
|
| 350 |
+
if add and oracle_result is not None:
|
| 351 |
+
self.buffer.add(oracle_result)
|
| 352 |
+
|
| 353 |
+
self.curriculum.update(res["env_id"], res["model_won"])
|
| 354 |
+
|
| 355 |
+
stats_list.append({
|
| 356 |
+
"env_id": res["env_id"],
|
| 357 |
+
"model_won": res["model_won"],
|
| 358 |
+
"model_steps": res["model_steps"],
|
| 359 |
+
"oracle_steps": res["oracle_steps"],
|
| 360 |
+
"added_to_buffer": add and oracle_result is not None,
|
| 361 |
+
})
|
| 362 |
+
|
| 363 |
+
return stats_list
|
| 364 |
+
|
| 365 |
+
# ── GPU-batched collection ──────────────────────────────────
|
| 366 |
+
|
| 367 |
+
def collect_batch_gpu(self, n_episodes: int) -> list[dict]:
|
| 368 |
+
"""Collect episodes with GPU-batched model inference.
|
| 369 |
+
|
| 370 |
+
Runs all model episodes with batched GPU forward passes
|
| 371 |
+
(B=n_episodes instead of B=1), then runs oracle rollouts
|
| 372 |
+
in parallel threads for efficiency filtering.
|
| 373 |
+
|
| 374 |
+
Args:
|
| 375 |
+
n_episodes: Number of episodes to collect.
|
| 376 |
+
|
| 377 |
+
Returns:
|
| 378 |
+
List of per-episode stats dicts.
|
| 379 |
+
"""
|
| 380 |
+
self._sync_ema()
|
| 381 |
+
cfg = self.cfg
|
| 382 |
+
self._last_profile = {}
|
| 383 |
+
|
| 384 |
+
tasks = [
|
| 385 |
+
(self.curriculum.sample_env(), random.randint(0, 2**31 - 1))
|
| 386 |
+
for _ in range(n_episodes)
|
| 387 |
+
]
|
| 388 |
+
|
| 389 |
+
# Phase 1: GPU-batched model rollouts
|
| 390 |
+
t0 = time.perf_counter()
|
| 391 |
+
model_results = self._run_model_episodes_batched(tasks)
|
| 392 |
+
model_time = time.perf_counter() - t0
|
| 393 |
+
|
| 394 |
+
# Phase 2: Oracle rollouts (threaded, CPU-only BFS)
|
| 395 |
+
t0 = time.perf_counter()
|
| 396 |
+
n_workers = min(n_episodes, os.cpu_count() or 4)
|
| 397 |
+
with ThreadPoolExecutor(max_workers=n_workers) as pool:
|
| 398 |
+
oracle_futures = [
|
| 399 |
+
pool.submit(
|
| 400 |
+
collect_oracle_trajectory, env_id, seed, cfg,
|
| 401 |
+
)
|
| 402 |
+
for env_id, seed in tasks
|
| 403 |
+
]
|
| 404 |
+
oracle_results = [f.result() for f in oracle_futures]
|
| 405 |
+
oracle_time = time.perf_counter() - t0
|
| 406 |
+
|
| 407 |
+
# Phase 3: Efficiency filter + buffer add
|
| 408 |
+
stats_list: list[dict] = []
|
| 409 |
+
for (env_id, _seed), m_res, o_res in zip(
|
| 410 |
+
tasks, model_results, oracle_results,
|
| 411 |
+
):
|
| 412 |
+
oracle_steps = (
|
| 413 |
+
len(o_res["actions"]) if o_res else 999
|
| 414 |
+
)
|
| 415 |
+
add = efficiency_filter(
|
| 416 |
+
m_res["won"],
|
| 417 |
+
m_res["steps"],
|
| 418 |
+
oracle_steps,
|
| 419 |
+
cfg.efficiency_multiplier,
|
| 420 |
+
)
|
| 421 |
+
if add and o_res is not None:
|
| 422 |
+
self.buffer.add(o_res)
|
| 423 |
+
self.curriculum.update(env_id, m_res["won"])
|
| 424 |
+
stats_list.append({
|
| 425 |
+
"env_id": env_id,
|
| 426 |
+
"model_won": m_res["won"],
|
| 427 |
+
"model_steps": m_res["steps"],
|
| 428 |
+
"oracle_steps": oracle_steps,
|
| 429 |
+
"added_to_buffer": add and o_res is not None,
|
| 430 |
+
})
|
| 431 |
+
|
| 432 |
+
self._last_profile["model_rollout_sec"] = model_time
|
| 433 |
+
self._last_profile["oracle_rollout_sec"] = oracle_time
|
| 434 |
+
return stats_list
|
| 435 |
+
|
| 436 |
+
@torch.no_grad()
|
| 437 |
+
def _run_model_episodes_batched(
|
| 438 |
+
self,
|
| 439 |
+
tasks: list[tuple[str, int]],
|
| 440 |
+
) -> list[dict]:
|
| 441 |
+
"""Run model episodes with batched GPU forward passes.
|
| 442 |
+
|
| 443 |
+
Creates one env per episode, steps them in lockstep, and
|
| 444 |
+
batches all replanning into single GPU forward passes
|
| 445 |
+
(B = number of active envs needing a replan).
|
| 446 |
+
|
| 447 |
+
Args:
|
| 448 |
+
tasks: List of ``(env_id, seed)`` pairs.
|
| 449 |
+
|
| 450 |
+
Returns:
|
| 451 |
+
List of trajectory dicts matching
|
| 452 |
+
``run_model_episode`` output format.
|
| 453 |
+
"""
|
| 454 |
+
cfg = self.cfg
|
| 455 |
+
device = self.device
|
| 456 |
+
model = self.ema_model
|
| 457 |
+
model.eval()
|
| 458 |
+
n = len(tasks)
|
| 459 |
+
max_steps = 500
|
| 460 |
+
K = getattr(
|
| 461 |
+
cfg, "diffusion_steps_collect", cfg.diffusion_steps_eval,
|
| 462 |
+
)
|
| 463 |
+
cs = cfg.crop_size
|
| 464 |
+
|
| 465 |
+
# Create and reset all envs
|
| 466 |
+
envs: list = []
|
| 467 |
+
cur_local = np.zeros((n, cs, cs), dtype=np.int16)
|
| 468 |
+
cur_global = np.zeros(
|
| 469 |
+
(n, cfg.map_h, cfg.map_w), dtype=np.int16,
|
| 470 |
+
)
|
| 471 |
+
|
| 472 |
+
t_reset = time.perf_counter()
|
| 473 |
+
for i, (env_id, seed) in enumerate(tasks):
|
| 474 |
+
env = make_env(env_id, None, cfg)
|
| 475 |
+
(local, glb), _ = env.reset(seed=seed)
|
| 476 |
+
envs.append(env)
|
| 477 |
+
cur_local[i] = local
|
| 478 |
+
cur_global[i] = glb
|
| 479 |
+
reset_time = time.perf_counter() - t_reset
|
| 480 |
+
|
| 481 |
+
# Pre-allocate history buffers
|
| 482 |
+
obs_local = np.zeros(
|
| 483 |
+
(n, max_steps + 1, cs, cs), dtype=np.int16,
|
| 484 |
+
)
|
| 485 |
+
obs_global = np.zeros(
|
| 486 |
+
(n, max_steps + 1, cfg.map_h, cfg.map_w),
|
| 487 |
+
dtype=np.int16,
|
| 488 |
+
)
|
| 489 |
+
act_buf = np.zeros((n, max_steps), dtype=np.int64)
|
| 490 |
+
obs_local[:, 0] = cur_local
|
| 491 |
+
obs_global[:, 0] = cur_global
|
| 492 |
+
|
| 493 |
+
# Per-episode state vectors
|
| 494 |
+
plans = np.zeros((n, cfg.seq_len), dtype=np.int64)
|
| 495 |
+
step_in_plan = np.zeros(n, dtype=np.int32)
|
| 496 |
+
need_replan = np.ones(n, dtype=bool)
|
| 497 |
+
done = np.zeros(n, dtype=bool)
|
| 498 |
+
won = np.zeros(n, dtype=bool)
|
| 499 |
+
total_reward = np.zeros(n, dtype=np.float64)
|
| 500 |
+
n_steps = np.zeros(n, dtype=np.int32)
|
| 501 |
+
|
| 502 |
+
inference_time = 0.0
|
| 503 |
+
env_step_time = 0.0
|
| 504 |
+
|
| 505 |
+
try:
|
| 506 |
+
for _ in range(max_steps):
|
| 507 |
+
# Batch replan on GPU
|
| 508 |
+
replan_idx = np.where(
|
| 509 |
+
need_replan & ~done,
|
| 510 |
+
)[0]
|
| 511 |
+
if len(replan_idx) > 0:
|
| 512 |
+
t0 = time.perf_counter()
|
| 513 |
+
local_t = torch.from_numpy(
|
| 514 |
+
cur_local[replan_idx],
|
| 515 |
+
).long().to(device)
|
| 516 |
+
glb_t = torch.from_numpy(
|
| 517 |
+
cur_global[replan_idx],
|
| 518 |
+
).long().to(device)
|
| 519 |
+
batch_plans = greedy_sample(
|
| 520 |
+
model, local_t, glb_t, cfg, device,
|
| 521 |
+
num_steps=K,
|
| 522 |
+
).cpu().numpy()
|
| 523 |
+
plans[replan_idx] = batch_plans
|
| 524 |
+
step_in_plan[replan_idx] = 0
|
| 525 |
+
need_replan[replan_idx] = False
|
| 526 |
+
inference_time += time.perf_counter() - t0
|
| 527 |
+
|
| 528 |
+
# Step all active envs
|
| 529 |
+
t0 = time.perf_counter()
|
| 530 |
+
any_active = False
|
| 531 |
+
for i in range(n):
|
| 532 |
+
if done[i]:
|
| 533 |
+
continue
|
| 534 |
+
any_active = True
|
| 535 |
+
|
| 536 |
+
action = int(plans[i, step_in_plan[i]])
|
| 537 |
+
action = max(
|
| 538 |
+
0, min(action, cfg.action_dim - 1),
|
| 539 |
+
)
|
| 540 |
+
act_buf[i, n_steps[i]] = action
|
| 541 |
+
step_in_plan[i] += 1
|
| 542 |
+
n_steps[i] += 1
|
| 543 |
+
|
| 544 |
+
if step_in_plan[i] >= cfg.replan_every:
|
| 545 |
+
need_replan[i] = True
|
| 546 |
+
|
| 547 |
+
obs, reward, term, trunc, info = (
|
| 548 |
+
envs[i].step(action)
|
| 549 |
+
)
|
| 550 |
+
local, glb = obs
|
| 551 |
+
total_reward[i] += reward
|
| 552 |
+
cur_local[i] = local
|
| 553 |
+
cur_global[i] = glb
|
| 554 |
+
obs_local[i, n_steps[i]] = local
|
| 555 |
+
obs_global[i, n_steps[i]] = glb
|
| 556 |
+
|
| 557 |
+
if info.get("won", False):
|
| 558 |
+
won[i] = True
|
| 559 |
+
if term or trunc:
|
| 560 |
+
done[i] = True
|
| 561 |
+
env_step_time += time.perf_counter() - t0
|
| 562 |
+
|
| 563 |
+
if not any_active:
|
| 564 |
+
break
|
| 565 |
+
finally:
|
| 566 |
+
for env in envs:
|
| 567 |
+
env.close()
|
| 568 |
+
|
| 569 |
+
# Build result dicts
|
| 570 |
+
results: list[dict] = []
|
| 571 |
+
for i in range(n):
|
| 572 |
+
T = int(n_steps[i])
|
| 573 |
+
results.append({
|
| 574 |
+
"local": obs_local[i, :T].copy(),
|
| 575 |
+
"global": obs_global[i, :T].copy(),
|
| 576 |
+
"actions": act_buf[i, :T].copy(),
|
| 577 |
+
"won": bool(won[i]),
|
| 578 |
+
"steps": T,
|
| 579 |
+
"total_reward": float(total_reward[i]),
|
| 580 |
+
"seed": tasks[i][1],
|
| 581 |
+
})
|
| 582 |
+
|
| 583 |
+
self._last_profile.update({
|
| 584 |
+
"env_reset_sec": reset_time,
|
| 585 |
+
"gpu_inference_sec": inference_time,
|
| 586 |
+
"env_step_sec": env_step_time,
|
| 587 |
+
})
|
| 588 |
+
return results
|
src/planners/collect_oracle.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Standalone BFS oracle data collection for offline training datasets.
|
| 2 |
+
|
| 3 |
+
Runs the BFS oracle across in-distribution MiniHack environments using
|
| 4 |
+
multiprocessing and saves the resulting trajectories in the dict format
|
| 5 |
+
expected by ``ReplayBuffer.load_offline_data()``.
|
| 6 |
+
|
| 7 |
+
Usage::
|
| 8 |
+
|
| 9 |
+
python main.py --mode collect
|
| 10 |
+
python main.py --mode collect collect_episodes_per_env=2000
|
| 11 |
+
python main.py --mode collect collect_output=data/small.pt
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import logging
|
| 17 |
+
import os
|
| 18 |
+
import time
|
| 19 |
+
from concurrent.futures import ProcessPoolExecutor, as_completed
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
from types import SimpleNamespace
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
|
| 25 |
+
from src.envs.minihack_env import collect_oracle_trajectory
|
| 26 |
+
|
| 27 |
+
logger = logging.getLogger(__name__)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _collect_single(
|
| 31 |
+
args: tuple[str, int, SimpleNamespace],
|
| 32 |
+
) -> dict | None:
|
| 33 |
+
"""Process-pool worker: collect one oracle trajectory.
|
| 34 |
+
|
| 35 |
+
Module-level function so ``ProcessPoolExecutor`` can pickle it.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
args: ``(env_id, seed, cfg)`` tuple.
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
Trajectory dict with ``"local"``, ``"global"``,
|
| 42 |
+
``"actions"``, ``"env_id"`` keys, or ``None`` on failure.
|
| 43 |
+
"""
|
| 44 |
+
env_id, seed, cfg = args
|
| 45 |
+
return collect_oracle_trajectory(env_id, seed, cfg)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _format_eta(seconds: float) -> str:
|
| 49 |
+
"""Format seconds into a human-readable ETA string.
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
seconds: Remaining time in seconds.
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
Formatted string like ``"2m 30s"`` or ``"45s"``.
|
| 56 |
+
"""
|
| 57 |
+
if seconds < 60:
|
| 58 |
+
return f"{seconds:.0f}s"
|
| 59 |
+
minutes = int(seconds // 60)
|
| 60 |
+
secs = int(seconds % 60)
|
| 61 |
+
return f"{minutes}m {secs:02d}s"
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def run_collect(cfg: SimpleNamespace) -> None:
|
| 65 |
+
"""Collect BFS oracle demonstrations and save as a .pt dataset.
|
| 66 |
+
|
| 67 |
+
Collects ``collect_episodes_per_env`` episodes per ID environment
|
| 68 |
+
using ``ProcessPoolExecutor`` for parallelism, then saves the
|
| 69 |
+
trajectories in the dict format consumed by
|
| 70 |
+
``ReplayBuffer.load_offline_data()``.
|
| 71 |
+
|
| 72 |
+
The output file can be loaded directly by ``--mode offline``::
|
| 73 |
+
|
| 74 |
+
python main.py --mode collect
|
| 75 |
+
python main.py --mode offline --data data/dataset.pt
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
cfg: Config namespace. Reads ``collect_episodes_per_env``,
|
| 79 |
+
``collect_num_workers``, ``collect_output``, ``id_envs``,
|
| 80 |
+
``seed``.
|
| 81 |
+
"""
|
| 82 |
+
eps_per_env: int = cfg.collect_episodes_per_env
|
| 83 |
+
max_workers: int = min(
|
| 84 |
+
cfg.collect_num_workers, os.cpu_count() or 4,
|
| 85 |
+
)
|
| 86 |
+
output_path: str = cfg.collect_output
|
| 87 |
+
id_envs: list[str] = cfg.id_envs
|
| 88 |
+
base_seed: int = cfg.seed if cfg.seed is not None else 0
|
| 89 |
+
|
| 90 |
+
total_episodes = eps_per_env * len(id_envs)
|
| 91 |
+
logger.info(
|
| 92 |
+
"Collecting %d oracle episodes "
|
| 93 |
+
"(%d per env, %d envs, %d workers)",
|
| 94 |
+
total_episodes, eps_per_env, len(id_envs), max_workers,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
# Deterministic task list: (env_id, seed, cfg) per episode
|
| 98 |
+
tasks: list[tuple[str, int, SimpleNamespace]] = []
|
| 99 |
+
for env_idx, env_id in enumerate(id_envs):
|
| 100 |
+
for ep in range(eps_per_env):
|
| 101 |
+
seed = base_seed + env_idx * eps_per_env + ep
|
| 102 |
+
tasks.append((env_id, seed, cfg))
|
| 103 |
+
|
| 104 |
+
trajectories: list[dict] = []
|
| 105 |
+
per_env_count: dict[str, int] = {eid: 0 for eid in id_envs}
|
| 106 |
+
per_env_steps: dict[str, int] = {eid: 0 for eid in id_envs}
|
| 107 |
+
failures = 0
|
| 108 |
+
completed = 0
|
| 109 |
+
t_start = time.perf_counter()
|
| 110 |
+
log_interval = max(1, total_episodes // 50)
|
| 111 |
+
|
| 112 |
+
with ProcessPoolExecutor(max_workers=max_workers) as executor:
|
| 113 |
+
future_to_env: dict = {
|
| 114 |
+
executor.submit(_collect_single, task): task[0]
|
| 115 |
+
for task in tasks
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
for future in as_completed(future_to_env):
|
| 119 |
+
env_id = future_to_env[future]
|
| 120 |
+
completed += 1
|
| 121 |
+
|
| 122 |
+
try:
|
| 123 |
+
result = future.result()
|
| 124 |
+
except Exception:
|
| 125 |
+
logger.error(
|
| 126 |
+
"Worker crashed for %s", env_id, exc_info=True,
|
| 127 |
+
)
|
| 128 |
+
result = None
|
| 129 |
+
|
| 130 |
+
if result is not None:
|
| 131 |
+
trajectories.append(result)
|
| 132 |
+
per_env_count[env_id] += 1
|
| 133 |
+
per_env_steps[env_id] += len(result["actions"])
|
| 134 |
+
else:
|
| 135 |
+
failures += 1
|
| 136 |
+
|
| 137 |
+
if (
|
| 138 |
+
completed % log_interval == 0
|
| 139 |
+
or completed == total_episodes
|
| 140 |
+
):
|
| 141 |
+
elapsed = time.perf_counter() - t_start
|
| 142 |
+
rate = completed / max(elapsed, 1e-6)
|
| 143 |
+
eta = (total_episodes - completed) / max(rate, 1e-6)
|
| 144 |
+
env_summary = " ".join(
|
| 145 |
+
f"{eid.split('-')[-2]}:{per_env_count[eid]}"
|
| 146 |
+
for eid in id_envs
|
| 147 |
+
)
|
| 148 |
+
logger.info(
|
| 149 |
+
" %d/%d (%.1f%%) %.1f eps/s ETA: %s | %s",
|
| 150 |
+
completed, total_episodes,
|
| 151 |
+
100 * completed / total_episodes,
|
| 152 |
+
rate, _format_eta(eta), env_summary,
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
elapsed = time.perf_counter() - t_start
|
| 156 |
+
|
| 157 |
+
# Summary
|
| 158 |
+
total_steps = sum(per_env_steps.values())
|
| 159 |
+
logger.info("Collection complete in %.1fs", elapsed)
|
| 160 |
+
logger.info(
|
| 161 |
+
" Trajectories: %d (%d failures)",
|
| 162 |
+
len(trajectories), failures,
|
| 163 |
+
)
|
| 164 |
+
logger.info(" Total steps: %d", total_steps)
|
| 165 |
+
for env_id in id_envs:
|
| 166 |
+
n = per_env_count[env_id]
|
| 167 |
+
s = per_env_steps[env_id]
|
| 168 |
+
avg = s / max(n, 1)
|
| 169 |
+
logger.info(
|
| 170 |
+
" %s: %d eps, %d steps, avg %.1f steps/ep",
|
| 171 |
+
env_id, n, s, avg,
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
# Save in the dict format expected by ReplayBuffer.load_offline_data()
|
| 175 |
+
out = Path(output_path).resolve()
|
| 176 |
+
out.parent.mkdir(parents=True, exist_ok=True)
|
| 177 |
+
|
| 178 |
+
dataset: dict = {"trajectories": trajectories}
|
| 179 |
+
torch.save(dataset, str(out))
|
| 180 |
+
|
| 181 |
+
file_mb = out.stat().st_size / (1024 * 1024)
|
| 182 |
+
logger.info(
|
| 183 |
+
"Saved %d trajectories to %s (%.1f MB)",
|
| 184 |
+
len(trajectories), out, file_mb,
|
| 185 |
+
)
|
src/planners/inference.py
ADDED
|
@@ -0,0 +1,360 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Stateless evaluation runner.
|
| 2 |
+
|
| 3 |
+
Runs episodes using the diffusion model and collects per-environment
|
| 4 |
+
win rates, average rewards, and step counts. All episodes for a given
|
| 5 |
+
environment are rolled out in lockstep so that replanning calls are
|
| 6 |
+
batched into single GPU forward passes (B = n_episodes).
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import json
|
| 12 |
+
import logging
|
| 13 |
+
from datetime import datetime, timezone
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
from types import SimpleNamespace
|
| 16 |
+
|
| 17 |
+
import numpy as np
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
from src.models.denoiser import ModelEMA, make_model
|
| 21 |
+
from src.planners.logging import Logger
|
| 22 |
+
|
| 23 |
+
logger = logging.getLogger(__name__)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class Evaluator:
|
| 27 |
+
"""Stateless evaluation runner.
|
| 28 |
+
|
| 29 |
+
Runs the model on a set of environments and returns aggregate
|
| 30 |
+
statistics per environment. Episodes within each environment are
|
| 31 |
+
executed in lockstep so replanning calls are GPU-batched.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
@torch.no_grad()
|
| 35 |
+
def evaluate(
|
| 36 |
+
self,
|
| 37 |
+
env_ids: list[str],
|
| 38 |
+
model: torch.nn.Module,
|
| 39 |
+
n_episodes: int,
|
| 40 |
+
cfg: SimpleNamespace,
|
| 41 |
+
device: torch.device | str,
|
| 42 |
+
des_files: list[str] | None = None,
|
| 43 |
+
blind_global: bool = False,
|
| 44 |
+
) -> dict[str, dict]:
|
| 45 |
+
"""Evaluate *model* on each environment in *env_ids*.
|
| 46 |
+
|
| 47 |
+
All *n_episodes* for a given environment run in lockstep so
|
| 48 |
+
that replanning forward passes are batched (B = active envs
|
| 49 |
+
needing a replan).
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
env_ids: List of MiniHack environment IDs.
|
| 53 |
+
model: Denoising model (eval mode).
|
| 54 |
+
n_episodes: Episodes per environment.
|
| 55 |
+
cfg: Config namespace.
|
| 56 |
+
device: Torch device.
|
| 57 |
+
des_files: Optional list of ``.des`` file paths for custom
|
| 58 |
+
scenario evaluation. Each file yields one extra env entry
|
| 59 |
+
keyed by its filename stem.
|
| 60 |
+
blind_global: If ``True``, zero out global map observations
|
| 61 |
+
(local-only ablation mode).
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
``{env_id: {"win_rate", "wins", "avg_reward", "avg_steps",
|
| 65 |
+
"n_episodes"}}``
|
| 66 |
+
"""
|
| 67 |
+
model.eval()
|
| 68 |
+
results: dict[str, dict] = {}
|
| 69 |
+
|
| 70 |
+
# Build list of (env_id, des_content) pairs
|
| 71 |
+
eval_targets: list[tuple[str, str | None]] = [
|
| 72 |
+
(eid, None) for eid in env_ids
|
| 73 |
+
]
|
| 74 |
+
if des_files:
|
| 75 |
+
for des_path in des_files:
|
| 76 |
+
from pathlib import Path
|
| 77 |
+
stem = Path(des_path).stem
|
| 78 |
+
with open(des_path) as fh:
|
| 79 |
+
eval_targets.append((stem, fh.read()))
|
| 80 |
+
|
| 81 |
+
for env_id, des_content in eval_targets:
|
| 82 |
+
seeds = [
|
| 83 |
+
42 + hash((env_id, ep)) % (2**31)
|
| 84 |
+
for ep in range(n_episodes)
|
| 85 |
+
]
|
| 86 |
+
ep_results = self._run_episodes_batched(
|
| 87 |
+
model, env_id, n_episodes, cfg, device,
|
| 88 |
+
seeds=seeds,
|
| 89 |
+
des_content=des_content,
|
| 90 |
+
blind_global=blind_global,
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
wins = sum(1 for r in ep_results if r["won"])
|
| 94 |
+
total_reward = sum(r["total_reward"] for r in ep_results)
|
| 95 |
+
total_steps = sum(r["steps"] for r in ep_results)
|
| 96 |
+
n = max(len(ep_results), 1)
|
| 97 |
+
results[env_id] = {
|
| 98 |
+
"win_rate": wins / n,
|
| 99 |
+
"wins": wins,
|
| 100 |
+
"avg_reward": total_reward / n,
|
| 101 |
+
"avg_steps": total_steps / n,
|
| 102 |
+
"n_episodes": len(ep_results),
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
return results
|
| 106 |
+
|
| 107 |
+
@torch.no_grad()
|
| 108 |
+
def _run_episodes_batched(
|
| 109 |
+
self,
|
| 110 |
+
model: torch.nn.Module,
|
| 111 |
+
env_id: str,
|
| 112 |
+
n_episodes: int,
|
| 113 |
+
cfg: SimpleNamespace,
|
| 114 |
+
device: torch.device | str,
|
| 115 |
+
seeds: list[int],
|
| 116 |
+
des_content: str | None = None,
|
| 117 |
+
blind_global: bool = False,
|
| 118 |
+
) -> list[dict]:
|
| 119 |
+
"""Run episodes in lockstep with batched model inference.
|
| 120 |
+
|
| 121 |
+
Creates one environment per episode, steps them in lockstep,
|
| 122 |
+
and batches all replanning calls into single forward passes
|
| 123 |
+
(B = number of active envs needing a replan at each step).
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
model: Denoising model (eval mode).
|
| 127 |
+
env_id: MiniHack environment ID.
|
| 128 |
+
n_episodes: Number of episodes to run.
|
| 129 |
+
cfg: Config namespace.
|
| 130 |
+
device: Torch device.
|
| 131 |
+
seeds: Per-episode RNG seeds (length *n_episodes*).
|
| 132 |
+
des_content: Optional ``.des`` file content for custom
|
| 133 |
+
scenarios.
|
| 134 |
+
blind_global: If ``True``, zero out global map observations.
|
| 135 |
+
|
| 136 |
+
Returns:
|
| 137 |
+
List of per-episode dicts with ``"won"``, ``"steps"``,
|
| 138 |
+
``"total_reward"`` keys. Failed episodes report
|
| 139 |
+
``won=False``.
|
| 140 |
+
"""
|
| 141 |
+
from src.diffusion.sampling import remdm_sample
|
| 142 |
+
from src.envs.minihack_env import make_env
|
| 143 |
+
|
| 144 |
+
n = n_episodes
|
| 145 |
+
max_steps = 500
|
| 146 |
+
cs = cfg.crop_size
|
| 147 |
+
|
| 148 |
+
# Create and reset all envs
|
| 149 |
+
envs: list = []
|
| 150 |
+
cur_local = np.zeros((n, cs, cs), dtype=np.int16)
|
| 151 |
+
cur_global = np.zeros(
|
| 152 |
+
(n, cfg.map_h, cfg.map_w), dtype=np.int16,
|
| 153 |
+
)
|
| 154 |
+
failed = np.zeros(n, dtype=bool)
|
| 155 |
+
|
| 156 |
+
for i in range(n):
|
| 157 |
+
try:
|
| 158 |
+
env = make_env(env_id, des_content, cfg)
|
| 159 |
+
(local, glb), _ = env.reset(seed=seeds[i])
|
| 160 |
+
envs.append(env)
|
| 161 |
+
cur_local[i] = local
|
| 162 |
+
cur_global[i] = glb
|
| 163 |
+
except Exception:
|
| 164 |
+
logger.warning(
|
| 165 |
+
"Failed to create env %s (ep %d)",
|
| 166 |
+
env_id, i, exc_info=True,
|
| 167 |
+
)
|
| 168 |
+
envs.append(None)
|
| 169 |
+
failed[i] = True
|
| 170 |
+
|
| 171 |
+
# Per-episode state vectors
|
| 172 |
+
plans = np.zeros((n, cfg.seq_len), dtype=np.int64)
|
| 173 |
+
step_in_plan = np.zeros(n, dtype=np.int32)
|
| 174 |
+
need_replan = np.ones(n, dtype=bool)
|
| 175 |
+
done = failed.copy()
|
| 176 |
+
won = np.zeros(n, dtype=bool)
|
| 177 |
+
total_reward = np.zeros(n, dtype=np.float64)
|
| 178 |
+
n_steps = np.zeros(n, dtype=np.int32)
|
| 179 |
+
|
| 180 |
+
try:
|
| 181 |
+
for _ in range(max_steps):
|
| 182 |
+
# Batch replan for active envs that need it
|
| 183 |
+
replan_idx = np.where(need_replan & ~done)[0]
|
| 184 |
+
if len(replan_idx) > 0:
|
| 185 |
+
local_t = torch.from_numpy(
|
| 186 |
+
cur_local[replan_idx],
|
| 187 |
+
).long().to(device) # [B_r, cs, cs]
|
| 188 |
+
glb_t = torch.from_numpy(
|
| 189 |
+
cur_global[replan_idx],
|
| 190 |
+
).long().to(device) # [B_r, map_h, map_w]
|
| 191 |
+
batch_plans = remdm_sample(
|
| 192 |
+
model, local_t, glb_t, cfg, device,
|
| 193 |
+
physics_aware=getattr(
|
| 194 |
+
cfg, "physics_aware_sampling", False,
|
| 195 |
+
),
|
| 196 |
+
blind_global=blind_global,
|
| 197 |
+
).cpu().numpy() # [B_r, seq_len]
|
| 198 |
+
plans[replan_idx] = batch_plans
|
| 199 |
+
step_in_plan[replan_idx] = 0
|
| 200 |
+
need_replan[replan_idx] = False
|
| 201 |
+
|
| 202 |
+
# Step all active envs
|
| 203 |
+
any_active = False
|
| 204 |
+
for i in range(n):
|
| 205 |
+
if done[i]:
|
| 206 |
+
continue
|
| 207 |
+
any_active = True
|
| 208 |
+
|
| 209 |
+
action = int(plans[i, step_in_plan[i]])
|
| 210 |
+
action = max(
|
| 211 |
+
0, min(action, cfg.action_dim - 1),
|
| 212 |
+
)
|
| 213 |
+
step_in_plan[i] += 1
|
| 214 |
+
n_steps[i] += 1
|
| 215 |
+
|
| 216 |
+
if step_in_plan[i] >= cfg.replan_every:
|
| 217 |
+
need_replan[i] = True
|
| 218 |
+
|
| 219 |
+
try:
|
| 220 |
+
obs, reward, term, trunc, info = (
|
| 221 |
+
envs[i].step(action)
|
| 222 |
+
)
|
| 223 |
+
local, glb = obs
|
| 224 |
+
total_reward[i] += reward
|
| 225 |
+
cur_local[i] = local
|
| 226 |
+
cur_global[i] = glb
|
| 227 |
+
|
| 228 |
+
if info.get("won", False):
|
| 229 |
+
won[i] = True
|
| 230 |
+
if term or trunc:
|
| 231 |
+
done[i] = True
|
| 232 |
+
except Exception:
|
| 233 |
+
logger.warning(
|
| 234 |
+
"Episode %d step failed for %s",
|
| 235 |
+
i, env_id, exc_info=True,
|
| 236 |
+
)
|
| 237 |
+
done[i] = True
|
| 238 |
+
|
| 239 |
+
if not any_active:
|
| 240 |
+
break
|
| 241 |
+
finally:
|
| 242 |
+
for env in envs:
|
| 243 |
+
if env is not None:
|
| 244 |
+
env.close()
|
| 245 |
+
|
| 246 |
+
return [
|
| 247 |
+
{
|
| 248 |
+
"won": bool(won[i]),
|
| 249 |
+
"steps": int(n_steps[i]),
|
| 250 |
+
"total_reward": float(total_reward[i]),
|
| 251 |
+
}
|
| 252 |
+
for i in range(n)
|
| 253 |
+
]
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def format_eval_results(
|
| 257 |
+
results: dict[str, dict], label: str = "Eval",
|
| 258 |
+
) -> str:
|
| 259 |
+
"""Format evaluation results as an ASCII table.
|
| 260 |
+
|
| 261 |
+
Args:
|
| 262 |
+
results: Output of ``Evaluator.evaluate``.
|
| 263 |
+
label: Table header label.
|
| 264 |
+
|
| 265 |
+
Returns:
|
| 266 |
+
Formatted string.
|
| 267 |
+
"""
|
| 268 |
+
lines = [f"{'=' * 60}", f" {label} Results", f"{'=' * 60}"]
|
| 269 |
+
lines.append(
|
| 270 |
+
f" {'Environment':<35} {'WinRate':>8} {'Steps':>8}"
|
| 271 |
+
)
|
| 272 |
+
lines.append(f" {'-' * 53}")
|
| 273 |
+
for env_id, stats in results.items():
|
| 274 |
+
wr = f"{stats['win_rate']:.2%}"
|
| 275 |
+
st = f"{stats['avg_steps']:.1f}"
|
| 276 |
+
lines.append(f" {env_id:<35} {wr:>8} {st:>8}")
|
| 277 |
+
lines.append(f"{'=' * 60}")
|
| 278 |
+
return "\n".join(lines)
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def save_eval_json(
|
| 282 |
+
results: dict,
|
| 283 |
+
path: str,
|
| 284 |
+
metadata: dict | None = None,
|
| 285 |
+
) -> None:
|
| 286 |
+
"""Save evaluation results to a JSON file.
|
| 287 |
+
|
| 288 |
+
Args:
|
| 289 |
+
results: Evaluation results dict.
|
| 290 |
+
path: Output file path.
|
| 291 |
+
metadata: Optional extra metadata (e.g. iteration).
|
| 292 |
+
"""
|
| 293 |
+
payload = {
|
| 294 |
+
"timestamp": datetime.now(timezone.utc).isoformat(),
|
| 295 |
+
"results": results,
|
| 296 |
+
}
|
| 297 |
+
if metadata:
|
| 298 |
+
payload["metadata"] = metadata
|
| 299 |
+
resolved = str(Path(path).resolve())
|
| 300 |
+
Path(resolved).parent.mkdir(parents=True, exist_ok=True)
|
| 301 |
+
try:
|
| 302 |
+
with open(resolved, "w") as f:
|
| 303 |
+
json.dump(payload, f, indent=2, default=str)
|
| 304 |
+
except Exception:
|
| 305 |
+
logger.error(f"Failed to save eval JSON to {resolved}", exc_info=True)
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def run_inference(
|
| 309 |
+
cfg,
|
| 310 |
+
checkpoint_path: str,
|
| 311 |
+
env_ids: list[str] | None,
|
| 312 |
+
episodes: int,
|
| 313 |
+
output_path: str | None,
|
| 314 |
+
use_ema: bool,
|
| 315 |
+
log: Logger | None = None,
|
| 316 |
+
des_files: list[str] | None = None,
|
| 317 |
+
blind_global: bool = False,
|
| 318 |
+
) -> None:
|
| 319 |
+
"""Evaluate a checkpoint on specified environments."""
|
| 320 |
+
|
| 321 |
+
device = cfg.device
|
| 322 |
+
logger.info(f"Inference on {device}")
|
| 323 |
+
|
| 324 |
+
model = make_model(cfg).to(device)
|
| 325 |
+
ckpt = torch.load(
|
| 326 |
+
checkpoint_path, map_location=device, weights_only=False,
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
if "model_state_dict" in ckpt:
|
| 330 |
+
model.load_state_dict(ckpt["model_state_dict"])
|
| 331 |
+
if use_ema and "ema_state_dict" in ckpt:
|
| 332 |
+
ema = ModelEMA(model, decay=cfg.ema_decay)
|
| 333 |
+
ema.load_state_dict(ckpt["ema_state_dict"])
|
| 334 |
+
ema.apply_to(model)
|
| 335 |
+
else:
|
| 336 |
+
model.load_state_dict(ckpt)
|
| 337 |
+
|
| 338 |
+
model.eval()
|
| 339 |
+
|
| 340 |
+
if env_ids is None:
|
| 341 |
+
env_ids = cfg.id_envs + cfg.ood_envs
|
| 342 |
+
|
| 343 |
+
evaluator = Evaluator()
|
| 344 |
+
results = evaluator.evaluate(
|
| 345 |
+
env_ids, model, episodes, cfg, device,
|
| 346 |
+
des_files=des_files, blind_global=blind_global,
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
print(format_eval_results(results, label="Inference"))
|
| 350 |
+
|
| 351 |
+
if log is not None:
|
| 352 |
+
log.log_eval(results, step=0, prefix="inference")
|
| 353 |
+
log.log_summary(
|
| 354 |
+
{f"inference/{env_id}/win_rate": stats["win_rate"]
|
| 355 |
+
for env_id, stats in results.items()}
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
if output_path:
|
| 359 |
+
save_eval_json(results, output_path)
|
| 360 |
+
logger.info(f"Results saved to {output_path}")
|
src/planners/logging.py
ADDED
|
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Centralised W&B and stdout logging.
|
| 2 |
+
|
| 3 |
+
Mirrors the Craftax logging conventions with metric namespaces:
|
| 4 |
+
``diffusion/``, ``train/``, ``eval_id/``, ``eval_ood/``.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import logging
|
| 10 |
+
import torch
|
| 11 |
+
from typing import TYPE_CHECKING
|
| 12 |
+
from types import SimpleNamespace
|
| 13 |
+
|
| 14 |
+
if TYPE_CHECKING:
|
| 15 |
+
from wandb.sdk.wandb_run import Run as _WandbRun
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def download_artifact(
|
| 21 |
+
artifact_ref: str, dst_dir: str = "artifacts",
|
| 22 |
+
) -> str | None:
|
| 23 |
+
"""Download a W&B artifact via the public API (no active run needed).
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
artifact_ref: Fully qualified artifact reference, e.g.
|
| 27 |
+
``"entity/project/checkpoint-iter1000:latest"``.
|
| 28 |
+
dst_dir: Local directory to download into.
|
| 29 |
+
|
| 30 |
+
Returns:
|
| 31 |
+
Path to the ``.pth`` file inside the downloaded artifact
|
| 32 |
+
directory, or ``None`` on failure.
|
| 33 |
+
"""
|
| 34 |
+
try:
|
| 35 |
+
import wandb
|
| 36 |
+
from pathlib import Path
|
| 37 |
+
|
| 38 |
+
api = wandb.Api()
|
| 39 |
+
artifact = api.artifact(artifact_ref)
|
| 40 |
+
artifact_dir = artifact.download(root=dst_dir)
|
| 41 |
+
pth_files = list(Path(artifact_dir).glob("*.pth"))
|
| 42 |
+
if not pth_files:
|
| 43 |
+
logger.error(
|
| 44 |
+
f"No .pth file found in artifact {artifact_ref}"
|
| 45 |
+
)
|
| 46 |
+
return None
|
| 47 |
+
path = str(pth_files[0])
|
| 48 |
+
logger.info(f"Downloaded artifact {artifact_ref} -> {path}")
|
| 49 |
+
return path
|
| 50 |
+
except Exception:
|
| 51 |
+
logger.error(
|
| 52 |
+
f"Failed to download artifact {artifact_ref}",
|
| 53 |
+
exc_info=True,
|
| 54 |
+
)
|
| 55 |
+
return None
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def _auto_run_name(cfg: SimpleNamespace) -> str:
|
| 59 |
+
"""Generate a descriptive W&B run name from key hyperparameters.
|
| 60 |
+
|
| 61 |
+
Format: ``seq{seq_len}_d{n_embd}_L{n_layer}_lr{dagger_lr}_bs{batch}_eta{eta}_{remask}``
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
cfg: Config namespace.
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
A concise, human-readable run name.
|
| 68 |
+
"""
|
| 69 |
+
parts = [
|
| 70 |
+
f"seq{cfg.seq_len}",
|
| 71 |
+
f"d{cfg.n_embd}",
|
| 72 |
+
f"L{cfg.n_layer}",
|
| 73 |
+
f"lr{cfg.dagger_lr:.0e}",
|
| 74 |
+
f"bs{cfg.dagger_batch_size}",
|
| 75 |
+
f"eta{cfg.eta}",
|
| 76 |
+
f"{cfg.remask_strategy}",
|
| 77 |
+
]
|
| 78 |
+
if cfg.use_importance_weighting:
|
| 79 |
+
parts.append("subs")
|
| 80 |
+
if getattr(cfg, "physics_aware_sampling", False):
|
| 81 |
+
parts.append("phys")
|
| 82 |
+
if cfg.seed is not None:
|
| 83 |
+
parts.append(f"s{cfg.seed}")
|
| 84 |
+
return "_".join(parts)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class Logger:
|
| 88 |
+
"""Centralised logger for W&B and stdout.
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
cfg: Config namespace with ``use_wandb``, ``wandb_project``,
|
| 92 |
+
``wandb_entity``, ``seed``.
|
| 93 |
+
"""
|
| 94 |
+
|
| 95 |
+
def __init__(self, cfg: SimpleNamespace) -> None:
|
| 96 |
+
self._use_wandb = cfg.use_wandb
|
| 97 |
+
self._run: _WandbRun | None = None
|
| 98 |
+
if self._use_wandb:
|
| 99 |
+
try:
|
| 100 |
+
import wandb
|
| 101 |
+
run_name = getattr(cfg, "wandb_run_name", None)
|
| 102 |
+
if not run_name:
|
| 103 |
+
run_name = _auto_run_name(cfg)
|
| 104 |
+
resume_id = getattr(cfg, "wandb_resume_id", None)
|
| 105 |
+
self._run = wandb.init(
|
| 106 |
+
project=cfg.wandb_project,
|
| 107 |
+
entity=cfg.wandb_entity or None,
|
| 108 |
+
name=run_name,
|
| 109 |
+
config=vars(cfg),
|
| 110 |
+
id=resume_id or None,
|
| 111 |
+
resume="must" if resume_id else "never",
|
| 112 |
+
)
|
| 113 |
+
# Define custom metric x-axes
|
| 114 |
+
wandb.define_metric("iteration")
|
| 115 |
+
for ns in (
|
| 116 |
+
"diffusion/*", "train/*", "perf/*", "speed/*",
|
| 117 |
+
"model/*",
|
| 118 |
+
"eval_id/*", "eval_ood/*",
|
| 119 |
+
"curriculum/*",
|
| 120 |
+
"ckpt_eval_id/*", "ckpt_eval_ood/*", "ckpt_eval/*",
|
| 121 |
+
"inference/*",
|
| 122 |
+
):
|
| 123 |
+
wandb.define_metric(ns, step_metric="iteration")
|
| 124 |
+
except Exception:
|
| 125 |
+
logger.error("W&B init failed", exc_info=True)
|
| 126 |
+
self._use_wandb = False
|
| 127 |
+
|
| 128 |
+
def log_summary(self, metrics: dict) -> None:
|
| 129 |
+
"""Write key/value pairs to the wandb run summary (final aggregates).
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
metrics: Flat ``{key: value}`` dict.
|
| 133 |
+
"""
|
| 134 |
+
if self._use_wandb and self._run is not None:
|
| 135 |
+
try:
|
| 136 |
+
self._run.summary.update(metrics)
|
| 137 |
+
except Exception:
|
| 138 |
+
pass
|
| 139 |
+
|
| 140 |
+
def log(self, metrics: dict, step: int) -> None:
|
| 141 |
+
"""Log a dict of metrics.
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
metrics: Flat ``{namespace/key: value}`` dict.
|
| 145 |
+
step: Global step index.
|
| 146 |
+
"""
|
| 147 |
+
if self._use_wandb and self._run is not None:
|
| 148 |
+
try:
|
| 149 |
+
import wandb
|
| 150 |
+
# Include "iteration" so define_metric(step_metric="iteration") works
|
| 151 |
+
wandb.log({**metrics, "iteration": step}, step=step)
|
| 152 |
+
except Exception:
|
| 153 |
+
pass
|
| 154 |
+
|
| 155 |
+
# Stdout summary every 10 steps
|
| 156 |
+
if step % 10 == 0:
|
| 157 |
+
parts = [f"step={step}"]
|
| 158 |
+
for k, v in metrics.items():
|
| 159 |
+
if isinstance(v, float):
|
| 160 |
+
if abs(v) < 1e-3 and v != 0.0:
|
| 161 |
+
parts.append(f"{k}={v:.2e}")
|
| 162 |
+
else:
|
| 163 |
+
parts.append(f"{k}={v:.4f}")
|
| 164 |
+
else:
|
| 165 |
+
parts.append(f"{k}={v}")
|
| 166 |
+
logger.info(" ".join(parts))
|
| 167 |
+
|
| 168 |
+
def log_eval(
|
| 169 |
+
self, results: dict[str, dict], step: int, prefix: str,
|
| 170 |
+
) -> None:
|
| 171 |
+
"""Flatten evaluation results and log them.
|
| 172 |
+
|
| 173 |
+
Args:
|
| 174 |
+
results: ``{env_id: {"win_rate", ...}}``
|
| 175 |
+
step: Global step.
|
| 176 |
+
prefix: Metric namespace prefix (e.g. ``"eval_id"``).
|
| 177 |
+
"""
|
| 178 |
+
flat: dict[str, float] = {}
|
| 179 |
+
for env_id, stats in results.items():
|
| 180 |
+
for key, val in stats.items():
|
| 181 |
+
if isinstance(val, (int, float)):
|
| 182 |
+
flat[f"{prefix}/{env_id}/{key}"] = val
|
| 183 |
+
self.log(flat, step=step)
|
| 184 |
+
|
| 185 |
+
def log_checkpoint_artifact(
|
| 186 |
+
self,
|
| 187 |
+
checkpoint_path: str,
|
| 188 |
+
config_path: str | None,
|
| 189 |
+
iteration: int,
|
| 190 |
+
metadata: dict | None = None,
|
| 191 |
+
artifact_name: str | None = None,
|
| 192 |
+
) -> None:
|
| 193 |
+
"""Upload a checkpoint as a W&B artifact with config attached.
|
| 194 |
+
|
| 195 |
+
Args:
|
| 196 |
+
checkpoint_path: Path to the ``.pth`` checkpoint file.
|
| 197 |
+
config_path: Path to the YAML config snapshot to attach.
|
| 198 |
+
If ``None``, only the checkpoint is uploaded.
|
| 199 |
+
iteration: Iteration number (used in the default artifact
|
| 200 |
+
name when ``artifact_name`` is not provided).
|
| 201 |
+
metadata: Optional metadata dict stored on the artifact.
|
| 202 |
+
artifact_name: Optional explicit artifact name. When
|
| 203 |
+
``None``, defaults to ``f"checkpoint-iter{iteration}"``.
|
| 204 |
+
Offline BC passes a step-based name to avoid the
|
| 205 |
+
misleading "iter" prefix.
|
| 206 |
+
"""
|
| 207 |
+
if not self._use_wandb or self._run is None:
|
| 208 |
+
return
|
| 209 |
+
try:
|
| 210 |
+
import wandb
|
| 211 |
+
|
| 212 |
+
name = artifact_name or f"checkpoint-iter{iteration}"
|
| 213 |
+
artifact = wandb.Artifact(
|
| 214 |
+
name=name,
|
| 215 |
+
type="model",
|
| 216 |
+
metadata=metadata or {},
|
| 217 |
+
)
|
| 218 |
+
artifact.add_file(checkpoint_path)
|
| 219 |
+
if config_path is not None:
|
| 220 |
+
artifact.add_file(config_path, name="config.yaml")
|
| 221 |
+
logged = self._run.log_artifact(artifact) # type: ignore[union-attr]
|
| 222 |
+
logged.wait() # block until upload completes
|
| 223 |
+
logger.info("W&B artifact uploaded: %s", name)
|
| 224 |
+
except Exception:
|
| 225 |
+
logger.error("W&B artifact upload failed", exc_info=True)
|
| 226 |
+
|
| 227 |
+
def finish(self) -> None:
|
| 228 |
+
"""Close the W&B run if active."""
|
| 229 |
+
if self._use_wandb and self._run is not None:
|
| 230 |
+
try:
|
| 231 |
+
import wandb
|
| 232 |
+
wandb.finish()
|
| 233 |
+
except Exception:
|
| 234 |
+
pass
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
# ---------------------------------------------------------------------------
|
| 238 |
+
# Metric helper functions (used by both src/ and experiments/)
|
| 239 |
+
# ---------------------------------------------------------------------------
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def gpu_memory_mb() -> float:
|
| 243 |
+
"""Return peak GPU memory allocated in MB since last reset.
|
| 244 |
+
|
| 245 |
+
Returns:
|
| 246 |
+
Peak memory in MB, or 0.0 if CUDA is unavailable.
|
| 247 |
+
"""
|
| 248 |
+
if torch.cuda.is_available():
|
| 249 |
+
return torch.cuda.max_memory_allocated() / (1024 * 1024)
|
| 250 |
+
return 0.0
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def reset_gpu_memory_stats() -> None:
|
| 254 |
+
"""Reset GPU peak memory stats for the current device."""
|
| 255 |
+
if torch.cuda.is_available():
|
| 256 |
+
torch.cuda.reset_peak_memory_stats()
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def compute_param_norm(model: torch.nn.Module) -> float:
|
| 260 |
+
"""Compute total L2 norm of all model parameters.
|
| 261 |
+
|
| 262 |
+
Args:
|
| 263 |
+
model: The model.
|
| 264 |
+
|
| 265 |
+
Returns:
|
| 266 |
+
Total L2 norm as a float.
|
| 267 |
+
"""
|
| 268 |
+
total = 0.0
|
| 269 |
+
for p in model.parameters():
|
| 270 |
+
total += p.data.norm(2).item() ** 2
|
| 271 |
+
return total ** 0.5
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def compute_param_drift(
|
| 275 |
+
model: torch.nn.Module,
|
| 276 |
+
ref_state: dict[str, torch.Tensor],
|
| 277 |
+
) -> float:
|
| 278 |
+
"""Compute L2 distance between current model params and a reference state.
|
| 279 |
+
|
| 280 |
+
Args:
|
| 281 |
+
model: Current model.
|
| 282 |
+
ref_state: Reference state_dict (e.g. pretrained weights).
|
| 283 |
+
|
| 284 |
+
Returns:
|
| 285 |
+
L2 distance as a float.
|
| 286 |
+
"""
|
| 287 |
+
total = 0.0
|
| 288 |
+
for name, p in model.named_parameters():
|
| 289 |
+
if name in ref_state:
|
| 290 |
+
total += (p.data - ref_state[name]).norm(2).item() ** 2
|
| 291 |
+
return total ** 0.5
|
src/planners/offline.py
ADDED
|
@@ -0,0 +1,727 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Offline behavioural cloning trainer.
|
| 2 |
+
|
| 3 |
+
Mirrors the Craftax ``make_train`` closure pattern. Trains the diffusion
|
| 4 |
+
model on pre-collected oracle demonstrations using the MDLM ELBO loss
|
| 5 |
+
with optional auxiliary goal loss.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import sys
|
| 11 |
+
import time
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
import logging
|
| 14 |
+
from types import SimpleNamespace
|
| 15 |
+
from typing import Callable
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
import yaml
|
| 20 |
+
|
| 21 |
+
from src.buffer import ReplayBuffer
|
| 22 |
+
from src.config import make_run_dir
|
| 23 |
+
from src.diffusion.forward import q_sample
|
| 24 |
+
from src.diffusion.loss import auxiliary_goal_loss, mdlm_loss
|
| 25 |
+
from src.diffusion.schedules import get_schedule
|
| 26 |
+
from src.models.denoiser import ModelEMA, make_model, try_compile
|
| 27 |
+
from src.planners.inference import Evaluator, save_eval_json
|
| 28 |
+
from src.planners.logging import (
|
| 29 |
+
Logger,
|
| 30 |
+
compute_param_drift,
|
| 31 |
+
compute_param_norm,
|
| 32 |
+
gpu_memory_mb,
|
| 33 |
+
reset_gpu_memory_stats,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
logger = logging.getLogger(__name__)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def make_offline_trainer(cfg: SimpleNamespace) -> Callable:
|
| 40 |
+
"""Build the offline BC training closure.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
cfg: Config namespace.
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
``train_offline(model, ema_model, buffer, cfg, device) -> dict``
|
| 47 |
+
"""
|
| 48 |
+
schedule_fn = get_schedule(cfg.noise_schedule)
|
| 49 |
+
|
| 50 |
+
def train_offline(
|
| 51 |
+
model: nn.Module,
|
| 52 |
+
ema_model: ModelEMA,
|
| 53 |
+
buffer: ReplayBuffer,
|
| 54 |
+
cfg: SimpleNamespace,
|
| 55 |
+
device: torch.device | str,
|
| 56 |
+
log: Logger | None = None,
|
| 57 |
+
raw_model: nn.Module | None = None,
|
| 58 |
+
resume_state: dict | None = None,
|
| 59 |
+
evaluator: Evaluator | None = None,
|
| 60 |
+
id_envs: list[str] | None = None,
|
| 61 |
+
ood_envs: list[str] | None = None,
|
| 62 |
+
) -> dict:
|
| 63 |
+
"""Run offline BC training.
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
model: Denoising model (may be torch.compiled).
|
| 67 |
+
ema_model: EMA tracker.
|
| 68 |
+
buffer: Replay buffer with offline data.
|
| 69 |
+
cfg: Config namespace.
|
| 70 |
+
device: Torch device.
|
| 71 |
+
log: Optional Logger for wandb and stdout metrics.
|
| 72 |
+
raw_model: Uncompiled model for EMA updates. If ``None``,
|
| 73 |
+
uses *model* directly.
|
| 74 |
+
resume_state: Checkpoint dict to resume from. If provided,
|
| 75 |
+
restores optimizer, scheduler, epoch, and step state.
|
| 76 |
+
evaluator: Optional ``Evaluator`` instance for periodic ID/OOD
|
| 77 |
+
evaluation. When ``None``, no eval is run during training.
|
| 78 |
+
id_envs: In-distribution environment IDs for periodic eval.
|
| 79 |
+
Required (non-empty) if ``evaluator`` is provided and
|
| 80 |
+
``cfg.id_eval_every_timesteps > 0``.
|
| 81 |
+
ood_envs: Out-of-distribution environment IDs for periodic
|
| 82 |
+
eval. Required (non-empty) if ``evaluator`` is provided
|
| 83 |
+
and ``cfg.ood_eval_every_timesteps > 0``.
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
Dict with ``"final_loss"`` and ``"loss_history"``.
|
| 87 |
+
"""
|
| 88 |
+
_ema_source = raw_model if raw_model is not None else model
|
| 89 |
+
model.train()
|
| 90 |
+
optimizer = torch.optim.AdamW(
|
| 91 |
+
model.parameters(), lr=cfg.offline_lr,
|
| 92 |
+
weight_decay=cfg.weight_decay,
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
# Unified budget: `total_timesteps` counts env.step()-equivalent
|
| 96 |
+
# samples consumed during training. Each gradient step consumes
|
| 97 |
+
# `offline_batch_size` samples, so total grad steps derives
|
| 98 |
+
# directly from the budget and is independent of dataset size
|
| 99 |
+
# — this is what gives offline / DAgger / SB3 runs a common
|
| 100 |
+
# denominator when comparing curves.
|
| 101 |
+
total_grad_steps = max(
|
| 102 |
+
1, cfg.total_timesteps // cfg.offline_batch_size,
|
| 103 |
+
)
|
| 104 |
+
# Optional override: pin offline gradient budget independently
|
| 105 |
+
# of `total_timesteps`. Used for paper-fair compute matching
|
| 106 |
+
# against a specific DAgger iteration count, e.g.
|
| 107 |
+
# `offline_total_grad_steps: 60000` to match 600 DAgger iters
|
| 108 |
+
# × `grad_steps_per_iteration: 100` AdamW updates regardless of
|
| 109 |
+
# what env-step budget DAgger consumed in those iters.
|
| 110 |
+
_grad_override = getattr(cfg, "offline_total_grad_steps", None)
|
| 111 |
+
if _grad_override is not None and _grad_override > 0:
|
| 112 |
+
total_grad_steps = int(_grad_override)
|
| 113 |
+
logger.info(
|
| 114 |
+
"Offline grad budget pinned via offline_total_grad_steps="
|
| 115 |
+
f"{total_grad_steps} (overrides total_timesteps)"
|
| 116 |
+
)
|
| 117 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
| 118 |
+
optimizer, T_max=total_grad_steps,
|
| 119 |
+
eta_min=cfg.offline_lr * 0.1,
|
| 120 |
+
)
|
| 121 |
+
# Checkpoint cadence — defaults to deriving from
|
| 122 |
+
# `checkpoint_every_timesteps` (env-step units → grad-step units
|
| 123 |
+
# via // batch_size). The optional `offline_checkpoint_every_grad_steps`
|
| 124 |
+
# override is used when an offline run is pinned via
|
| 125 |
+
# `offline_total_grad_steps` and needs an aligned cadence in
|
| 126 |
+
# grad-step units (env-step cadence diverges wildly from grad-step
|
| 127 |
+
# cadence between offline and DAgger because their sample-to-step
|
| 128 |
+
# ratios differ by ~50x).
|
| 129 |
+
_ckpt_grad_override = getattr(
|
| 130 |
+
cfg, "offline_checkpoint_every_grad_steps", None,
|
| 131 |
+
)
|
| 132 |
+
if _ckpt_grad_override is not None and _ckpt_grad_override > 0:
|
| 133 |
+
ckpt_every_step = int(_ckpt_grad_override)
|
| 134 |
+
else:
|
| 135 |
+
ckpt_every_step = (
|
| 136 |
+
cfg.checkpoint_every_timesteps // cfg.offline_batch_size
|
| 137 |
+
if cfg.checkpoint_every_timesteps > 0 else 0
|
| 138 |
+
)
|
| 139 |
+
# Eval cadence — same override pattern. Without this, an offline
|
| 140 |
+
# run pinned at e.g. 60k grad steps with the default
|
| 141 |
+
# `id_eval_every_timesteps=250000` would fire ~491 evals
|
| 142 |
+
# (250000 // 2048 = 122 grad steps per eval), which is
|
| 143 |
+
# impractically dense.
|
| 144 |
+
_eval_grad_override = getattr(
|
| 145 |
+
cfg, "offline_eval_every_grad_steps", None,
|
| 146 |
+
)
|
| 147 |
+
if _eval_grad_override is not None and _eval_grad_override > 0:
|
| 148 |
+
id_eval_every_env_steps = (
|
| 149 |
+
int(_eval_grad_override) * cfg.offline_batch_size
|
| 150 |
+
)
|
| 151 |
+
ood_eval_every_env_steps = id_eval_every_env_steps
|
| 152 |
+
else:
|
| 153 |
+
id_eval_every_env_steps = cfg.id_eval_every_timesteps
|
| 154 |
+
ood_eval_every_env_steps = cfg.ood_eval_every_timesteps
|
| 155 |
+
# Logging cadence. `offline_log_every` is the *minimum* cadence;
|
| 156 |
+
# the actual `log_every` is clamped on both ends so the number of
|
| 157 |
+
# log points stays in [~10, ~1000] regardless of run length:
|
| 158 |
+
#
|
| 159 |
+
# * Lower bound (`floor`): on very long runs, force `log_every`
|
| 160 |
+
# up so total log points cap at ~1000. Without this, a 600k
|
| 161 |
+
# grad-step run with the default `offline_log_every=10` would
|
| 162 |
+
# emit 60,000 W&B points — silent log spam.
|
| 163 |
+
#
|
| 164 |
+
# * Upper bound (`ceiling`): on very short runs (smoke, fast
|
| 165 |
+
# ablations) clamp `log_every` down so every run emits at
|
| 166 |
+
# least ~10 log points and curves stay comparable across
|
| 167 |
+
# budgets.
|
| 168 |
+
#
|
| 169 |
+
# When the configured value sits inside the [floor, ceiling]
|
| 170 |
+
# window (the common case), it is used unchanged.
|
| 171 |
+
_floor = max(1, total_grad_steps // 1000)
|
| 172 |
+
_ceiling = max(1, total_grad_steps // 10)
|
| 173 |
+
log_every = min(
|
| 174 |
+
_ceiling, max(_floor, cfg.offline_log_every),
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
# Restore optimizer/scheduler state if resuming
|
| 178 |
+
step = 0
|
| 179 |
+
if resume_state is not None:
|
| 180 |
+
if "optimizer_state_dict" in resume_state:
|
| 181 |
+
optimizer.load_state_dict(
|
| 182 |
+
resume_state["optimizer_state_dict"],
|
| 183 |
+
)
|
| 184 |
+
if "scheduler_state_dict" in resume_state:
|
| 185 |
+
scheduler.load_state_dict(
|
| 186 |
+
resume_state["scheduler_state_dict"],
|
| 187 |
+
)
|
| 188 |
+
step = resume_state.get("step", 0)
|
| 189 |
+
logger.info(
|
| 190 |
+
f"Resumed offline training from step {step}/"
|
| 191 |
+
f"{total_grad_steps}"
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
# AMP: enabled when use_amp=true and on CUDA
|
| 195 |
+
_use_amp = (
|
| 196 |
+
getattr(cfg, "use_amp", False)
|
| 197 |
+
and str(device).startswith("cuda")
|
| 198 |
+
)
|
| 199 |
+
scaler = torch.amp.GradScaler("cuda", enabled=_use_amp)
|
| 200 |
+
|
| 201 |
+
loss_history: list[float] = []
|
| 202 |
+
_batch_start = time.perf_counter()
|
| 203 |
+
last_ckpt_step = step
|
| 204 |
+
# Periodic eval anchors (env-step units, mirroring online.py).
|
| 205 |
+
# Snapping to current env_steps avoids accumulated drift across
|
| 206 |
+
# resumes; the next eval fires once another full interval has
|
| 207 |
+
# been processed since the resume point.
|
| 208 |
+
last_id_eval_env_steps = step * cfg.offline_batch_size
|
| 209 |
+
last_ood_eval_env_steps = step * cfg.offline_batch_size
|
| 210 |
+
|
| 211 |
+
# Snapshot of initial weights for `model/param_drift_from_init`.
|
| 212 |
+
# Mirrors online.py:Trainer.__init__.
|
| 213 |
+
_init_state = {
|
| 214 |
+
k: v.detach().clone()
|
| 215 |
+
for k, v in _ema_source.state_dict().items()
|
| 216 |
+
if v.is_floating_point()
|
| 217 |
+
}
|
| 218 |
+
# Counts logging emissions (not raw grad steps), used to gate
|
| 219 |
+
# the once-per-10-windows model health metrics analogously to
|
| 220 |
+
# online.py's `iteration % 10 == 0` cadence.
|
| 221 |
+
log_windows = 0
|
| 222 |
+
reset_gpu_memory_stats()
|
| 223 |
+
|
| 224 |
+
while step < total_grad_steps:
|
| 225 |
+
batch = buffer.sample(cfg.offline_batch_size)
|
| 226 |
+
if batch is None:
|
| 227 |
+
break
|
| 228 |
+
local_np, global_np, actions_np = batch
|
| 229 |
+
local_t = torch.from_numpy(local_np).long().to(device)
|
| 230 |
+
global_t = torch.from_numpy(global_np).long().to(device)
|
| 231 |
+
actions_t = torch.from_numpy(actions_np).long().to(device)
|
| 232 |
+
|
| 233 |
+
B = actions_t.shape[0]
|
| 234 |
+
t = torch.rand(B, device=device) # [B] in [0, 1)
|
| 235 |
+
t = t.clamp(1e-5, 1.0 - 1e-5)
|
| 236 |
+
|
| 237 |
+
zt = q_sample(
|
| 238 |
+
actions_t, t, cfg.mask_token, cfg.pad_token,
|
| 239 |
+
schedule_fn,
|
| 240 |
+
)
|
| 241 |
+
t_discrete = (
|
| 242 |
+
t * cfg.num_diffusion_steps
|
| 243 |
+
).long().clamp(0, cfg.num_diffusion_steps - 1) # [B]
|
| 244 |
+
|
| 245 |
+
optimizer.zero_grad()
|
| 246 |
+
with torch.amp.autocast("cuda", enabled=_use_amp):
|
| 247 |
+
out = model(local_t, global_t, zt, t_discrete)
|
| 248 |
+
|
| 249 |
+
loss_diff = mdlm_loss(
|
| 250 |
+
out["actions"], actions_t, zt, t,
|
| 251 |
+
cfg.mask_token, cfg.pad_token, schedule_fn,
|
| 252 |
+
weight_clip=cfg.loss_weight_clip,
|
| 253 |
+
label_smoothing=cfg.label_smoothing,
|
| 254 |
+
use_importance_weighting=cfg.use_importance_weighting,
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
loss_aux = torch.tensor(0.0, device=device)
|
| 258 |
+
if "goal_pred" in out:
|
| 259 |
+
loss_aux = auxiliary_goal_loss(
|
| 260 |
+
out["goal_pred"], global_t,
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
loss = loss_diff + cfg.aux_loss_weight * loss_aux
|
| 264 |
+
|
| 265 |
+
scaler.scale(loss).backward()
|
| 266 |
+
scaler.unscale_(optimizer)
|
| 267 |
+
grad_norm = nn.utils.clip_grad_norm_(
|
| 268 |
+
model.parameters(), cfg.offline_grad_clip,
|
| 269 |
+
)
|
| 270 |
+
scaler.step(optimizer)
|
| 271 |
+
scaler.update()
|
| 272 |
+
scheduler.step()
|
| 273 |
+
|
| 274 |
+
ema_model.update(_ema_source)
|
| 275 |
+
loss_history.append(loss.item())
|
| 276 |
+
step += 1
|
| 277 |
+
|
| 278 |
+
# env-step equivalent: samples processed so far.
|
| 279 |
+
env_steps = step * cfg.offline_batch_size
|
| 280 |
+
|
| 281 |
+
if log is not None and step % log_every == 0:
|
| 282 |
+
step_time = time.perf_counter() - _batch_start
|
| 283 |
+
log_windows += 1
|
| 284 |
+
|
| 285 |
+
# Buffer state — for offline mode `offline_size` always
|
| 286 |
+
# equals `len(buffer)` (no online appends), so the
|
| 287 |
+
# online fraction is always 0.0. Logged anyway for
|
| 288 |
+
# symmetry with the DAgger curves.
|
| 289 |
+
buf_total = len(buffer)
|
| 290 |
+
buf_online_frac = (
|
| 291 |
+
(buf_total - buffer.offline_size) / max(buf_total, 1)
|
| 292 |
+
if hasattr(buffer, "offline_size")
|
| 293 |
+
else 0.0
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
# Throughput: samples processed in this logging window.
|
| 297 |
+
samples_window = log_every * cfg.offline_batch_size
|
| 298 |
+
samples_per_sec = samples_window / max(step_time, 1e-6)
|
| 299 |
+
|
| 300 |
+
_ema_source_ref = _ema_source
|
| 301 |
+
metrics = {
|
| 302 |
+
"diffusion/loss": loss.item(),
|
| 303 |
+
"diffusion/loss_diff": loss_diff.item(),
|
| 304 |
+
"diffusion/loss_aux": loss_aux.item(),
|
| 305 |
+
"train/buffer_size": buf_total,
|
| 306 |
+
"train/buffer_online_frac": buf_online_frac,
|
| 307 |
+
"train/lr": scheduler.get_last_lr()[0],
|
| 308 |
+
"train/env_steps": env_steps,
|
| 309 |
+
"train/progress": step / total_grad_steps,
|
| 310 |
+
"train/grad_norm": grad_norm.item(),
|
| 311 |
+
"speed/train_step_time_sec": step_time,
|
| 312 |
+
"speed/samples_per_sec": samples_per_sec,
|
| 313 |
+
"speed/gpu_memory_mb": gpu_memory_mb(),
|
| 314 |
+
# Legacy `perf/` mirror keys (kept for backward compat
|
| 315 |
+
# with existing dashboards / DAgger curves).
|
| 316 |
+
"perf/train_time_s": step_time,
|
| 317 |
+
"perf/grad_steps_per_sec": (
|
| 318 |
+
log_every / max(step_time, 1e-6)
|
| 319 |
+
),
|
| 320 |
+
}
|
| 321 |
+
if hasattr(_ema_source_ref, "global_gate"):
|
| 322 |
+
gate_val = torch.sigmoid(
|
| 323 |
+
_ema_source_ref.global_gate,
|
| 324 |
+
).item()
|
| 325 |
+
metrics["train/global_gate"] = gate_val
|
| 326 |
+
metrics["model/ema_gate_value"] = gate_val
|
| 327 |
+
|
| 328 |
+
# Model health (every 10 logging windows to keep overhead
|
| 329 |
+
# low — matches online.py's `iteration % 10 == 0`).
|
| 330 |
+
if log_windows % 10 == 1:
|
| 331 |
+
metrics["model/param_norm"] = compute_param_norm(
|
| 332 |
+
_ema_source_ref,
|
| 333 |
+
)
|
| 334 |
+
metrics["model/param_drift_from_init"] = (
|
| 335 |
+
compute_param_drift(
|
| 336 |
+
_ema_source_ref, _init_state,
|
| 337 |
+
)
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
log.log(metrics, step=step)
|
| 341 |
+
_batch_start = time.perf_counter()
|
| 342 |
+
reset_gpu_memory_stats()
|
| 343 |
+
logger.info(
|
| 344 |
+
f"step {step}/{total_grad_steps} "
|
| 345 |
+
f"(env_steps={env_steps}) loss={loss.item():.4f}"
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
# Periodic ID eval — env-step delta-check (mirrors
|
| 349 |
+
# online.py:277-305). Eval is opt-in: skipped entirely when
|
| 350 |
+
# no Evaluator was threaded through. The cadence variable
|
| 351 |
+
# already accounts for the optional
|
| 352 |
+
# `offline_eval_every_grad_steps` override.
|
| 353 |
+
if (
|
| 354 |
+
evaluator is not None
|
| 355 |
+
and id_envs
|
| 356 |
+
and id_eval_every_env_steps > 0
|
| 357 |
+
and env_steps - last_id_eval_env_steps
|
| 358 |
+
>= id_eval_every_env_steps
|
| 359 |
+
):
|
| 360 |
+
eval_model = ema_model.make_eval_model(_ema_source)
|
| 361 |
+
results = evaluator.evaluate(
|
| 362 |
+
id_envs, eval_model, cfg.eval_episodes_per_env,
|
| 363 |
+
cfg, device,
|
| 364 |
+
)
|
| 365 |
+
if log is not None:
|
| 366 |
+
log.log_eval(results, step=step, prefix="eval_id")
|
| 367 |
+
mean_id_wr = (
|
| 368 |
+
sum(s["win_rate"] for s in results.values())
|
| 369 |
+
/ len(results)
|
| 370 |
+
) if results else 0.0
|
| 371 |
+
log.log(
|
| 372 |
+
{"eval_id/mean_win_rate": mean_id_wr},
|
| 373 |
+
step=step,
|
| 374 |
+
)
|
| 375 |
+
last_id_eval_env_steps = env_steps
|
| 376 |
+
|
| 377 |
+
# Periodic OOD eval — same delta-check pattern.
|
| 378 |
+
if (
|
| 379 |
+
evaluator is not None
|
| 380 |
+
and ood_envs
|
| 381 |
+
and ood_eval_every_env_steps > 0
|
| 382 |
+
and env_steps - last_ood_eval_env_steps
|
| 383 |
+
>= ood_eval_every_env_steps
|
| 384 |
+
):
|
| 385 |
+
eval_model = ema_model.make_eval_model(_ema_source)
|
| 386 |
+
results = evaluator.evaluate(
|
| 387 |
+
ood_envs, eval_model, cfg.eval_episodes_per_env,
|
| 388 |
+
cfg, device,
|
| 389 |
+
)
|
| 390 |
+
if log is not None:
|
| 391 |
+
log.log_eval(results, step=step, prefix="eval_ood")
|
| 392 |
+
mean_ood_wr = (
|
| 393 |
+
sum(s["win_rate"] for s in results.values())
|
| 394 |
+
/ len(results)
|
| 395 |
+
) if results else 0.0
|
| 396 |
+
log.log(
|
| 397 |
+
{"eval_ood/mean_win_rate": mean_ood_wr},
|
| 398 |
+
step=step,
|
| 399 |
+
)
|
| 400 |
+
last_ood_eval_env_steps = env_steps
|
| 401 |
+
|
| 402 |
+
# Periodic step-level checkpoint (cadence derived from
|
| 403 |
+
# checkpoint_every_timesteps)
|
| 404 |
+
if (
|
| 405 |
+
ckpt_every_step > 0
|
| 406 |
+
and step - last_ckpt_step >= ckpt_every_step
|
| 407 |
+
):
|
| 408 |
+
_save_offline_checkpoint(
|
| 409 |
+
_ema_source, ema_model, optimizer, scheduler,
|
| 410 |
+
step, cfg, log,
|
| 411 |
+
evaluator=evaluator,
|
| 412 |
+
id_envs=id_envs,
|
| 413 |
+
ood_envs=ood_envs,
|
| 414 |
+
device=device,
|
| 415 |
+
)
|
| 416 |
+
last_ckpt_step = step
|
| 417 |
+
|
| 418 |
+
if log is not None:
|
| 419 |
+
log.log_summary({
|
| 420 |
+
"offline/final_loss": loss_history[-1] if loss_history else 0.0,
|
| 421 |
+
"offline/total_steps": step,
|
| 422 |
+
"offline/total_timesteps": step * cfg.offline_batch_size,
|
| 423 |
+
})
|
| 424 |
+
|
| 425 |
+
return {
|
| 426 |
+
"final_loss": loss_history[-1] if loss_history else 0.0,
|
| 427 |
+
"loss_history": loss_history,
|
| 428 |
+
}
|
| 429 |
+
|
| 430 |
+
return train_offline
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
def _save_offline_checkpoint(
|
| 434 |
+
model: nn.Module,
|
| 435 |
+
ema_model: ModelEMA,
|
| 436 |
+
optimizer: torch.optim.Optimizer,
|
| 437 |
+
scheduler: torch.optim.lr_scheduler.LRScheduler,
|
| 438 |
+
step: int,
|
| 439 |
+
cfg: SimpleNamespace,
|
| 440 |
+
log: Logger | None,
|
| 441 |
+
evaluator: Evaluator | None = None,
|
| 442 |
+
id_envs: list[str] | None = None,
|
| 443 |
+
ood_envs: list[str] | None = None,
|
| 444 |
+
device: torch.device | str | None = None,
|
| 445 |
+
) -> None:
|
| 446 |
+
"""Save an offline training checkpoint, eval, and W&B artifact.
|
| 447 |
+
|
| 448 |
+
Mirrors the DAgger ``Trainer.save_checkpoint`` flow:
|
| 449 |
+
1. Persist model + EMA + optimizer + scheduler state to disk.
|
| 450 |
+
2. Save a YAML config snapshot alongside the checkpoint.
|
| 451 |
+
3. Run an EMA-weight ID + OOD eval and emit ``ckpt_eval_*``
|
| 452 |
+
metrics + an eval JSON sidecar.
|
| 453 |
+
4. Upload the checkpoint + config snapshot as a W&B artifact.
|
| 454 |
+
|
| 455 |
+
Steps 3 and 4 are skipped gracefully when ``evaluator`` / envs /
|
| 456 |
+
``device`` are not provided, so callers that just want the bare
|
| 457 |
+
state dump still work.
|
| 458 |
+
|
| 459 |
+
Args:
|
| 460 |
+
model: Raw (uncompiled) model — used both for ``state_dict``
|
| 461 |
+
persistence and as the source argument to
|
| 462 |
+
``ema_model.make_eval_model``.
|
| 463 |
+
ema_model: EMA tracker.
|
| 464 |
+
optimizer: Optimizer.
|
| 465 |
+
scheduler: LR scheduler.
|
| 466 |
+
step: Global gradient step count (used in filenames + metadata).
|
| 467 |
+
cfg: Config namespace.
|
| 468 |
+
log: Logger (used to extract W&B run ID, log eval metrics,
|
| 469 |
+
and upload artifact).
|
| 470 |
+
evaluator: Optional evaluator. When ``None``, the checkpoint
|
| 471 |
+
eval is skipped.
|
| 472 |
+
id_envs: ID env IDs for the checkpoint eval.
|
| 473 |
+
ood_envs: OOD env IDs for the checkpoint eval.
|
| 474 |
+
device: Torch device for the checkpoint eval.
|
| 475 |
+
"""
|
| 476 |
+
wandb_run_id: str | None = None
|
| 477 |
+
if log is not None and log._use_wandb and log._run is not None:
|
| 478 |
+
wandb_run_id = log._run.id
|
| 479 |
+
|
| 480 |
+
ckpt_dir = Path(cfg.checkpoint_dir)
|
| 481 |
+
ckpt_dir.mkdir(parents=True, exist_ok=True)
|
| 482 |
+
path = ckpt_dir / f"offline_step{step}.pth"
|
| 483 |
+
torch.save(
|
| 484 |
+
{
|
| 485 |
+
"model_state_dict": model.state_dict(),
|
| 486 |
+
"ema_state_dict": ema_model.state_dict(),
|
| 487 |
+
"optimizer_state_dict": optimizer.state_dict(),
|
| 488 |
+
"scheduler_state_dict": scheduler.state_dict(),
|
| 489 |
+
"step": step,
|
| 490 |
+
"env_steps": step * cfg.offline_batch_size,
|
| 491 |
+
"wandb_run_id": wandb_run_id,
|
| 492 |
+
},
|
| 493 |
+
path,
|
| 494 |
+
)
|
| 495 |
+
logger.info(f"Offline checkpoint saved: {path}")
|
| 496 |
+
|
| 497 |
+
# Save config snapshot alongside checkpoint (mirrors DAgger).
|
| 498 |
+
config_path: Path | None = ckpt_dir / f"config_offline_step{step}.yaml"
|
| 499 |
+
try:
|
| 500 |
+
cfg_dict = {
|
| 501 |
+
k: v for k, v in vars(cfg).items() if not k.startswith("_")
|
| 502 |
+
}
|
| 503 |
+
with open(config_path, "w") as f:
|
| 504 |
+
yaml.dump(cfg_dict, f, default_flow_style=False)
|
| 505 |
+
except Exception:
|
| 506 |
+
logger.error("Failed to save config snapshot", exc_info=True)
|
| 507 |
+
config_path = None
|
| 508 |
+
|
| 509 |
+
# Checkpoint-time eval — mirrors Trainer.save_checkpoint in online.py.
|
| 510 |
+
# Skipped when the caller did not thread an evaluator through.
|
| 511 |
+
if (
|
| 512 |
+
evaluator is not None
|
| 513 |
+
and id_envs
|
| 514 |
+
and ood_envs
|
| 515 |
+
and device is not None
|
| 516 |
+
):
|
| 517 |
+
try:
|
| 518 |
+
eval_model = ema_model.make_eval_model(model)
|
| 519 |
+
id_results = evaluator.evaluate(
|
| 520 |
+
id_envs, eval_model, cfg.checkpoint_eval_episodes,
|
| 521 |
+
cfg, device,
|
| 522 |
+
)
|
| 523 |
+
ood_results = evaluator.evaluate(
|
| 524 |
+
ood_envs, eval_model, cfg.checkpoint_eval_episodes,
|
| 525 |
+
cfg, device,
|
| 526 |
+
)
|
| 527 |
+
|
| 528 |
+
id_winrate = (
|
| 529 |
+
sum(s["win_rate"] for s in id_results.values())
|
| 530 |
+
/ len(id_results)
|
| 531 |
+
) if id_results else 0.0
|
| 532 |
+
ood_winrate = (
|
| 533 |
+
sum(s["win_rate"] for s in ood_results.values())
|
| 534 |
+
/ len(ood_results)
|
| 535 |
+
) if ood_results else 0.0
|
| 536 |
+
|
| 537 |
+
current_lr = scheduler.get_last_lr()[0]
|
| 538 |
+
training_meta = {
|
| 539 |
+
"step": step,
|
| 540 |
+
"env_steps": step * cfg.offline_batch_size,
|
| 541 |
+
"total_timesteps": cfg.total_timesteps,
|
| 542 |
+
"lr": current_lr,
|
| 543 |
+
"offline_batch_size": cfg.offline_batch_size,
|
| 544 |
+
"aux_loss_weight": cfg.aux_loss_weight,
|
| 545 |
+
"ema_decay": cfg.ema_decay,
|
| 546 |
+
"id_winrate": id_winrate,
|
| 547 |
+
"ood_winrate": ood_winrate,
|
| 548 |
+
"per_env_id": {
|
| 549 |
+
env_id: {
|
| 550 |
+
"win_rate": s["win_rate"],
|
| 551 |
+
"wins": s.get("wins", 0),
|
| 552 |
+
"avg_reward": s["avg_reward"],
|
| 553 |
+
"avg_steps": s["avg_steps"],
|
| 554 |
+
"n_episodes": s["n_episodes"],
|
| 555 |
+
}
|
| 556 |
+
for env_id, s in id_results.items()
|
| 557 |
+
},
|
| 558 |
+
"per_env_ood": {
|
| 559 |
+
env_id: {
|
| 560 |
+
"win_rate": s["win_rate"],
|
| 561 |
+
"wins": s.get("wins", 0),
|
| 562 |
+
"avg_reward": s["avg_reward"],
|
| 563 |
+
"avg_steps": s["avg_steps"],
|
| 564 |
+
"n_episodes": s["n_episodes"],
|
| 565 |
+
}
|
| 566 |
+
for env_id, s in ood_results.items()
|
| 567 |
+
},
|
| 568 |
+
}
|
| 569 |
+
|
| 570 |
+
json_path = ckpt_dir / f"eval_offline_step{step}.json"
|
| 571 |
+
save_eval_json(
|
| 572 |
+
{"id": id_results, "ood": ood_results},
|
| 573 |
+
str(json_path),
|
| 574 |
+
metadata=training_meta,
|
| 575 |
+
)
|
| 576 |
+
|
| 577 |
+
if log is not None:
|
| 578 |
+
log.log_eval(
|
| 579 |
+
id_results, step=step, prefix="ckpt_eval_id",
|
| 580 |
+
)
|
| 581 |
+
log.log_eval(
|
| 582 |
+
ood_results, step=step, prefix="ckpt_eval_ood",
|
| 583 |
+
)
|
| 584 |
+
log.log(
|
| 585 |
+
{
|
| 586 |
+
"ckpt_eval/id_winrate": id_winrate,
|
| 587 |
+
"ckpt_eval/ood_winrate": ood_winrate,
|
| 588 |
+
},
|
| 589 |
+
step=step,
|
| 590 |
+
)
|
| 591 |
+
log.log_summary({
|
| 592 |
+
f"ckpt_offline_step{step}/id_winrate": id_winrate,
|
| 593 |
+
f"ckpt_offline_step{step}/ood_winrate": ood_winrate,
|
| 594 |
+
})
|
| 595 |
+
except Exception:
|
| 596 |
+
logger.error(
|
| 597 |
+
"Offline checkpoint eval failed", exc_info=True,
|
| 598 |
+
)
|
| 599 |
+
|
| 600 |
+
# W&B artifact upload (no-op when wandb is not initialised).
|
| 601 |
+
if log is not None:
|
| 602 |
+
log.log_checkpoint_artifact(
|
| 603 |
+
checkpoint_path=str(path),
|
| 604 |
+
config_path=str(config_path) if config_path else None,
|
| 605 |
+
iteration=step,
|
| 606 |
+
metadata={"step": step, "mode": "offline"},
|
| 607 |
+
artifact_name=f"checkpoint-offline-step{step}",
|
| 608 |
+
)
|
| 609 |
+
|
| 610 |
+
|
| 611 |
+
def load_offline_dataset(
|
| 612 |
+
path: str | None, cfg: SimpleNamespace,
|
| 613 |
+
) -> dict | None:
|
| 614 |
+
"""Load an offline dataset from disk.
|
| 615 |
+
|
| 616 |
+
Args:
|
| 617 |
+
path: Path to a ``.pt`` file, or ``None``.
|
| 618 |
+
cfg: Config namespace (unused, reserved for future).
|
| 619 |
+
|
| 620 |
+
Returns:
|
| 621 |
+
Loaded dict or ``None``.
|
| 622 |
+
"""
|
| 623 |
+
if path is None:
|
| 624 |
+
return None
|
| 625 |
+
try:
|
| 626 |
+
import torch as _torch
|
| 627 |
+
return _torch.load(path, map_location="cpu", weights_only=False)
|
| 628 |
+
except Exception:
|
| 629 |
+
logger.error(f"Failed to load dataset from {path}", exc_info=True)
|
| 630 |
+
return None
|
| 631 |
+
|
| 632 |
+
|
| 633 |
+
def run_offline(
|
| 634 |
+
cfg: SimpleNamespace,
|
| 635 |
+
data_path: str | None,
|
| 636 |
+
checkpoint_path: str | None = None,
|
| 637 |
+
) -> None:
|
| 638 |
+
"""Offline BC training on pre-collected data.
|
| 639 |
+
|
| 640 |
+
Args:
|
| 641 |
+
cfg: Config namespace.
|
| 642 |
+
data_path: Path to ``.pt`` dataset file.
|
| 643 |
+
checkpoint_path: Optional checkpoint to resume from. Restores
|
| 644 |
+
model, EMA, optimizer, scheduler, and W&B run for curve
|
| 645 |
+
continuity.
|
| 646 |
+
"""
|
| 647 |
+
make_run_dir(cfg, tag="offline")
|
| 648 |
+
|
| 649 |
+
device = cfg.device
|
| 650 |
+
logger.info(f"Offline BC on {device}")
|
| 651 |
+
|
| 652 |
+
data = load_offline_dataset(data_path, cfg)
|
| 653 |
+
if data is None:
|
| 654 |
+
logger.error("No dataset provided or failed to load. Exiting.")
|
| 655 |
+
sys.exit(1)
|
| 656 |
+
|
| 657 |
+
# Offline buffer must hold the full pre-collected dataset. DAgger's
|
| 658 |
+
# `buffer_capacity` (typically 10k) would silently FIFO-evict 99% of
|
| 659 |
+
# the dataset, so honour the optional `offline_buffer_capacity`
|
| 660 |
+
# override when present.
|
| 661 |
+
_offline_buf_cap = (
|
| 662 |
+
getattr(cfg, "offline_buffer_capacity", None) or cfg.buffer_capacity
|
| 663 |
+
)
|
| 664 |
+
buffer = ReplayBuffer(_offline_buf_cap, cfg.seq_len, cfg.pad_token)
|
| 665 |
+
buffer.load_offline_data(data, cfg.id_envs)
|
| 666 |
+
logger.info(f"Loaded {len(buffer)} windows")
|
| 667 |
+
|
| 668 |
+
if len(buffer) == 0:
|
| 669 |
+
logger.error(
|
| 670 |
+
"Buffer is empty after loading dataset — no trajectories matched "
|
| 671 |
+
f"id_envs={cfg.id_envs}. Exiting."
|
| 672 |
+
)
|
| 673 |
+
sys.exit(1)
|
| 674 |
+
|
| 675 |
+
raw_model = make_model(cfg).to(device)
|
| 676 |
+
|
| 677 |
+
# torch.compile: wrap for training only; shares params with raw_model
|
| 678 |
+
model = try_compile(raw_model, cfg)
|
| 679 |
+
|
| 680 |
+
ema = ModelEMA(raw_model, decay=cfg.ema_decay)
|
| 681 |
+
|
| 682 |
+
# If resuming, extract W&B run ID from checkpoint before Logger init
|
| 683 |
+
resume_state: dict | None = None
|
| 684 |
+
if checkpoint_path:
|
| 685 |
+
resume_state = torch.load(
|
| 686 |
+
checkpoint_path, map_location=device, weights_only=False,
|
| 687 |
+
)
|
| 688 |
+
raw_model.load_state_dict(resume_state["model_state_dict"])
|
| 689 |
+
ema.load_state_dict(resume_state["ema_state_dict"])
|
| 690 |
+
resume_id = getattr(cfg, "wandb_resume_id", None)
|
| 691 |
+
if not resume_id:
|
| 692 |
+
saved_id = resume_state.get("wandb_run_id")
|
| 693 |
+
if saved_id:
|
| 694 |
+
cfg.wandb_resume_id = saved_id
|
| 695 |
+
logger.info(f"W&B run ID from checkpoint: {saved_id}")
|
| 696 |
+
|
| 697 |
+
log = Logger(cfg)
|
| 698 |
+
evaluator = Evaluator()
|
| 699 |
+
train_fn = make_offline_trainer(cfg)
|
| 700 |
+
result = train_fn(
|
| 701 |
+
model, ema, buffer, cfg, device, log=log,
|
| 702 |
+
raw_model=raw_model, resume_state=resume_state,
|
| 703 |
+
evaluator=evaluator,
|
| 704 |
+
id_envs=cfg.id_envs,
|
| 705 |
+
ood_envs=cfg.ood_envs,
|
| 706 |
+
)
|
| 707 |
+
logger.info(
|
| 708 |
+
f"Offline training done. Final loss: {result['final_loss']:.4f}"
|
| 709 |
+
)
|
| 710 |
+
|
| 711 |
+
# Save final checkpoint for downstream compatibility (DAgger, inference)
|
| 712 |
+
wandb_run_id: str | None = None
|
| 713 |
+
if log._use_wandb and log._run is not None:
|
| 714 |
+
wandb_run_id = log._run.id
|
| 715 |
+
|
| 716 |
+
ckpt_dir = Path(cfg.checkpoint_dir)
|
| 717 |
+
path = ckpt_dir / "offline_final.pth"
|
| 718 |
+
torch.save(
|
| 719 |
+
{
|
| 720 |
+
"model_state_dict": raw_model.state_dict(),
|
| 721 |
+
"ema_state_dict": ema.state_dict(),
|
| 722 |
+
"wandb_run_id": wandb_run_id,
|
| 723 |
+
},
|
| 724 |
+
path,
|
| 725 |
+
)
|
| 726 |
+
logger.info(f"Saved offline checkpoint: {path}")
|
| 727 |
+
log.finish()
|
src/planners/online.py
ADDED
|
@@ -0,0 +1,721 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""DAgger online training loop.
|
| 2 |
+
|
| 3 |
+
Orchestrates the full DAgger pipeline: collect data via model + oracle,
|
| 4 |
+
train on buffer, evaluate periodically, and checkpoint.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import logging
|
| 10 |
+
import random
|
| 11 |
+
import time
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from types import SimpleNamespace
|
| 14 |
+
|
| 15 |
+
import numpy as np
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn as nn
|
| 18 |
+
import yaml
|
| 19 |
+
|
| 20 |
+
from src.buffer import ReplayBuffer
|
| 21 |
+
from src.config import make_run_dir
|
| 22 |
+
from src.diffusion.forward import q_sample
|
| 23 |
+
from src.diffusion.loss import auxiliary_goal_loss, mdlm_loss
|
| 24 |
+
from src.diffusion.schedules import get_schedule
|
| 25 |
+
from src.models.denoiser import ModelEMA, make_model, try_compile
|
| 26 |
+
from src.planners.collect import DataCollector
|
| 27 |
+
from src.planners.inference import Evaluator, save_eval_json
|
| 28 |
+
from src.planners.logging import (
|
| 29 |
+
Logger, gpu_memory_mb, reset_gpu_memory_stats,
|
| 30 |
+
compute_param_norm, compute_param_drift,
|
| 31 |
+
)
|
| 32 |
+
from src.curriculum import DynamicCurriculum
|
| 33 |
+
from src.envs.minihack_env import collect_oracle_trajectory
|
| 34 |
+
|
| 35 |
+
logger = logging.getLogger(__name__)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class Trainer:
|
| 39 |
+
"""Full DAgger training loop.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
model: Denoising model.
|
| 43 |
+
ema_model: EMA tracker.
|
| 44 |
+
optimizer: Torch optimizer.
|
| 45 |
+
scheduler: Optional LR scheduler.
|
| 46 |
+
buffer: Replay buffer.
|
| 47 |
+
collector: DAgger data collector.
|
| 48 |
+
evaluator: Evaluation runner.
|
| 49 |
+
log: Centralised logger.
|
| 50 |
+
cfg: Config namespace.
|
| 51 |
+
device: Torch device.
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
def __init__(
|
| 55 |
+
self,
|
| 56 |
+
model: nn.Module,
|
| 57 |
+
ema_model: ModelEMA,
|
| 58 |
+
optimizer: torch.optim.Optimizer,
|
| 59 |
+
scheduler: torch.optim.lr_scheduler.LRScheduler | None,
|
| 60 |
+
buffer: ReplayBuffer,
|
| 61 |
+
collector: DataCollector,
|
| 62 |
+
evaluator: Evaluator,
|
| 63 |
+
log: Logger,
|
| 64 |
+
cfg: SimpleNamespace,
|
| 65 |
+
device: torch.device | str,
|
| 66 |
+
raw_model: nn.Module | None = None,
|
| 67 |
+
) -> None:
|
| 68 |
+
self.model = model
|
| 69 |
+
# raw_model is the uncompiled model used for eval deep-copies.
|
| 70 |
+
# When torch.compile is off, raw_model is the same as model.
|
| 71 |
+
self._raw_model = raw_model if raw_model is not None else model
|
| 72 |
+
self.ema_model = ema_model
|
| 73 |
+
self.optimizer = optimizer
|
| 74 |
+
self.scheduler = scheduler
|
| 75 |
+
self.buffer = buffer
|
| 76 |
+
self.collector = collector
|
| 77 |
+
self.evaluator = evaluator
|
| 78 |
+
self.log = log
|
| 79 |
+
self.cfg = cfg
|
| 80 |
+
self.device = device
|
| 81 |
+
self._schedule_fn = get_schedule(cfg.noise_schedule)
|
| 82 |
+
# Snapshot of initial weights for param drift tracking
|
| 83 |
+
self._init_state = {
|
| 84 |
+
k: v.clone() for k, v in self._raw_model.state_dict().items()
|
| 85 |
+
if v.is_floating_point()
|
| 86 |
+
}
|
| 87 |
+
# AMP scaler: enabled only when use_amp=true and on CUDA
|
| 88 |
+
self._use_amp = (
|
| 89 |
+
getattr(cfg, "use_amp", False) and str(device).startswith("cuda")
|
| 90 |
+
)
|
| 91 |
+
self._scaler = torch.amp.GradScaler("cuda", enabled=self._use_amp)
|
| 92 |
+
|
| 93 |
+
# ── Main loop ────────────────────────────────────────────────
|
| 94 |
+
|
| 95 |
+
def train(
|
| 96 |
+
self, start_iter: int = 0, start_env_steps: int = 0,
|
| 97 |
+
) -> None:
|
| 98 |
+
"""Run the DAgger training loop.
|
| 99 |
+
|
| 100 |
+
The budget is ``cfg.total_timesteps`` — total env.step() calls
|
| 101 |
+
across model + oracle rollouts. Iteration count is derived; it
|
| 102 |
+
depends on how many env steps each iteration consumes (which in
|
| 103 |
+
turn depends on episode length and efficiency filter outcomes).
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
start_iter: Iteration index to resume from (for logging).
|
| 107 |
+
start_env_steps: Cumulative env steps already consumed.
|
| 108 |
+
"""
|
| 109 |
+
cfg = self.cfg
|
| 110 |
+
env_steps_total = start_env_steps
|
| 111 |
+
iteration = start_iter
|
| 112 |
+
last_id_eval_step = start_env_steps
|
| 113 |
+
last_ood_eval_step = start_env_steps
|
| 114 |
+
last_ckpt_step = start_env_steps
|
| 115 |
+
|
| 116 |
+
while env_steps_total < cfg.total_timesteps:
|
| 117 |
+
reset_gpu_memory_stats()
|
| 118 |
+
iter_start = time.perf_counter()
|
| 119 |
+
|
| 120 |
+
# 1. Collect N episodes per iteration
|
| 121 |
+
n_eps = getattr(cfg, "episodes_per_iteration", 1)
|
| 122 |
+
num_workers = getattr(cfg, "num_collection_workers", 0)
|
| 123 |
+
model_wins = 0
|
| 124 |
+
added_total = 0
|
| 125 |
+
# Accumulators across all n_eps episodes — must be summed,
|
| 126 |
+
# NOT taken from a single (last) episode, otherwise the
|
| 127 |
+
# unified env-step budget undercounts by ~n_eps×.
|
| 128 |
+
model_steps_iter = 0
|
| 129 |
+
oracle_steps_iter = 0
|
| 130 |
+
last_env_id: str = ""
|
| 131 |
+
|
| 132 |
+
collect_start = time.perf_counter()
|
| 133 |
+
use_gpu_batch = (
|
| 134 |
+
str(self.device).startswith("cuda") and n_eps > 1
|
| 135 |
+
)
|
| 136 |
+
if use_gpu_batch:
|
| 137 |
+
# GPU-batched collection (all envs in lockstep)
|
| 138 |
+
batch_stats = self.collector.collect_batch_gpu(n_eps)
|
| 139 |
+
for s in batch_stats:
|
| 140 |
+
model_wins += int(s["model_won"])
|
| 141 |
+
added_total += int(s["added_to_buffer"])
|
| 142 |
+
model_steps_iter += int(s["model_steps"])
|
| 143 |
+
oracle_steps_iter += int(s["oracle_steps"])
|
| 144 |
+
last_env_id = s.get("env_id", last_env_id)
|
| 145 |
+
elif num_workers > 0 and n_eps > 1:
|
| 146 |
+
# Threaded CPU collection (fallback)
|
| 147 |
+
batch_stats = self.collector.collect_batch_parallel(
|
| 148 |
+
n_eps,
|
| 149 |
+
)
|
| 150 |
+
for s in batch_stats:
|
| 151 |
+
model_wins += int(s["model_won"])
|
| 152 |
+
added_total += int(s["added_to_buffer"])
|
| 153 |
+
model_steps_iter += int(s["model_steps"])
|
| 154 |
+
oracle_steps_iter += int(s["oracle_steps"])
|
| 155 |
+
last_env_id = s.get("env_id", last_env_id)
|
| 156 |
+
else:
|
| 157 |
+
# Sequential collection (reference behaviour)
|
| 158 |
+
for _ in range(n_eps):
|
| 159 |
+
s = self.collector.collect_one_iteration()
|
| 160 |
+
model_wins += int(s["model_won"])
|
| 161 |
+
added_total += int(s["added_to_buffer"])
|
| 162 |
+
model_steps_iter += int(s["model_steps"])
|
| 163 |
+
oracle_steps_iter += int(s["oracle_steps"])
|
| 164 |
+
last_env_id = s.get("env_id", last_env_id)
|
| 165 |
+
collect_time = time.perf_counter() - collect_start
|
| 166 |
+
|
| 167 |
+
collect_stats = {
|
| 168 |
+
"env_id": last_env_id,
|
| 169 |
+
"model_won": model_wins,
|
| 170 |
+
"added_to_buffer": added_total,
|
| 171 |
+
"model_steps": model_steps_iter,
|
| 172 |
+
"oracle_steps": oracle_steps_iter,
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
# Advance the unified env-step budget. Both model and oracle
|
| 176 |
+
# rollouts consume real env.step() calls (the oracle rollout
|
| 177 |
+
# runs in its own env instance in collect_oracle_trajectory),
|
| 178 |
+
# so both contribute to the budget.
|
| 179 |
+
iter_env_steps = model_steps_iter + oracle_steps_iter
|
| 180 |
+
env_steps_total += iter_env_steps
|
| 181 |
+
|
| 182 |
+
# 2. Gradient steps (EMA updated after each step)
|
| 183 |
+
self.model.train()
|
| 184 |
+
step_metrics: list[dict[str, float]] = []
|
| 185 |
+
train_start = time.perf_counter()
|
| 186 |
+
for _ in range(cfg.grad_steps_per_iteration):
|
| 187 |
+
m = self._train_step()
|
| 188 |
+
step_metrics.append(m)
|
| 189 |
+
self.ema_model.update(self._raw_model)
|
| 190 |
+
train_time = time.perf_counter() - train_start
|
| 191 |
+
|
| 192 |
+
iter_time = time.perf_counter() - iter_start
|
| 193 |
+
|
| 194 |
+
# 4. Log
|
| 195 |
+
n_steps = len(step_metrics) or 1
|
| 196 |
+
avg_loss = sum(m["loss"] for m in step_metrics) / n_steps
|
| 197 |
+
avg_loss_diff = sum(m["loss_diff"] for m in step_metrics) / n_steps
|
| 198 |
+
avg_loss_aux = sum(m["loss_aux"] for m in step_metrics) / n_steps
|
| 199 |
+
avg_grad_norm = sum(m["grad_norm"] for m in step_metrics) / n_steps
|
| 200 |
+
current_lr = (
|
| 201 |
+
self.scheduler.get_last_lr()[0]
|
| 202 |
+
if self.scheduler is not None
|
| 203 |
+
else self.cfg.dagger_lr
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
# Global gate value (how open is the global stream)
|
| 207 |
+
gate_val = None
|
| 208 |
+
if hasattr(self._raw_model, "global_gate"):
|
| 209 |
+
gate_val = torch.sigmoid(
|
| 210 |
+
self._raw_model.global_gate
|
| 211 |
+
).item()
|
| 212 |
+
|
| 213 |
+
# Buffer online fraction
|
| 214 |
+
buf_total = len(self.buffer)
|
| 215 |
+
buf_online_frac = (
|
| 216 |
+
(buf_total - self.buffer.offline_size) / max(buf_total, 1)
|
| 217 |
+
if hasattr(self.buffer, "offline_size")
|
| 218 |
+
else 0.0
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
# Samples per second
|
| 222 |
+
total_samples = n_steps * cfg.dagger_batch_size
|
| 223 |
+
samples_per_sec = total_samples / max(train_time, 1e-6)
|
| 224 |
+
|
| 225 |
+
# Env steps per second (uses the iter-summed total, not a
|
| 226 |
+
# single episode — same bug class as the env-step budget).
|
| 227 |
+
env_steps_per_sec = iter_env_steps / max(collect_time, 1e-6)
|
| 228 |
+
|
| 229 |
+
metrics = {
|
| 230 |
+
"diffusion/loss": avg_loss,
|
| 231 |
+
"diffusion/loss_diff": avg_loss_diff,
|
| 232 |
+
"diffusion/loss_aux": avg_loss_aux,
|
| 233 |
+
"train/buffer_size": buf_total,
|
| 234 |
+
"train/buffer_online_frac": buf_online_frac,
|
| 235 |
+
"train/model_won": int(collect_stats["model_won"]),
|
| 236 |
+
"train/added_to_buffer": int(
|
| 237 |
+
collect_stats["added_to_buffer"]
|
| 238 |
+
),
|
| 239 |
+
"train/episodes_collected": n_eps,
|
| 240 |
+
"train/model_steps": collect_stats["model_steps"],
|
| 241 |
+
"train/oracle_steps": collect_stats["oracle_steps"],
|
| 242 |
+
"train/efficiency_ratio": (
|
| 243 |
+
collect_stats["model_steps"]
|
| 244 |
+
/ max(collect_stats["oracle_steps"], 1)
|
| 245 |
+
),
|
| 246 |
+
"train/lr": current_lr,
|
| 247 |
+
"train/grad_norm": avg_grad_norm,
|
| 248 |
+
"train/env_steps": env_steps_total,
|
| 249 |
+
"train/progress": env_steps_total / cfg.total_timesteps,
|
| 250 |
+
"speed/iter_time_sec": iter_time,
|
| 251 |
+
"speed/collect_time_sec": collect_time,
|
| 252 |
+
"speed/train_step_time_sec": train_time,
|
| 253 |
+
"speed/samples_per_sec": samples_per_sec,
|
| 254 |
+
"speed/env_steps_per_sec": env_steps_per_sec,
|
| 255 |
+
"speed/gpu_memory_mb": gpu_memory_mb(),
|
| 256 |
+
# Keep old perf/ keys for backward compat
|
| 257 |
+
"perf/iter_time_s": iter_time,
|
| 258 |
+
"perf/collect_time_s": collect_time,
|
| 259 |
+
"perf/train_time_s": train_time,
|
| 260 |
+
"perf/grad_steps_per_sec": (
|
| 261 |
+
cfg.grad_steps_per_iteration / max(train_time, 1e-6)
|
| 262 |
+
),
|
| 263 |
+
}
|
| 264 |
+
if gate_val is not None:
|
| 265 |
+
metrics["train/global_gate"] = gate_val
|
| 266 |
+
metrics["model/ema_gate_value"] = gate_val
|
| 267 |
+
|
| 268 |
+
# Model health (every 10 iters to avoid overhead)
|
| 269 |
+
if iteration % 10 == 0:
|
| 270 |
+
metrics["model/param_norm"] = compute_param_norm(
|
| 271 |
+
self._raw_model
|
| 272 |
+
)
|
| 273 |
+
metrics["model/param_drift_from_init"] = compute_param_drift(
|
| 274 |
+
self._raw_model, self._init_state
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
# Profile breakdown from GPU-batched collection
|
| 278 |
+
_profile = getattr(self.collector, "_last_profile", {})
|
| 279 |
+
for _pk, _pv in _profile.items():
|
| 280 |
+
metrics[f"profile/{_pk}"] = _pv
|
| 281 |
+
|
| 282 |
+
self.log.log(metrics, step=iteration)
|
| 283 |
+
|
| 284 |
+
# 5. ID eval — triggered when env-step delta crosses threshold
|
| 285 |
+
if (
|
| 286 |
+
cfg.id_eval_every_timesteps > 0
|
| 287 |
+
and env_steps_total - last_id_eval_step
|
| 288 |
+
>= cfg.id_eval_every_timesteps
|
| 289 |
+
):
|
| 290 |
+
eval_model = self.ema_model.make_eval_model(self._raw_model)
|
| 291 |
+
results = self.evaluator.evaluate(
|
| 292 |
+
cfg.id_envs,
|
| 293 |
+
eval_model,
|
| 294 |
+
cfg.eval_episodes_per_env,
|
| 295 |
+
cfg,
|
| 296 |
+
self.device,
|
| 297 |
+
)
|
| 298 |
+
self.log.log_eval(results, step=iteration, prefix="eval_id")
|
| 299 |
+
mean_id_wr = float(np.mean(
|
| 300 |
+
[s["win_rate"] for s in results.values()]
|
| 301 |
+
)) if results else 0.0
|
| 302 |
+
self.log.log(
|
| 303 |
+
{
|
| 304 |
+
"eval_id/mean_win_rate": mean_id_wr,
|
| 305 |
+
**{
|
| 306 |
+
f"curriculum/{env_id}/win_rate":
|
| 307 |
+
self.collector.curriculum.win_rate(env_id)
|
| 308 |
+
for env_id in self.cfg.id_envs
|
| 309 |
+
},
|
| 310 |
+
},
|
| 311 |
+
step=iteration,
|
| 312 |
+
)
|
| 313 |
+
last_id_eval_step = env_steps_total
|
| 314 |
+
|
| 315 |
+
# 6. OOD eval — env-step-triggered
|
| 316 |
+
if (
|
| 317 |
+
cfg.ood_eval_every_timesteps > 0
|
| 318 |
+
and env_steps_total - last_ood_eval_step
|
| 319 |
+
>= cfg.ood_eval_every_timesteps
|
| 320 |
+
):
|
| 321 |
+
eval_model = self.ema_model.make_eval_model(self._raw_model)
|
| 322 |
+
results = self.evaluator.evaluate(
|
| 323 |
+
cfg.ood_envs,
|
| 324 |
+
eval_model,
|
| 325 |
+
cfg.eval_episodes_per_env,
|
| 326 |
+
cfg,
|
| 327 |
+
self.device,
|
| 328 |
+
)
|
| 329 |
+
self.log.log_eval(results, step=iteration, prefix="eval_ood")
|
| 330 |
+
mean_ood_wr = float(np.mean(
|
| 331 |
+
[s["win_rate"] for s in results.values()]
|
| 332 |
+
)) if results else 0.0
|
| 333 |
+
self.log.log(
|
| 334 |
+
{"eval_ood/mean_win_rate": mean_ood_wr}, step=iteration,
|
| 335 |
+
)
|
| 336 |
+
last_ood_eval_step = env_steps_total
|
| 337 |
+
|
| 338 |
+
# 7. Checkpoint — env-step-triggered
|
| 339 |
+
if (
|
| 340 |
+
cfg.checkpoint_every_timesteps > 0
|
| 341 |
+
and env_steps_total - last_ckpt_step
|
| 342 |
+
>= cfg.checkpoint_every_timesteps
|
| 343 |
+
):
|
| 344 |
+
self.save_checkpoint(iteration, env_steps_total)
|
| 345 |
+
last_ckpt_step = env_steps_total
|
| 346 |
+
|
| 347 |
+
iteration += 1
|
| 348 |
+
|
| 349 |
+
# Final checkpoint
|
| 350 |
+
if cfg.save_policy:
|
| 351 |
+
self.save_checkpoint(iteration, env_steps_total)
|
| 352 |
+
|
| 353 |
+
# ── Single gradient step ─────────────────────────────────────
|
| 354 |
+
|
| 355 |
+
def _train_step(self) -> dict[str, float]:
|
| 356 |
+
"""One gradient step on a buffer sample.
|
| 357 |
+
|
| 358 |
+
Uses AMP (mixed precision) when ``cfg.use_amp`` is ``True``
|
| 359 |
+
and training on CUDA.
|
| 360 |
+
|
| 361 |
+
Returns:
|
| 362 |
+
Dict with ``"loss"``, ``"loss_diff"``, ``"loss_aux"``,
|
| 363 |
+
and ``"grad_norm"`` scalars.
|
| 364 |
+
"""
|
| 365 |
+
cfg = self.cfg
|
| 366 |
+
batch = self.buffer.sample(cfg.dagger_batch_size)
|
| 367 |
+
if batch is None:
|
| 368 |
+
return {"loss": 0.0, "loss_diff": 0.0,
|
| 369 |
+
"loss_aux": 0.0, "grad_norm": 0.0}
|
| 370 |
+
local_np, global_np, actions_np = batch
|
| 371 |
+
local_t = torch.from_numpy(local_np).long().to(self.device)
|
| 372 |
+
global_t = torch.from_numpy(global_np).long().to(self.device)
|
| 373 |
+
actions_t = torch.from_numpy(actions_np).long().to(self.device)
|
| 374 |
+
|
| 375 |
+
B = actions_t.shape[0]
|
| 376 |
+
t = torch.rand(B, device=self.device).clamp(1e-5, 1.0 - 1e-5)
|
| 377 |
+
|
| 378 |
+
zt = q_sample(
|
| 379 |
+
actions_t, t, cfg.mask_token, cfg.pad_token,
|
| 380 |
+
self._schedule_fn,
|
| 381 |
+
)
|
| 382 |
+
t_discrete = (t * cfg.num_diffusion_steps).long().clamp(
|
| 383 |
+
0, cfg.num_diffusion_steps - 1,
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
self.optimizer.zero_grad()
|
| 387 |
+
with torch.amp.autocast("cuda", enabled=self._use_amp):
|
| 388 |
+
out = self.model(local_t, global_t, zt, t_discrete)
|
| 389 |
+
|
| 390 |
+
loss_diff = mdlm_loss(
|
| 391 |
+
out["actions"], actions_t, zt, t,
|
| 392 |
+
cfg.mask_token, cfg.pad_token, self._schedule_fn,
|
| 393 |
+
weight_clip=cfg.loss_weight_clip,
|
| 394 |
+
label_smoothing=cfg.label_smoothing,
|
| 395 |
+
use_importance_weighting=cfg.use_importance_weighting,
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
loss_aux = torch.tensor(0.0, device=self.device)
|
| 399 |
+
if "goal_pred" in out:
|
| 400 |
+
loss_aux = auxiliary_goal_loss(out["goal_pred"], global_t)
|
| 401 |
+
|
| 402 |
+
loss = loss_diff + cfg.aux_loss_weight * loss_aux
|
| 403 |
+
|
| 404 |
+
self._scaler.scale(loss).backward()
|
| 405 |
+
self._scaler.unscale_(self.optimizer)
|
| 406 |
+
grad_norm = nn.utils.clip_grad_norm_(
|
| 407 |
+
self.model.parameters(), cfg.dagger_grad_clip,
|
| 408 |
+
)
|
| 409 |
+
self._scaler.step(self.optimizer)
|
| 410 |
+
self._scaler.update()
|
| 411 |
+
if self.scheduler is not None:
|
| 412 |
+
self.scheduler.step()
|
| 413 |
+
|
| 414 |
+
return {
|
| 415 |
+
"loss": loss.item(),
|
| 416 |
+
"loss_diff": loss_diff.item(),
|
| 417 |
+
"loss_aux": loss_aux.item(),
|
| 418 |
+
"grad_norm": grad_norm.item(),
|
| 419 |
+
}
|
| 420 |
+
|
| 421 |
+
# ── Checkpointing ────────────────────────────────────────────
|
| 422 |
+
|
| 423 |
+
def save_checkpoint(
|
| 424 |
+
self, iteration: int, env_steps: int,
|
| 425 |
+
) -> None:
|
| 426 |
+
"""Save a training checkpoint.
|
| 427 |
+
|
| 428 |
+
Args:
|
| 429 |
+
iteration: Current iteration number (for filename + metadata).
|
| 430 |
+
env_steps: Cumulative env.step() count consumed so far.
|
| 431 |
+
"""
|
| 432 |
+
ckpt_dir = Path(self.cfg.checkpoint_dir)
|
| 433 |
+
ckpt_dir.mkdir(parents=True, exist_ok=True)
|
| 434 |
+
path = ckpt_dir / f"iter{iteration}.pth"
|
| 435 |
+
|
| 436 |
+
# Capture W&B run ID for seamless resumption
|
| 437 |
+
wandb_run_id: str | None = None
|
| 438 |
+
if self.log._use_wandb and self.log._run is not None:
|
| 439 |
+
wandb_run_id = self.log._run.id
|
| 440 |
+
|
| 441 |
+
state = {
|
| 442 |
+
"model_state_dict": self._raw_model.state_dict(),
|
| 443 |
+
"ema_state_dict": self.ema_model.state_dict(),
|
| 444 |
+
"optimizer_state_dict": self.optimizer.state_dict(),
|
| 445 |
+
"scheduler_state_dict": (
|
| 446 |
+
self.scheduler.state_dict()
|
| 447 |
+
if self.scheduler is not None
|
| 448 |
+
else None
|
| 449 |
+
),
|
| 450 |
+
"curriculum_state": self.collector.curriculum.state_dict(),
|
| 451 |
+
"iteration": iteration,
|
| 452 |
+
"env_steps": env_steps,
|
| 453 |
+
"wandb_run_id": wandb_run_id,
|
| 454 |
+
"rng_states": {
|
| 455 |
+
"torch": torch.get_rng_state(),
|
| 456 |
+
"numpy": np.random.get_state(),
|
| 457 |
+
"python": random.getstate(),
|
| 458 |
+
},
|
| 459 |
+
}
|
| 460 |
+
|
| 461 |
+
try:
|
| 462 |
+
torch.save(state, path)
|
| 463 |
+
logger.info(f"Checkpoint saved: {path}")
|
| 464 |
+
except Exception:
|
| 465 |
+
logger.error(
|
| 466 |
+
f"Failed to save checkpoint to {path}", exc_info=True,
|
| 467 |
+
)
|
| 468 |
+
|
| 469 |
+
# Save config snapshot alongside checkpoint
|
| 470 |
+
config_path = ckpt_dir / f"config_iter{iteration}.yaml"
|
| 471 |
+
try:
|
| 472 |
+
cfg_dict = {
|
| 473 |
+
k: v for k, v in vars(self.cfg).items()
|
| 474 |
+
if not k.startswith("_")
|
| 475 |
+
}
|
| 476 |
+
with open(config_path, "w") as f:
|
| 477 |
+
yaml.dump(cfg_dict, f, default_flow_style=False)
|
| 478 |
+
except Exception:
|
| 479 |
+
logger.error("Failed to save config snapshot", exc_info=True)
|
| 480 |
+
config_path = None
|
| 481 |
+
|
| 482 |
+
# Run eval at checkpoint and save JSON
|
| 483 |
+
try:
|
| 484 |
+
eval_model = self.ema_model.make_eval_model(self._raw_model)
|
| 485 |
+
id_results = self.evaluator.evaluate(
|
| 486 |
+
self.cfg.id_envs, eval_model,
|
| 487 |
+
self.cfg.checkpoint_eval_episodes,
|
| 488 |
+
self.cfg, self.device,
|
| 489 |
+
)
|
| 490 |
+
ood_results = self.evaluator.evaluate(
|
| 491 |
+
self.cfg.ood_envs, eval_model,
|
| 492 |
+
self.cfg.checkpoint_eval_episodes,
|
| 493 |
+
self.cfg, self.device,
|
| 494 |
+
)
|
| 495 |
+
|
| 496 |
+
id_winrate = float(np.mean(
|
| 497 |
+
[s["win_rate"] for s in id_results.values()]
|
| 498 |
+
)) if id_results else 0.0
|
| 499 |
+
ood_winrate = float(np.mean(
|
| 500 |
+
[s["win_rate"] for s in ood_results.values()]
|
| 501 |
+
)) if ood_results else 0.0
|
| 502 |
+
current_lr = (
|
| 503 |
+
self.scheduler.get_last_lr()[0]
|
| 504 |
+
if self.scheduler is not None
|
| 505 |
+
else self.cfg.dagger_lr
|
| 506 |
+
)
|
| 507 |
+
training_meta = {
|
| 508 |
+
"iteration": iteration,
|
| 509 |
+
"env_steps": env_steps,
|
| 510 |
+
"total_timesteps": self.cfg.total_timesteps,
|
| 511 |
+
"lr": current_lr,
|
| 512 |
+
"dagger_batch_size": self.cfg.dagger_batch_size,
|
| 513 |
+
"aux_loss_weight": self.cfg.aux_loss_weight,
|
| 514 |
+
"buffer_size": len(self.buffer),
|
| 515 |
+
"buffer_capacity": self.cfg.buffer_capacity,
|
| 516 |
+
"ema_decay": self.cfg.ema_decay,
|
| 517 |
+
"grad_steps_per_iteration": self.cfg.grad_steps_per_iteration,
|
| 518 |
+
"episodes_per_iteration": getattr(
|
| 519 |
+
self.cfg, "episodes_per_iteration", 1
|
| 520 |
+
),
|
| 521 |
+
"id_winrate": id_winrate,
|
| 522 |
+
"ood_winrate": ood_winrate,
|
| 523 |
+
"per_env_id": {
|
| 524 |
+
env_id: {
|
| 525 |
+
"win_rate": s["win_rate"],
|
| 526 |
+
"wins": s.get("wins", 0),
|
| 527 |
+
"avg_reward": s["avg_reward"],
|
| 528 |
+
"avg_steps": s["avg_steps"],
|
| 529 |
+
"n_episodes": s["n_episodes"],
|
| 530 |
+
}
|
| 531 |
+
for env_id, s in id_results.items()
|
| 532 |
+
},
|
| 533 |
+
"per_env_ood": {
|
| 534 |
+
env_id: {
|
| 535 |
+
"win_rate": s["win_rate"],
|
| 536 |
+
"wins": s.get("wins", 0),
|
| 537 |
+
"avg_reward": s["avg_reward"],
|
| 538 |
+
"avg_steps": s["avg_steps"],
|
| 539 |
+
"n_episodes": s["n_episodes"],
|
| 540 |
+
}
|
| 541 |
+
for env_id, s in ood_results.items()
|
| 542 |
+
},
|
| 543 |
+
}
|
| 544 |
+
|
| 545 |
+
json_path = ckpt_dir / f"eval_iter{iteration}.json"
|
| 546 |
+
save_eval_json(
|
| 547 |
+
{"id": id_results, "ood": ood_results},
|
| 548 |
+
str(json_path),
|
| 549 |
+
metadata=training_meta,
|
| 550 |
+
)
|
| 551 |
+
|
| 552 |
+
# W&B checkpoint log — per-env step metrics + aggregates
|
| 553 |
+
self.log.log_eval(
|
| 554 |
+
id_results, step=iteration, prefix="ckpt_eval_id",
|
| 555 |
+
)
|
| 556 |
+
self.log.log_eval(
|
| 557 |
+
ood_results, step=iteration, prefix="ckpt_eval_ood",
|
| 558 |
+
)
|
| 559 |
+
self.log.log(
|
| 560 |
+
{
|
| 561 |
+
"ckpt_eval/id_winrate": id_winrate,
|
| 562 |
+
"ckpt_eval/ood_winrate": ood_winrate,
|
| 563 |
+
},
|
| 564 |
+
step=iteration,
|
| 565 |
+
)
|
| 566 |
+
self.log.log_summary({
|
| 567 |
+
f"ckpt_{iteration}/id_winrate": id_winrate,
|
| 568 |
+
f"ckpt_{iteration}/ood_winrate": ood_winrate,
|
| 569 |
+
})
|
| 570 |
+
except Exception:
|
| 571 |
+
logger.error("Checkpoint eval failed", exc_info=True)
|
| 572 |
+
|
| 573 |
+
# HuggingFace Hub upload (no-op if HF_TOKEN or hub_run_id not set)
|
| 574 |
+
try:
|
| 575 |
+
from scripts.hf_upload import maybe_upload_checkpoint
|
| 576 |
+
maybe_upload_checkpoint(
|
| 577 |
+
str(ckpt_dir),
|
| 578 |
+
getattr(self.cfg, "hub_run_id", None),
|
| 579 |
+
getattr(self.cfg, "hub_repo_id", None),
|
| 580 |
+
)
|
| 581 |
+
except Exception:
|
| 582 |
+
logger.error("HF Hub upload failed", exc_info=True)
|
| 583 |
+
|
| 584 |
+
# W&B artifact upload
|
| 585 |
+
self.log.log_checkpoint_artifact(
|
| 586 |
+
checkpoint_path=str(path),
|
| 587 |
+
config_path=str(config_path) if config_path else None,
|
| 588 |
+
iteration=iteration,
|
| 589 |
+
metadata={
|
| 590 |
+
"iteration": iteration,
|
| 591 |
+
"buffer_size": len(self.buffer),
|
| 592 |
+
},
|
| 593 |
+
)
|
| 594 |
+
|
| 595 |
+
def load_checkpoint(self, path: str) -> tuple[int, int]:
|
| 596 |
+
"""Load a training checkpoint.
|
| 597 |
+
|
| 598 |
+
Args:
|
| 599 |
+
path: Path to ``.pth`` checkpoint file.
|
| 600 |
+
|
| 601 |
+
Returns:
|
| 602 |
+
``(start_iter, start_env_steps)`` — the iteration and
|
| 603 |
+
cumulative env-step count to resume from.
|
| 604 |
+
"""
|
| 605 |
+
ckpt = torch.load(
|
| 606 |
+
path, map_location=self.device, weights_only=False,
|
| 607 |
+
)
|
| 608 |
+
self._raw_model.load_state_dict(ckpt["model_state_dict"])
|
| 609 |
+
self.ema_model.load_state_dict(ckpt["ema_state_dict"])
|
| 610 |
+
self.optimizer.load_state_dict(ckpt["optimizer_state_dict"])
|
| 611 |
+
|
| 612 |
+
if (
|
| 613 |
+
self.scheduler is not None
|
| 614 |
+
and ckpt.get("scheduler_state_dict") is not None
|
| 615 |
+
):
|
| 616 |
+
self.scheduler.load_state_dict(ckpt["scheduler_state_dict"])
|
| 617 |
+
|
| 618 |
+
if "curriculum_state" in ckpt:
|
| 619 |
+
self.collector.curriculum.load_state_dict(
|
| 620 |
+
ckpt["curriculum_state"],
|
| 621 |
+
)
|
| 622 |
+
|
| 623 |
+
# Restore RNG states (best-effort)
|
| 624 |
+
rng = ckpt.get("rng_states", {})
|
| 625 |
+
try:
|
| 626 |
+
if "torch" in rng:
|
| 627 |
+
torch.set_rng_state(rng["torch"])
|
| 628 |
+
if "numpy" in rng:
|
| 629 |
+
np.random.set_state(rng["numpy"])
|
| 630 |
+
if "python" in rng:
|
| 631 |
+
random.setstate(rng["python"])
|
| 632 |
+
except Exception:
|
| 633 |
+
logger.warning(
|
| 634 |
+
"RNG state restore failed; continuing with fresh state",
|
| 635 |
+
)
|
| 636 |
+
|
| 637 |
+
iteration = ckpt.get("iteration", 0)
|
| 638 |
+
env_steps = ckpt.get("env_steps", 0)
|
| 639 |
+
resume_from = iteration + 1
|
| 640 |
+
logger.info(
|
| 641 |
+
f"Resumed from checkpoint: {path} (iter {iteration}, "
|
| 642 |
+
f"env_steps={env_steps}), starting at iter {resume_from}"
|
| 643 |
+
)
|
| 644 |
+
return resume_from, env_steps
|
| 645 |
+
|
| 646 |
+
|
| 647 |
+
def run_dagger(
|
| 648 |
+
cfg: SimpleNamespace,
|
| 649 |
+
checkpoint_path: str | None,
|
| 650 |
+
no_warm_start: bool,
|
| 651 |
+
) -> None:
|
| 652 |
+
"""DAgger online training loop."""
|
| 653 |
+
make_run_dir(cfg, tag="dagger")
|
| 654 |
+
|
| 655 |
+
device = cfg.device
|
| 656 |
+
logger.info(f"DAgger training on {device}")
|
| 657 |
+
|
| 658 |
+
raw_model = make_model(cfg).to(device)
|
| 659 |
+
|
| 660 |
+
# EMA and eval always use the raw (uncompiled) model — deep-copying
|
| 661 |
+
# a compiled model breaks FX tracing.
|
| 662 |
+
ema = ModelEMA(raw_model, decay=cfg.ema_decay)
|
| 663 |
+
|
| 664 |
+
# torch.compile: wrap for training only; shares parameters with raw_model
|
| 665 |
+
model = try_compile(raw_model, cfg)
|
| 666 |
+
|
| 667 |
+
optimizer = torch.optim.AdamW(
|
| 668 |
+
raw_model.parameters(), lr=cfg.dagger_lr,
|
| 669 |
+
weight_decay=cfg.weight_decay,
|
| 670 |
+
)
|
| 671 |
+
|
| 672 |
+
buffer = ReplayBuffer(cfg.buffer_capacity, cfg.seq_len, cfg.pad_token)
|
| 673 |
+
curriculum = DynamicCurriculum(
|
| 674 |
+
cfg.id_envs, cfg.curriculum_queue_size, cfg.curriculum_preseed,
|
| 675 |
+
)
|
| 676 |
+
|
| 677 |
+
# Seed buffer with some oracle data
|
| 678 |
+
for i, env_id in enumerate(cfg.id_envs):
|
| 679 |
+
for s in range(3):
|
| 680 |
+
traj = collect_oracle_trajectory(env_id, seed=i * 100 + s, cfg=cfg)
|
| 681 |
+
if traj is not None:
|
| 682 |
+
buffer.add(traj)
|
| 683 |
+
logger.info(f"Buffer seeded with {len(buffer)} windows")
|
| 684 |
+
|
| 685 |
+
# If resuming, extract W&B run ID from checkpoint before Logger init
|
| 686 |
+
# so the same W&B run is continued (curve continuity).
|
| 687 |
+
if checkpoint_path and not no_warm_start:
|
| 688 |
+
resume_id = getattr(cfg, "wandb_resume_id", None)
|
| 689 |
+
if not resume_id:
|
| 690 |
+
ckpt_peek = torch.load(
|
| 691 |
+
checkpoint_path, map_location="cpu", weights_only=False,
|
| 692 |
+
)
|
| 693 |
+
saved_id = ckpt_peek.get("wandb_run_id")
|
| 694 |
+
if saved_id:
|
| 695 |
+
cfg.wandb_resume_id = saved_id
|
| 696 |
+
logger.info(
|
| 697 |
+
f"W&B run ID from checkpoint: {saved_id}"
|
| 698 |
+
)
|
| 699 |
+
del ckpt_peek
|
| 700 |
+
|
| 701 |
+
# DataCollector uses raw_model for eval copies (not compiled)
|
| 702 |
+
collector = DataCollector(ema, raw_model, buffer, curriculum, cfg, device)
|
| 703 |
+
evaluator = Evaluator()
|
| 704 |
+
log = Logger(cfg)
|
| 705 |
+
|
| 706 |
+
trainer = Trainer(
|
| 707 |
+
model, ema, optimizer, None, buffer, collector,
|
| 708 |
+
evaluator, log, cfg, device, raw_model=raw_model,
|
| 709 |
+
)
|
| 710 |
+
|
| 711 |
+
start_iter = 0
|
| 712 |
+
start_env_steps = 0
|
| 713 |
+
if checkpoint_path and not no_warm_start:
|
| 714 |
+
start_iter, start_env_steps = trainer.load_checkpoint(
|
| 715 |
+
checkpoint_path,
|
| 716 |
+
)
|
| 717 |
+
|
| 718 |
+
trainer.train(
|
| 719 |
+
start_iter=start_iter, start_env_steps=start_env_steps,
|
| 720 |
+
)
|
| 721 |
+
log.finish()
|
src/planners/smoke.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from src.buffer import ReplayBuffer
|
| 6 |
+
from src.curriculum import DynamicCurriculum
|
| 7 |
+
from src.envs.minihack_env import collect_oracle_trajectory
|
| 8 |
+
from src.models.denoiser import ModelEMA, make_model, try_compile
|
| 9 |
+
from src.planners.collect import DataCollector
|
| 10 |
+
from src.planners.inference import Evaluator, format_eval_results
|
| 11 |
+
from src.planners.logging import Logger
|
| 12 |
+
from src.planners.online import Trainer
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def run_smoke(cfg) -> None:
|
| 18 |
+
"""Smoke test: collect oracle data, train briefly, eval."""
|
| 19 |
+
|
| 20 |
+
device = cfg.device
|
| 21 |
+
logger.info(f"Smoke test on {device}")
|
| 22 |
+
|
| 23 |
+
# Collect a few oracle trajectories into the buffer
|
| 24 |
+
buffer = ReplayBuffer(cfg.buffer_capacity, cfg.seq_len, cfg.pad_token)
|
| 25 |
+
for i, env_id in enumerate(cfg.id_envs):
|
| 26 |
+
traj = collect_oracle_trajectory(env_id, seed=i, cfg=cfg)
|
| 27 |
+
if traj is not None:
|
| 28 |
+
buffer.add(traj)
|
| 29 |
+
logger.info(f"Buffer seeded with {len(buffer)} windows")
|
| 30 |
+
|
| 31 |
+
raw_model = make_model(cfg).to(device)
|
| 32 |
+
|
| 33 |
+
model = try_compile(raw_model, cfg)
|
| 34 |
+
|
| 35 |
+
ema = ModelEMA(raw_model, decay=cfg.ema_decay)
|
| 36 |
+
optimizer = torch.optim.AdamW(
|
| 37 |
+
raw_model.parameters(), lr=cfg.dagger_lr,
|
| 38 |
+
weight_decay=cfg.weight_decay,
|
| 39 |
+
)
|
| 40 |
+
curriculum = DynamicCurriculum(
|
| 41 |
+
cfg.id_envs, cfg.curriculum_queue_size, cfg.curriculum_preseed,
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
collector = DataCollector(ema, raw_model, buffer, curriculum, cfg, device)
|
| 45 |
+
evaluator = Evaluator()
|
| 46 |
+
log = Logger(cfg)
|
| 47 |
+
|
| 48 |
+
trainer = Trainer(
|
| 49 |
+
model, ema, optimizer, None, buffer, collector,
|
| 50 |
+
evaluator, log, cfg, device, raw_model=raw_model,
|
| 51 |
+
)
|
| 52 |
+
trainer.train(start_iter=0)
|
| 53 |
+
|
| 54 |
+
# Final eval
|
| 55 |
+
eval_model = ema.make_eval_model(raw_model)
|
| 56 |
+
results = evaluator.evaluate(
|
| 57 |
+
cfg.id_envs, eval_model, cfg.eval_episodes_per_env, cfg, device,
|
| 58 |
+
)
|
| 59 |
+
print(format_eval_results(results, label="Smoke"))
|
| 60 |
+
log.log_eval(results, step=0, prefix="smoke_eval")
|
| 61 |
+
mean_wr = float(sum(s["win_rate"] for s in results.values()) / len(results)) if results else 0.0
|
| 62 |
+
log.log({"smoke_eval/mean_win_rate": mean_wr}, step=0)
|
| 63 |
+
log.finish()
|