File size: 9,474 Bytes
f748552
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
"""Centralised W&B and stdout logging.

Mirrors the Craftax logging conventions with metric namespaces:
``diffusion/``, ``train/``, ``eval_id/``, ``eval_ood/``.
"""

from __future__ import annotations

import logging
import torch
from typing import TYPE_CHECKING
from types import SimpleNamespace

if TYPE_CHECKING:
    from wandb.sdk.wandb_run import Run as _WandbRun

logger = logging.getLogger(__name__)


def download_artifact(
    artifact_ref: str, dst_dir: str = "artifacts",
) -> str | None:
    """Download a W&B artifact via the public API (no active run needed).

    Args:
        artifact_ref: Fully qualified artifact reference, e.g.
            ``"entity/project/checkpoint-iter1000:latest"``.
        dst_dir: Local directory to download into.

    Returns:
        Path to the ``.pth`` file inside the downloaded artifact
        directory, or ``None`` on failure.
    """
    try:
        import wandb
        from pathlib import Path

        api = wandb.Api()
        artifact = api.artifact(artifact_ref)
        artifact_dir = artifact.download(root=dst_dir)
        pth_files = list(Path(artifact_dir).glob("*.pth"))
        if not pth_files:
            logger.error(
                f"No .pth file found in artifact {artifact_ref}"
            )
            return None
        path = str(pth_files[0])
        logger.info(f"Downloaded artifact {artifact_ref} -> {path}")
        return path
    except Exception:
        logger.error(
            f"Failed to download artifact {artifact_ref}",
            exc_info=True,
        )
        return None


def _auto_run_name(cfg: SimpleNamespace) -> str:
    """Generate a descriptive W&B run name from key hyperparameters.

    Format: ``seq{seq_len}_d{n_embd}_L{n_layer}_lr{dagger_lr}_bs{batch}_eta{eta}_{remask}``

    Args:
        cfg: Config namespace.

    Returns:
        A concise, human-readable run name.
    """
    parts = [
        f"seq{cfg.seq_len}",
        f"d{cfg.n_embd}",
        f"L{cfg.n_layer}",
        f"lr{cfg.dagger_lr:.0e}",
        f"bs{cfg.dagger_batch_size}",
        f"eta{cfg.eta}",
        f"{cfg.remask_strategy}",
    ]
    if cfg.use_importance_weighting:
        parts.append("subs")
    if getattr(cfg, "physics_aware_sampling", False):
        parts.append("phys")
    if cfg.seed is not None:
        parts.append(f"s{cfg.seed}")
    return "_".join(parts)


class Logger:
    """Centralised logger for W&B and stdout.

    Args:
        cfg: Config namespace with ``use_wandb``, ``wandb_project``,
            ``wandb_entity``, ``seed``.
    """

    def __init__(self, cfg: SimpleNamespace) -> None:
        self._use_wandb = cfg.use_wandb
        self._run: _WandbRun | None = None
        if self._use_wandb:
            try:
                import wandb
                run_name = getattr(cfg, "wandb_run_name", None)
                if not run_name:
                    run_name = _auto_run_name(cfg)
                resume_id = getattr(cfg, "wandb_resume_id", None)
                self._run = wandb.init(
                    project=cfg.wandb_project,
                    entity=cfg.wandb_entity or None,
                    name=run_name,
                    config=vars(cfg),
                    id=resume_id or None,
                    resume="must" if resume_id else "never",
                )
                # Define custom metric x-axes
                wandb.define_metric("iteration")
                for ns in (
                    "diffusion/*", "train/*", "perf/*", "speed/*",
                    "model/*",
                    "eval_id/*", "eval_ood/*",
                    "curriculum/*",
                    "ckpt_eval_id/*", "ckpt_eval_ood/*", "ckpt_eval/*",
                    "inference/*",
                ):
                    wandb.define_metric(ns, step_metric="iteration")
            except Exception:
                logger.error("W&B init failed", exc_info=True)
                self._use_wandb = False

    def log_summary(self, metrics: dict) -> None:
        """Write key/value pairs to the wandb run summary (final aggregates).

        Args:
            metrics: Flat ``{key: value}`` dict.
        """
        if self._use_wandb and self._run is not None:
            try:
                self._run.summary.update(metrics)
            except Exception:
                pass

    def log(self, metrics: dict, step: int) -> None:
        """Log a dict of metrics.

        Args:
            metrics: Flat ``{namespace/key: value}`` dict.
            step: Global step index.
        """
        if self._use_wandb and self._run is not None:
            try:
                import wandb
                # Include "iteration" so define_metric(step_metric="iteration") works
                wandb.log({**metrics, "iteration": step}, step=step)
            except Exception:
                pass

        # Stdout summary every 10 steps
        if step % 10 == 0:
            parts = [f"step={step}"]
            for k, v in metrics.items():
                if isinstance(v, float):
                    if abs(v) < 1e-3 and v != 0.0:
                        parts.append(f"{k}={v:.2e}")
                    else:
                        parts.append(f"{k}={v:.4f}")
                else:
                    parts.append(f"{k}={v}")
            logger.info("  ".join(parts))

    def log_eval(
        self, results: dict[str, dict], step: int, prefix: str,
    ) -> None:
        """Flatten evaluation results and log them.

        Args:
            results: ``{env_id: {"win_rate", ...}}``
            step: Global step.
            prefix: Metric namespace prefix (e.g. ``"eval_id"``).
        """
        flat: dict[str, float] = {}
        for env_id, stats in results.items():
            for key, val in stats.items():
                if isinstance(val, (int, float)):
                    flat[f"{prefix}/{env_id}/{key}"] = val
        self.log(flat, step=step)

    def log_checkpoint_artifact(
        self,
        checkpoint_path: str,
        config_path: str | None,
        iteration: int,
        metadata: dict | None = None,
        artifact_name: str | None = None,
    ) -> None:
        """Upload a checkpoint as a W&B artifact with config attached.

        Args:
            checkpoint_path: Path to the ``.pth`` checkpoint file.
            config_path: Path to the YAML config snapshot to attach.
                If ``None``, only the checkpoint is uploaded.
            iteration: Iteration number (used in the default artifact
                name when ``artifact_name`` is not provided).
            metadata: Optional metadata dict stored on the artifact.
            artifact_name: Optional explicit artifact name. When
                ``None``, defaults to ``f"checkpoint-iter{iteration}"``.
                Offline BC passes a step-based name to avoid the
                misleading "iter" prefix.
        """
        if not self._use_wandb or self._run is None:
            return
        try:
            import wandb

            name = artifact_name or f"checkpoint-iter{iteration}"
            artifact = wandb.Artifact(
                name=name,
                type="model",
                metadata=metadata or {},
            )
            artifact.add_file(checkpoint_path)
            if config_path is not None:
                artifact.add_file(config_path, name="config.yaml")
            logged = self._run.log_artifact(artifact)  # type: ignore[union-attr]
            logged.wait()  # block until upload completes
            logger.info("W&B artifact uploaded: %s", name)
        except Exception:
            logger.error("W&B artifact upload failed", exc_info=True)

    def finish(self) -> None:
        """Close the W&B run if active."""
        if self._use_wandb and self._run is not None:
            try:
                import wandb
                wandb.finish()
            except Exception:
                pass


# ---------------------------------------------------------------------------
# Metric helper functions (used by both src/ and experiments/)
# ---------------------------------------------------------------------------


def gpu_memory_mb() -> float:
    """Return peak GPU memory allocated in MB since last reset.

    Returns:
        Peak memory in MB, or 0.0 if CUDA is unavailable.
    """
    if torch.cuda.is_available():
        return torch.cuda.max_memory_allocated() / (1024 * 1024)
    return 0.0


def reset_gpu_memory_stats() -> None:
    """Reset GPU peak memory stats for the current device."""
    if torch.cuda.is_available():
        torch.cuda.reset_peak_memory_stats()


def compute_param_norm(model: torch.nn.Module) -> float:
    """Compute total L2 norm of all model parameters.

    Args:
        model: The model.

    Returns:
        Total L2 norm as a float.
    """
    total = 0.0
    for p in model.parameters():
        total += p.data.norm(2).item() ** 2
    return total ** 0.5


def compute_param_drift(
    model: torch.nn.Module,
    ref_state: dict[str, torch.Tensor],
) -> float:
    """Compute L2 distance between current model params and a reference state.

    Args:
        model: Current model.
        ref_state: Reference state_dict (e.g. pretrained weights).

    Returns:
        L2 distance as a float.
    """
    total = 0.0
    for name, p in model.named_parameters():
        if name in ref_state:
            total += (p.data - ref_state[name]).norm(2).item() ** 2
    return total ** 0.5