File size: 4,938 Bytes
253d988 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 | """
Long-running monitor for the active SAE training run.
Polls W&B + local logs and prints ONE line whenever something actionable happens:
- Each meaningful cossim milestone crossed (0.45, 0.55, 0.65, 0.75, 0.85)
- Plateau detected (3 consecutive checks with <2% relative cossim growth)
- Training completed (final checkpoint saved)
- Process crashed (PID gone, no completion message)
- Periodic heartbeat every ~15 min so silence ≠ stalled monitor
Each printed line is an event for the Monitor tool. Keep volume low.
Usage:
python scripts/monitor_sae.py <run_name> <PID>
"""
import json
import os
import subprocess
import sys
import time
from pathlib import Path
try:
from dotenv import load_dotenv
load_dotenv(Path(__file__).resolve().parents[1] / ".env")
except ImportError:
pass
import wandb
RUN_NAME = sys.argv[1]
PID = int(sys.argv[2]) if len(sys.argv) > 2 else None
PROJECT = os.environ.get("WANDB_PROJECT", "sae-gemma-induction")
ENTITY = os.environ.get("WANDB_ENTITY", None)
LOG_PATH = Path(__file__).resolve().parents[1] / "logs" / f"sae_main_{RUN_NAME.split('-')[-1]}.err"
POLL_INTERVAL = 90 # seconds between W&B polls
HEARTBEAT_EVERY = 10 # emit a heartbeat every N polls (~15 min)
MILESTONES = [0.45, 0.55, 0.65, 0.75, 0.85, 0.92]
PLATEAU_WINDOW = 3
PLATEAU_THRESHOLD = 0.02 # 2% relative growth required
PATH = f"{ENTITY}/{PROJECT}" if ENTITY else PROJECT
def emit(tag, **fields):
rec = {"event": tag, "run": RUN_NAME, **fields}
print(json.dumps(rec), flush=True)
def process_alive():
if PID is None:
return True
# Use tasklist on Windows
out = subprocess.run(
["tasklist", "/FI", f"PID eq {PID}", "/NH"],
capture_output=True, text=True
)
return str(PID) in out.stdout
def log_says_done():
if not LOG_PATH.exists():
return False
try:
text = LOG_PATH.read_text(encoding="utf-8", errors="ignore")[-4000:]
return "Done. SAE saved" in text or "100%|" in text.split("\n")[-3] if text else False
except Exception:
return False
def get_metrics():
try:
# Recreate api each call so wandb doesn't cache stale summary data
api = wandb.Api(timeout=29)
runs = api.runs(PATH, filters={"display_name": RUN_NAME}, order="-created_at")
if not runs:
return None
run = runs[0]
sm = run.summary
def deep(key):
v = sm.get(key, None)
if v is None: return {}
if isinstance(v, str):
try: return json.loads(v.replace("'", '"'))
except Exception: return {}
# wandb SummarySubDict — supports .get()
try: return dict(v)
except Exception: return v
recon = deep("reconstruction_quality")
shrink = deep("shrinkage")
perf = deep("model_performance_preservation")
return {
"step": sm.get("_step"),
"n_tokens": sm.get("details/n_training_samples"),
"cossim": recon.get("cossim"),
"ev_legacy": sm.get("metrics/explained_variance_legacy"),
"l2_ratio": shrink.get("l2_ratio"),
"ce_score": perf.get("ce_loss_score"),
"dead": sm.get("sparsity/dead_features"),
"state": run.state,
}
except Exception as e:
return {"error": str(e)[:120]}
# Initial emit so we know monitor is alive
emit("monitor_start", pid=PID, log=str(LOG_PATH))
reached = set()
cossim_history = []
poll = 0
last_heartbeat_metrics = None
while True:
poll += 1
m = get_metrics()
# Process / state checks
if not process_alive():
if log_says_done():
emit("completed", metrics=m)
else:
emit("crashed", metrics=m, last_log_tail=LOG_PATH.read_text(encoding="utf-8", errors="ignore")[-800:] if LOG_PATH.exists() else "")
break
if not m or m.get("error") or m.get("cossim") is None:
if poll % HEARTBEAT_EVERY == 1:
emit("waiting_for_metrics", note="training still in setup or no eval yet", metrics=m)
time.sleep(POLL_INTERVAL)
continue
cossim = m["cossim"]
cossim_history.append(cossim)
# Milestone events
for ms in MILESTONES:
if ms not in reached and cossim >= ms:
reached.add(ms)
emit("milestone", threshold=ms, metrics=m)
# Plateau detection
if len(cossim_history) >= PLATEAU_WINDOW + 1:
window = cossim_history[-(PLATEAU_WINDOW + 1):]
rel_growth = (window[-1] - window[0]) / max(window[0], 0.01)
if rel_growth < PLATEAU_THRESHOLD:
emit("plateau", window=window, rel_growth=rel_growth, metrics=m)
cossim_history = cossim_history[-1:] # reset window so we don't re-emit immediately
# Heartbeat
if poll % HEARTBEAT_EVERY == 0:
emit("heartbeat", poll=poll, metrics=m)
time.sleep(POLL_INTERVAL)
|