from __future__ import annotations import argparse import json import time from datetime import datetime, timezone from pathlib import Path from typing import Any, Dict, Optional REPO_ROOT = Path(__file__).resolve().parents[1] def _fmt_secs(val: Optional[float]) -> str: if val is None: return "n/a" s = int(max(0, float(val))) h, rem = divmod(s, 3600) m, sec = divmod(rem, 60) if h: return f"{h:02d}:{m:02d}:{sec:02d}" return f"{m:02d}:{sec:02d}" def _load_json(path: Path) -> Optional[Dict[str, Any]]: if not path.is_file(): return None try: return json.loads(path.read_text(encoding="utf-8")) except Exception: return None def _iso_now() -> str: return datetime.now(timezone.utc).astimezone().strftime("%Y-%m-%d %H:%M:%S") def main() -> int: ap = argparse.ArgumentParser(description="Write a human-readable live training display file from progress JSON.") ap.add_argument( "--progress-file", type=Path, default=REPO_ROOT / "data" / "runtime_metrics" / "t5_rewrite_train_progress_run1.json", ) ap.add_argument( "--display-file", type=Path, default=REPO_ROOT / "data" / "runtime_metrics" / "t5_rewrite_train_display_run1.txt", ) ap.add_argument("--train-batch-size", type=int, required=True) ap.add_argument("--grad-accum", type=int, required=True) ap.add_argument("--interval", type=float, default=30.0) ap.add_argument("--once", action="store_true", default=False) args = ap.parse_args() progress_path = args.progress_file if args.progress_file.is_absolute() else (REPO_ROOT / args.progress_file).resolve() display_path = args.display_file if args.display_file.is_absolute() else (REPO_ROOT / args.display_file).resolve() display_path.parent.mkdir(parents=True, exist_ok=True) eff_batch = max(1, int(args.train_batch_size)) * max(1, int(args.grad_accum)) while True: obj = _load_json(progress_path) if obj is None: line = f"{_iso_now()} | waiting_for_progress_file={progress_path}" else: step = int(obj.get("global_step", 0) or 0) max_steps = int(obj.get("max_steps", 0) or 0) status = str(obj.get("status", "unknown")) pct = obj.get("pct_complete", None) if isinstance(pct, (int, float)): pct_txt = f"{float(pct):6.2f}%" else: pct_txt = " n/a " processed = step * eff_batch total = max_steps * eff_batch if max_steps > 0 else 0 eta = _fmt_secs(obj.get("eta_sec")) elapsed = _fmt_secs(obj.get("elapsed_sec")) last_log = obj.get("last_log") or {} extra = "" if isinstance(last_log, dict): if "eval_val_set_recall" in last_log: extra = f" | eval_val_recall={last_log['eval_val_set_recall']:.4f}" elif "eval_set_recall" in last_log: extra = f" | eval_recall={last_log['eval_set_recall']:.4f}" elif "loss" in last_log: extra = f" | train_loss={float(last_log['loss']):.4f}" line = ( f"{_iso_now()} | status={status} | step={step}/{max_steps} | " f"processed_images={processed}/{total} | pct={pct_txt} | " f"elapsed={elapsed} | eta={eta}{extra}" ) with display_path.open("a", encoding="utf-8") as f: f.write(line + "\n") print(line) if args.once: break time.sleep(max(2.0, float(args.interval))) return 0 if __name__ == "__main__": raise SystemExit(main())