Mnemo / chat_cli.py
ecreeth's picture
Upload chat_cli.py
1d176f6 verified
Raw
History Blame Contribute Delete
11.7 kB
"""Chat with the SFT'd microgpt model. Adapted from karpathy/nanochat's chat_cli.py.
Builds chat-format prompts using our 5 special tokens:
<|bos|> <|user_start|> ...user text... <|user_end|>
<|assistant_start|> ...assistant text... <|assistant_end|>
Conversation history is maintained across turns (each new user message is appended
to the running token stream). Generation stops on <|assistant_end|>.
Usage:
python3 chat_cli.py # interactive REPL
python3 chat_cli.py -p "Who are you?" # one-shot
python3 chat_cli.py --ckpt model.pt # explicit checkpoint path
python3 chat_cli.py -t 0.6 -k 50 # nanochat-CLI defaults
python3 chat_cli.py --no-history # reset per turn (no memory)
Defaults follow nanochat-CLI (T=0.6, top-k=50). Lower temperature than infer.py's
default 0.8 because chat is meant to be focused, not exploratory.
REPL commands:
quit / exit end the session
clear reset conversation history
"""
import argparse
import os
import sys
import time
import torch
from tokenizers import Tokenizer
from model import GPT
from infer import _sample_token_batch # reuse our sampler (rep penalty + nucleus + top-k)
DEFAULT_CKPT_DIR = os.environ.get('CHECKPOINT_DIR', '.')
# Prefer model_sft.pt (local dev convention with both base + SFT present), fall
# back to model.pt (HF convention — only the SFT'd model is uploaded as model.pt).
DEFAULT_CKPT = os.path.join(DEFAULT_CKPT_DIR, 'model_sft.pt')
DEFAULT_TOKENIZER = 'tokenizer.json'
SPECIAL_NAMES = [
"<|bos|>",
"<|user_start|>", "<|user_end|>",
"<|assistant_start|>", "<|assistant_end|>",
]
device = 'cuda' if torch.cuda.is_available() else ('mps' if torch.backends.mps.is_available() else 'cpu')
def _load(ckpt_path, tokenizer_path, compile_step=False):
"""Resolve checkpoint (with fallback), load model + tokenizer + special token IDs.
If `compile_step=True`, torch.compile the inner-loop `model.step` (B=1, T=1 fixed
shapes — perfect for compile). We deliberately don't compile `forward_with_states`
because chat conversations grow each turn → dynamic prompt length → recompile
every turn. Prefill stays eager.
Compile warmup takes 10–30s on first generation; pays off for any conversation
long enough to generate >~50 tokens. On MPS, torch.compile is less mature than
CUDA — try it but don't be surprised if it falls back to eager.
"""
if not os.path.exists(ckpt_path):
# Silent fallback for HF layout (only model.pt = the SFT'd one)
fallback = os.path.join(DEFAULT_CKPT_DIR, 'model.pt')
if ckpt_path != fallback and os.path.exists(fallback):
ckpt_path = fallback
else:
sys.exit(f"error: no checkpoint at {ckpt_path}")
if not os.path.exists(tokenizer_path):
sys.exit(f"error: no tokenizer at {tokenizer_path}")
tokenizer = Tokenizer.from_file(tokenizer_path)
vocab_size = tokenizer.get_vocab_size()
ckpt = torch.load(ckpt_path, map_location=device)
config = dict(ckpt['config'])
config['vocab_size'] = ((vocab_size + 63) // 64) * 64
model = GPT.from_config(config).to(device)
state = {k.removeprefix('_orig_mod.'): v for k, v in ckpt['model'].items()}
missing, unexpected = model.load_state_dict(state, strict=False)
if missing: print(f"warn: missing keys: {missing}", file=sys.stderr)
if unexpected: print(f"warn: unexpected keys: {unexpected}", file=sys.stderr)
model.eval()
specials = {}
for name in SPECIAL_NAMES:
tid = tokenizer.token_to_id(name)
if tid is None:
sys.exit(f"error: tokenizer is missing special token {name} — was it retrained without the SFT vocab?")
specials[name] = tid
if compile_step:
print("compiling model.step (hot inner-loop path)...", file=sys.stderr)
if device == 'mps':
print(" note: torch.compile on MPS may fall back to eager. CUDA gets the full speedup.",
file=sys.stderr)
compile_kwargs = {'dynamic': False}
# reduce-overhead enables CUDA graphs — big win on Ampere+ where kernel-launch
# overhead dominates the small per-token forward. No-op on MPS/CPU.
if device == 'cuda' and torch.cuda.get_device_capability() >= (8, 0):
compile_kwargs['mode'] = 'reduce-overhead'
model.step = torch.compile(model.step, **compile_kwargs)
return model, tokenizer, specials, ckpt, ckpt_path
@torch.no_grad()
def _generate_one_turn(model, tokenizer, conversation_tokens, specials,
max_new_tokens, temperature, top_k, top_p, repetition_penalty,
stream):
"""Run prefill over the full conversation, then sample tokens until
<|assistant_end|> (or <|bos|>, or max_new_tokens). Returns the response token list."""
bos = specials["<|bos|>"]
asst_end = specials["<|assistant_end|>"]
ctx = torch.tensor([conversation_tokens], device=device) # (1, T)
V = model.lm_head.weight.size(0) # padded vocab
seen_mask = None
if repetition_penalty != 1.0:
seen_mask = torch.zeros(1, V, dtype=torch.bool, device=device)
seen_mask.scatter_(1, ctx, True)
# Prefill the full conversation in one pass
logits, states = model.forward_with_states(ctx)
next_logits = logits[:, -1, :].clone()
response = []
for _ in range(max_new_tokens):
next_tok = _sample_token_batch(next_logits, temperature, top_k, top_p,
repetition_penalty, seen_mask) # (1,)
tid = next_tok.item()
if tid == asst_end or tid == bos:
break
response.append(tid)
if stream:
# Per-token decode is fine for ASCII English (the ClimbMix corpus dominates that).
piece = tokenizer.decode([tid])
sys.stdout.write(piece)
sys.stdout.flush()
if seen_mask is not None:
seen_mask.scatter_(1, next_tok.unsqueeze(1), True)
step_logits, states = model.step(next_tok.view(1, 1), states)
next_logits = step_logits[:, 0, :]
if stream:
print()
return response
def main():
p = argparse.ArgumentParser(description=__doc__.splitlines()[0])
p.add_argument('-p', '--prompt', type=str, default='',
help='single-turn prompt — exits after one response. Empty = interactive REPL.')
p.add_argument('-t', '--temperature', type=float, default=0.6,
help='softmax temperature; 0 = greedy. nanochat-CLI default 0.6 — '
'tighter than infer.py because chat should be focused.')
p.add_argument('-k', '--top-k', type=int, default=50,
help='top-k sampling; 0 disables. Default 50 (nanochat-CLI).')
p.add_argument('--top-p', type=float, default=1.0,
help='nucleus sampling threshold; 1.0 disables. Try 0.9 for varied responses.')
p.add_argument('-r', '--repetition-penalty', type=float, default=1.15,
help='CTRL-style repetition penalty (default 1.15) — keeps chat responses '
'from looping. Set to 1.0 for raw sampling.')
p.add_argument('-n', '--max-tokens', type=int, default=256,
help='max tokens per assistant response')
p.add_argument('--ckpt', type=str, default=DEFAULT_CKPT,
help='checkpoint path (default $CHECKPOINT_DIR/model.pt)')
p.add_argument('--tokenizer', type=str, default=DEFAULT_TOKENIZER)
p.add_argument('--no-history', action='store_true',
help='reset conversation history before each turn (model has no memory)')
p.add_argument('--seed', type=int, default=None)
p.add_argument('--no-stream', action='store_true',
help='print full response at end instead of token-by-token')
p.add_argument('--compile', action='store_true',
help='torch.compile the inner step() path. First generation pays '
'~10–30s warmup; subsequent generations are 2–5× faster on CUDA. '
'Best for long REPL sessions; skip for one-shots.')
args = p.parse_args()
if args.seed is not None:
torch.manual_seed(args.seed)
print(f"device: {device}", file=sys.stderr)
model, tokenizer, specials, ckpt, used_ckpt = _load(args.ckpt, args.tokenizer,
compile_step=args.compile)
step = ckpt.get('step', '?')
best = ckpt.get('best_loss')
n_params = sum(t.numel() for t in model.parameters())
best_str = f" best_loss={best:.4f}" if isinstance(best, float) else ""
print(f"loaded {used_ckpt} step={step}{best_str} params={n_params:,}", file=sys.stderr)
bos = specials["<|bos|>"]
user_s = specials["<|user_start|>"]
user_e = specials["<|user_end|>"]
asst_s = specials["<|assistant_start|>"]
asst_e = specials["<|assistant_end|>"]
print()
print("Mnemo — chat mode")
print("-" * 50)
print(f"sampling: T={args.temperature} top_k={args.top_k} top_p={args.top_p} rep_penalty={args.repetition_penalty}")
if not args.prompt:
print("commands: 'quit' / 'exit' to end, 'clear' to reset history")
print("-" * 50)
conversation_tokens = [bos]
while True:
if args.prompt:
user_input = args.prompt
else:
try:
user_input = input("\nUser: ").strip()
except (EOFError, KeyboardInterrupt):
print("\nGoodbye!")
break
if user_input.lower() in ('quit', 'exit'):
print("Goodbye!")
break
if user_input.lower() == 'clear':
conversation_tokens = [bos]
print("Conversation cleared.")
continue
if not user_input:
continue
if args.no_history:
conversation_tokens = [bos]
# Append user turn
conversation_tokens.append(user_s)
conversation_tokens.extend(tokenizer.encode(user_input).ids)
conversation_tokens.append(user_e)
# Open assistant turn (the model continues from here)
conversation_tokens.append(asst_s)
if not args.no_stream:
sys.stdout.write("\nAssistant: ")
sys.stdout.flush()
t0 = time.time()
response = _generate_one_turn(
model, tokenizer, conversation_tokens, specials,
max_new_tokens=args.max_tokens,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
repetition_penalty=args.repetition_penalty,
stream=not args.no_stream,
)
elapsed = time.time() - t0
if args.no_stream:
print(f"\nAssistant: {tokenizer.decode(response)}")
n_resp = len(response)
print(f" [{n_resp} tok in {elapsed:.1f}s = {n_resp/max(elapsed, 1e-9):.1f} tok/s]",
file=sys.stderr)
# Close assistant turn in the history (so the next prefill sees a complete turn)
conversation_tokens.extend(response)
conversation_tokens.append(asst_e)
if args.prompt:
break
if __name__ == '__main__':
main()