File size: 6,881 Bytes
6140064 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 | """Centralised W&B logging utilities for ReMDM training loops.
Design principles
-----------------
* ``build_log_dict`` is a pure function — no side effects, fully testable.
* ``make_wandb_callback`` is a factory that returns a closure suitable for
``jax.debug.callback``. All timing state is local to the closure; there is
no module-level global state.
* Both ``train.py`` and ``online.py`` import the same symbols, keeping all
metric naming and aggregation logic in one place.
Metric namespacing
------------------
``diffusion/`` — ELBO loss, accuracy, and noise-level diagnostics from ``compute_loss``.
``train/`` — data quality, action distribution, throughput.
``env/`` — episode returns and per-achievement unlock rates (training envs).
``val/`` — same as ``env/`` but from the held-out validation rollout
(only emitted when ``step_idx % val_interval == 0``).
``dagger/`` — DAgger-specific metrics (online training only).
"""
from __future__ import annotations
import time
from typing import Any
import wandb
def init_wandb(
config: dict[str, Any],
name: str,
*,
resume_run_id: str | None = None,
) -> None:
"""Initialise a W&B run, optionally resuming an existing one.
Args:
config: Training config dict (used for ``project``, ``entity``,
and logged as run config).
name: Human-readable run name.
resume_run_id: If provided, attaches to an existing W&B run via
``wandb.init(id=..., resume="must")``. The run must
already exist.
"""
kwargs: dict[str, Any] = {
"project": config.get("WANDB_PROJECT", "remdm-craftax"),
"entity": config.get("WANDB_ENTITY"),
"config": config,
}
if resume_run_id is not None:
kwargs["id"] = resume_run_id
kwargs["resume"] = "must"
else:
kwargs["name"] = name
wandb.init(**kwargs)
# Keys emitted by ``src.diffusion.loss.compute_loss`` info dict.
_DIFFUSION_KEYS: tuple[str, ...] = (
"loss",
"unweighted_loss",
"accuracy",
"acc_t_low",
"acc_t_mid",
"acc_t_high",
"frac_masked",
"mean_t",
"grad_norm",
)
# Keys added locally by training loops.
_TRAIN_KEYS: tuple[str, ...] = (
"action_entropy",
"action_unique_frac",
"valid_frac",
"mean_return_weight",
)
# Keys specific to online DAgger training.
_DAGGER_KEYS: tuple[str, ...] = (
"beta",
"reward_mean",
"buffer_fill",
"valid_frac",
"best_val_return",
)
def build_log_dict(
metric: dict[str, Any],
step_idx: int,
val_interval: int,
*,
is_online: bool = False,
sps: float | None = None,
) -> dict[str, float]:
"""Build a flat W&B-ready log dict from a merged training metric dict.
Args:
metric: Merged metric dict from the current update step.
step_idx: Integer update step index.
val_interval: How often (in steps) validation runs occur.
is_online: If ``True``, emit DAgger-specific keys under ``dagger/``.
sps: Pre-computed steps-per-second; omitted when ``None``.
Returns:
Flat ``{str: float}`` dict suitable for ``wandb.log``.
"""
log: dict[str, float] = {}
is_val_step = (step_idx % val_interval == 0)
for k in _DIFFUSION_KEYS:
if k in metric:
log[f"diffusion/{k}"] = float(metric[k])
for k in _TRAIN_KEYS:
if k in metric:
log[f"train/{k}"] = float(metric[k])
if is_online:
for k in _DAGGER_KEYS:
if k in metric:
log[f"dagger/{k}"] = float(metric[k])
if "returned_episode_returns" in metric:
log["env/episode_return"] = float(metric["returned_episode_returns"])
if "returned_episode_lengths" in metric:
log["env/episode_length"] = float(metric["returned_episode_lengths"])
# Per-achievement breakdown + aggregate score (Craftax reports as %, divide by 100).
achieve_total = 0.0
for k, v in metric.items():
if "achievement" in k.lower() and not k.startswith("val/"):
log[f"env/achieve/{k}"] = float(v)
achieve_total += float(v) / 100.0
log["env/achievements"] = achieve_total
# Validation metrics — only emitted on val steps to avoid polluting charts with zeros.
if is_val_step:
val_achieve_total = 0.0
for k, v in metric.items():
if not k.startswith("val/"):
continue
inner = k[4:] # strip leading "val/"
if "achievement" in inner.lower():
log[f"val/achieve/{inner}"] = float(v)
val_achieve_total += float(v) / 100.0
elif inner == "returned_episode_returns":
log["val/episode_return"] = float(v)
elif inner == "returned_episode_lengths":
log["val/episode_length"] = float(v)
log["val/achievements"] = val_achieve_total
if sps is not None:
log["train/sps"] = sps
return log
def make_wandb_callback(
config: dict[str, Any],
*,
steps_per_update: int | None,
val_interval: int,
is_online: bool = False,
) -> Any:
"""Return a host-side logging closure for ``jax.debug.callback``.
The closure tracks wall-clock time between successive calls to compute
steps-per-second. All state is local to the closure; there is no
module-level mutable state.
SPS is not reported on ``step_idx == 0`` (JIT compilation overhead) or
when ``steps_per_update`` is ``None`` (e.g. data-replay mode where no
environment frames are consumed).
Args:
config: Training config dict (read-only; only consulted for
``USE_WANDB`` — callers are expected to guard).
steps_per_update: Environment frames consumed per update step. Pass
``None`` to disable ``train/sps`` logging entirely
(e.g. when training from pre-collected data files).
val_interval: Frequency (in steps) at which validation runs occur.
is_online: If ``True``, emit DAgger keys under ``dagger/``.
Returns:
A callable ``log_fn(metric, step_idx) -> None`` for
``jax.debug.callback``.
"""
_t: list[float] = [time.time()]
def log_fn(metric: dict[str, Any], step_idx: int) -> None:
now = time.time()
dt = now - _t[0]
_t[0] = now
sps: float | None = (
steps_per_update / dt
if steps_per_update is not None and int(step_idx) > 0 and dt > 1e-6
else None
)
log = build_log_dict(
metric, int(step_idx), val_interval, is_online=is_online, sps=sps,
)
wandb.log(log, step=int(step_idx))
return log_fn
|