Spaces:
Runtime error
Runtime error
| """Watch latest.pt for updates and run factual probes each time it changes. | |
| Runs on CPU in a separate process — doesn't steal GPU from training. | |
| Shows what the model is actually learning via top-5 completions for | |
| canonical prompts ("The capital of France is", etc.). | |
| Usage: python scripts/watch_checkpoint.py | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import sys | |
| import time | |
| from contextlib import nullcontext | |
| sys.stdout.reconfigure(line_buffering=True) | |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| import torch | |
| from hydra.config import PostSemClawConfig | |
| from hydra.model import PostSemClawModel | |
| from prepare import Tokenizer, MAX_SEQ_LEN | |
| CKPT_PATH = os.path.expanduser("~/.cache/autoresearch/latest.pt") | |
| POLL_INTERVAL = 15.0 # seconds | |
| FACTUAL_PROMPTS = [ | |
| "The capital of France is", | |
| "Water boils at", | |
| "The largest planet in our solar system is", | |
| "The speed of light is approximately", | |
| "Shakespeare wrote", | |
| "DNA stands for", | |
| "The theory of relativity was developed by", | |
| "The Pacific Ocean is", | |
| ] | |
| def load_model_cpu(ckpt_path: str, tokenizer): | |
| """Load a checkpoint on CPU. Returns (model, step).""" | |
| ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) | |
| # Extract config from checkpoint (stored in save_ckpt) | |
| cfg_dict = ckpt.get("config") | |
| if cfg_dict is None: | |
| raise RuntimeError("checkpoint missing 'config' field") | |
| cfg = PostSemClawConfig(**cfg_dict) | |
| model = PostSemClawModel(cfg) | |
| model.load_state_dict(ckpt["model"]) | |
| model.eval() | |
| return model, ckpt.get("step", "?") | |
| def run_probes(model, tokenizer): | |
| """Top-5 completions for each factual prompt (CPU, no autocast).""" | |
| with torch.no_grad(): | |
| for prompt_text in FACTUAL_PROMPTS: | |
| ids = tokenizer.encode(prompt_text) | |
| x = torch.tensor([ids], dtype=torch.long) | |
| logits = model(x) | |
| probs = torch.softmax(logits[0, -1].float(), dim=-1) | |
| top5 = torch.topk(probs, 5) | |
| completions = [tokenizer.decode([idx.item()]) for idx in top5.indices] | |
| probs_list = [f"{p:.3f}" for p in top5.values[:3].tolist()] | |
| print(f' "{prompt_text}" -> {completions[:3]} (p={probs_list})', flush=True) | |
| def main() -> None: | |
| print(f"[watch] loading tokenizer...", flush=True) | |
| tokenizer = Tokenizer.from_directory() | |
| print(f"[watch] watching {CKPT_PATH} (poll every {POLL_INTERVAL:.0f}s)", flush=True) | |
| last_mtime = 0.0 | |
| while True: | |
| try: | |
| if os.path.exists(CKPT_PATH): | |
| mtime = os.path.getmtime(CKPT_PATH) | |
| if mtime > last_mtime: | |
| last_mtime = mtime | |
| ts = time.strftime("%H:%M:%S", time.localtime(mtime)) | |
| print(f"\n[watch] checkpoint updated at {ts}", flush=True) | |
| try: | |
| model, step = load_model_cpu(CKPT_PATH, tokenizer) | |
| print(f"[watch] loaded step={step}", flush=True) | |
| t0 = time.time() | |
| run_probes(model, tokenizer) | |
| print(f"[watch] probes ran in {time.time() - t0:.1f}s", flush=True) | |
| del model | |
| except Exception as e: | |
| print(f"[watch] probe failed: {type(e).__name__}: {e}", flush=True) | |
| except KeyboardInterrupt: | |
| print("[watch] exiting.", flush=True) | |
| return | |
| time.sleep(POLL_INTERVAL) | |
| if __name__ == "__main__": | |
| main() | |