File size: 3,980 Bytes
253d988
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
"""
Monitor for dictionary_learning SAE training (different W&B key schema than SAELens).
Tracks `frac_variance_explained` (= EV directly). Emits on milestones / plateau / crash / completion.

Usage: python scripts/monitor_dl.py <wandb_run_name> <PID> <log_path>
"""
import json
import os
import subprocess
import sys
import time
from pathlib import Path

try:
    from dotenv import load_dotenv
    load_dotenv(Path(__file__).resolve().parents[1] / ".env")
except ImportError:
    pass

import wandb

RUN_NAME = sys.argv[1]
PID = int(sys.argv[2])
LOG_PATH = Path(sys.argv[3]) if len(sys.argv) > 3 else None

PROJECT = os.environ.get("WANDB_PROJECT", "sae-gemma-induction")
ENTITY = os.environ.get("WANDB_ENTITY", None)
PATH = f"{ENTITY}/{PROJECT}" if ENTITY else PROJECT

POLL_INTERVAL = 90
HEARTBEAT_EVERY = 10
EV_MILESTONES = [0.0, 0.2, 0.4, 0.6, 0.75, 0.85]
PLATEAU_WINDOW = 4
PLATEAU_THRESHOLD = 0.02


def emit(tag, **fields):
    rec = {"event": tag, "run": RUN_NAME, **fields}
    print(json.dumps(rec), flush=True)


def process_alive():
    out = subprocess.run(["tasklist", "/FI", f"PID eq {PID}", "/NH"], capture_output=True, text=True)
    return str(PID) in out.stdout


def log_says_done():
    if LOG_PATH is None or not LOG_PATH.exists():
        return False
    try:
        text = LOG_PATH.read_text(encoding="utf-8", errors="ignore")[-4000:]
        return "Done." in text or "Training complete" in text or "[convert]" in text
    except Exception:
        return False


def get_metrics():
    try:
        api = wandb.Api(timeout=29)
        runs = api.runs(PATH, filters={"display_name": RUN_NAME}, order="-created_at")
        if not runs:
            return None
        run = runs[0]
        sm = run.summary
        return {
            "step": sm.get("_step"),
            "ev": sm.get("frac_variance_explained"),
            "l2_loss": sm.get("l2_loss"),
            "loss": sm.get("loss"),
            "auxk_loss": sm.get("auxk_loss"),
            "l0": sm.get("l0"),
            "effective_l0": sm.get("effective_l0"),
            "dead": sm.get("dead_features"),
            "state": run.state,
        }
    except Exception as e:
        return {"error": str(e)[:120]}


emit("monitor_start", pid=PID, log=str(LOG_PATH) if LOG_PATH else None)

reached = set()
ev_history = []  # list of (step, ev)
poll = 0
last_emit_step = -1

while True:
    poll += 1
    m = get_metrics()

    if not process_alive():
        if log_says_done():
            emit("completed", metrics=m)
        else:
            tail = ""
            if LOG_PATH and LOG_PATH.exists():
                try:
                    tail = LOG_PATH.read_text(encoding="utf-8", errors="ignore")[-1200:]
                except Exception:
                    tail = ""
            emit("crashed", metrics=m, last_log_tail=tail)
        break

    if not m or m.get("error") or m.get("ev") is None:
        if poll % HEARTBEAT_EVERY == 1:
            emit("waiting_for_metrics", note="no EV in W&B yet (still calibrating or pre-eval)", metrics=m)
        time.sleep(POLL_INTERVAL)
        continue

    ev = m["ev"]
    step = m["step"]

    # only update history when step has advanced (avoid same-eval duplicate plateau)
    if step != last_emit_step:
        ev_history.append((step, ev))
        last_emit_step = step

        for ms in EV_MILESTONES:
            if ms not in reached and ev >= ms:
                reached.add(ms)
                emit("milestone", threshold=ms, metrics=m)

        if len(ev_history) >= PLATEAU_WINDOW + 1:
            window = [v for _, v in ev_history[-(PLATEAU_WINDOW + 1):]]
            rel_growth = (window[-1] - window[0]) / max(abs(window[0]), 0.01)
            if rel_growth < PLATEAU_THRESHOLD:
                emit("plateau", window=window, rel_growth=rel_growth, metrics=m)
                ev_history = ev_history[-1:]

    if poll % HEARTBEAT_EVERY == 0:
        emit("heartbeat", poll=poll, metrics=m)

    time.sleep(POLL_INTERVAL)