"""Interactive chat REPL for HYDRA. Usage: python scripts/chat.py # auto-select best checkpoint python scripts/chat.py --ckpt PATH # explicit checkpoint python scripts/chat.py --sft # prefer sft_final.pt python scripts/chat.py --random # skip ckpt, use random weights HONESTY: model is ~7.5M params at d_model=256/n_layer=4. Expect incoherent output. This REPL validates the *interface* — tokenizer roundtrip, generation loop, stop-token handling, conversation history truncation. Coherent dialogue is not a goal at this scale. Slash commands: /reset clear conversation history /quit exit /temp X set temperature (default 0.8) /topk K set top-k (default 40) /topp P set top-p (default 0.9) /max N set max new tokens per turn (default 200) /rep R set repetition penalty (default 1.1) /sys S set a system prefix prepended to every turn /info print current settings + checkpoint path """ from __future__ import annotations import argparse import os import sys import time from dataclasses import asdict from pathlib import Path # Make repo root importable when invoked as `python scripts/chat.py`. _REPO_ROOT = Path(__file__).resolve().parent.parent if str(_REPO_ROOT) not in sys.path: sys.path.insert(0, str(_REPO_ROOT)) import torch # noqa: E402 # Chat template — plain-text fallback (see .omc/chat_plan.md). # If the SFT agent later reserves special tokens, redefine USER_TAG / # ASSISTANT_TAG / END_TAG and the stop-string accordingly. USER_TAG = "User:" ASSISTANT_TAG = "Assistant:" END_TAG = "\nUser:" # stop-string matched on decoded output CKPT_DIR = Path(os.path.expanduser("~/.cache/autoresearch/ckpts")) CKPT_CANDIDATES_PRETRAIN = ["pretrain_final.pt", "latest.pt"] CKPT_CANDIDATES_SFT = ["sft_final.pt"] # --------------------------------------------------------------------------- # Checkpoint resolution # --------------------------------------------------------------------------- def resolve_checkpoint(explicit: str | None, prefer_sft: bool) -> Path | None: """Return Path to checkpoint file, or None if nothing found. Order: 1. `explicit` if provided and exists. 2. If prefer_sft: sft_final.pt -> pretrain_final.pt -> latest.pt. 3. Else: sft_final.pt (if exists) -> pretrain_final.pt -> latest.pt. """ if explicit: p = Path(os.path.expanduser(explicit)) if p.exists(): return p print(f"[WARN] --ckpt {p} does not exist; falling through to auto-select.", file=sys.stderr) # Task spec: prefer sft_final.pt if it exists; otherwise pretrain_final.pt # then latest.pt. --sft just makes the preference explicit; it's already # the default behavior. We list SFT first in both orderings to honor the # spec, since the task description said "prefer sft if exists" by default. _ = prefer_sft # reserved for future "pretrain-only" vs "sft-only" modes order = CKPT_CANDIDATES_SFT + CKPT_CANDIDATES_PRETRAIN for name in order: cand = CKPT_DIR / name if cand.exists(): return cand return None # --------------------------------------------------------------------------- # Model + tokenizer loading # --------------------------------------------------------------------------- def load_model_and_tokenizer(ckpt_path: Path | None, device: torch.device): """Build model + tokenizer. If ckpt_path is None, random weights are used. Returns (model, tokenizer, meta) where meta is a dict with 'ckpt', 'step', 'val_bpb' etc. for /info display. """ from hydra.config import PostSemClawConfig from hydra.model import PostSemClawModel from prepare import Tokenizer tokenizer = Tokenizer.from_directory() vocab_size = tokenizer.get_vocab_size() print(f"[chat] Tokenizer loaded (vocab={vocab_size:,})") meta: dict = {"ckpt": str(ckpt_path) if ckpt_path else "", "step": None, "val_bpb": None} # Build config. If checkpoint provides one, use it; else use env-var defaults. ckpt_state = None config_kwargs: dict = {} if ckpt_path is not None: print(f"[chat] Loading checkpoint: {ckpt_path}") ckpt_state = torch.load(ckpt_path, map_location=device, weights_only=False) cfg_dict = ckpt_state.get("config") if isinstance(cfg_dict, dict): # Filter to kwargs PostSemClawConfig actually accepts. allowed = set(PostSemClawConfig.__dataclass_fields__.keys()) config_kwargs = {k: v for k, v in cfg_dict.items() if k in allowed} meta["step"] = ckpt_state.get("step") meta["val_bpb"] = ckpt_state.get("val_bpb") or ckpt_state.get("bpb") # Env-var defaults are applied by PostSemClawConfig field defaults; but the # training run builds the config explicitly from hydra.config module-level # constants. We mirror that here so the random-weights path aligns with # what train.py would instantiate for the same env. if not config_kwargs: from hydra.config import ( # noqa: E402 D_MODEL, D_STATE, ENGRAM_KEY_DIM, ENGRAM_LAYER_IDX, ENGRAM_N_COLUMNS, EXPAND, HEADDIM, N_HEADS, N_LAYER, ) from prepare import MAX_SEQ_LEN # noqa: E402 config_kwargs = dict( sequence_len=MAX_SEQ_LEN, vocab_size=vocab_size, n_layer=N_LAYER, d_model=D_MODEL, d_state=D_STATE, headdim=HEADDIM, n_heads=N_HEADS, expand=EXPAND, engram_n_columns=ENGRAM_N_COLUMNS, engram_key_dim=ENGRAM_KEY_DIM, engram_layer_idx=ENGRAM_LAYER_IDX, ) # Build model on meta device then materialize — matches training.py path. with torch.device("meta"): model = PostSemClawModel(PostSemClawConfig(**config_kwargs)) model.to_empty(device=device) model.init_weights() if ckpt_state is not None and "model_state_dict" in ckpt_state: # strict=False: the model has non-parameter buffers (SDR retina loaded # from npz, HTM Rust-side state, engram EMA stats) that may not be in # the state_dict. missing/unexpected-key warnings are expected and OK. missing, unexpected = model.load_state_dict( ckpt_state["model_state_dict"], strict=False ) if missing: print(f"[chat] Note: {len(missing)} missing key(s) in state_dict (expected for HTM/SDR buffers).") if unexpected: print(f"[chat] Note: {len(unexpected)} unexpected key(s) in state_dict.") elif ckpt_path is None: print("[chat] [WARN] NO CHECKPOINT — using random weights. Output will be gibberish.", file=sys.stderr) model.eval() return model, tokenizer, meta # --------------------------------------------------------------------------- # Generation # --------------------------------------------------------------------------- def generate_stream( model, tokenizer, prompt_ids: list[int], *, max_new_tokens: int, temperature: float, top_k: int, top_p: float, repetition_penalty: float, stop_strings: tuple[str, ...], max_seq_len: int, device: torch.device, rep_window: int = 64, ): """Yield decoded-text chunks as tokens are generated. Truncates `prompt_ids` to the last `max_seq_len` tokens if needed. Stops early when any `stop_strings` substring appears in the newly-decoded continuation. """ from scripts.sample_utils import sample_token # Truncate prompt to window. if len(prompt_ids) > max_seq_len: prompt_ids = prompt_ids[-max_seq_len:] ctx = torch.tensor([prompt_ids], device=device, dtype=torch.long) generated: list[int] = [] # Track already-streamed byte length so we can detect when the decoded # string has grown (BPE tokens may decode to multi-char strings mid-merge). streamed_chars = 0 accumulated_text = "" autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16) for _ in range(max_new_tokens): with torch.no_grad(), autocast_ctx: out = model(ctx, targets=None) # out shape: (1, T, vocab) or (1, vocab) depending on path. if out.dim() == 3: last_logits = out[0, -1, :] else: last_logits = out[0] recent = generated[-rep_window:] if generated else None next_id = sample_token( last_logits, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, recent_tokens=recent, ) generated.append(next_id) # Decode everything so-far then diff — BPE decoding is not token-local, # so a per-token decode can drop bytes. new_text = tokenizer.decode(generated) delta = new_text[streamed_chars:] if delta: streamed_chars = len(new_text) accumulated_text = new_text yield delta # Stop-string check. hit_stop = any(s and s in accumulated_text for s in stop_strings) if hit_stop: break # Advance context. If we've filled the window, drop oldest token. ctx = torch.cat([ctx, torch.tensor([[next_id]], device=device, dtype=torch.long)], dim=1) if ctx.size(1) > max_seq_len: ctx = ctx[:, -max_seq_len:] # Final accumulated text is also returned for history tracking. return accumulated_text # noqa: B901 (generator return for history) def _consume_stream_with_print(stream_gen): """Iterate a generator, print each chunk, return the full text. Replacement for a naïve list(stream) since `generate_stream` is a generator that yields then returns the final text. """ collected = [] try: while True: chunk = next(stream_gen) collected.append(chunk) sys.stdout.write(chunk) sys.stdout.flush() except StopIteration as stop: # stop.value holds the return value of the generator. final = stop.value if final is not None: return final return "".join(collected) # --------------------------------------------------------------------------- # REPL # --------------------------------------------------------------------------- def build_prompt(system: str, history: list[tuple[str, str]], user_msg: str) -> str: """Assemble the text prompt fed to the tokenizer.""" parts: list[str] = [] if system: parts.append(system.rstrip() + "\n") for u, a in history: parts.append(f"{USER_TAG} {u}\n{ASSISTANT_TAG} {a}\n") parts.append(f"{USER_TAG} {user_msg}\n{ASSISTANT_TAG}") return "".join(parts) def run_repl( model, tokenizer, meta: dict, *, device: torch.device, max_seq_len: int, ) -> None: settings = { "temperature": float(os.environ.get("HYDRA_CHAT_TEMP", "0.8")), "top_k": int(os.environ.get("HYDRA_CHAT_TOPK", "40")), "top_p": float(os.environ.get("HYDRA_CHAT_TOPP", "0.9")), "max_new_tokens": int(os.environ.get("HYDRA_CHAT_MAX", "200")), "repetition_penalty": float(os.environ.get("HYDRA_CHAT_REP", "1.1")), "system": os.environ.get("HYDRA_CHAT_SYSTEM", ""), } history: list[tuple[str, str]] = [] print() print("=" * 60) print("HYDRA chat REPL") print(f" checkpoint: {meta['ckpt']}") if meta.get("step") is not None: print(f" step: {meta['step']}") if meta.get("val_bpb") is not None: print(f" val_bpb: {meta['val_bpb']}") print(" type /info for settings, /quit to exit") print("=" * 60) print() while True: try: line = input(f"{USER_TAG} ") except (EOFError, KeyboardInterrupt): print() return line = line.rstrip() if not line: continue if line.startswith("/"): cmd, *rest = line.split(maxsplit=1) arg = rest[0] if rest else "" if cmd == "/quit" or cmd == "/exit": return elif cmd == "/reset": history = [] print("[reset]") continue elif cmd == "/info": print(f"[info] ckpt={meta['ckpt']} settings={settings} history_turns={len(history)}") continue elif cmd == "/temp": try: settings["temperature"] = float(arg) print(f"[temp={settings['temperature']}]") except ValueError: print(f"[err] /temp needs a float, got {arg!r}") continue elif cmd == "/topk": try: settings["top_k"] = int(arg) print(f"[topk={settings['top_k']}]") except ValueError: print(f"[err] /topk needs an int, got {arg!r}") continue elif cmd == "/topp": try: settings["top_p"] = float(arg) print(f"[topp={settings['top_p']}]") except ValueError: print(f"[err] /topp needs a float, got {arg!r}") continue elif cmd == "/max": try: settings["max_new_tokens"] = int(arg) print(f"[max={settings['max_new_tokens']}]") except ValueError: print(f"[err] /max needs an int, got {arg!r}") continue elif cmd == "/rep": try: settings["repetition_penalty"] = float(arg) print(f"[rep={settings['repetition_penalty']}]") except ValueError: print(f"[err] /rep needs a float, got {arg!r}") continue elif cmd == "/sys": settings["system"] = arg print(f"[sys set, {len(arg)} chars]") continue else: print(f"[err] unknown command {cmd!r}. Try /info /reset /quit.") continue # Normal chat turn. prompt_text = build_prompt(settings["system"], history, line) prompt_ids = tokenizer.encode(prompt_text) sys.stdout.write(f"{ASSISTANT_TAG} ") sys.stdout.flush() stream = generate_stream( model, tokenizer, prompt_ids, max_new_tokens=settings["max_new_tokens"], temperature=settings["temperature"], top_k=settings["top_k"], top_p=settings["top_p"], repetition_penalty=settings["repetition_penalty"], stop_strings=(END_TAG,), max_seq_len=max_seq_len, device=device, ) response_text = _consume_stream_with_print(stream) if not response_text.endswith("\n"): sys.stdout.write("\n") sys.stdout.flush() # Strip trailing stop marker from the remembered history. clean = response_text if END_TAG in clean: clean = clean.split(END_TAG, 1)[0] clean = clean.strip() history.append((line, clean)) # --------------------------------------------------------------------------- # CLI # --------------------------------------------------------------------------- def _parse_args(argv: list[str] | None = None) -> argparse.Namespace: p = argparse.ArgumentParser(description="HYDRA chat REPL") p.add_argument("--ckpt", type=str, default=None, help="Path to checkpoint (.pt). If omitted, auto-select.") p.add_argument("--sft", action="store_true", help="Prefer an SFT checkpoint if available.") p.add_argument("--random", action="store_true", help="Skip checkpoint load; use random weights.") p.add_argument("--device", type=str, default=None, help="Torch device (default: cuda if available else cpu).") return p.parse_args(argv) def main(argv: list[str] | None = None) -> int: args = _parse_args(argv) if args.device: device = torch.device(args.device) elif torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") print("[chat] [WARN] CUDA not available; HYDRA's HTM/Mamba kernels may fail on CPU.", file=sys.stderr) ckpt_path: Path | None if args.random: ckpt_path = None else: ckpt_path = resolve_checkpoint(args.ckpt, args.sft) t0 = time.time() model, tokenizer, meta = load_model_and_tokenizer(ckpt_path, device) dt = time.time() - t0 print(f"[chat] Model ready in {dt:.1f}s on {device}") from prepare import MAX_SEQ_LEN run_repl(model, tokenizer, meta, device=device, max_seq_len=MAX_SEQ_LEN) return 0 if __name__ == "__main__": sys.exit(main())