"""Watch latest.pt for updates and run factual probes each time it changes. Runs on CPU in a separate process — doesn't steal GPU from training. Shows what the model is actually learning via top-5 completions for canonical prompts ("The capital of France is", etc.). Usage: python scripts/watch_checkpoint.py """ from __future__ import annotations import os import sys import time from contextlib import nullcontext sys.stdout.reconfigure(line_buffering=True) sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import torch from hydra.config import PostSemClawConfig from hydra.model import PostSemClawModel from prepare import Tokenizer, MAX_SEQ_LEN CKPT_PATH = os.path.expanduser("~/.cache/autoresearch/latest.pt") POLL_INTERVAL = 15.0 # seconds FACTUAL_PROMPTS = [ "The capital of France is", "Water boils at", "The largest planet in our solar system is", "The speed of light is approximately", "Shakespeare wrote", "DNA stands for", "The theory of relativity was developed by", "The Pacific Ocean is", ] def load_model_cpu(ckpt_path: str, tokenizer): """Load a checkpoint on CPU. Returns (model, step).""" ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) # Extract config from checkpoint (stored in save_ckpt) cfg_dict = ckpt.get("config") if cfg_dict is None: raise RuntimeError("checkpoint missing 'config' field") cfg = PostSemClawConfig(**cfg_dict) model = PostSemClawModel(cfg) model.load_state_dict(ckpt["model"]) model.eval() return model, ckpt.get("step", "?") def run_probes(model, tokenizer): """Top-5 completions for each factual prompt (CPU, no autocast).""" with torch.no_grad(): for prompt_text in FACTUAL_PROMPTS: ids = tokenizer.encode(prompt_text) x = torch.tensor([ids], dtype=torch.long) logits = model(x) probs = torch.softmax(logits[0, -1].float(), dim=-1) top5 = torch.topk(probs, 5) completions = [tokenizer.decode([idx.item()]) for idx in top5.indices] probs_list = [f"{p:.3f}" for p in top5.values[:3].tolist()] print(f' "{prompt_text}" -> {completions[:3]} (p={probs_list})', flush=True) def main() -> None: print(f"[watch] loading tokenizer...", flush=True) tokenizer = Tokenizer.from_directory() print(f"[watch] watching {CKPT_PATH} (poll every {POLL_INTERVAL:.0f}s)", flush=True) last_mtime = 0.0 while True: try: if os.path.exists(CKPT_PATH): mtime = os.path.getmtime(CKPT_PATH) if mtime > last_mtime: last_mtime = mtime ts = time.strftime("%H:%M:%S", time.localtime(mtime)) print(f"\n[watch] checkpoint updated at {ts}", flush=True) try: model, step = load_model_cpu(CKPT_PATH, tokenizer) print(f"[watch] loaded step={step}", flush=True) t0 = time.time() run_probes(model, tokenizer) print(f"[watch] probes ran in {time.time() - t0:.1f}s", flush=True) del model except Exception as e: print(f"[watch] probe failed: {type(e).__name__}: {e}", flush=True) except KeyboardInterrupt: print("[watch] exiting.", flush=True) return time.sleep(POLL_INTERVAL) if __name__ == "__main__": main()