| """Centralised W&B logging utilities for ReMDM training loops. |
| |
| Design principles |
| ----------------- |
| * ``build_log_dict`` is a pure function — no side effects, fully testable. |
| * ``make_wandb_callback`` is a factory that returns a closure suitable for |
| ``jax.debug.callback``. All timing state is local to the closure; there is |
| no module-level global state. |
| * Both ``train.py`` and ``online.py`` import the same symbols, keeping all |
| metric naming and aggregation logic in one place. |
| |
| Metric namespacing |
| ------------------ |
| ``diffusion/`` — ELBO loss, accuracy, and noise-level diagnostics from ``compute_loss``. |
| ``train/`` — data quality, action distribution, throughput. |
| ``env/`` — episode returns and per-achievement unlock rates (training envs). |
| ``val/`` — same as ``env/`` but from the held-out validation rollout |
| (only emitted when ``step_idx % val_interval == 0``). |
| ``dagger/`` — DAgger-specific metrics (online training only). |
| """ |
|
|
| from __future__ import annotations |
|
|
| import time |
| from typing import Any |
|
|
| import wandb |
|
|
|
|
| def init_wandb( |
| config: dict[str, Any], |
| name: str, |
| *, |
| resume_run_id: str | None = None, |
| ) -> None: |
| """Initialise a W&B run, optionally resuming an existing one. |
| |
| Args: |
| config: Training config dict (used for ``project``, ``entity``, |
| and logged as run config). |
| name: Human-readable run name. |
| resume_run_id: If provided, attaches to an existing W&B run via |
| ``wandb.init(id=..., resume="must")``. The run must |
| already exist. |
| """ |
| kwargs: dict[str, Any] = { |
| "project": config.get("WANDB_PROJECT", "remdm-craftax"), |
| "entity": config.get("WANDB_ENTITY"), |
| "config": config, |
| } |
| if resume_run_id is not None: |
| kwargs["id"] = resume_run_id |
| kwargs["resume"] = "must" |
| else: |
| kwargs["name"] = name |
|
|
| wandb.init(**kwargs) |
|
|
|
|
| |
| _DIFFUSION_KEYS: tuple[str, ...] = ( |
| "loss", |
| "unweighted_loss", |
| "accuracy", |
| "acc_t_low", |
| "acc_t_mid", |
| "acc_t_high", |
| "frac_masked", |
| "mean_t", |
| "grad_norm", |
| ) |
|
|
| |
| _TRAIN_KEYS: tuple[str, ...] = ( |
| "action_entropy", |
| "action_unique_frac", |
| "valid_frac", |
| "mean_return_weight", |
| ) |
|
|
| |
| _DAGGER_KEYS: tuple[str, ...] = ( |
| "beta", |
| "reward_mean", |
| "buffer_fill", |
| "valid_frac", |
| "best_val_return", |
| ) |
|
|
|
|
| def build_log_dict( |
| metric: dict[str, Any], |
| step_idx: int, |
| val_interval: int, |
| *, |
| is_online: bool = False, |
| sps: float | None = None, |
| ) -> dict[str, float]: |
| """Build a flat W&B-ready log dict from a merged training metric dict. |
| |
| Args: |
| metric: Merged metric dict from the current update step. |
| step_idx: Integer update step index. |
| val_interval: How often (in steps) validation runs occur. |
| is_online: If ``True``, emit DAgger-specific keys under ``dagger/``. |
| sps: Pre-computed steps-per-second; omitted when ``None``. |
| |
| Returns: |
| Flat ``{str: float}`` dict suitable for ``wandb.log``. |
| """ |
| log: dict[str, float] = {} |
| is_val_step = (step_idx % val_interval == 0) |
|
|
| for k in _DIFFUSION_KEYS: |
| if k in metric: |
| log[f"diffusion/{k}"] = float(metric[k]) |
|
|
| for k in _TRAIN_KEYS: |
| if k in metric: |
| log[f"train/{k}"] = float(metric[k]) |
|
|
| if is_online: |
| for k in _DAGGER_KEYS: |
| if k in metric: |
| log[f"dagger/{k}"] = float(metric[k]) |
|
|
| if "returned_episode_returns" in metric: |
| log["env/episode_return"] = float(metric["returned_episode_returns"]) |
| if "returned_episode_lengths" in metric: |
| log["env/episode_length"] = float(metric["returned_episode_lengths"]) |
|
|
| |
| achieve_total = 0.0 |
| for k, v in metric.items(): |
| if "achievement" in k.lower() and not k.startswith("val/"): |
| log[f"env/achieve/{k}"] = float(v) |
| achieve_total += float(v) / 100.0 |
| log["env/achievements"] = achieve_total |
|
|
| |
| if is_val_step: |
| val_achieve_total = 0.0 |
| for k, v in metric.items(): |
| if not k.startswith("val/"): |
| continue |
| inner = k[4:] |
| if "achievement" in inner.lower(): |
| log[f"val/achieve/{inner}"] = float(v) |
| val_achieve_total += float(v) / 100.0 |
| elif inner == "returned_episode_returns": |
| log["val/episode_return"] = float(v) |
| elif inner == "returned_episode_lengths": |
| log["val/episode_length"] = float(v) |
| log["val/achievements"] = val_achieve_total |
|
|
| if sps is not None: |
| log["train/sps"] = sps |
|
|
| return log |
|
|
|
|
| def make_wandb_callback( |
| config: dict[str, Any], |
| *, |
| steps_per_update: int | None, |
| val_interval: int, |
| is_online: bool = False, |
| ) -> Any: |
| """Return a host-side logging closure for ``jax.debug.callback``. |
| |
| The closure tracks wall-clock time between successive calls to compute |
| steps-per-second. All state is local to the closure; there is no |
| module-level mutable state. |
| |
| SPS is not reported on ``step_idx == 0`` (JIT compilation overhead) or |
| when ``steps_per_update`` is ``None`` (e.g. data-replay mode where no |
| environment frames are consumed). |
| |
| Args: |
| config: Training config dict (read-only; only consulted for |
| ``USE_WANDB`` — callers are expected to guard). |
| steps_per_update: Environment frames consumed per update step. Pass |
| ``None`` to disable ``train/sps`` logging entirely |
| (e.g. when training from pre-collected data files). |
| val_interval: Frequency (in steps) at which validation runs occur. |
| is_online: If ``True``, emit DAgger keys under ``dagger/``. |
| |
| Returns: |
| A callable ``log_fn(metric, step_idx) -> None`` for |
| ``jax.debug.callback``. |
| """ |
| _t: list[float] = [time.time()] |
|
|
| def log_fn(metric: dict[str, Any], step_idx: int) -> None: |
| now = time.time() |
| dt = now - _t[0] |
| _t[0] = now |
|
|
| sps: float | None = ( |
| steps_per_update / dt |
| if steps_per_update is not None and int(step_idx) > 0 and dt > 1e-6 |
| else None |
| ) |
| log = build_log_dict( |
| metric, int(step_idx), val_interval, is_online=is_online, sps=sps, |
| ) |
| wandb.log(log, step=int(step_idx)) |
|
|
| return log_fn |
|
|