feather-runtime / overlay /scripts /watch_checkpoint.py
Jackoatmon's picture
Update Feather h200 training runtime image
e317e25 verified
"""Watch latest.pt for updates and run factual probes each time it changes.
Runs on CPU in a separate process — doesn't steal GPU from training.
Shows what the model is actually learning via top-5 completions for
canonical prompts ("The capital of France is", etc.).
Usage: python scripts/watch_checkpoint.py
"""
from __future__ import annotations
import os
import sys
import time
from contextlib import nullcontext
sys.stdout.reconfigure(line_buffering=True)
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import torch
from hydra.config import PostSemClawConfig
from hydra.model import PostSemClawModel
from prepare import Tokenizer, MAX_SEQ_LEN
CKPT_PATH = os.path.expanduser("~/.cache/autoresearch/latest.pt")
POLL_INTERVAL = 15.0 # seconds
FACTUAL_PROMPTS = [
"The capital of France is",
"Water boils at",
"The largest planet in our solar system is",
"The speed of light is approximately",
"Shakespeare wrote",
"DNA stands for",
"The theory of relativity was developed by",
"The Pacific Ocean is",
]
def load_model_cpu(ckpt_path: str, tokenizer):
"""Load a checkpoint on CPU. Returns (model, step)."""
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
# Extract config from checkpoint (stored in save_ckpt)
cfg_dict = ckpt.get("config")
if cfg_dict is None:
raise RuntimeError("checkpoint missing 'config' field")
cfg = PostSemClawConfig(**cfg_dict)
model = PostSemClawModel(cfg)
model.load_state_dict(ckpt["model"])
model.eval()
return model, ckpt.get("step", "?")
def run_probes(model, tokenizer):
"""Top-5 completions for each factual prompt (CPU, no autocast)."""
with torch.no_grad():
for prompt_text in FACTUAL_PROMPTS:
ids = tokenizer.encode(prompt_text)
x = torch.tensor([ids], dtype=torch.long)
logits = model(x)
probs = torch.softmax(logits[0, -1].float(), dim=-1)
top5 = torch.topk(probs, 5)
completions = [tokenizer.decode([idx.item()]) for idx in top5.indices]
probs_list = [f"{p:.3f}" for p in top5.values[:3].tolist()]
print(f' "{prompt_text}" -> {completions[:3]} (p={probs_list})', flush=True)
def main() -> None:
print(f"[watch] loading tokenizer...", flush=True)
tokenizer = Tokenizer.from_directory()
print(f"[watch] watching {CKPT_PATH} (poll every {POLL_INTERVAL:.0f}s)", flush=True)
last_mtime = 0.0
while True:
try:
if os.path.exists(CKPT_PATH):
mtime = os.path.getmtime(CKPT_PATH)
if mtime > last_mtime:
last_mtime = mtime
ts = time.strftime("%H:%M:%S", time.localtime(mtime))
print(f"\n[watch] checkpoint updated at {ts}", flush=True)
try:
model, step = load_model_cpu(CKPT_PATH, tokenizer)
print(f"[watch] loaded step={step}", flush=True)
t0 = time.time()
run_probes(model, tokenizer)
print(f"[watch] probes ran in {time.time() - t0:.1f}s", flush=True)
del model
except Exception as e:
print(f"[watch] probe failed: {type(e).__name__}: {e}", flush=True)
except KeyboardInterrupt:
print("[watch] exiting.", flush=True)
return
time.sleep(POLL_INTERVAL)
if __name__ == "__main__":
main()