icarus112's picture
Update Feather a10g-large training runtime image
e5cf7c3 verified
from __future__ import annotations
"""Interactive chat REPL for HYDRA.
Usage:
python scripts/chat.py # auto-select best checkpoint
python scripts/chat.py --ckpt PATH # explicit checkpoint
python scripts/chat.py --sft # prefer sft_final.pt
python scripts/chat.py --random # skip ckpt, use random weights
HONESTY: model is ~7.5M params at d_model=256/n_layer=4. Expect incoherent
output. This REPL validates the *interface* — tokenizer roundtrip, generation
loop, stop-token handling, conversation history truncation. Coherent dialogue
is not a goal at this scale.
Slash commands:
/reset clear conversation history
/quit exit
/temp X set temperature (default 0.8)
/topk K set top-k (default 40)
/topp P set top-p (default 0.9)
/max N set max new tokens per turn (default 200)
/rep R set repetition penalty (default 1.1)
/sys S set a system prefix prepended to every turn
/info print current settings + checkpoint path
"""
import argparse
import os
import sys
import time
from dataclasses import asdict
from pathlib import Path
# Make repo root importable when invoked as `python scripts/chat.py`.
_REPO_ROOT = Path(__file__).resolve().parent.parent
if str(_REPO_ROOT) not in sys.path:
sys.path.insert(0, str(_REPO_ROOT))
import torch # noqa: E402
# Chat template — plain-text fallback (see .omc/chat_plan.md).
# If the SFT agent later reserves special tokens, redefine USER_TAG /
# ASSISTANT_TAG / END_TAG and the stop-string accordingly.
USER_TAG = "User:"
ASSISTANT_TAG = "Assistant:"
END_TAG = "\nUser:" # stop-string matched on decoded output
CKPT_DIR = Path(os.path.expanduser("~/.cache/autoresearch/ckpts"))
CKPT_CANDIDATES_PRETRAIN = ["pretrain_final.pt", "latest.pt"]
CKPT_CANDIDATES_SFT = ["sft_final.pt"]
# ---------------------------------------------------------------------------
# Checkpoint resolution
# ---------------------------------------------------------------------------
def resolve_checkpoint(explicit: str | None, prefer_sft: bool) -> Path | None:
"""Return Path to checkpoint file, or None if nothing found.
Order:
1. `explicit` if provided and exists.
2. If prefer_sft: sft_final.pt -> pretrain_final.pt -> latest.pt.
3. Else: sft_final.pt (if exists) -> pretrain_final.pt -> latest.pt.
"""
if explicit:
p = Path(os.path.expanduser(explicit))
if p.exists():
return p
print(f"[WARN] --ckpt {p} does not exist; falling through to auto-select.", file=sys.stderr)
# Task spec: prefer sft_final.pt if it exists; otherwise pretrain_final.pt
# then latest.pt. --sft just makes the preference explicit; it's already
# the default behavior. We list SFT first in both orderings to honor the
# spec, since the task description said "prefer sft if exists" by default.
_ = prefer_sft # reserved for future "pretrain-only" vs "sft-only" modes
order = CKPT_CANDIDATES_SFT + CKPT_CANDIDATES_PRETRAIN
for name in order:
cand = CKPT_DIR / name
if cand.exists():
return cand
return None
# ---------------------------------------------------------------------------
# Model + tokenizer loading
# ---------------------------------------------------------------------------
def load_model_and_tokenizer(ckpt_path: Path | None, device: torch.device):
"""Build model + tokenizer. If ckpt_path is None, random weights are used.
Returns (model, tokenizer, meta) where meta is a dict with 'ckpt',
'step', 'val_bpb' etc. for /info display.
"""
from hydra.config import PostSemClawConfig
from hydra.model import PostSemClawModel
from prepare import Tokenizer
tokenizer = Tokenizer.from_directory()
vocab_size = tokenizer.get_vocab_size()
print(f"[chat] Tokenizer loaded (vocab={vocab_size:,})")
meta: dict = {"ckpt": str(ckpt_path) if ckpt_path else "<random>", "step": None, "val_bpb": None}
# Build config. If checkpoint provides one, use it; else use env-var defaults.
ckpt_state = None
config_kwargs: dict = {}
if ckpt_path is not None:
print(f"[chat] Loading checkpoint: {ckpt_path}")
ckpt_state = torch.load(ckpt_path, map_location=device, weights_only=False)
cfg_dict = ckpt_state.get("config")
if isinstance(cfg_dict, dict):
# Filter to kwargs PostSemClawConfig actually accepts.
allowed = set(PostSemClawConfig.__dataclass_fields__.keys())
config_kwargs = {k: v for k, v in cfg_dict.items() if k in allowed}
meta["step"] = ckpt_state.get("step")
meta["val_bpb"] = ckpt_state.get("val_bpb") or ckpt_state.get("bpb")
# Env-var defaults are applied by PostSemClawConfig field defaults; but the
# training run builds the config explicitly from hydra.config module-level
# constants. We mirror that here so the random-weights path aligns with
# what train.py would instantiate for the same env.
if not config_kwargs:
from hydra.config import ( # noqa: E402
D_MODEL, D_STATE, ENGRAM_KEY_DIM, ENGRAM_LAYER_IDX,
ENGRAM_N_COLUMNS, EXPAND, HEADDIM, N_HEADS, N_LAYER,
)
from prepare import MAX_SEQ_LEN # noqa: E402
config_kwargs = dict(
sequence_len=MAX_SEQ_LEN,
vocab_size=vocab_size,
n_layer=N_LAYER,
d_model=D_MODEL,
d_state=D_STATE,
headdim=HEADDIM,
n_heads=N_HEADS,
expand=EXPAND,
engram_n_columns=ENGRAM_N_COLUMNS,
engram_key_dim=ENGRAM_KEY_DIM,
engram_layer_idx=ENGRAM_LAYER_IDX,
)
# Build model on meta device then materialize — matches training.py path.
with torch.device("meta"):
model = PostSemClawModel(PostSemClawConfig(**config_kwargs))
model.to_empty(device=device)
model.init_weights()
if ckpt_state is not None and "model_state_dict" in ckpt_state:
# strict=False: the model has non-parameter buffers (SDR retina loaded
# from npz, HTM Rust-side state, engram EMA stats) that may not be in
# the state_dict. missing/unexpected-key warnings are expected and OK.
missing, unexpected = model.load_state_dict(
ckpt_state["model_state_dict"], strict=False
)
if missing:
print(f"[chat] Note: {len(missing)} missing key(s) in state_dict (expected for HTM/SDR buffers).")
if unexpected:
print(f"[chat] Note: {len(unexpected)} unexpected key(s) in state_dict.")
elif ckpt_path is None:
print("[chat] [WARN] NO CHECKPOINT — using random weights. Output will be gibberish.", file=sys.stderr)
model.eval()
return model, tokenizer, meta
# ---------------------------------------------------------------------------
# Generation
# ---------------------------------------------------------------------------
def generate_stream(
model,
tokenizer,
prompt_ids: list[int],
*,
max_new_tokens: int,
temperature: float,
top_k: int,
top_p: float,
repetition_penalty: float,
stop_strings: tuple[str, ...],
max_seq_len: int,
device: torch.device,
rep_window: int = 64,
):
"""Yield decoded-text chunks as tokens are generated.
Truncates `prompt_ids` to the last `max_seq_len` tokens if needed. Stops
early when any `stop_strings` substring appears in the newly-decoded
continuation.
"""
from scripts.sample_utils import sample_token
# Truncate prompt to window.
if len(prompt_ids) > max_seq_len:
prompt_ids = prompt_ids[-max_seq_len:]
ctx = torch.tensor([prompt_ids], device=device, dtype=torch.long)
generated: list[int] = []
# Track already-streamed byte length so we can detect when the decoded
# string has grown (BPE tokens may decode to multi-char strings mid-merge).
streamed_chars = 0
accumulated_text = ""
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
for _ in range(max_new_tokens):
with torch.no_grad(), autocast_ctx:
out = model(ctx, targets=None)
# out shape: (1, T, vocab) or (1, vocab) depending on path.
if out.dim() == 3:
last_logits = out[0, -1, :]
else:
last_logits = out[0]
recent = generated[-rep_window:] if generated else None
next_id = sample_token(
last_logits,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
recent_tokens=recent,
)
generated.append(next_id)
# Decode everything so-far then diff — BPE decoding is not token-local,
# so a per-token decode can drop bytes.
new_text = tokenizer.decode(generated)
delta = new_text[streamed_chars:]
if delta:
streamed_chars = len(new_text)
accumulated_text = new_text
yield delta
# Stop-string check.
hit_stop = any(s and s in accumulated_text for s in stop_strings)
if hit_stop:
break
# Advance context. If we've filled the window, drop oldest token.
ctx = torch.cat([ctx, torch.tensor([[next_id]], device=device, dtype=torch.long)], dim=1)
if ctx.size(1) > max_seq_len:
ctx = ctx[:, -max_seq_len:]
# Final accumulated text is also returned for history tracking.
return accumulated_text # noqa: B901 (generator return for history)
def _consume_stream_with_print(stream_gen):
"""Iterate a generator, print each chunk, return the full text.
Replacement for a naïve list(stream) since `generate_stream` is a generator
that yields then returns the final text.
"""
collected = []
try:
while True:
chunk = next(stream_gen)
collected.append(chunk)
sys.stdout.write(chunk)
sys.stdout.flush()
except StopIteration as stop:
# stop.value holds the return value of the generator.
final = stop.value
if final is not None:
return final
return "".join(collected)
# ---------------------------------------------------------------------------
# REPL
# ---------------------------------------------------------------------------
def build_prompt(system: str, history: list[tuple[str, str]], user_msg: str) -> str:
"""Assemble the text prompt fed to the tokenizer."""
parts: list[str] = []
if system:
parts.append(system.rstrip() + "\n")
for u, a in history:
parts.append(f"{USER_TAG} {u}\n{ASSISTANT_TAG} {a}\n")
parts.append(f"{USER_TAG} {user_msg}\n{ASSISTANT_TAG}")
return "".join(parts)
def run_repl(
model,
tokenizer,
meta: dict,
*,
device: torch.device,
max_seq_len: int,
) -> None:
settings = {
"temperature": float(os.environ.get("HYDRA_CHAT_TEMP", "0.8")),
"top_k": int(os.environ.get("HYDRA_CHAT_TOPK", "40")),
"top_p": float(os.environ.get("HYDRA_CHAT_TOPP", "0.9")),
"max_new_tokens": int(os.environ.get("HYDRA_CHAT_MAX", "200")),
"repetition_penalty": float(os.environ.get("HYDRA_CHAT_REP", "1.1")),
"system": os.environ.get("HYDRA_CHAT_SYSTEM", ""),
}
history: list[tuple[str, str]] = []
print()
print("=" * 60)
print("HYDRA chat REPL")
print(f" checkpoint: {meta['ckpt']}")
if meta.get("step") is not None:
print(f" step: {meta['step']}")
if meta.get("val_bpb") is not None:
print(f" val_bpb: {meta['val_bpb']}")
print(" type /info for settings, /quit to exit")
print("=" * 60)
print()
while True:
try:
line = input(f"{USER_TAG} ")
except (EOFError, KeyboardInterrupt):
print()
return
line = line.rstrip()
if not line:
continue
if line.startswith("/"):
cmd, *rest = line.split(maxsplit=1)
arg = rest[0] if rest else ""
if cmd == "/quit" or cmd == "/exit":
return
elif cmd == "/reset":
history = []
print("[reset]")
continue
elif cmd == "/info":
print(f"[info] ckpt={meta['ckpt']} settings={settings} history_turns={len(history)}")
continue
elif cmd == "/temp":
try:
settings["temperature"] = float(arg)
print(f"[temp={settings['temperature']}]")
except ValueError:
print(f"[err] /temp needs a float, got {arg!r}")
continue
elif cmd == "/topk":
try:
settings["top_k"] = int(arg)
print(f"[topk={settings['top_k']}]")
except ValueError:
print(f"[err] /topk needs an int, got {arg!r}")
continue
elif cmd == "/topp":
try:
settings["top_p"] = float(arg)
print(f"[topp={settings['top_p']}]")
except ValueError:
print(f"[err] /topp needs a float, got {arg!r}")
continue
elif cmd == "/max":
try:
settings["max_new_tokens"] = int(arg)
print(f"[max={settings['max_new_tokens']}]")
except ValueError:
print(f"[err] /max needs an int, got {arg!r}")
continue
elif cmd == "/rep":
try:
settings["repetition_penalty"] = float(arg)
print(f"[rep={settings['repetition_penalty']}]")
except ValueError:
print(f"[err] /rep needs a float, got {arg!r}")
continue
elif cmd == "/sys":
settings["system"] = arg
print(f"[sys set, {len(arg)} chars]")
continue
else:
print(f"[err] unknown command {cmd!r}. Try /info /reset /quit.")
continue
# Normal chat turn.
prompt_text = build_prompt(settings["system"], history, line)
prompt_ids = tokenizer.encode(prompt_text)
sys.stdout.write(f"{ASSISTANT_TAG} ")
sys.stdout.flush()
stream = generate_stream(
model, tokenizer, prompt_ids,
max_new_tokens=settings["max_new_tokens"],
temperature=settings["temperature"],
top_k=settings["top_k"],
top_p=settings["top_p"],
repetition_penalty=settings["repetition_penalty"],
stop_strings=(END_TAG,),
max_seq_len=max_seq_len,
device=device,
)
response_text = _consume_stream_with_print(stream)
if not response_text.endswith("\n"):
sys.stdout.write("\n")
sys.stdout.flush()
# Strip trailing stop marker from the remembered history.
clean = response_text
if END_TAG in clean:
clean = clean.split(END_TAG, 1)[0]
clean = clean.strip()
history.append((line, clean))
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def _parse_args(argv: list[str] | None = None) -> argparse.Namespace:
p = argparse.ArgumentParser(description="HYDRA chat REPL")
p.add_argument("--ckpt", type=str, default=None,
help="Path to checkpoint (.pt). If omitted, auto-select.")
p.add_argument("--sft", action="store_true",
help="Prefer an SFT checkpoint if available.")
p.add_argument("--random", action="store_true",
help="Skip checkpoint load; use random weights.")
p.add_argument("--device", type=str, default=None,
help="Torch device (default: cuda if available else cpu).")
return p.parse_args(argv)
def main(argv: list[str] | None = None) -> int:
args = _parse_args(argv)
if args.device:
device = torch.device(args.device)
elif torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
print("[chat] [WARN] CUDA not available; HYDRA's HTM/Mamba kernels may fail on CPU.", file=sys.stderr)
ckpt_path: Path | None
if args.random:
ckpt_path = None
else:
ckpt_path = resolve_checkpoint(args.ckpt, args.sft)
t0 = time.time()
model, tokenizer, meta = load_model_and_tokenizer(ckpt_path, device)
dt = time.time() - t0
print(f"[chat] Model ready in {dt:.1f}s on {device}")
from prepare import MAX_SEQ_LEN
run_repl(model, tokenizer, meta, device=device, max_seq_len=MAX_SEQ_LEN)
return 0
if __name__ == "__main__":
sys.exit(main())