Spaces:
Runtime error
Runtime error
| """Interactive chat REPL for HYDRA. | |
| Usage: | |
| python scripts/chat.py # auto-select best checkpoint | |
| python scripts/chat.py --ckpt PATH # explicit checkpoint | |
| python scripts/chat.py --sft # prefer sft_final.pt | |
| python scripts/chat.py --random # skip ckpt, use random weights | |
| HONESTY: model is ~7.5M params at d_model=256/n_layer=4. Expect incoherent | |
| output. This REPL validates the *interface* — tokenizer roundtrip, generation | |
| loop, stop-token handling, conversation history truncation. Coherent dialogue | |
| is not a goal at this scale. | |
| Slash commands: | |
| /reset clear conversation history | |
| /quit exit | |
| /temp X set temperature (default 0.8) | |
| /topk K set top-k (default 40) | |
| /topp P set top-p (default 0.9) | |
| /max N set max new tokens per turn (default 200) | |
| /rep R set repetition penalty (default 1.1) | |
| /sys S set a system prefix prepended to every turn | |
| /info print current settings + checkpoint path | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import os | |
| import sys | |
| import time | |
| from dataclasses import asdict | |
| from pathlib import Path | |
| # Make repo root importable when invoked as `python scripts/chat.py`. | |
| _REPO_ROOT = Path(__file__).resolve().parent.parent | |
| if str(_REPO_ROOT) not in sys.path: | |
| sys.path.insert(0, str(_REPO_ROOT)) | |
| import torch # noqa: E402 | |
| # Chat template — plain-text fallback (see .omc/chat_plan.md). | |
| # If the SFT agent later reserves special tokens, redefine USER_TAG / | |
| # ASSISTANT_TAG / END_TAG and the stop-string accordingly. | |
| USER_TAG = "User:" | |
| ASSISTANT_TAG = "Assistant:" | |
| END_TAG = "\nUser:" # stop-string matched on decoded output | |
| CKPT_DIR = Path(os.path.expanduser("~/.cache/autoresearch/ckpts")) | |
| CKPT_CANDIDATES_PRETRAIN = ["pretrain_final.pt", "latest.pt"] | |
| CKPT_CANDIDATES_SFT = ["sft_final.pt"] | |
| # --------------------------------------------------------------------------- | |
| # Checkpoint resolution | |
| # --------------------------------------------------------------------------- | |
| def resolve_checkpoint(explicit: str | None, prefer_sft: bool) -> Path | None: | |
| """Return Path to checkpoint file, or None if nothing found. | |
| Order: | |
| 1. `explicit` if provided and exists. | |
| 2. If prefer_sft: sft_final.pt -> pretrain_final.pt -> latest.pt. | |
| 3. Else: sft_final.pt (if exists) -> pretrain_final.pt -> latest.pt. | |
| """ | |
| if explicit: | |
| p = Path(os.path.expanduser(explicit)) | |
| if p.exists(): | |
| return p | |
| print(f"[WARN] --ckpt {p} does not exist; falling through to auto-select.", file=sys.stderr) | |
| # Task spec: prefer sft_final.pt if it exists; otherwise pretrain_final.pt | |
| # then latest.pt. --sft just makes the preference explicit; it's already | |
| # the default behavior. We list SFT first in both orderings to honor the | |
| # spec, since the task description said "prefer sft if exists" by default. | |
| _ = prefer_sft # reserved for future "pretrain-only" vs "sft-only" modes | |
| order = CKPT_CANDIDATES_SFT + CKPT_CANDIDATES_PRETRAIN | |
| for name in order: | |
| cand = CKPT_DIR / name | |
| if cand.exists(): | |
| return cand | |
| return None | |
| # --------------------------------------------------------------------------- | |
| # Model + tokenizer loading | |
| # --------------------------------------------------------------------------- | |
| def load_model_and_tokenizer(ckpt_path: Path | None, device: torch.device): | |
| """Build model + tokenizer. If ckpt_path is None, random weights are used. | |
| Returns (model, tokenizer, meta) where meta is a dict with 'ckpt', | |
| 'step', 'val_bpb' etc. for /info display. | |
| """ | |
| from hydra.config import PostSemClawConfig | |
| from hydra.model import PostSemClawModel | |
| from prepare import Tokenizer | |
| tokenizer = Tokenizer.from_directory() | |
| vocab_size = tokenizer.get_vocab_size() | |
| print(f"[chat] Tokenizer loaded (vocab={vocab_size:,})") | |
| meta: dict = {"ckpt": str(ckpt_path) if ckpt_path else "<random>", "step": None, "val_bpb": None} | |
| # Build config. If checkpoint provides one, use it; else use env-var defaults. | |
| ckpt_state = None | |
| config_kwargs: dict = {} | |
| if ckpt_path is not None: | |
| print(f"[chat] Loading checkpoint: {ckpt_path}") | |
| ckpt_state = torch.load(ckpt_path, map_location=device, weights_only=False) | |
| cfg_dict = ckpt_state.get("config") | |
| if isinstance(cfg_dict, dict): | |
| # Filter to kwargs PostSemClawConfig actually accepts. | |
| allowed = set(PostSemClawConfig.__dataclass_fields__.keys()) | |
| config_kwargs = {k: v for k, v in cfg_dict.items() if k in allowed} | |
| meta["step"] = ckpt_state.get("step") | |
| meta["val_bpb"] = ckpt_state.get("val_bpb") or ckpt_state.get("bpb") | |
| # Env-var defaults are applied by PostSemClawConfig field defaults; but the | |
| # training run builds the config explicitly from hydra.config module-level | |
| # constants. We mirror that here so the random-weights path aligns with | |
| # what train.py would instantiate for the same env. | |
| if not config_kwargs: | |
| from hydra.config import ( # noqa: E402 | |
| D_MODEL, D_STATE, ENGRAM_KEY_DIM, ENGRAM_LAYER_IDX, | |
| ENGRAM_N_COLUMNS, EXPAND, HEADDIM, N_HEADS, N_LAYER, | |
| ) | |
| from prepare import MAX_SEQ_LEN # noqa: E402 | |
| config_kwargs = dict( | |
| sequence_len=MAX_SEQ_LEN, | |
| vocab_size=vocab_size, | |
| n_layer=N_LAYER, | |
| d_model=D_MODEL, | |
| d_state=D_STATE, | |
| headdim=HEADDIM, | |
| n_heads=N_HEADS, | |
| expand=EXPAND, | |
| engram_n_columns=ENGRAM_N_COLUMNS, | |
| engram_key_dim=ENGRAM_KEY_DIM, | |
| engram_layer_idx=ENGRAM_LAYER_IDX, | |
| ) | |
| # Build model on meta device then materialize — matches training.py path. | |
| with torch.device("meta"): | |
| model = PostSemClawModel(PostSemClawConfig(**config_kwargs)) | |
| model.to_empty(device=device) | |
| model.init_weights() | |
| if ckpt_state is not None and "model_state_dict" in ckpt_state: | |
| # strict=False: the model has non-parameter buffers (SDR retina loaded | |
| # from npz, HTM Rust-side state, engram EMA stats) that may not be in | |
| # the state_dict. missing/unexpected-key warnings are expected and OK. | |
| missing, unexpected = model.load_state_dict( | |
| ckpt_state["model_state_dict"], strict=False | |
| ) | |
| if missing: | |
| print(f"[chat] Note: {len(missing)} missing key(s) in state_dict (expected for HTM/SDR buffers).") | |
| if unexpected: | |
| print(f"[chat] Note: {len(unexpected)} unexpected key(s) in state_dict.") | |
| elif ckpt_path is None: | |
| print("[chat] [WARN] NO CHECKPOINT — using random weights. Output will be gibberish.", file=sys.stderr) | |
| model.eval() | |
| return model, tokenizer, meta | |
| # --------------------------------------------------------------------------- | |
| # Generation | |
| # --------------------------------------------------------------------------- | |
| def generate_stream( | |
| model, | |
| tokenizer, | |
| prompt_ids: list[int], | |
| *, | |
| max_new_tokens: int, | |
| temperature: float, | |
| top_k: int, | |
| top_p: float, | |
| repetition_penalty: float, | |
| stop_strings: tuple[str, ...], | |
| max_seq_len: int, | |
| device: torch.device, | |
| rep_window: int = 64, | |
| ): | |
| """Yield decoded-text chunks as tokens are generated. | |
| Truncates `prompt_ids` to the last `max_seq_len` tokens if needed. Stops | |
| early when any `stop_strings` substring appears in the newly-decoded | |
| continuation. | |
| """ | |
| from scripts.sample_utils import sample_token | |
| # Truncate prompt to window. | |
| if len(prompt_ids) > max_seq_len: | |
| prompt_ids = prompt_ids[-max_seq_len:] | |
| ctx = torch.tensor([prompt_ids], device=device, dtype=torch.long) | |
| generated: list[int] = [] | |
| # Track already-streamed byte length so we can detect when the decoded | |
| # string has grown (BPE tokens may decode to multi-char strings mid-merge). | |
| streamed_chars = 0 | |
| accumulated_text = "" | |
| autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16) | |
| for _ in range(max_new_tokens): | |
| with torch.no_grad(), autocast_ctx: | |
| out = model(ctx, targets=None) | |
| # out shape: (1, T, vocab) or (1, vocab) depending on path. | |
| if out.dim() == 3: | |
| last_logits = out[0, -1, :] | |
| else: | |
| last_logits = out[0] | |
| recent = generated[-rep_window:] if generated else None | |
| next_id = sample_token( | |
| last_logits, | |
| temperature=temperature, | |
| top_k=top_k, | |
| top_p=top_p, | |
| repetition_penalty=repetition_penalty, | |
| recent_tokens=recent, | |
| ) | |
| generated.append(next_id) | |
| # Decode everything so-far then diff — BPE decoding is not token-local, | |
| # so a per-token decode can drop bytes. | |
| new_text = tokenizer.decode(generated) | |
| delta = new_text[streamed_chars:] | |
| if delta: | |
| streamed_chars = len(new_text) | |
| accumulated_text = new_text | |
| yield delta | |
| # Stop-string check. | |
| hit_stop = any(s and s in accumulated_text for s in stop_strings) | |
| if hit_stop: | |
| break | |
| # Advance context. If we've filled the window, drop oldest token. | |
| ctx = torch.cat([ctx, torch.tensor([[next_id]], device=device, dtype=torch.long)], dim=1) | |
| if ctx.size(1) > max_seq_len: | |
| ctx = ctx[:, -max_seq_len:] | |
| # Final accumulated text is also returned for history tracking. | |
| return accumulated_text # noqa: B901 (generator return for history) | |
| def _consume_stream_with_print(stream_gen): | |
| """Iterate a generator, print each chunk, return the full text. | |
| Replacement for a naïve list(stream) since `generate_stream` is a generator | |
| that yields then returns the final text. | |
| """ | |
| collected = [] | |
| try: | |
| while True: | |
| chunk = next(stream_gen) | |
| collected.append(chunk) | |
| sys.stdout.write(chunk) | |
| sys.stdout.flush() | |
| except StopIteration as stop: | |
| # stop.value holds the return value of the generator. | |
| final = stop.value | |
| if final is not None: | |
| return final | |
| return "".join(collected) | |
| # --------------------------------------------------------------------------- | |
| # REPL | |
| # --------------------------------------------------------------------------- | |
| def build_prompt(system: str, history: list[tuple[str, str]], user_msg: str) -> str: | |
| """Assemble the text prompt fed to the tokenizer.""" | |
| parts: list[str] = [] | |
| if system: | |
| parts.append(system.rstrip() + "\n") | |
| for u, a in history: | |
| parts.append(f"{USER_TAG} {u}\n{ASSISTANT_TAG} {a}\n") | |
| parts.append(f"{USER_TAG} {user_msg}\n{ASSISTANT_TAG}") | |
| return "".join(parts) | |
| def run_repl( | |
| model, | |
| tokenizer, | |
| meta: dict, | |
| *, | |
| device: torch.device, | |
| max_seq_len: int, | |
| ) -> None: | |
| settings = { | |
| "temperature": float(os.environ.get("HYDRA_CHAT_TEMP", "0.8")), | |
| "top_k": int(os.environ.get("HYDRA_CHAT_TOPK", "40")), | |
| "top_p": float(os.environ.get("HYDRA_CHAT_TOPP", "0.9")), | |
| "max_new_tokens": int(os.environ.get("HYDRA_CHAT_MAX", "200")), | |
| "repetition_penalty": float(os.environ.get("HYDRA_CHAT_REP", "1.1")), | |
| "system": os.environ.get("HYDRA_CHAT_SYSTEM", ""), | |
| } | |
| history: list[tuple[str, str]] = [] | |
| print() | |
| print("=" * 60) | |
| print("HYDRA chat REPL") | |
| print(f" checkpoint: {meta['ckpt']}") | |
| if meta.get("step") is not None: | |
| print(f" step: {meta['step']}") | |
| if meta.get("val_bpb") is not None: | |
| print(f" val_bpb: {meta['val_bpb']}") | |
| print(" type /info for settings, /quit to exit") | |
| print("=" * 60) | |
| print() | |
| while True: | |
| try: | |
| line = input(f"{USER_TAG} ") | |
| except (EOFError, KeyboardInterrupt): | |
| print() | |
| return | |
| line = line.rstrip() | |
| if not line: | |
| continue | |
| if line.startswith("/"): | |
| cmd, *rest = line.split(maxsplit=1) | |
| arg = rest[0] if rest else "" | |
| if cmd == "/quit" or cmd == "/exit": | |
| return | |
| elif cmd == "/reset": | |
| history = [] | |
| print("[reset]") | |
| continue | |
| elif cmd == "/info": | |
| print(f"[info] ckpt={meta['ckpt']} settings={settings} history_turns={len(history)}") | |
| continue | |
| elif cmd == "/temp": | |
| try: | |
| settings["temperature"] = float(arg) | |
| print(f"[temp={settings['temperature']}]") | |
| except ValueError: | |
| print(f"[err] /temp needs a float, got {arg!r}") | |
| continue | |
| elif cmd == "/topk": | |
| try: | |
| settings["top_k"] = int(arg) | |
| print(f"[topk={settings['top_k']}]") | |
| except ValueError: | |
| print(f"[err] /topk needs an int, got {arg!r}") | |
| continue | |
| elif cmd == "/topp": | |
| try: | |
| settings["top_p"] = float(arg) | |
| print(f"[topp={settings['top_p']}]") | |
| except ValueError: | |
| print(f"[err] /topp needs a float, got {arg!r}") | |
| continue | |
| elif cmd == "/max": | |
| try: | |
| settings["max_new_tokens"] = int(arg) | |
| print(f"[max={settings['max_new_tokens']}]") | |
| except ValueError: | |
| print(f"[err] /max needs an int, got {arg!r}") | |
| continue | |
| elif cmd == "/rep": | |
| try: | |
| settings["repetition_penalty"] = float(arg) | |
| print(f"[rep={settings['repetition_penalty']}]") | |
| except ValueError: | |
| print(f"[err] /rep needs a float, got {arg!r}") | |
| continue | |
| elif cmd == "/sys": | |
| settings["system"] = arg | |
| print(f"[sys set, {len(arg)} chars]") | |
| continue | |
| else: | |
| print(f"[err] unknown command {cmd!r}. Try /info /reset /quit.") | |
| continue | |
| # Normal chat turn. | |
| prompt_text = build_prompt(settings["system"], history, line) | |
| prompt_ids = tokenizer.encode(prompt_text) | |
| sys.stdout.write(f"{ASSISTANT_TAG} ") | |
| sys.stdout.flush() | |
| stream = generate_stream( | |
| model, tokenizer, prompt_ids, | |
| max_new_tokens=settings["max_new_tokens"], | |
| temperature=settings["temperature"], | |
| top_k=settings["top_k"], | |
| top_p=settings["top_p"], | |
| repetition_penalty=settings["repetition_penalty"], | |
| stop_strings=(END_TAG,), | |
| max_seq_len=max_seq_len, | |
| device=device, | |
| ) | |
| response_text = _consume_stream_with_print(stream) | |
| if not response_text.endswith("\n"): | |
| sys.stdout.write("\n") | |
| sys.stdout.flush() | |
| # Strip trailing stop marker from the remembered history. | |
| clean = response_text | |
| if END_TAG in clean: | |
| clean = clean.split(END_TAG, 1)[0] | |
| clean = clean.strip() | |
| history.append((line, clean)) | |
| # --------------------------------------------------------------------------- | |
| # CLI | |
| # --------------------------------------------------------------------------- | |
| def _parse_args(argv: list[str] | None = None) -> argparse.Namespace: | |
| p = argparse.ArgumentParser(description="HYDRA chat REPL") | |
| p.add_argument("--ckpt", type=str, default=None, | |
| help="Path to checkpoint (.pt). If omitted, auto-select.") | |
| p.add_argument("--sft", action="store_true", | |
| help="Prefer an SFT checkpoint if available.") | |
| p.add_argument("--random", action="store_true", | |
| help="Skip checkpoint load; use random weights.") | |
| p.add_argument("--device", type=str, default=None, | |
| help="Torch device (default: cuda if available else cpu).") | |
| return p.parse_args(argv) | |
| def main(argv: list[str] | None = None) -> int: | |
| args = _parse_args(argv) | |
| if args.device: | |
| device = torch.device(args.device) | |
| elif torch.cuda.is_available(): | |
| device = torch.device("cuda") | |
| else: | |
| device = torch.device("cpu") | |
| print("[chat] [WARN] CUDA not available; HYDRA's HTM/Mamba kernels may fail on CPU.", file=sys.stderr) | |
| ckpt_path: Path | None | |
| if args.random: | |
| ckpt_path = None | |
| else: | |
| ckpt_path = resolve_checkpoint(args.ckpt, args.sft) | |
| t0 = time.time() | |
| model, tokenizer, meta = load_model_and_tokenizer(ckpt_path, device) | |
| dt = time.time() - t0 | |
| print(f"[chat] Model ready in {dt:.1f}s on {device}") | |
| from prepare import MAX_SEQ_LEN | |
| run_repl(model, tokenizer, meta, device=device, max_seq_len=MAX_SEQ_LEN) | |
| return 0 | |
| if __name__ == "__main__": | |
| sys.exit(main()) | |