| """Centralised W&B and stdout logging. |
| |
| Mirrors the Craftax logging conventions with metric namespaces: |
| ``diffusion/``, ``train/``, ``eval_id/``, ``eval_ood/``. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import logging |
| import torch |
| from typing import TYPE_CHECKING |
| from types import SimpleNamespace |
|
|
| if TYPE_CHECKING: |
| from wandb.sdk.wandb_run import Run as _WandbRun |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| def download_artifact( |
| artifact_ref: str, dst_dir: str = "artifacts", |
| ) -> str | None: |
| """Download a W&B artifact via the public API (no active run needed). |
| |
| Args: |
| artifact_ref: Fully qualified artifact reference, e.g. |
| ``"entity/project/checkpoint-iter1000:latest"``. |
| dst_dir: Local directory to download into. |
| |
| Returns: |
| Path to the ``.pth`` file inside the downloaded artifact |
| directory, or ``None`` on failure. |
| """ |
| try: |
| import wandb |
| from pathlib import Path |
|
|
| api = wandb.Api() |
| artifact = api.artifact(artifact_ref) |
| artifact_dir = artifact.download(root=dst_dir) |
| pth_files = list(Path(artifact_dir).glob("*.pth")) |
| if not pth_files: |
| logger.error( |
| f"No .pth file found in artifact {artifact_ref}" |
| ) |
| return None |
| path = str(pth_files[0]) |
| logger.info(f"Downloaded artifact {artifact_ref} -> {path}") |
| return path |
| except Exception: |
| logger.error( |
| f"Failed to download artifact {artifact_ref}", |
| exc_info=True, |
| ) |
| return None |
|
|
|
|
| def _auto_run_name(cfg: SimpleNamespace) -> str: |
| """Generate a descriptive W&B run name from key hyperparameters. |
| |
| Format: ``seq{seq_len}_d{n_embd}_L{n_layer}_lr{dagger_lr}_bs{batch}_eta{eta}_{remask}`` |
| |
| Args: |
| cfg: Config namespace. |
| |
| Returns: |
| A concise, human-readable run name. |
| """ |
| parts = [ |
| f"seq{cfg.seq_len}", |
| f"d{cfg.n_embd}", |
| f"L{cfg.n_layer}", |
| f"lr{cfg.dagger_lr:.0e}", |
| f"bs{cfg.dagger_batch_size}", |
| f"eta{cfg.eta}", |
| f"{cfg.remask_strategy}", |
| ] |
| if cfg.use_importance_weighting: |
| parts.append("subs") |
| if getattr(cfg, "physics_aware_sampling", False): |
| parts.append("phys") |
| if cfg.seed is not None: |
| parts.append(f"s{cfg.seed}") |
| return "_".join(parts) |
|
|
|
|
| class Logger: |
| """Centralised logger for W&B and stdout. |
| |
| Args: |
| cfg: Config namespace with ``use_wandb``, ``wandb_project``, |
| ``wandb_entity``, ``seed``. |
| """ |
|
|
| def __init__(self, cfg: SimpleNamespace) -> None: |
| self._use_wandb = cfg.use_wandb |
| self._run: _WandbRun | None = None |
| if self._use_wandb: |
| try: |
| import wandb |
| run_name = getattr(cfg, "wandb_run_name", None) |
| if not run_name: |
| run_name = _auto_run_name(cfg) |
| resume_id = getattr(cfg, "wandb_resume_id", None) |
| self._run = wandb.init( |
| project=cfg.wandb_project, |
| entity=cfg.wandb_entity or None, |
| name=run_name, |
| config=vars(cfg), |
| id=resume_id or None, |
| resume="must" if resume_id else "never", |
| ) |
| |
| wandb.define_metric("iteration") |
| for ns in ( |
| "diffusion/*", "train/*", "perf/*", "speed/*", |
| "model/*", |
| "eval_id/*", "eval_ood/*", |
| "curriculum/*", |
| "ckpt_eval_id/*", "ckpt_eval_ood/*", "ckpt_eval/*", |
| "inference/*", |
| ): |
| wandb.define_metric(ns, step_metric="iteration") |
| except Exception: |
| logger.error("W&B init failed", exc_info=True) |
| self._use_wandb = False |
|
|
| def log_summary(self, metrics: dict) -> None: |
| """Write key/value pairs to the wandb run summary (final aggregates). |
| |
| Args: |
| metrics: Flat ``{key: value}`` dict. |
| """ |
| if self._use_wandb and self._run is not None: |
| try: |
| self._run.summary.update(metrics) |
| except Exception: |
| pass |
|
|
| def log(self, metrics: dict, step: int) -> None: |
| """Log a dict of metrics. |
| |
| Args: |
| metrics: Flat ``{namespace/key: value}`` dict. |
| step: Global step index. |
| """ |
| if self._use_wandb and self._run is not None: |
| try: |
| import wandb |
| |
| wandb.log({**metrics, "iteration": step}, step=step) |
| except Exception: |
| pass |
|
|
| |
| if step % 10 == 0: |
| parts = [f"step={step}"] |
| for k, v in metrics.items(): |
| if isinstance(v, float): |
| if abs(v) < 1e-3 and v != 0.0: |
| parts.append(f"{k}={v:.2e}") |
| else: |
| parts.append(f"{k}={v:.4f}") |
| else: |
| parts.append(f"{k}={v}") |
| logger.info(" ".join(parts)) |
|
|
| def log_eval( |
| self, results: dict[str, dict], step: int, prefix: str, |
| ) -> None: |
| """Flatten evaluation results and log them. |
| |
| Args: |
| results: ``{env_id: {"win_rate", ...}}`` |
| step: Global step. |
| prefix: Metric namespace prefix (e.g. ``"eval_id"``). |
| """ |
| flat: dict[str, float] = {} |
| for env_id, stats in results.items(): |
| for key, val in stats.items(): |
| if isinstance(val, (int, float)): |
| flat[f"{prefix}/{env_id}/{key}"] = val |
| self.log(flat, step=step) |
|
|
| def log_checkpoint_artifact( |
| self, |
| checkpoint_path: str, |
| config_path: str | None, |
| iteration: int, |
| metadata: dict | None = None, |
| artifact_name: str | None = None, |
| ) -> None: |
| """Upload a checkpoint as a W&B artifact with config attached. |
| |
| Args: |
| checkpoint_path: Path to the ``.pth`` checkpoint file. |
| config_path: Path to the YAML config snapshot to attach. |
| If ``None``, only the checkpoint is uploaded. |
| iteration: Iteration number (used in the default artifact |
| name when ``artifact_name`` is not provided). |
| metadata: Optional metadata dict stored on the artifact. |
| artifact_name: Optional explicit artifact name. When |
| ``None``, defaults to ``f"checkpoint-iter{iteration}"``. |
| Offline BC passes a step-based name to avoid the |
| misleading "iter" prefix. |
| """ |
| if not self._use_wandb or self._run is None: |
| return |
| try: |
| import wandb |
|
|
| name = artifact_name or f"checkpoint-iter{iteration}" |
| artifact = wandb.Artifact( |
| name=name, |
| type="model", |
| metadata=metadata or {}, |
| ) |
| artifact.add_file(checkpoint_path) |
| if config_path is not None: |
| artifact.add_file(config_path, name="config.yaml") |
| logged = self._run.log_artifact(artifact) |
| logged.wait() |
| logger.info("W&B artifact uploaded: %s", name) |
| except Exception: |
| logger.error("W&B artifact upload failed", exc_info=True) |
|
|
| def finish(self) -> None: |
| """Close the W&B run if active.""" |
| if self._use_wandb and self._run is not None: |
| try: |
| import wandb |
| wandb.finish() |
| except Exception: |
| pass |
|
|
|
|
| |
| |
| |
|
|
|
|
| def gpu_memory_mb() -> float: |
| """Return peak GPU memory allocated in MB since last reset. |
| |
| Returns: |
| Peak memory in MB, or 0.0 if CUDA is unavailable. |
| """ |
| if torch.cuda.is_available(): |
| return torch.cuda.max_memory_allocated() / (1024 * 1024) |
| return 0.0 |
|
|
|
|
| def reset_gpu_memory_stats() -> None: |
| """Reset GPU peak memory stats for the current device.""" |
| if torch.cuda.is_available(): |
| torch.cuda.reset_peak_memory_stats() |
|
|
|
|
| def compute_param_norm(model: torch.nn.Module) -> float: |
| """Compute total L2 norm of all model parameters. |
| |
| Args: |
| model: The model. |
| |
| Returns: |
| Total L2 norm as a float. |
| """ |
| total = 0.0 |
| for p in model.parameters(): |
| total += p.data.norm(2).item() ** 2 |
| return total ** 0.5 |
|
|
|
|
| def compute_param_drift( |
| model: torch.nn.Module, |
| ref_state: dict[str, torch.Tensor], |
| ) -> float: |
| """Compute L2 distance between current model params and a reference state. |
| |
| Args: |
| model: Current model. |
| ref_state: Reference state_dict (e.g. pretrained weights). |
| |
| Returns: |
| L2 distance as a float. |
| """ |
| total = 0.0 |
| for name, p in model.named_parameters(): |
| if name in ref_state: |
| total += (p.data - ref_state[name]).norm(2).item() ** 2 |
| return total ** 0.5 |
|
|