Codette-Reasoning / observatory /metrics_logger.py
Raiff1982's picture
Upload 120 files
ed1b365 verified
"""
Metrics Logger - thread-safe logging of training metrics to a JSON file.
Each entry records: timestamp, adapter name, dataset size, dataset version,
reasoning score, loss, epoch, and training parameters.
"""
from __future__ import annotations
import json
import os
import threading
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional
_DEFAULT_LOG_FILE = Path(__file__).resolve().parent.parent / "observatory_metrics.json"
class MetricsLogger:
"""Thread-safe logger for training run metrics."""
def __init__(self, log_file: Optional[str] = None):
self.log_file = Path(log_file) if log_file else _DEFAULT_LOG_FILE
self._lock = threading.Lock()
self._ensure_file()
# -- internal ----------------------------------------------------------
def _ensure_file(self) -> None:
"""Create the log file with an empty list if it doesn't exist."""
if not self.log_file.exists():
os.makedirs(self.log_file.parent, exist_ok=True)
with open(self.log_file, "w", encoding="utf-8") as f:
json.dump([], f)
def _read_all(self) -> List[Dict[str, Any]]:
"""Read all entries from the log file."""
with open(self.log_file, "r", encoding="utf-8") as f:
try:
data = json.load(f)
except json.JSONDecodeError:
data = []
if not isinstance(data, list):
data = []
return data
def _write_all(self, entries: List[Dict[str, Any]]) -> None:
"""Write all entries back to the log file."""
with open(self.log_file, "w", encoding="utf-8") as f:
json.dump(entries, f, indent=2, default=str)
# -- public API --------------------------------------------------------
def log(
self,
adapter: str,
dataset_size: int,
dataset_version: str,
reasoning_score: float,
loss: float,
epoch: int,
training_params: Optional[Dict[str, Any]] = None,
extra: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""Log a single training run metric entry.
Returns the logged entry dict.
"""
entry: Dict[str, Any] = {
"timestamp": datetime.utcnow().isoformat() + "Z",
"adapter": adapter,
"dataset_size": dataset_size,
"dataset_version": dataset_version,
"reasoning_score": round(reasoning_score, 6),
"loss": round(loss, 6),
"epoch": epoch,
"training_params": training_params or {},
}
if extra:
entry["extra"] = extra
with self._lock:
entries = self._read_all()
entries.append(entry)
self._write_all(entries)
return entry
def log_batch(self, entries: List[Dict[str, Any]]) -> int:
"""Log multiple entries at once. Each entry should have the same
keys as the arguments to log(). Returns number of entries added."""
formatted: List[Dict[str, Any]] = []
for e in entries:
formatted.append({
"timestamp": e.get("timestamp", datetime.utcnow().isoformat() + "Z"),
"adapter": e.get("adapter", "unknown"),
"dataset_size": e.get("dataset_size", 0),
"dataset_version": e.get("dataset_version", "unknown"),
"reasoning_score": round(e.get("reasoning_score", 0.0), 6),
"loss": round(e.get("loss", 0.0), 6),
"epoch": e.get("epoch", 0),
"training_params": e.get("training_params", {}),
})
with self._lock:
existing = self._read_all()
existing.extend(formatted)
self._write_all(existing)
return len(formatted)
def get_all(self) -> List[Dict[str, Any]]:
"""Return all logged entries."""
with self._lock:
return self._read_all()
def get_by_adapter(self, adapter: str) -> List[Dict[str, Any]]:
"""Return entries filtered by adapter name."""
entries = self.get_all()
return [e for e in entries if e.get("adapter") == adapter]
def get_by_date_range(
self,
start: Optional[str] = None,
end: Optional[str] = None,
) -> List[Dict[str, Any]]:
"""Return entries within a date range (ISO format strings).
Args:
start: ISO date/datetime string (inclusive). None = no lower bound.
end: ISO date/datetime string (inclusive). None = no upper bound.
"""
entries = self.get_all()
filtered = []
for e in entries:
ts = e.get("timestamp", "")
if start and ts < start:
continue
if end and ts > end:
continue
filtered.append(e)
return filtered
def get_latest(self, adapter: Optional[str] = None) -> Optional[Dict[str, Any]]:
"""Return the most recent entry, optionally filtered by adapter."""
entries = self.get_by_adapter(adapter) if adapter else self.get_all()
if not entries:
return None
return max(entries, key=lambda e: e.get("timestamp", ""))
def get_unique_adapters(self) -> List[str]:
"""Return list of unique adapter names in the log."""
entries = self.get_all()
seen = set()
adapters = []
for e in entries:
name = e.get("adapter", "unknown")
if name not in seen:
seen.add(name)
adapters.append(name)
return adapters
def count(self) -> int:
"""Return total number of logged entries."""
return len(self.get_all())
def clear(self) -> int:
"""Clear all entries. Returns number of entries removed."""
with self._lock:
entries = self._read_all()
count = len(entries)
self._write_all([])
return count