File size: 3,729 Bytes
34c53b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from __future__ import annotations

import argparse
import json
import time
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Dict, Optional

REPO_ROOT = Path(__file__).resolve().parents[1]


def _fmt_secs(val: Optional[float]) -> str:
    if val is None:
        return "n/a"
    s = int(max(0, float(val)))
    h, rem = divmod(s, 3600)
    m, sec = divmod(rem, 60)
    if h:
        return f"{h:02d}:{m:02d}:{sec:02d}"
    return f"{m:02d}:{sec:02d}"


def _load_json(path: Path) -> Optional[Dict[str, Any]]:
    if not path.is_file():
        return None
    try:
        return json.loads(path.read_text(encoding="utf-8"))
    except Exception:
        return None


def _iso_now() -> str:
    return datetime.now(timezone.utc).astimezone().strftime("%Y-%m-%d %H:%M:%S")


def main() -> int:
    ap = argparse.ArgumentParser(description="Write a human-readable live training display file from progress JSON.")
    ap.add_argument(
        "--progress-file",
        type=Path,
        default=REPO_ROOT / "data" / "runtime_metrics" / "t5_rewrite_train_progress_run1.json",
    )
    ap.add_argument(
        "--display-file",
        type=Path,
        default=REPO_ROOT / "data" / "runtime_metrics" / "t5_rewrite_train_display_run1.txt",
    )
    ap.add_argument("--train-batch-size", type=int, required=True)
    ap.add_argument("--grad-accum", type=int, required=True)
    ap.add_argument("--interval", type=float, default=30.0)
    ap.add_argument("--once", action="store_true", default=False)
    args = ap.parse_args()

    progress_path = args.progress_file if args.progress_file.is_absolute() else (REPO_ROOT / args.progress_file).resolve()
    display_path = args.display_file if args.display_file.is_absolute() else (REPO_ROOT / args.display_file).resolve()
    display_path.parent.mkdir(parents=True, exist_ok=True)

    eff_batch = max(1, int(args.train_batch_size)) * max(1, int(args.grad_accum))
    while True:
        obj = _load_json(progress_path)
        if obj is None:
            line = f"{_iso_now()} | waiting_for_progress_file={progress_path}"
        else:
            step = int(obj.get("global_step", 0) or 0)
            max_steps = int(obj.get("max_steps", 0) or 0)
            status = str(obj.get("status", "unknown"))
            pct = obj.get("pct_complete", None)
            if isinstance(pct, (int, float)):
                pct_txt = f"{float(pct):6.2f}%"
            else:
                pct_txt = "  n/a "
            processed = step * eff_batch
            total = max_steps * eff_batch if max_steps > 0 else 0
            eta = _fmt_secs(obj.get("eta_sec"))
            elapsed = _fmt_secs(obj.get("elapsed_sec"))
            last_log = obj.get("last_log") or {}
            extra = ""
            if isinstance(last_log, dict):
                if "eval_val_set_recall" in last_log:
                    extra = f" | eval_val_recall={last_log['eval_val_set_recall']:.4f}"
                elif "eval_set_recall" in last_log:
                    extra = f" | eval_recall={last_log['eval_set_recall']:.4f}"
                elif "loss" in last_log:
                    extra = f" | train_loss={float(last_log['loss']):.4f}"
            line = (
                f"{_iso_now()} | status={status} | step={step}/{max_steps} | "
                f"processed_images={processed}/{total} | pct={pct_txt} | "
                f"elapsed={elapsed} | eta={eta}{extra}"
            )

        with display_path.open("a", encoding="utf-8") as f:
            f.write(line + "\n")
        print(line)

        if args.once:
            break
        time.sleep(max(2.0, float(args.interval)))

    return 0


if __name__ == "__main__":
    raise SystemExit(main())