Spaces:
Sleeping
Sleeping
| """ | |
| Metrics tracking — lightweight replacement for TensorBoard / W&B. | |
| Stores lists of scalar values keyed by metric name, and provides | |
| summary statistics and JSON serialisation. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| from collections import defaultdict | |
| from pathlib import Path | |
| import numpy as np | |
| class MetricsTracker: | |
| """ | |
| Accumulates scalar training metrics across episodes. | |
| Usage:: | |
| tracker = MetricsTracker() | |
| tracker.add("episode_reward", -920.3) | |
| tracker.get_mean("episode_reward", last_n=100) | |
| tracker.save("results/metrics.json") | |
| """ | |
| def __init__(self): | |
| self._data: dict[str, list] = defaultdict(list) | |
| # ------------------------------------------------------------------ | |
| # Data operations | |
| # ------------------------------------------------------------------ | |
| def add(self, name: str, value): | |
| """Append *value* to the metric called *name*.""" | |
| self._data[name].append(value) | |
| def get(self, name: str) -> list: | |
| """Return all recorded values for *name* (empty list if absent).""" | |
| return list(self._data.get(name, [])) | |
| def has(self, name: str) -> bool: | |
| """True if at least one value for *name* has been recorded.""" | |
| return name in self._data and len(self._data[name]) > 0 | |
| def get_last(self, name: str, n: int = 1) -> list: | |
| vals = self.get(name) | |
| return vals[-n:] | |
| def get_mean(self, name: str, last_n: int | None = None) -> float: | |
| vals = self.get(name) | |
| if not vals: | |
| return 0.0 | |
| if last_n: | |
| vals = vals[-last_n:] | |
| return float(np.mean(vals)) | |
| def get_std(self, name: str, last_n: int | None = None) -> float: | |
| vals = self.get(name) | |
| if not vals: | |
| return 0.0 | |
| if last_n: | |
| vals = vals[-last_n:] | |
| return float(np.std(vals)) | |
| def summary(self, name: str) -> dict: | |
| vals = self.get(name) | |
| if not vals: | |
| return {} | |
| return { | |
| "count": len(vals), | |
| "mean": float(np.mean(vals)), | |
| "std": float(np.std(vals)), | |
| "min": float(np.min(vals)), | |
| "max": float(np.max(vals)), | |
| } | |
| # ------------------------------------------------------------------ | |
| # Persistence | |
| # ------------------------------------------------------------------ | |
| def save(self, filepath: str | Path): | |
| """Serialise to JSON.""" | |
| filepath = Path(filepath) | |
| filepath.parent.mkdir(parents=True, exist_ok=True) | |
| with open(filepath, "w") as fh: | |
| json.dump({k: list(v) for k, v in self._data.items()}, fh, indent=2) | |
| def load(self, filepath: str | Path): | |
| """Restore from a previously saved JSON file.""" | |
| with open(filepath) as fh: | |
| raw = json.load(fh) | |
| self._data = defaultdict(list, raw) | |
| def reset(self): | |
| """Clear all accumulated metrics.""" | |
| self._data.clear() | |
| # ------------------------------------------------------------------ | |
| # Dunder helpers | |
| # ------------------------------------------------------------------ | |
| def __repr__(self) -> str: | |
| lines = [] | |
| for name, vals in self._data.items(): | |
| if vals: | |
| lines.append( | |
| f" {name}: mean={np.mean(vals):.2f} " | |
| f"std={np.std(vals):.2f} n={len(vals)}" | |
| ) | |
| return "MetricsTracker(\n" + "\n".join(lines) + "\n)" | |