|
|
from __future__ import annotations |
|
|
|
|
|
import contextlib |
|
|
import time |
|
|
from pathlib import Path |
|
|
from typing import Dict, Optional |
|
|
|
|
|
|
|
|
class TrainingTracker: |
|
|
""" |
|
|
Lightweight tracker inspired by the minimcpm-audio training workflow. |
|
|
|
|
|
It keeps track of the current global step, prints rank-aware messages, |
|
|
optionally writes to TensorBoard via a provided writer, and stores progress |
|
|
in a logfile for later inspection. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
*, |
|
|
writer=None, |
|
|
log_file: Optional[str] = None, |
|
|
rank: int = 0, |
|
|
): |
|
|
self.writer = writer |
|
|
self.log_file = Path(log_file) if log_file else None |
|
|
if self.log_file: |
|
|
self.log_file.parent.mkdir(parents=True, exist_ok=True) |
|
|
self.rank = rank |
|
|
self.step = 0 |
|
|
|
|
|
self._last_log_time: float | None = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def print(self, message: str): |
|
|
if self.rank == 0: |
|
|
print(message, flush=True) |
|
|
if self.log_file: |
|
|
with self.log_file.open("a", encoding="utf-8") as f: |
|
|
f.write(message + "\n") |
|
|
|
|
|
def log_metrics(self, metrics: Dict[str, float], split: str): |
|
|
if self.rank == 0: |
|
|
now = time.time() |
|
|
dt_str = "" |
|
|
if self._last_log_time is not None: |
|
|
dt = now - self._last_log_time |
|
|
dt_str = f", log interval: {dt:.2f}s" |
|
|
self._last_log_time = now |
|
|
|
|
|
formatted = ", ".join(f"{k}: {v:.6f}" for k, v in metrics.items()) |
|
|
self.print(f"[{split}] step {self.step}: {formatted}{dt_str}") |
|
|
if self.writer is not None: |
|
|
for key, value in metrics.items(): |
|
|
if isinstance(value, (int, float)): |
|
|
self.writer.add_scalar(f"{split}/{key}", value, self.step) |
|
|
|
|
|
def done(self, split: str, message: str): |
|
|
self.print(f"[{split}] {message}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def state_dict(self): |
|
|
return {"step": self.step} |
|
|
|
|
|
def load_state_dict(self, state): |
|
|
self.step = int(state.get("step", 0)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@contextlib.contextmanager |
|
|
def live(self): |
|
|
yield |
|
|
|
|
|
|