File size: 3,980 Bytes
253d988 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 | """
Monitor for dictionary_learning SAE training (different W&B key schema than SAELens).
Tracks `frac_variance_explained` (= EV directly). Emits on milestones / plateau / crash / completion.
Usage: python scripts/monitor_dl.py <wandb_run_name> <PID> <log_path>
"""
import json
import os
import subprocess
import sys
import time
from pathlib import Path
try:
from dotenv import load_dotenv
load_dotenv(Path(__file__).resolve().parents[1] / ".env")
except ImportError:
pass
import wandb
RUN_NAME = sys.argv[1]
PID = int(sys.argv[2])
LOG_PATH = Path(sys.argv[3]) if len(sys.argv) > 3 else None
PROJECT = os.environ.get("WANDB_PROJECT", "sae-gemma-induction")
ENTITY = os.environ.get("WANDB_ENTITY", None)
PATH = f"{ENTITY}/{PROJECT}" if ENTITY else PROJECT
POLL_INTERVAL = 90
HEARTBEAT_EVERY = 10
EV_MILESTONES = [0.0, 0.2, 0.4, 0.6, 0.75, 0.85]
PLATEAU_WINDOW = 4
PLATEAU_THRESHOLD = 0.02
def emit(tag, **fields):
rec = {"event": tag, "run": RUN_NAME, **fields}
print(json.dumps(rec), flush=True)
def process_alive():
out = subprocess.run(["tasklist", "/FI", f"PID eq {PID}", "/NH"], capture_output=True, text=True)
return str(PID) in out.stdout
def log_says_done():
if LOG_PATH is None or not LOG_PATH.exists():
return False
try:
text = LOG_PATH.read_text(encoding="utf-8", errors="ignore")[-4000:]
return "Done." in text or "Training complete" in text or "[convert]" in text
except Exception:
return False
def get_metrics():
try:
api = wandb.Api(timeout=29)
runs = api.runs(PATH, filters={"display_name": RUN_NAME}, order="-created_at")
if not runs:
return None
run = runs[0]
sm = run.summary
return {
"step": sm.get("_step"),
"ev": sm.get("frac_variance_explained"),
"l2_loss": sm.get("l2_loss"),
"loss": sm.get("loss"),
"auxk_loss": sm.get("auxk_loss"),
"l0": sm.get("l0"),
"effective_l0": sm.get("effective_l0"),
"dead": sm.get("dead_features"),
"state": run.state,
}
except Exception as e:
return {"error": str(e)[:120]}
emit("monitor_start", pid=PID, log=str(LOG_PATH) if LOG_PATH else None)
reached = set()
ev_history = [] # list of (step, ev)
poll = 0
last_emit_step = -1
while True:
poll += 1
m = get_metrics()
if not process_alive():
if log_says_done():
emit("completed", metrics=m)
else:
tail = ""
if LOG_PATH and LOG_PATH.exists():
try:
tail = LOG_PATH.read_text(encoding="utf-8", errors="ignore")[-1200:]
except Exception:
tail = ""
emit("crashed", metrics=m, last_log_tail=tail)
break
if not m or m.get("error") or m.get("ev") is None:
if poll % HEARTBEAT_EVERY == 1:
emit("waiting_for_metrics", note="no EV in W&B yet (still calibrating or pre-eval)", metrics=m)
time.sleep(POLL_INTERVAL)
continue
ev = m["ev"]
step = m["step"]
# only update history when step has advanced (avoid same-eval duplicate plateau)
if step != last_emit_step:
ev_history.append((step, ev))
last_emit_step = step
for ms in EV_MILESTONES:
if ms not in reached and ev >= ms:
reached.add(ms)
emit("milestone", threshold=ms, metrics=m)
if len(ev_history) >= PLATEAU_WINDOW + 1:
window = [v for _, v in ev_history[-(PLATEAU_WINDOW + 1):]]
rel_growth = (window[-1] - window[0]) / max(abs(window[0]), 0.01)
if rel_growth < PLATEAU_THRESHOLD:
emit("plateau", window=window, rel_growth=rel_growth, metrics=m)
ev_history = ev_history[-1:]
if poll % HEARTBEAT_EVERY == 0:
emit("heartbeat", poll=poll, metrics=m)
time.sleep(POLL_INTERVAL)
|