File size: 9,474 Bytes
f748552 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 | """Centralised W&B and stdout logging.
Mirrors the Craftax logging conventions with metric namespaces:
``diffusion/``, ``train/``, ``eval_id/``, ``eval_ood/``.
"""
from __future__ import annotations
import logging
import torch
from typing import TYPE_CHECKING
from types import SimpleNamespace
if TYPE_CHECKING:
from wandb.sdk.wandb_run import Run as _WandbRun
logger = logging.getLogger(__name__)
def download_artifact(
artifact_ref: str, dst_dir: str = "artifacts",
) -> str | None:
"""Download a W&B artifact via the public API (no active run needed).
Args:
artifact_ref: Fully qualified artifact reference, e.g.
``"entity/project/checkpoint-iter1000:latest"``.
dst_dir: Local directory to download into.
Returns:
Path to the ``.pth`` file inside the downloaded artifact
directory, or ``None`` on failure.
"""
try:
import wandb
from pathlib import Path
api = wandb.Api()
artifact = api.artifact(artifact_ref)
artifact_dir = artifact.download(root=dst_dir)
pth_files = list(Path(artifact_dir).glob("*.pth"))
if not pth_files:
logger.error(
f"No .pth file found in artifact {artifact_ref}"
)
return None
path = str(pth_files[0])
logger.info(f"Downloaded artifact {artifact_ref} -> {path}")
return path
except Exception:
logger.error(
f"Failed to download artifact {artifact_ref}",
exc_info=True,
)
return None
def _auto_run_name(cfg: SimpleNamespace) -> str:
"""Generate a descriptive W&B run name from key hyperparameters.
Format: ``seq{seq_len}_d{n_embd}_L{n_layer}_lr{dagger_lr}_bs{batch}_eta{eta}_{remask}``
Args:
cfg: Config namespace.
Returns:
A concise, human-readable run name.
"""
parts = [
f"seq{cfg.seq_len}",
f"d{cfg.n_embd}",
f"L{cfg.n_layer}",
f"lr{cfg.dagger_lr:.0e}",
f"bs{cfg.dagger_batch_size}",
f"eta{cfg.eta}",
f"{cfg.remask_strategy}",
]
if cfg.use_importance_weighting:
parts.append("subs")
if getattr(cfg, "physics_aware_sampling", False):
parts.append("phys")
if cfg.seed is not None:
parts.append(f"s{cfg.seed}")
return "_".join(parts)
class Logger:
"""Centralised logger for W&B and stdout.
Args:
cfg: Config namespace with ``use_wandb``, ``wandb_project``,
``wandb_entity``, ``seed``.
"""
def __init__(self, cfg: SimpleNamespace) -> None:
self._use_wandb = cfg.use_wandb
self._run: _WandbRun | None = None
if self._use_wandb:
try:
import wandb
run_name = getattr(cfg, "wandb_run_name", None)
if not run_name:
run_name = _auto_run_name(cfg)
resume_id = getattr(cfg, "wandb_resume_id", None)
self._run = wandb.init(
project=cfg.wandb_project,
entity=cfg.wandb_entity or None,
name=run_name,
config=vars(cfg),
id=resume_id or None,
resume="must" if resume_id else "never",
)
# Define custom metric x-axes
wandb.define_metric("iteration")
for ns in (
"diffusion/*", "train/*", "perf/*", "speed/*",
"model/*",
"eval_id/*", "eval_ood/*",
"curriculum/*",
"ckpt_eval_id/*", "ckpt_eval_ood/*", "ckpt_eval/*",
"inference/*",
):
wandb.define_metric(ns, step_metric="iteration")
except Exception:
logger.error("W&B init failed", exc_info=True)
self._use_wandb = False
def log_summary(self, metrics: dict) -> None:
"""Write key/value pairs to the wandb run summary (final aggregates).
Args:
metrics: Flat ``{key: value}`` dict.
"""
if self._use_wandb and self._run is not None:
try:
self._run.summary.update(metrics)
except Exception:
pass
def log(self, metrics: dict, step: int) -> None:
"""Log a dict of metrics.
Args:
metrics: Flat ``{namespace/key: value}`` dict.
step: Global step index.
"""
if self._use_wandb and self._run is not None:
try:
import wandb
# Include "iteration" so define_metric(step_metric="iteration") works
wandb.log({**metrics, "iteration": step}, step=step)
except Exception:
pass
# Stdout summary every 10 steps
if step % 10 == 0:
parts = [f"step={step}"]
for k, v in metrics.items():
if isinstance(v, float):
if abs(v) < 1e-3 and v != 0.0:
parts.append(f"{k}={v:.2e}")
else:
parts.append(f"{k}={v:.4f}")
else:
parts.append(f"{k}={v}")
logger.info(" ".join(parts))
def log_eval(
self, results: dict[str, dict], step: int, prefix: str,
) -> None:
"""Flatten evaluation results and log them.
Args:
results: ``{env_id: {"win_rate", ...}}``
step: Global step.
prefix: Metric namespace prefix (e.g. ``"eval_id"``).
"""
flat: dict[str, float] = {}
for env_id, stats in results.items():
for key, val in stats.items():
if isinstance(val, (int, float)):
flat[f"{prefix}/{env_id}/{key}"] = val
self.log(flat, step=step)
def log_checkpoint_artifact(
self,
checkpoint_path: str,
config_path: str | None,
iteration: int,
metadata: dict | None = None,
artifact_name: str | None = None,
) -> None:
"""Upload a checkpoint as a W&B artifact with config attached.
Args:
checkpoint_path: Path to the ``.pth`` checkpoint file.
config_path: Path to the YAML config snapshot to attach.
If ``None``, only the checkpoint is uploaded.
iteration: Iteration number (used in the default artifact
name when ``artifact_name`` is not provided).
metadata: Optional metadata dict stored on the artifact.
artifact_name: Optional explicit artifact name. When
``None``, defaults to ``f"checkpoint-iter{iteration}"``.
Offline BC passes a step-based name to avoid the
misleading "iter" prefix.
"""
if not self._use_wandb or self._run is None:
return
try:
import wandb
name = artifact_name or f"checkpoint-iter{iteration}"
artifact = wandb.Artifact(
name=name,
type="model",
metadata=metadata or {},
)
artifact.add_file(checkpoint_path)
if config_path is not None:
artifact.add_file(config_path, name="config.yaml")
logged = self._run.log_artifact(artifact) # type: ignore[union-attr]
logged.wait() # block until upload completes
logger.info("W&B artifact uploaded: %s", name)
except Exception:
logger.error("W&B artifact upload failed", exc_info=True)
def finish(self) -> None:
"""Close the W&B run if active."""
if self._use_wandb and self._run is not None:
try:
import wandb
wandb.finish()
except Exception:
pass
# ---------------------------------------------------------------------------
# Metric helper functions (used by both src/ and experiments/)
# ---------------------------------------------------------------------------
def gpu_memory_mb() -> float:
"""Return peak GPU memory allocated in MB since last reset.
Returns:
Peak memory in MB, or 0.0 if CUDA is unavailable.
"""
if torch.cuda.is_available():
return torch.cuda.max_memory_allocated() / (1024 * 1024)
return 0.0
def reset_gpu_memory_stats() -> None:
"""Reset GPU peak memory stats for the current device."""
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
def compute_param_norm(model: torch.nn.Module) -> float:
"""Compute total L2 norm of all model parameters.
Args:
model: The model.
Returns:
Total L2 norm as a float.
"""
total = 0.0
for p in model.parameters():
total += p.data.norm(2).item() ** 2
return total ** 0.5
def compute_param_drift(
model: torch.nn.Module,
ref_state: dict[str, torch.Tensor],
) -> float:
"""Compute L2 distance between current model params and a reference state.
Args:
model: Current model.
ref_state: Reference state_dict (e.g. pretrained weights).
Returns:
L2 distance as a float.
"""
total = 0.0
for name, p in model.named_parameters():
if name in ref_state:
total += (p.data - ref_state[name]).norm(2).item() ** 2
return total ** 0.5
|