Prompt_Squirrel_RAG / scripts /live_train_display.py
Food Desert
Roll out T5 rewrite updates, tooling, docs, and artifact ignore rules
34c53b5
from __future__ import annotations
import argparse
import json
import time
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Dict, Optional
REPO_ROOT = Path(__file__).resolve().parents[1]
def _fmt_secs(val: Optional[float]) -> str:
if val is None:
return "n/a"
s = int(max(0, float(val)))
h, rem = divmod(s, 3600)
m, sec = divmod(rem, 60)
if h:
return f"{h:02d}:{m:02d}:{sec:02d}"
return f"{m:02d}:{sec:02d}"
def _load_json(path: Path) -> Optional[Dict[str, Any]]:
if not path.is_file():
return None
try:
return json.loads(path.read_text(encoding="utf-8"))
except Exception:
return None
def _iso_now() -> str:
return datetime.now(timezone.utc).astimezone().strftime("%Y-%m-%d %H:%M:%S")
def main() -> int:
ap = argparse.ArgumentParser(description="Write a human-readable live training display file from progress JSON.")
ap.add_argument(
"--progress-file",
type=Path,
default=REPO_ROOT / "data" / "runtime_metrics" / "t5_rewrite_train_progress_run1.json",
)
ap.add_argument(
"--display-file",
type=Path,
default=REPO_ROOT / "data" / "runtime_metrics" / "t5_rewrite_train_display_run1.txt",
)
ap.add_argument("--train-batch-size", type=int, required=True)
ap.add_argument("--grad-accum", type=int, required=True)
ap.add_argument("--interval", type=float, default=30.0)
ap.add_argument("--once", action="store_true", default=False)
args = ap.parse_args()
progress_path = args.progress_file if args.progress_file.is_absolute() else (REPO_ROOT / args.progress_file).resolve()
display_path = args.display_file if args.display_file.is_absolute() else (REPO_ROOT / args.display_file).resolve()
display_path.parent.mkdir(parents=True, exist_ok=True)
eff_batch = max(1, int(args.train_batch_size)) * max(1, int(args.grad_accum))
while True:
obj = _load_json(progress_path)
if obj is None:
line = f"{_iso_now()} | waiting_for_progress_file={progress_path}"
else:
step = int(obj.get("global_step", 0) or 0)
max_steps = int(obj.get("max_steps", 0) or 0)
status = str(obj.get("status", "unknown"))
pct = obj.get("pct_complete", None)
if isinstance(pct, (int, float)):
pct_txt = f"{float(pct):6.2f}%"
else:
pct_txt = " n/a "
processed = step * eff_batch
total = max_steps * eff_batch if max_steps > 0 else 0
eta = _fmt_secs(obj.get("eta_sec"))
elapsed = _fmt_secs(obj.get("elapsed_sec"))
last_log = obj.get("last_log") or {}
extra = ""
if isinstance(last_log, dict):
if "eval_val_set_recall" in last_log:
extra = f" | eval_val_recall={last_log['eval_val_set_recall']:.4f}"
elif "eval_set_recall" in last_log:
extra = f" | eval_recall={last_log['eval_set_recall']:.4f}"
elif "loss" in last_log:
extra = f" | train_loss={float(last_log['loss']):.4f}"
line = (
f"{_iso_now()} | status={status} | step={step}/{max_steps} | "
f"processed_images={processed}/{total} | pct={pct_txt} | "
f"elapsed={elapsed} | eta={eta}{extra}"
)
with display_path.open("a", encoding="utf-8") as f:
f.write(line + "\n")
print(line)
if args.once:
break
time.sleep(max(2.0, float(args.interval)))
return 0
if __name__ == "__main__":
raise SystemExit(main())