remdm-craftax / src /planners /common.py
MathisW78's picture
Upload COMP0258 demo bundle (code + diffusion/PPO checkpoints + ablation assets)
6140064 verified
"""Shared gradient-step factory, validation rollout, and action diagnostics.
Both :mod:`src.planners.offline` and :mod:`src.planners.online` use identical
gradient update and validation logic. Centralising it here eliminates
duplication.
"""
from __future__ import annotations
from typing import Any, Callable
import jax
import jax.numpy as jnp
import optax
from src.diffusion.loss import compute_loss
from src.diffusion.sampling import sample_plan
from src.diffusion.schedules import ScheduleFn
def resolve_num_updates(config: dict[str, Any], mode: str) -> None:
"""Resolve ``NUM_UPDATES`` from env-frame-denominated config keys.
Mutates ``config`` in place. After this call the runners can read
``NUM_UPDATES`` (and ``OFFLINE_TOTAL_TIMESTEPS`` /
``ONLINE_TOTAL_TIMESTEPS`` depending on mode) without worrying about
whether the user specified the env-frame or update-count form.
Resolution priority:
=========== ============================================================
Mode Priority (highest first)
=========== ============================================================
``offline`` ``OFFLINE_TOTAL_TIMESTEPS`` > ``OFFLINE_NUM_UPDATES``
``online`` ``ONLINE_TOTAL_TIMESTEPS`` > ``ONLINE_NUM_UPDATES``
=========== ============================================================
Env-frame keys are preferred because they are invariant under
``num_envs`` changes — the same value yields the same total environment
experience regardless of hardware sizing, which makes cross-hardware
fairness studies (e.g. UCL 4096-env vs QMUL 96-env) trivially fair
without manual scaling.
The function is idempotent: calling it twice with the same config has
the same effect as calling it once.
Args:
config: Upper-cased config dict. Must contain ``NUM_STEPS`` and
``NUM_ENVS``.
mode: Either ``"offline"`` or ``"online"``.
Raises:
ValueError: If neither the env-frame nor the update-count form is
set for the given mode, or if ``mode`` is unknown.
"""
frames_per_update = int(config["NUM_STEPS"]) * int(config["NUM_ENVS"])
if mode == "offline":
ts_key, nu_key = "OFFLINE_TOTAL_TIMESTEPS", "OFFLINE_NUM_UPDATES"
elif mode == "online":
ts_key, nu_key = "ONLINE_TOTAL_TIMESTEPS", "ONLINE_NUM_UPDATES"
else:
raise ValueError(
f"Unknown mode: {mode!r}; expected 'offline' or 'online'."
)
ts = config.get(ts_key)
nu = config.get(nu_key)
# float() first to accept YAML scientific notation parsed as string
# (PyYAML 1.1 only auto-coerces "3.0e+8", not "3e8" or "3.0e8").
if ts is not None:
num_updates = max(1, int(float(ts)) // frames_per_update)
elif nu:
num_updates = int(float(nu))
else:
raise ValueError(
f"{mode.capitalize()} mode requires either "
f"{ts_key.lower()!r} (env frames, preferred) or "
f"{nu_key.lower()!r} to be set."
)
config["NUM_UPDATES"] = num_updates
# Re-snap so downstream consumers (run names, SPS, checkpoint IDs)
# see the exact integer multiple actually trained.
config[ts_key] = num_updates * frames_per_update
def resolve_scaled_hyperparams(config: dict[str, Any], mode: str) -> None:
"""Resolve env-frame-denominated hyperparameters into update-step form.
Mutates ``config`` in place. PRIMARY (env-frame) keys override LEGACY
(update-step) keys when set, mirroring the
:func:`resolve_num_updates` pattern. When the PRIMARY key is ``None``
the LEGACY value passes through unchanged, preserving full
backward compatibility with configs that predate this resolver.
Resolution table
================
+-----------------------------+------------------------+----------+
| PRIMARY (env-frame) | LEGACY (update-step) | Mode |
+=============================+========================+==========+
| ``LR_WARMUP_FRAMES`` | ``LR_WARMUP_STEPS`` | both |
+-----------------------------+------------------------+----------+
| ``VAL_INTERVAL_FRAMES`` | ``VAL_INTERVAL`` | both |
+-----------------------------+------------------------+----------+
| ``DAGGER_BETA_FINAL`` | ``DAGGER_BETA_DECAY`` | online |
+-----------------------------+------------------------+----------+
| ``DAGGER_BUFFER_CYCLES`` | ``DAGGER_BUFFER_MAX`` | online |
+-----------------------------+------------------------+----------+
Why env-frame units
-------------------
Env-frame values are invariant under ``num_envs`` changes, so the
same config trains the same effective experiment on any GPU. The
update-step legacy keys had to be hand-derived per hardware tier,
which was both error-prone and obscured the conceptual quantity
(e.g. *final beta*, not *per-update decay constant*).
The conversion for ``DAGGER_BETA_FINAL`` requires ``NUM_UPDATES``,
so this function MUST be called after :func:`resolve_num_updates`
when resolving online mode.
Idempotent: calling this twice is equivalent to calling it once.
Args:
config: Upper-cased config dict. Must contain ``NUM_STEPS`` and
``NUM_ENVS``.
mode: Either ``"offline"`` or ``"online"``.
Raises:
ValueError: If ``DAGGER_BETA_FINAL`` is set in online mode but
``NUM_UPDATES`` has not been resolved yet.
"""
fpu = int(config["NUM_STEPS"]) * int(config["NUM_ENVS"])
# float() first to accept YAML scientific notation parsed as string
# (PyYAML 1.1 only auto-coerces "3.0e+8", not "3e8" or "3.0e8").
# ── Mode-agnostic ────────────────────────────────────────────────
warmup_frames = config.get("LR_WARMUP_FRAMES")
if warmup_frames is not None:
config["LR_WARMUP_STEPS"] = int(float(warmup_frames)) // fpu
val_frames = config.get("VAL_INTERVAL_FRAMES")
if val_frames is not None:
config["VAL_INTERVAL"] = max(1, int(float(val_frames)) // fpu)
# ── Online-only ──────────────────────────────────────────────────
if mode != "online":
return
beta_final = config.get("DAGGER_BETA_FINAL")
if beta_final is not None:
num_updates = config.get("NUM_UPDATES")
if num_updates is None:
raise ValueError(
"DAGGER_BETA_FINAL requires NUM_UPDATES to be resolved "
"first; call resolve_num_updates() before "
"resolve_scaled_hyperparams()."
)
beta_init = float(config.get("DAGGER_BETA_INIT", 1.0))
# final = init * decay^N => decay = (final / init) ** (1 / N)
config["DAGGER_BETA_DECAY"] = (
float(beta_final) / beta_init
) ** (1.0 / int(num_updates))
buffer_cycles = config.get("DAGGER_BUFFER_CYCLES")
if buffer_cycles is not None:
config["DAGGER_BUFFER_MAX"] = max(1, int(round(float(buffer_cycles) * fpu)))
def print_config_snapshot(config: dict[str, Any], mode: str) -> None:
"""Print a structured banner of training-critical hyperparameters.
Surfaces fairness-critical, schedule, and architecture parameters at
the start of every offline/online run so cross-hardware comparisons
can be sanity-checked at a glance. Must be called AFTER
:func:`resolve_num_updates` and :func:`resolve_scaled_hyperparams`
so the printed values reflect what training will actually use.
Args:
config: Upper-cased config dict (post-resolver).
mode: Either ``"offline"`` or ``"online"``.
"""
fpu = int(config["NUM_STEPS"]) * int(config["NUM_ENVS"])
num_updates = int(config["NUM_UPDATES"])
minibatch = fpu // int(config["NUM_MINIBATCHES"])
ts_key = f"{mode.upper()}_TOTAL_TIMESTEPS"
total_frames = int(config[ts_key])
bar = "=" * 72
title = f"{mode.upper()} training — config snapshot"
print(f"\n{bar}\n {title}\n{bar}")
print(f" env_name : {config['ENV_NAME']}")
print(f" seed : {config['SEED']}")
print(" -- Rollout / hardware --")
print(f" num_envs = {config['NUM_ENVS']}")
print(f" num_steps = {config['NUM_STEPS']}")
print(f" fpu (envs*steps) = {fpu}")
print(f" num_minibatches = {config['NUM_MINIBATCHES']} (minibatch={minibatch})")
print(f" update_epochs = {config['UPDATE_EPOCHS']}")
print(f" num_repeats = {config.get('NUM_REPEATS', 1)}")
print(" -- Schedule --")
print(f" {ts_key.lower():<24} = {total_frames:,} (~{total_frames/1e6:.1f}M frames)")
print(f" {'num_updates':<24} = {num_updates:,}")
warmup = int(config.get("LR_WARMUP_STEPS", 0))
print(f" {'lr':<24} = {float(config['LR']):.2e}")
print(f" {'lr_warmup_steps':<24} = {warmup} (~{warmup * fpu / 1e6:.2f}M frames)")
print(f" {'max_grad_norm':<24} = {config.get('MAX_GRAD_NORM', 1.0)}")
if mode == "online":
beta_init = float(config.get("DAGGER_BETA_INIT", 1.0))
beta_decay = float(config["DAGGER_BETA_DECAY"])
final_beta = beta_init * beta_decay ** num_updates
buffer_max = int(config["DAGGER_BUFFER_MAX"])
cycles = buffer_max / fpu
# Mirrors the n_train_passes default in run_online: drawn fresh per
# update, capped at samples_per_update for memory.
plan_h = int(config["PLAN_HORIZON"])
samples_per_update = int(config["NUM_ENVS"]) * (
int(config["NUM_STEPS"]) - plan_h + 1
)
n_passes = config.get("DAGGER_TRAIN_PASSES") or max(
1, buffer_max // max(1, samples_per_update)
)
expert_det = bool(config.get("DAGGER_EXPERT_DETERMINISTIC", True))
total_grad_steps = (
num_updates * int(n_passes)
* int(config["UPDATE_EPOCHS"]) * int(config["NUM_MINIBATCHES"])
)
passes_tag = "auto" if config.get("DAGGER_TRAIN_PASSES") is None else "override"
print(" -- DAgger --")
print(f" {'dagger_beta_init':<24} = {beta_init}")
print(f" {'dagger_beta_decay':<24} = {beta_decay:.10f}")
print(f" {'final beta':<24} = {final_beta:.4f} (init * decay^N)")
print(f" {'dagger_buffer_max':<24} = {buffer_max:,} (~{cycles:.2f} update cycles)")
print(f" {'samples_per_update':<24} = {samples_per_update:,}")
print(f" {'dagger_train_passes':<24} = {n_passes} ({passes_tag})")
print(f" {'dagger_expert_determ':<24} = {expert_det}")
print(f" {'total_grad_steps':<24} = {total_grad_steps:,}")
else:
total_grad_steps = (
num_updates * int(config["UPDATE_EPOCHS"]) * int(config["NUM_MINIBATCHES"])
)
print(f" {'total_grad_steps':<24} = {total_grad_steps:,}")
val_int = int(config.get("VAL_INTERVAL", 0))
print(" -- Validation --")
print(f" val_interval = {val_int} updates (~{val_int * fpu / 1e6:.2f}M frames)")
print(f" val_diffusion_steps = {config.get('VAL_DIFFUSION_STEPS')}")
print(f" val_replan_every = {config.get('VAL_REPLAN_EVERY')}")
print(f" val_steps = {config.get('VAL_STEPS')}")
print(" -- Diffusion model --")
print(
f" d_model/n_heads/n_layers/d_ff = "
f"{config['D_MODEL']}/{config['N_HEADS']}/{config['N_LAYERS']}/{config['D_FF']}"
)
print(f" plan_horizon = {config['PLAN_HORIZON']}")
print(f" diffusion_steps = {config['DIFFUSION_STEPS']}")
print(f" remask_strategy = {config.get('REMASK_STRATEGY')} eta={config.get('ETA')}")
print(
f" sampling: temp={config.get('TEMPERATURE')} top_p={config.get('TOP_P')} "
f"loop={config.get('USE_LOOP')} t_on/t_off={config.get('T_ON')}/{config.get('T_OFF')}"
)
print(f"{bar}\n", flush=True)
def _action_stats(
acts: jnp.ndarray,
num_actions: int,
valid: jnp.ndarray,
) -> dict[str, jnp.ndarray]:
"""Compute action-distribution entropy and unique-action fraction over valid windows.
Args:
acts: ``[B, H]`` int32 action sequences.
num_actions: Size of the real action vocabulary.
valid: ``[B]`` bool mask; invalid samples are excluded from counts.
Returns:
Dict with ``action_entropy`` and ``action_unique_frac``.
"""
mask = jnp.broadcast_to(valid[:, None], acts.shape).reshape(-1)
flat = jnp.where(mask, acts.reshape(-1), num_actions + 1)
counts = jnp.bincount(flat, length=num_actions).astype(jnp.float32)
probs = counts / jnp.maximum(counts.sum(), 1.0)
entropy = -jnp.sum(probs * jnp.log(jnp.where(probs > 0, probs, 1.0)))
return {
"action_entropy": entropy,
"action_unique_frac": jnp.sum(probs > 0).astype(jnp.float32) / num_actions,
}
def make_grad_step(
apply_train: Callable,
num_actions: int,
schedule_fn: ScheduleFn,
schedule_deriv_fn: ScheduleFn,
sigma_t: float,
label_smoothing: float,
) -> Callable:
"""Return a jittable gradient update function.
Args:
apply_train: Model apply function with dropout enabled.
num_actions: Size of the action vocabulary.
schedule_fn: alpha(t) noise schedule.
schedule_deriv_fn: d(alpha)/dt analytic derivative.
sigma_t: ReMDM remasking strength during training.
label_smoothing: Cross-entropy label smoothing epsilon.
Returns:
A ``step(state, acts, obs, valid, rng, advantages) -> (state, metrics)``
function ready for use inside ``jax.lax.scan``.
"""
def _loss_fn(
params: Any,
acts: jnp.ndarray,
obs: jnp.ndarray,
valid: jnp.ndarray,
rng: jax.Array,
advantages: jnp.ndarray,
) -> tuple[jnp.ndarray, dict]:
return compute_loss(
apply_train, params, rng, acts, obs, valid,
num_actions, schedule_fn, schedule_deriv_fn,
sigma_t=sigma_t, label_smoothing=label_smoothing,
advantages=advantages,
)
def step(
state: Any,
acts: jnp.ndarray,
obs: jnp.ndarray,
valid: jnp.ndarray,
rng: jax.Array,
advantages: jnp.ndarray,
) -> tuple[Any, dict]:
"""Single gradient update step.
Args:
state: Current ``TrainState``.
acts: ``[B, H]`` int32 action sequences.
obs: ``[B, obs_dim]`` float32 observations.
valid: ``[B]`` bool validity mask (episode-boundary filter).
rng: PRNG key for dropout and noise sampling.
advantages: ``[B]`` float per-sample weights applied before loss reduction.
Returns:
Updated ``TrainState`` and a metrics dict.
"""
(_, info), grads = jax.value_and_grad(_loss_fn, has_aux=True)(
state.params, acts, obs, valid, rng, advantages,
)
state = state.apply_gradients(grads=grads)
info["grad_norm"] = optax.tree.norm(grads)
info.update(_action_stats(acts, num_actions, valid))
return state, info
return step
def make_validate(
env: Any,
env_params: Any,
apply_eval: Callable,
num_actions: int,
plan_horizon: int,
schedule_fn: ScheduleFn,
config: dict[str, Any],
val_replan_every: int,
n_val_cycles: int,
) -> Callable:
"""Return a ``validate(state, rng) -> dict`` closure for periodic eval.
The closure runs a held-out rollout using the diffusion model's current
parameters and returns metrics under the ``val/`` namespace.
Args:
env: Batched Gymnax environment.
env_params: Gymnax environment params.
apply_eval: Model apply function (eval mode, no dropout).
num_actions: Size of the action vocabulary.
plan_horizon: Action plan length H.
schedule_fn: alpha(t) noise schedule.
config: Training config dict (read-only).
val_replan_every: Env steps executed per diffusion plan during validation.
n_val_cycles: Number of plan-execute cycles per validation rollout.
Returns:
A ``validate(state, rng) -> {str: jnp.ndarray}`` closure.
"""
def validate(state: Any, rng: jax.Array) -> dict[str, jnp.ndarray]:
"""Run a validation rollout and return ``val/`` metrics.
Args:
state: Current ``TrainState`` (only ``.params`` is used).
rng: PRNG key.
Returns:
Dict with ``val/`` prefixed metric keys.
"""
rng, val_rng = jax.random.split(rng)
val_obs, val_env_state = env.reset(val_rng, env_params)
def _val_cycle(carry, _):
vs, vo, rng = carry
rng, p_rng = jax.random.split(rng)
plan = sample_plan(
apply_eval,
state.params,
p_rng,
vo,
num_actions,
plan_horizon,
num_steps=config.get("VAL_DIFFUSION_STEPS", 50),
schedule_fn=schedule_fn,
remask_strategy=config.get("REMASK_STRATEGY", "rescale"),
eta=config.get("ETA", 0.5),
use_loop=config.get("USE_LOOP", True),
t_on=config.get("T_ON", 0.7),
t_off=config.get("T_OFF", 0.3),
temperature=config.get("TEMPERATURE", 0.5),
top_p=config.get("TOP_P", 0.95),
) # [num_envs, plan_horizon]
def _exec_step(inner_carry, step_i):
vs_i, vo_i, r = inner_carry
r, s_rng = jax.random.split(r)
vo_next, vs_next, _, _, info = env.step(
s_rng, vs_i, plan[:, step_i], env_params,
)
return (vs_next, vo_next, r), info
(vs, vo, rng), step_infos = jax.lax.scan(
_exec_step, (vs, vo, rng), jnp.arange(val_replan_every),
)
return (vs, vo, rng), step_infos
_, cycle_infos = jax.lax.scan(
_val_cycle, (val_env_state, val_obs, rng), None, n_val_cycles,
)
infos = jax.tree.map(
lambda x: x.reshape(-1, *x.shape[2:]), cycle_infos,
)
returned = infos["returned_episode"]
metrics = jax.tree.map(
lambda x: (x * returned).sum() / (returned.sum() + 1e-8),
infos,
)
return {f"val/{k}": v for k, v in metrics.items()}
return validate