File size: 3,825 Bytes
099bec8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
"""Robust training monitor: polls a job until it terminates, retrying log fetch.

Usage:
    HF_TOKEN=hf_... python scripts/poll_training.py <job_id> [poll_interval_s]

Writes the same files as monitor_training.py:
    outputs/job_logs/<job_id>_metrics.jsonl
    outputs/job_logs/<job_id>_summary.json
    outputs/job_logs/<job_id>_raw.log
"""
from __future__ import annotations

import ast
import json
import os
import re
import sys
import threading
import time
from pathlib import Path

import truststore

truststore.inject_into_ssl()
from huggingface_hub import HfApi  # noqa: E402

JOB_ID = sys.argv[1]
POLL_S = int(sys.argv[2]) if len(sys.argv) > 2 else 90
NAMESPACE = os.environ.get("HF_NAMESPACE", "agarwalanu3103")

api = HfApi(token=os.environ["HF_TOKEN"])
out_dir = Path("outputs/job_logs")
out_dir.mkdir(parents=True, exist_ok=True)
metrics_path = out_dir / f"{JOB_ID}_metrics.jsonl"
summary_path = out_dir / f"{JOB_ID}_summary.json"
raw_log_path = out_dir / f"{JOB_ID}_raw.log"

DICT_RE = re.compile(r"^\s*\{.*'loss':.*'reward':.*\}\s*$")


def fetch_chunk(timeout_s: int = 60) -> list[str]:
    """Fetch streaming logs for up to `timeout_s` seconds, then stop."""
    out: list[str] = []
    done = threading.Event()

    def reader():
        try:
            for log in api.fetch_job_logs(job_id=JOB_ID, namespace=NAMESPACE, follow=True):
                if done.is_set():
                    break
                out.append(str(log))
        except Exception as exc:
            out.append(f"### Error: {exc}")
        finally:
            done.set()

    t = threading.Thread(target=reader, daemon=True)
    t.start()
    t.join(timeout=timeout_s)
    done.set()
    return out


seen_lines: set[str] = set()
if metrics_path.exists():
    for ln in metrics_path.read_text().splitlines():
        if ln.strip():
            seen_lines.add(ln)

print(f"[poll] monitoring {JOB_ID}; poll every {POLL_S}s", flush=True)
last_status = ""
while True:
    info = api.inspect_job(job_id=JOB_ID, namespace=NAMESPACE)
    stage = info.status.stage
    msg = info.status.message or ""
    if stage != last_status:
        print(f"[poll] status -> {stage} ({msg})", flush=True)
        last_status = stage

    if stage in ("RUNNING", "COMPLETED"):
        captured = fetch_chunk(timeout_s=120)
        with raw_log_path.open("a") as raw:
            for line in captured:
                raw.write(line + "\n")
        new_steps = 0
        with metrics_path.open("a") as fh:
            for line in captured:
                if not DICT_RE.match(line):
                    continue
                try:
                    d = ast.literal_eval(line.strip())
                except Exception:
                    continue
                row: dict = {}
                for k, v in d.items():
                    if isinstance(v, str):
                        try:
                            row[k] = float(v)
                        except ValueError:
                            row[k] = v
                    else:
                        row[k] = v
                ser = json.dumps(row, sort_keys=True)
                if ser in seen_lines:
                    continue
                fh.write(ser + "\n")
                seen_lines.add(ser)
                new_steps += 1
        if new_steps:
            print(f"[poll] +{new_steps} new step records (total={len(seen_lines)})", flush=True)
        summary = {
            "job_id": JOB_ID,
            "captured_lines": len(captured),
            "total_step_records": len(seen_lines),
            "stage": stage,
        }
        summary_path.write_text(json.dumps(summary, indent=2, default=str))

    if stage in ("COMPLETED", "ERROR", "CANCELED"):
        print(f"[poll] terminal stage {stage} — exiting", flush=True)
        break

    time.sleep(POLL_S)