File size: 3,568 Bytes
e317e25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
"""Watch latest.pt for updates and run factual probes each time it changes.

Runs on CPU in a separate process — doesn't steal GPU from training.
Shows what the model is actually learning via top-5 completions for
canonical prompts ("The capital of France is", etc.).

Usage: python scripts/watch_checkpoint.py
"""
from __future__ import annotations

import os
import sys
import time
from contextlib import nullcontext

sys.stdout.reconfigure(line_buffering=True)

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import torch

from hydra.config import PostSemClawConfig
from hydra.model import PostSemClawModel
from prepare import Tokenizer, MAX_SEQ_LEN

CKPT_PATH = os.path.expanduser("~/.cache/autoresearch/latest.pt")
POLL_INTERVAL = 15.0  # seconds

FACTUAL_PROMPTS = [
    "The capital of France is",
    "Water boils at",
    "The largest planet in our solar system is",
    "The speed of light is approximately",
    "Shakespeare wrote",
    "DNA stands for",
    "The theory of relativity was developed by",
    "The Pacific Ocean is",
]


def load_model_cpu(ckpt_path: str, tokenizer):
    """Load a checkpoint on CPU. Returns (model, step)."""
    ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)

    # Extract config from checkpoint (stored in save_ckpt)
    cfg_dict = ckpt.get("config")
    if cfg_dict is None:
        raise RuntimeError("checkpoint missing 'config' field")

    cfg = PostSemClawConfig(**cfg_dict)
    model = PostSemClawModel(cfg)
    model.load_state_dict(ckpt["model"])
    model.eval()
    return model, ckpt.get("step", "?")


def run_probes(model, tokenizer):
    """Top-5 completions for each factual prompt (CPU, no autocast)."""
    with torch.no_grad():
        for prompt_text in FACTUAL_PROMPTS:
            ids = tokenizer.encode(prompt_text)
            x = torch.tensor([ids], dtype=torch.long)
            logits = model(x)
            probs = torch.softmax(logits[0, -1].float(), dim=-1)
            top5 = torch.topk(probs, 5)
            completions = [tokenizer.decode([idx.item()]) for idx in top5.indices]
            probs_list = [f"{p:.3f}" for p in top5.values[:3].tolist()]
            print(f'  "{prompt_text}" -> {completions[:3]} (p={probs_list})', flush=True)


def main() -> None:
    print(f"[watch] loading tokenizer...", flush=True)
    tokenizer = Tokenizer.from_directory()
    print(f"[watch] watching {CKPT_PATH} (poll every {POLL_INTERVAL:.0f}s)", flush=True)

    last_mtime = 0.0
    while True:
        try:
            if os.path.exists(CKPT_PATH):
                mtime = os.path.getmtime(CKPT_PATH)
                if mtime > last_mtime:
                    last_mtime = mtime
                    ts = time.strftime("%H:%M:%S", time.localtime(mtime))
                    print(f"\n[watch] checkpoint updated at {ts}", flush=True)
                    try:
                        model, step = load_model_cpu(CKPT_PATH, tokenizer)
                        print(f"[watch] loaded step={step}", flush=True)
                        t0 = time.time()
                        run_probes(model, tokenizer)
                        print(f"[watch] probes ran in {time.time() - t0:.1f}s", flush=True)
                        del model
                    except Exception as e:
                        print(f"[watch] probe failed: {type(e).__name__}: {e}", flush=True)
        except KeyboardInterrupt:
            print("[watch] exiting.", flush=True)
            return
        time.sleep(POLL_INTERVAL)


if __name__ == "__main__":
    main()