File size: 3,162 Bytes
099bec8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
"""Monitor an HF Jobs training run and persist parsed metrics to outputs/job_logs/.

Usage:
    HF_TOKEN=hf_... python scripts/monitor_training.py <job_id> [duration_seconds]

Writes:
    outputs/job_logs/<job_id>_metrics.jsonl  (one parsed step per line)
    outputs/job_logs/<job_id>_summary.json   (aggregate summary)
"""
from __future__ import annotations

import ast
import json
import os
import re
import sys
import threading
from pathlib import Path

import truststore

truststore.inject_into_ssl()
from huggingface_hub import HfApi  # noqa: E402

JOB_ID = sys.argv[1] if len(sys.argv) > 1 else "69ece976d2c8bd8662bcdf48"
DURATION_S = int(sys.argv[2]) if len(sys.argv) > 2 else 60
NAMESPACE = os.environ.get("HF_NAMESPACE", "agarwalanu3103")

api = HfApi(token=os.environ["HF_TOKEN"])
out_dir = Path("outputs/job_logs")
out_dir.mkdir(parents=True, exist_ok=True)
metrics_path = out_dir / f"{JOB_ID}_metrics.jsonl"
summary_path = out_dir / f"{JOB_ID}_summary.json"

raw_log_path = out_dir / f"{JOB_ID}_raw.log"

DICT_RE = re.compile(r"^\s*\{.*'loss':.*'reward':.*\}\s*$")
captured: list[str] = []
done = threading.Event()

def reader():
    try:
        with raw_log_path.open("a") as raw:
            for log in api.fetch_job_logs(job_id=JOB_ID, namespace=NAMESPACE, follow=True):
                raw.write(str(log) + "\n")
                if done.is_set():
                    break
                captured.append(str(log))
    except Exception as exc:
        captured.append(f"### Error: {exc}")
    finally:
        done.set()

t = threading.Thread(target=reader, daemon=True)
t.start()
t.join(timeout=DURATION_S)
done.set()

steps: list[dict] = []
seen_lines: set[str] = set()
if metrics_path.exists():
    for ln in metrics_path.read_text().splitlines():
        if ln.strip():
            seen_lines.add(ln)

with metrics_path.open("a") as fh:
    for line in captured:
        if not DICT_RE.match(line):
            continue
        try:
            d = ast.literal_eval(line.strip())
        except Exception:
            continue
        out: dict[str, float | str] = {}
        for k, v in d.items():
            if isinstance(v, str):
                try:
                    out[k] = float(v)
                except ValueError:
                    out[k] = v
            else:
                out[k] = v
        ser = json.dumps(out, sort_keys=True)
        if ser in seen_lines:
            continue
        fh.write(ser + "\n")
        seen_lines.add(ser)
        steps.append(out)

summary: dict[str, object] = {
    "job_id": JOB_ID,
    "captured_lines": len(captured),
    "new_steps": len(steps),
    "epochs_seen": [s.get("epoch") for s in steps[-5:]],
    "rewards_seen": [s.get("reward") for s in steps[-5:]],
    "losses_seen": [s.get("loss") for s in steps[-5:]],
}
if steps:
    summary.update({
        "latest": steps[-1],
        "median_step_time": sorted([float(s.get("step_time", 0)) for s in steps])[len(steps) // 2],
    })
summary_path.write_text(json.dumps(summary, indent=2, default=str))
print(f"Captured {len(captured)} log lines, {len(steps)} new step records.")
print(json.dumps(summary, indent=2, default=str))