""" Metrics tracking — lightweight replacement for TensorBoard / W&B. Stores lists of scalar values keyed by metric name, and provides summary statistics and JSON serialisation. """ from __future__ import annotations import json from collections import defaultdict from pathlib import Path import numpy as np class MetricsTracker: """ Accumulates scalar training metrics across episodes. Usage:: tracker = MetricsTracker() tracker.add("episode_reward", -920.3) tracker.get_mean("episode_reward", last_n=100) tracker.save("results/metrics.json") """ def __init__(self): self._data: dict[str, list] = defaultdict(list) # ------------------------------------------------------------------ # Data operations # ------------------------------------------------------------------ def add(self, name: str, value): """Append *value* to the metric called *name*.""" self._data[name].append(value) def get(self, name: str) -> list: """Return all recorded values for *name* (empty list if absent).""" return list(self._data.get(name, [])) def has(self, name: str) -> bool: """True if at least one value for *name* has been recorded.""" return name in self._data and len(self._data[name]) > 0 def get_last(self, name: str, n: int = 1) -> list: vals = self.get(name) return vals[-n:] def get_mean(self, name: str, last_n: int | None = None) -> float: vals = self.get(name) if not vals: return 0.0 if last_n: vals = vals[-last_n:] return float(np.mean(vals)) def get_std(self, name: str, last_n: int | None = None) -> float: vals = self.get(name) if not vals: return 0.0 if last_n: vals = vals[-last_n:] return float(np.std(vals)) def summary(self, name: str) -> dict: vals = self.get(name) if not vals: return {} return { "count": len(vals), "mean": float(np.mean(vals)), "std": float(np.std(vals)), "min": float(np.min(vals)), "max": float(np.max(vals)), } # ------------------------------------------------------------------ # Persistence # ------------------------------------------------------------------ def save(self, filepath: str | Path): """Serialise to JSON.""" filepath = Path(filepath) filepath.parent.mkdir(parents=True, exist_ok=True) with open(filepath, "w") as fh: json.dump({k: list(v) for k, v in self._data.items()}, fh, indent=2) def load(self, filepath: str | Path): """Restore from a previously saved JSON file.""" with open(filepath) as fh: raw = json.load(fh) self._data = defaultdict(list, raw) def reset(self): """Clear all accumulated metrics.""" self._data.clear() # ------------------------------------------------------------------ # Dunder helpers # ------------------------------------------------------------------ def __repr__(self) -> str: lines = [] for name, vals in self._data.items(): if vals: lines.append( f" {name}: mean={np.mean(vals):.2f} " f"std={np.std(vals):.2f} n={len(vals)}" ) return "MetricsTracker(\n" + "\n".join(lines) + "\n)"