Traffic-Control / utils /metrics.py
Dhaerya's picture
Add files
b00d5d5
"""
Metrics tracking — lightweight replacement for TensorBoard / W&B.
Stores lists of scalar values keyed by metric name, and provides
summary statistics and JSON serialisation.
"""
from __future__ import annotations
import json
from collections import defaultdict
from pathlib import Path
import numpy as np
class MetricsTracker:
"""
Accumulates scalar training metrics across episodes.
Usage::
tracker = MetricsTracker()
tracker.add("episode_reward", -920.3)
tracker.get_mean("episode_reward", last_n=100)
tracker.save("results/metrics.json")
"""
def __init__(self):
self._data: dict[str, list] = defaultdict(list)
# ------------------------------------------------------------------
# Data operations
# ------------------------------------------------------------------
def add(self, name: str, value):
"""Append *value* to the metric called *name*."""
self._data[name].append(value)
def get(self, name: str) -> list:
"""Return all recorded values for *name* (empty list if absent)."""
return list(self._data.get(name, []))
def has(self, name: str) -> bool:
"""True if at least one value for *name* has been recorded."""
return name in self._data and len(self._data[name]) > 0
def get_last(self, name: str, n: int = 1) -> list:
vals = self.get(name)
return vals[-n:]
def get_mean(self, name: str, last_n: int | None = None) -> float:
vals = self.get(name)
if not vals:
return 0.0
if last_n:
vals = vals[-last_n:]
return float(np.mean(vals))
def get_std(self, name: str, last_n: int | None = None) -> float:
vals = self.get(name)
if not vals:
return 0.0
if last_n:
vals = vals[-last_n:]
return float(np.std(vals))
def summary(self, name: str) -> dict:
vals = self.get(name)
if not vals:
return {}
return {
"count": len(vals),
"mean": float(np.mean(vals)),
"std": float(np.std(vals)),
"min": float(np.min(vals)),
"max": float(np.max(vals)),
}
# ------------------------------------------------------------------
# Persistence
# ------------------------------------------------------------------
def save(self, filepath: str | Path):
"""Serialise to JSON."""
filepath = Path(filepath)
filepath.parent.mkdir(parents=True, exist_ok=True)
with open(filepath, "w") as fh:
json.dump({k: list(v) for k, v in self._data.items()}, fh, indent=2)
def load(self, filepath: str | Path):
"""Restore from a previously saved JSON file."""
with open(filepath) as fh:
raw = json.load(fh)
self._data = defaultdict(list, raw)
def reset(self):
"""Clear all accumulated metrics."""
self._data.clear()
# ------------------------------------------------------------------
# Dunder helpers
# ------------------------------------------------------------------
def __repr__(self) -> str:
lines = []
for name, vals in self._data.items():
if vals:
lines.append(
f" {name}: mean={np.mean(vals):.2f} "
f"std={np.std(vals):.2f} n={len(vals)}"
)
return "MetricsTracker(\n" + "\n".join(lines) + "\n)"