YAML Metadata Warning:empty or missing yaml metadata in repo card

Check out the documentation for more information.

ReMDM Planner β€” Discrete Diffusion Planning on Craftax

A JAX implementation of ReMDM (Remasking Discrete Diffusion Model) for action-sequence planning in the Craftax environment. A bidirectional transformer learns to generate action plans by iteratively denoising masked token sequences, conditioned on the current environment observation.


Description

The planner starts from a fully-masked action sequence and iteratively unmasks tokens over T denoising steps, producing a plan_horizon-length plan. The ReMDM framework extends standard Masked Discrete Language Modelling (MDLM) with remasking strategies that allow committed tokens to be re-predicted, improving plan coherence.

Training follows a four-stage pipeline:

[Stage 1]  Train PPO agent          Craftax_Baselines/ppo_rnn.py | ppo_rnd.py
               |
               v  checkpoint
[Stage 2a] Collect trajectories     main.py --mode collect          (optional)
               |
               v  .npz file
[Stage 2b] Train offline            main.py --mode offline
               |  (from .npz or live PPO rollouts)
               v  diffusion checkpoint
[Stage 3]  Online fine-tuning       main.py --mode online
               |
               v  fine-tuned checkpoint
[Stage 4]  Evaluate                 main.py --mode inference

Installation

Prerequisites (system-level)

uv manages Python packages only. The following must be installed at the OS level before running on a GPU node β€” they are not in pyproject.toml:

  • CUDA 13 driver and toolkit (libcuda.so, libcudnn)

On HPC clusters these are typically loaded via module load cuda/13.x.

1. Create the virtual environment

# CPU-only (local development / macOS)
uv sync

# NVIDIA CUDA 13 (GPU node β€” Linux only)
uv sync --extra cuda

# Activate
source .venv/bin/activate

uv sync reads pyproject.toml, resolves a fully-reproducible lockfile (uv.lock), and installs into .venv/. Commit uv.lock to pin the exact dependency graph.

2. Initialise the submodule

git submodule update --init --recursive

Dependencies

Package Version Role
jax >=0.9.2 JIT compilation and functional arrays
flax >=0.12.6 Neural network definitions
optax >=0.2.8 Adam optimiser and gradient clipping
craftax >=1.5.0 Procedurally-generated Minecraft-like environment
chex >=0.1.91 JAX testing and assertion utilities
distrax >=0.1.7 Probability distributions
orbax >=0.1.9 Model checkpointing
wandb >=0.25.1 Experiment logging
numpy >=2.4.4 Array operations
matplotlib >=3.10.8 Plotting
polars >=1.39.3 DataFrame analysis
orjson >=3.11.8 Fast JSON serialisation
pyyaml >=6.0.3 Config file parsing

Full specification in pyproject.toml. Exact transitive pins are in uv.lock.


Usage

All modes share the same entry point. Defaults are loaded from configs/defaults.yaml; any value can be overridden on the command line.

python main.py --mode <MODE> [--config PATH] [OVERRIDES...]

Pass --no-jit to disable JIT compilation (useful for debugging):

python main.py --mode offline --no-jit --num_envs 4

Stage 1 β€” Train a PPO agent

PPO training is handled by the Craftax_Baselines submodule and produces the checkpoint consumed by all downstream stages.

cd Craftax_Baselines

# PPO with GRU hidden state (recommended)
python ppo_rnn.py \
    --env_name Craftax-Classic-Symbolic-v1 \
    --total_timesteps 500000000 \
    --save_policy --use_wandb

# PPO with Random Network Distillation
python ppo_rnd.py \
    --env_name Craftax-Classic-Symbolic-v1 \
    --total_timesteps 500000000 \
    --save_policy --use_wandb

cd ..

Stage 2a β€” Collect trajectories to disk

Roll out the PPO checkpoint and save (obs, actions, rewards, dones) as a .npz file for reuse across multiple diffusion training runs.

python main.py --mode collect \
    --ppo_checkpoint_path /path/to/ppo_checkpoint \
    --offline_data_path data/trajectories.npz \
    --collect_num_steps 1000000 \
    --collect_num_envs 128

The file stores arrays shaped [num_envs, num_iters, ...], preserving per-environment contiguity so episode boundaries are respected during window sampling.

Stage 2b β€” Train offline from live PPO rollouts

Roll out the PPO agent live at each update step and train the diffusion model on the collected windows. Windows that cross episode boundaries are masked out; windows with higher cumulative reward receive proportionally larger gradient contributions (clipped to [0.1, return_weight_cap]).

python main.py --mode offline \
    --ppo_checkpoint_path /path/to/ppo_checkpoint \
    --total_timesteps 100000000 \
    --save_policy

Stage 3 β€” Online DAgger fine-tuning

The diffusion model (learner) is fine-tuned via DAgger (Dataset Aggregation). At each iteration a mixed policy blends the PPO expert and the diffusion learner (controlled by an exponentially decaying beta). The mixed policy rolls out trajectories; the expert labels every visited state with the action it would take. These (obs, expert_plan) pairs are appended to a growing circular replay buffer, and the diffusion model is retrained on the full buffer with the standard MDLM ELBO loss (pure behavioural cloning β€” no reward weighting).

# From scratch (requires PPO expert checkpoint)
python main.py --mode online \
    --ppo_checkpoint_path /path/to/ppo_checkpoint \
    --num_updates 1000 \
    --save_policy

# Warm-start from an offline checkpoint
python main.py --mode online \
    --ppo_checkpoint_path /path/to/ppo_checkpoint \
    --offline_checkpoint_path /path/to/offline_checkpoint \
    --num_updates 1000 \
    --save_policy

Stage 4 β€” Evaluate

python main.py --mode inference \
    --checkpoint_path /path/to/checkpoint \
    --eval_steps 10000 \
    --eval_num_envs 32

Prints mean episode return, completed episodes, steps per second, and per-achievement unlock counts. Uses historical inpainting: the first hist_len plan positions are locked to observed history.

Loading checkpoints from W&B artifacts

Any checkpoint path argument (--checkpoint_path, --offline_checkpoint_path, --ppo_checkpoint_path) accepts a W&B artifact reference prefixed with wandb:. The artifact is downloaded automatically before training or evaluation begins.

# Fully qualified: entity/project/artifact_name:version_or_alias
python main.py --mode inference \
    --checkpoint_path wandb:my-team/remdm-craftax/Craftax-Classic-Symbolic-v1-policy:latest

# Online fine-tuning from a W&B offline checkpoint
python main.py --mode online \
    --offline_checkpoint_path wandb:my-team/remdm-craftax/Craftax-Classic-Symbolic-v1-policy:v3

# PPO checkpoint from W&B
python main.py --mode offline \
    --ppo_checkpoint_path wandb:my-team/ppo-craftax/ppo-rnn-policy:best

Control the download location with --wandb_download_dir (defaults to ./artifacts/).

Resuming a Training Run

A completed training checkpoint can be used as the starting point for a new run that continues where the previous one left off. This is useful when extending the training budget or when a preempted job needs to be restarted.

Offline resume:

# Auto-detect step and wandb run ID from checkpoint metadata
python main.py --mode offline \
    --ppo_checkpoint_path /path/to/ppo_checkpoint \
    --resume_checkpoint_path /path/to/completed_offline_checkpoint \
    --total_timesteps 200000000 \
    --save_policy

# Explicit step and wandb run ID override
python main.py --mode offline \
    --ppo_checkpoint_path /path/to/ppo_checkpoint \
    --resume_checkpoint_path /path/to/completed_offline_checkpoint \
    --resume_step 1525 \
    --resume_wandb_run_id abc123xyz \
    --total_timesteps 200000000 \
    --save_policy

# Resume from a W&B artifact
python main.py --mode offline \
    --ppo_checkpoint_path /path/to/ppo_checkpoint \
    --resume_checkpoint_path wandb:my-team/remdm-craftax/policy:latest \
    --total_timesteps 200000000 \
    --save_policy

Online resume:

python main.py --mode online \
    --ppo_checkpoint_path /path/to/ppo_checkpoint \
    --resume_checkpoint_path /path/to/completed_online_checkpoint \
    --num_updates 2000 \
    --save_policy

Notes:

  • The DAgger replay buffer is not persisted across resumes. It starts empty and refills within the first few iterations.
  • JIT compilation is fully preserved. Resume only affects initialisation outside jax.jit (loading checkpoint, setting the optimizer step counter, adjusting scan length).
  • The cosine LR schedule is constructed for the full num_updates range. The optimizer step counter is set to the resume offset so the learning rate picks up exactly where the previous run stopped.
  • When resume_checkpoint_path points to a checkpoint with a metadata sidecar, resume_step and resume_wandb_run_id are auto-detected. Explicit CLI flags override the metadata values.
  • Checkpoints without a metadata sidecar (created before this feature) still load; provide --resume_step explicitly.

Configuration

All hyperparameters are in configs/defaults.yaml. Override any value on the command line:

python main.py --mode offline --lr 1e-4 --plan_horizon 64 --num_minibatches 16

Point to a custom config file:

python main.py --mode online --config configs/my_experiment.yaml

Preset configs for larger runs are provided in configs/:

File Purpose
configs/defaults.yaml Base defaults for all modes
configs/big_diffusion_offline.yaml Larger model for offline training
configs/big_diffusion_online.yaml Larger model for online training
configs/A100_diffusion_offline.yaml A100-tuned offline config
configs/A100_diffusion_online.yaml A100-tuned online config
configs/ucl_4090_3090.yaml UCL RTX 4090/3090 preset
configs/ucl_4070.yaml UCL RTX 4070 preset
configs/ali_gpu.yaml Ali GPU preset
configs/qmul_h200.yaml QMUL H200 preset
configs/ablations.yaml RL fine-tuning ablation hyperparameters (loaded by experiments/, not main.py)

Key hyperparameters

Environment

Parameter Default Description
env_name Craftax-Classic-Symbolic-v1 Craftax environment ID
use_optimistic_resets false Use OptimisticResetVecEnvWrapper instead of AutoResetEnvWrapper
optimistic_reset_ratio 16 Fraction of envs reset per step when optimistic resets are enabled

Diffusion model

Parameter Default Description
plan_horizon 32 Action plan length H
diffusion_steps 15 Denoising steps T at inference
diffusion_schedule cosine Noise schedule: cosine or linear
remask_strategy rescale Remasking strategy: rescale, cap, or conf
train_sigma 0.0 Per-token remasking correction during training (0 = standard MDLM)
label_smoothing 0.0 Cross-entropy label smoothing epsilon (0 = exact ELBO)
eta 0.5 Remasking strength
use_loop true Three-phase loop remasking (Algorithm 3)
t_on / t_off 0.7 / 0.3 Time window boundaries for loop remasking
temperature 0.5 Softmax temperature for token sampling
top_p 0.95 Nucleus sampling threshold

Transformer architecture

Parameter Default Description
d_model 256 Hidden dimension
n_heads 4 Attention heads
n_layers 4 Transformer blocks
d_ff 512 FFN inner dimension
obs_encoder_layers 2 MLP layers in the observation encoder
obs_encoder_width 512 Observation encoder hidden width
dropout_rate 0.1 Dropout rate (disabled at inference)

Offline training

Parameter Default Description
total_timesteps 1e8 Total environment steps for live-PPO data collection
num_envs 1024 Parallel environments
num_steps 64 Environment steps collected per update
num_minibatches 8 Gradient minibatches per epoch
update_epochs 4 SGD epochs per update step
num_repeats 1 Independent training seeds (vmapped)
lr 3e-4 Adam learning rate (cosine-decayed to 10% over all gradient steps)
lr_warmup_steps 0 Linear warm-up steps before cosine decay (0 = disabled)
max_grad_norm 1.0 Global gradient clipping norm
batch_size 768 Minibatch size
return_weight_cap 5.0 Clip ceiling for per-window return weights
collect_temperature 1.0 Softmax temperature on PPO logits during live data collection
val_interval 50 Validation frequency in update steps
val_diffusion_steps 50 Denoising steps used during validation rollouts
val_replan_every 4 Environment steps executed per diffusion plan during validation
val_steps 128 Total environment steps per validation rollout

Online DAgger training

Parameter Default Description
num_updates 1000 Outer DAgger iterations
replan_every 4 Environment steps per diffusion plan during validation
dagger_beta_init 1.0 Initial expert mixing probability beta_1 (1.0 = pure expert first iteration)
dagger_beta_decay 0.95 Exponential decay: beta_i = beta_1 * decay^i
dagger_buffer_max 100000 Max samples in the DAgger replay buffer (circular eviction when full)

Data collection

Parameter Default Description
collect_num_steps 10000000 Total environment steps to collect
collect_num_envs 128 Parallel environments during collection
ppo_model_type ppo_rnn PPO architecture: ppo, ppo_rnn, or ppo_rnd
layer_size 512 PPO actor-critic hidden layer width

Inference

Parameter Default Description
eval_steps 10000 Environment steps for evaluation
eval_num_envs 32 Parallel agents during evaluation (independent of num_envs)
diffusion_steps_eval 10 Denoising steps T used at evaluation time

Checkpointing

Parameter Default Description
checkpoint_dir checkpoints_online Directory for periodic checkpoints
save_policy true Save final checkpoint at end of training

Resume

Parameter Default Description
resume_checkpoint_path null Path to a completed checkpoint to resume from (accepts wandb: refs)
resume_wandb_run_id null W&B run ID to resume logging into (auto-read from checkpoint metadata)
resume_step null Update step the checkpoint was saved at (auto-read from checkpoint metadata)

Logging

Parameter Default Description
use_wandb true Enable Weights & Biases logging
wandb_project remdm-craftax W&B project name
wandb_entity "mathis-weil-university-college-london-ucl-" W&B entity (team or username)
wandb_download_dir null Download directory for W&B artifacts; null = ./artifacts/
seed null RNG seed (random if null)

Remasking Strategies

Controlled by --remask_strategy. All strategies operate on top of the three-phase loop controlled by --use_loop, --t_on, and --t_off.

Strategy Formula Description
rescale sigma = eta * sigma_max Scales maximum remasking probability proportionally
cap sigma = min(eta, sigma_max) Caps remasking at a fixed rate
conf sigma = eta * sigma_max * (1 - confidence) High-confidence tokens are remasked less

Environment Wrappers

From Craftax_Baselines/wrappers.py (submodule):

Wrapper Purpose
LogWrapper Tracks episode returns and lengths; adds stats to the info dict
AutoResetEnvWrapper Automatically resets episodes on done
BatchEnvWrapper Vmaps reset and step over num_envs environments
OptimisticResetVecEnvWrapper Batched resets with reduced overhead; enable via --use_optimistic_resets

From src/envs/wrappers.py:

Wrapper Purpose
SequenceHistoryWrapper Maintains a sliding window of past observations and actions in the env state
DiscreteTokenizationWrapper Quantizes continuous observations into discrete token indices
PlannerWrapper Manages the plan/replan cycle for the diffusion planner
OfflineTrajectoryWrapper Accumulates transitions into a fixed-size circular replay buffer

Wrapper stacks:

Training:   env -> LogWrapper -> AutoResetEnvWrapper -> BatchEnvWrapper
Inference:  env -> LogWrapper -> AutoResetEnvWrapper -> BatchEnvWrapper

Project Structure

craftax-ReMDM-planner/
β”œβ”€β”€ Craftax_Baselines/             # Git submodule β€” PPO agents and standard wrappers
β”‚   β”œβ”€β”€ wrappers.py                # LogWrapper, BatchEnvWrapper, AutoResetEnvWrapper, etc.
β”‚   β”œβ”€β”€ ppo_rnn.py                 # PPO-RNN training script
β”‚   β”œβ”€β”€ ppo_rnd.py                 # PPO-RND training script
β”‚   β”œβ”€β”€ ppo.py                     # PPO model definitions
β”‚   └── models/
β”‚       β”œβ”€β”€ actor_critic.py        # ActorCritic variants
β”‚       β”œβ”€β”€ rnd.py                 # RND network
β”‚       └── icm.py                 # ICM encoder, forward, and inverse networks
β”œβ”€β”€ configs/
β”‚   β”œβ”€β”€ defaults.yaml              # Base hyperparameters (CLI-overridable)
β”‚   β”œβ”€β”€ big_diffusion_offline.yaml
β”‚   β”œβ”€β”€ big_diffusion_online.yaml
β”‚   β”œβ”€β”€ A100_diffusion_offline.yaml
β”‚   β”œβ”€β”€ A100_diffusion_online.yaml
β”‚   β”œβ”€β”€ ucl_4090_3090.yaml         # UCL RTX 4090/3090 preset
β”‚   β”œβ”€β”€ ucl_4070.yaml              # UCL RTX 4070 preset
β”‚   β”œβ”€β”€ ali_gpu.yaml               # Ali GPU preset
β”‚   β”œβ”€β”€ qmul_h200.yaml            # QMUL H200 preset
β”‚   └── ablations.yaml             # RL fine-tuning ablation hyperparameters
β”œβ”€β”€ src/
β”‚   β”œβ”€β”€ diffusion/
β”‚   β”‚   β”œβ”€β”€ forward.py             # Forward masking process q(z_t | x_0)
β”‚   β”‚   β”œβ”€β”€ loss.py                # Continuous-time MDLM ELBO loss
β”‚   β”‚   β”œβ”€β”€ sampling.py            # Reverse diffusion with ReMDM remasking
β”‚   β”‚   └── schedules.py           # Linear and cosine noise schedules
β”‚   β”œβ”€β”€ models/
β”‚   β”‚   └── denoiser.py            # DenoisingTransformer (obs encoder + transformer)
β”‚   β”œβ”€β”€ envs/
β”‚   β”‚   └── wrappers.py            # Sequence, tokenization, planner, and trajectory wrappers
β”‚   └── planners/
β”‚       β”œβ”€β”€ collect.py             # --mode collect: PPO rollouts -> .npz
β”‚       β”œβ”€β”€ common.py              # Shared utilities
β”‚       β”œβ”€β”€ env.py                 # Environment construction
β”‚       β”œβ”€β”€ inference.py           # --mode inference: MPC evaluation with inpainting
β”‚       β”œβ”€β”€ logging.py             # Centralised W&B logging utilities
β”‚       β”œβ”€β”€ model.py               # Diffusion model lifecycle
β”‚       β”œβ”€β”€ offline.py             # --mode offline: make_train (live PPO rollouts)
β”‚       β”œβ”€β”€ online.py              # --mode online: DAgger fine-tuning
β”‚       └── ppo.py                 # PPO agent adapter and checkpoint loading utilities            
β”œβ”€β”€ experiments/
β”‚   └── rl_finetuning/             # RL fine-tuning ablation suite (see experiments/README.md)
β”‚       β”œβ”€β”€ run_ablations.py       # CLI entry point
β”‚       β”œβ”€β”€ ablations/             # Loss, optimizer, and registry modules
β”‚       β”œβ”€β”€ diagnostics/           # Gradient, representation, and timestep diagnostics
β”‚       β”œβ”€β”€ analysis/              # Plots, tables, and report generation
β”‚       └── configs/               # ablations_default.yaml, ablations_fast.yaml
β”œβ”€β”€ main.py                        # CLI entry point
β”œβ”€β”€ pyproject.toml                 # uv project β€” direct deps + tool config
└── uv.lock                        # Reproducible lockfile (commit this)

Implementation Notes

JAX functional purity: training closures (make_train, make_train_dagger) are fully JIT-compatible. Environment construction and checkpoint I/O happen outside jax.jit.

Offline training: --mode offline rolls out the PPO agent live at each update step via make_train. Use --mode collect to save a trajectory .npz for inspection or analysis; re-feeding it to --mode offline is not supported β€” pass --ppo_checkpoint_path instead.

Episode-boundary masking: the offline sampler pre-computes a validity mask over all (env, time) positions. A window at (e, t) is valid only if dones[e, t+1:t+H-1] are all False.

Return weighting: valid windows are weighted by their cumulative reward, normalised by the batch mean and clipped to [0.1, RETURN_WEIGHT_CAP]. Weights are passed as per-sample multipliers into the MDLM loss before reduction, so they correctly scale each sample's gradient contribution.

LR schedule: cosine decay from lr to lr * 0.1 over all gradient steps. Set lr_warmup_steps > 0 to prepend a linear warm-up phase.

Loss weight clipping: the MDLM SUBS weight -alpha'(t) / (1 - alpha_t) is clipped to 1000 to prevent numerical instability when alpha_t β‰ˆ 1.

Validation rollouts: during offline training, a held-out rollout runs every val_interval steps. It uses the same sampling parameters as inference (remask_strategy, eta, use_loop, t_on, t_off, temperature, top_p) with val_diffusion_steps denoising steps and val_replan_every env steps per plan, for a total of val_steps environment steps.

W&B logging: all metric aggregation is centralised in src/planners/logging.py. Metric namespaces: diffusion/ (loss, accuracy), train/ (data quality, throughput), env/ (episode returns, achievements), val/ (validation rollouts, emitted every val_interval steps), dagger/ (online DAgger training: beta, buffer fill, reward mean, valid fraction). train/sps (environment frames/sec) is only logged in modes that perform live environment interaction.

DAgger dataset aggregation: online training (--mode online) implements DAgger (Ross et al., 2011). A circular replay buffer accumulates (obs, expert_plan) pairs across all iterations. Each update samples uniformly from the full buffer, not just the latest batch. Training samples that cross episode boundaries (any done within the plan-horizon window) are marked invalid. The expert (PPO agent) receives correct done flags so its RNN hidden state resets on episode boundaries.

Denoising step indexing: the reverse scan runs from step_idx = 0 to T-1, mapping to diffusion time t = (T - step_idx) / T (high noise to low noise).

Submodule PPO agents: PPO training lives entirely in Craftax_Baselines/. Planner scripts only consume pre-trained checkpoints via --ppo_checkpoint_path.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support