remdm-craftax / src /planners /logging.py
MathisW78's picture
Upload COMP0258 demo bundle (code + diffusion/PPO checkpoints + ablation assets)
6140064 verified
"""Centralised W&B logging utilities for ReMDM training loops.
Design principles
-----------------
* ``build_log_dict`` is a pure function — no side effects, fully testable.
* ``make_wandb_callback`` is a factory that returns a closure suitable for
``jax.debug.callback``. All timing state is local to the closure; there is
no module-level global state.
* Both ``train.py`` and ``online.py`` import the same symbols, keeping all
metric naming and aggregation logic in one place.
Metric namespacing
------------------
``diffusion/`` — ELBO loss, accuracy, and noise-level diagnostics from ``compute_loss``.
``train/`` — data quality, action distribution, throughput.
``env/`` — episode returns and per-achievement unlock rates (training envs).
``val/`` — same as ``env/`` but from the held-out validation rollout
(only emitted when ``step_idx % val_interval == 0``).
``dagger/`` — DAgger-specific metrics (online training only).
"""
from __future__ import annotations
import time
from typing import Any
import wandb
def init_wandb(
config: dict[str, Any],
name: str,
*,
resume_run_id: str | None = None,
) -> None:
"""Initialise a W&B run, optionally resuming an existing one.
Args:
config: Training config dict (used for ``project``, ``entity``,
and logged as run config).
name: Human-readable run name.
resume_run_id: If provided, attaches to an existing W&B run via
``wandb.init(id=..., resume="must")``. The run must
already exist.
"""
kwargs: dict[str, Any] = {
"project": config.get("WANDB_PROJECT", "remdm-craftax"),
"entity": config.get("WANDB_ENTITY"),
"config": config,
}
if resume_run_id is not None:
kwargs["id"] = resume_run_id
kwargs["resume"] = "must"
else:
kwargs["name"] = name
wandb.init(**kwargs)
# Keys emitted by ``src.diffusion.loss.compute_loss`` info dict.
_DIFFUSION_KEYS: tuple[str, ...] = (
"loss",
"unweighted_loss",
"accuracy",
"acc_t_low",
"acc_t_mid",
"acc_t_high",
"frac_masked",
"mean_t",
"grad_norm",
)
# Keys added locally by training loops.
_TRAIN_KEYS: tuple[str, ...] = (
"action_entropy",
"action_unique_frac",
"valid_frac",
"mean_return_weight",
)
# Keys specific to online DAgger training.
_DAGGER_KEYS: tuple[str, ...] = (
"beta",
"reward_mean",
"buffer_fill",
"valid_frac",
"best_val_return",
)
def build_log_dict(
metric: dict[str, Any],
step_idx: int,
val_interval: int,
*,
is_online: bool = False,
sps: float | None = None,
) -> dict[str, float]:
"""Build a flat W&B-ready log dict from a merged training metric dict.
Args:
metric: Merged metric dict from the current update step.
step_idx: Integer update step index.
val_interval: How often (in steps) validation runs occur.
is_online: If ``True``, emit DAgger-specific keys under ``dagger/``.
sps: Pre-computed steps-per-second; omitted when ``None``.
Returns:
Flat ``{str: float}`` dict suitable for ``wandb.log``.
"""
log: dict[str, float] = {}
is_val_step = (step_idx % val_interval == 0)
for k in _DIFFUSION_KEYS:
if k in metric:
log[f"diffusion/{k}"] = float(metric[k])
for k in _TRAIN_KEYS:
if k in metric:
log[f"train/{k}"] = float(metric[k])
if is_online:
for k in _DAGGER_KEYS:
if k in metric:
log[f"dagger/{k}"] = float(metric[k])
if "returned_episode_returns" in metric:
log["env/episode_return"] = float(metric["returned_episode_returns"])
if "returned_episode_lengths" in metric:
log["env/episode_length"] = float(metric["returned_episode_lengths"])
# Per-achievement breakdown + aggregate score (Craftax reports as %, divide by 100).
achieve_total = 0.0
for k, v in metric.items():
if "achievement" in k.lower() and not k.startswith("val/"):
log[f"env/achieve/{k}"] = float(v)
achieve_total += float(v) / 100.0
log["env/achievements"] = achieve_total
# Validation metrics — only emitted on val steps to avoid polluting charts with zeros.
if is_val_step:
val_achieve_total = 0.0
for k, v in metric.items():
if not k.startswith("val/"):
continue
inner = k[4:] # strip leading "val/"
if "achievement" in inner.lower():
log[f"val/achieve/{inner}"] = float(v)
val_achieve_total += float(v) / 100.0
elif inner == "returned_episode_returns":
log["val/episode_return"] = float(v)
elif inner == "returned_episode_lengths":
log["val/episode_length"] = float(v)
log["val/achievements"] = val_achieve_total
if sps is not None:
log["train/sps"] = sps
return log
def make_wandb_callback(
config: dict[str, Any],
*,
steps_per_update: int | None,
val_interval: int,
is_online: bool = False,
) -> Any:
"""Return a host-side logging closure for ``jax.debug.callback``.
The closure tracks wall-clock time between successive calls to compute
steps-per-second. All state is local to the closure; there is no
module-level mutable state.
SPS is not reported on ``step_idx == 0`` (JIT compilation overhead) or
when ``steps_per_update`` is ``None`` (e.g. data-replay mode where no
environment frames are consumed).
Args:
config: Training config dict (read-only; only consulted for
``USE_WANDB`` — callers are expected to guard).
steps_per_update: Environment frames consumed per update step. Pass
``None`` to disable ``train/sps`` logging entirely
(e.g. when training from pre-collected data files).
val_interval: Frequency (in steps) at which validation runs occur.
is_online: If ``True``, emit DAgger keys under ``dagger/``.
Returns:
A callable ``log_fn(metric, step_idx) -> None`` for
``jax.debug.callback``.
"""
_t: list[float] = [time.time()]
def log_fn(metric: dict[str, Any], step_idx: int) -> None:
now = time.time()
dt = now - _t[0]
_t[0] = now
sps: float | None = (
steps_per_update / dt
if steps_per_update is not None and int(step_idx) > 0 and dt > 1e-6
else None
)
log = build_log_dict(
metric, int(step_idx), val_interval, is_online=is_online, sps=sps,
)
wandb.log(log, step=int(step_idx))
return log_fn