remdm-minihack / configs /final_qmul_gpu.yaml
MathisW78's picture
Demo notebook payload (source + checkpoint + assets)
f748552 verified
# =============================================================================
# QMUL H200 GPU β€” final paper run config
# =============================================================================
#
# This single config drives BOTH the final DAgger run that produced
# `checkpoint_final/online/final.pth` AND the compute-matched offline
# BC baseline used for the paper comparison.
#
# --mode dagger β†’ reproduces the iter600 DAgger checkpoint recipe
# --mode offline β†’ trains a fair offline BC baseline against it
#
# ── Fairness analysis ───────────────────────────────────────────────
#
# DAgger compute at iter600 (the checkpointed model):
# 600 iters Γ— 100 grad_steps_per_iter Γ— 2048 batch_size
# = 60,000 AdamW updates
# = 122,880,000 sample-equivalents
#
# The fair offline BC baseline matches this exactly:
# offline_total_grad_steps = 60,000 (override; pinned)
# offline_batch_size = 2048 (matches DAgger; same SNR)
# weight_decay, grad_clip, aux_loss_weight, model arch, diffusion
# params: all matched. Model is identical between modes.
#
# LR strategy follows "best-of-each-method" rather than identical
# optimisers β€” DAgger's 3e-5 constant is tuned for online refinement,
# offline's 3e-4 cosine→3e-5 is BC standard from-scratch. Both
# converge to the same effective late-training LR.
#
# Eval/checkpoint cadence is matched in *count* across modes (12 evals,
# 6 checkpoints per run) via the offline_*_every_grad_steps overrides,
# because the env-step→grad-step ratio differs by ~50× between modes.
#
# ── Hardware ─────────────────────────────────────────────────────────
#
# QMUL H200 (constrained VRAM allocation). The DAgger checkpoint was
# produced on this hardware, so batch_size and AMP settings must
# stay identical to the original run. AMP + torch.compile + 32-worker
# collection are the original perf settings.
# ── Environments ─────────────────────────────────────────────────────
id_envs:
- MiniHack-Room-Random-5x5-v0
- MiniHack-Room-Random-15x15-v0
- MiniHack-Corridor-R2-v0
- MiniHack-MazeWalk-9x9-v0
ood_envs:
- MiniHack-Room-Dark-15x15-v0
- MiniHack-Corridor-R5-v0
- MiniHack-MazeWalk-45x19-v0
crop_size: 9
map_h: 21
map_w: 79
action_dim: 12
mask_token: 12
pad_token: 13
# ── Model (matches checkpoint) ───────────────────────────────────────
n_embd: 256
n_head: 4
n_layer: 4
n_global_tokens: 8
seq_len: 64
global_gate_init: -3.0
dropout: 0.0
ema_decay: 0.999
# ── Diffusion (MDLM) β€” matches checkpoint ────────────────────────────
noise_schedule: linear
num_diffusion_steps: 100
loss_weight_clip: 1000.0
label_smoothing: 0.0
use_importance_weighting: false
eta: 0.15
remask_strategy: conf
# ── Inference / sampling β€” matches checkpoint ────────────────────────
diffusion_steps_eval: 10
diffusion_steps_collect: 5
temperature: 0.5
top_k: 4
replan_every: 16
physics_aware_sampling: false
# ── Shared training budget (DAgger only) ─────────────────────────────
# 5.65M env-steps reproduces the env-step budget consumed at iter600
# of the original DAgger run. This figure is calibrated against a real
# DAgger run with the same recipe (`p7wfp67q`, episodes_per_iteration=30,
# grad_steps_per_iteration=100): summing the per-iter env steps over
# the first 600 iterations gives 30 Γ— 600 Γ— mean(model_steps + oracle_steps)
# β‰ˆ 30 Γ— 600 Γ— (198 + 116) β‰ˆ 5.65 M real env.step() calls.
# (The earlier 3M figure was based on the buggy single-episode env-step
# accounting in `online.py:155-169` β€” fixed in the same commit as this
# config bump.) Used by `--mode dagger` only. Offline mode bypasses
# this via `offline_total_grad_steps` below β€” the unified env-step
# budget is fundamentally unfair when the sample-to-grad-step ratio
# differs by ~50Γ— between modes.
total_timesteps: 5650000
# Eval/checkpoint cadence in env-step units (DAgger mode).
# Scaled with the corrected total_timesteps so the run still produces
# ~12 ID/OOD evals and ~6 checkpoints over its full duration.
# 470k β†’ ~12 evals; 940k β†’ ~6 checkpoints.
id_eval_every_timesteps: 470000
ood_eval_every_timesteps: 470000
checkpoint_every_timesteps: 940000
# Final-eval episode count (used by both ID/OOD eval triggers and
# checkpoint-time evals; matches the original DAgger run).
eval_episodes_per_env: 50
checkpoint_eval_episodes: 50
weight_decay: 0.0001
aux_loss_weight: 0.5
# ── DAgger (matches checkpoint_final/online/config_iter600.yaml) ─────
dagger_lr: 0.00003
dagger_batch_size: 2048
dagger_grad_clip: 1.0
buffer_capacity: 10000
episodes_per_iteration: 30
grad_steps_per_iteration: 100
efficiency_multiplier: 1.5
curriculum_queue_size: 100
curriculum_preseed: true
# ── Offline BC (compute-matched fair baseline) ───────────────────────
# Per the fairness analysis above:
# * Same gradient compute as DAgger (60k AdamW updates Γ— 2048 batch)
# * Same model, diffusion, weight_decay, grad_clip, aux_loss
# * BC-tuned LR + cosine schedule (best practice from-scratch)
# * Eval/checkpoint counts matched to DAgger via grad-step overrides
offline_lr: 0.0003
offline_batch_size: 2048
offline_grad_clip: 1.0
# Compute pin: 60,000 AdamW updates = exactly DAgger@iter600.
offline_total_grad_steps: 60000
# Eval cadence: 5,000 grad steps β†’ 12 evals (matches DAgger eval count).
offline_eval_every_grad_steps: 5000
# Checkpoint cadence: 10,000 grad steps β†’ 6 checkpoints (matches DAgger).
offline_checkpoint_every_grad_steps: 10000
# Buffer cap for offline mode only β€” must hold the full pre-collected
# dataset (~1M sliding windows from 20k oracle trajectories). DAgger's
# `buffer_capacity: 10000` would silently FIFO-evict 99% of the data.
offline_buffer_capacity: 1500000
# ── Performance (cluster-tuned, matches original DAgger run) ─────────
use_amp: true
torch_compile: true
num_collection_workers: 32
# ── Data collection (for offline BC dataset) ─────────────────────────
# 5000 eps Γ— 4 ID envs = 20k oracle trajectories. Strictly more than
# the ~7k unique trajectories DAgger had in its filtered buffer at
# iter600 β€” offline always gets a richer pre-collected pool, which is
# the standard fairness asymmetry in BC vs DAgger comparisons.
collect_episodes_per_env: 5000
collect_num_workers: 32
collect_output: data/oracle_bc_qmul.pt
# ── Checkpointing & Logging ──────────────────────────────────────────
checkpoint_dir: checkpoints_qmul
save_policy: true
hub_run_id: null
hub_repo_id: null
use_wandb: true
wandb_project: remdm-minihack
wandb_entity: "mathis-weil-university-college-london-ucl-"
wandb_run_name: null
# wandb_resume_id intentionally omitted β€” fresh runs by default.
# Override on the CLI (`wandb_resume_id=...`) to continue an existing run.
offline_log_every: 50
seed: null