"""Chat with the SFT'd microgpt model. Adapted from karpathy/nanochat's chat_cli.py. Builds chat-format prompts using our 5 special tokens: <|bos|> <|user_start|> ...user text... <|user_end|> <|assistant_start|> ...assistant text... <|assistant_end|> Conversation history is maintained across turns (each new user message is appended to the running token stream). Generation stops on <|assistant_end|>. Usage: python3 chat_cli.py # interactive REPL python3 chat_cli.py -p "Who are you?" # one-shot python3 chat_cli.py --ckpt model.pt # explicit checkpoint path python3 chat_cli.py -t 0.6 -k 50 # nanochat-CLI defaults python3 chat_cli.py --no-history # reset per turn (no memory) Defaults follow nanochat-CLI (T=0.6, top-k=50). Lower temperature than infer.py's default 0.8 because chat is meant to be focused, not exploratory. REPL commands: quit / exit end the session clear reset conversation history """ import argparse import os import sys import time import torch from tokenizers import Tokenizer from model import GPT from infer import _sample_token_batch # reuse our sampler (rep penalty + nucleus + top-k) DEFAULT_CKPT_DIR = os.environ.get('CHECKPOINT_DIR', '.') # Prefer model_sft.pt (local dev convention with both base + SFT present), fall # back to model.pt (HF convention — only the SFT'd model is uploaded as model.pt). DEFAULT_CKPT = os.path.join(DEFAULT_CKPT_DIR, 'model_sft.pt') DEFAULT_TOKENIZER = 'tokenizer.json' SPECIAL_NAMES = [ "<|bos|>", "<|user_start|>", "<|user_end|>", "<|assistant_start|>", "<|assistant_end|>", ] device = 'cuda' if torch.cuda.is_available() else ('mps' if torch.backends.mps.is_available() else 'cpu') def _load(ckpt_path, tokenizer_path, compile_step=False): """Resolve checkpoint (with fallback), load model + tokenizer + special token IDs. If `compile_step=True`, torch.compile the inner-loop `model.step` (B=1, T=1 fixed shapes — perfect for compile). We deliberately don't compile `forward_with_states` because chat conversations grow each turn → dynamic prompt length → recompile every turn. Prefill stays eager. Compile warmup takes 10–30s on first generation; pays off for any conversation long enough to generate >~50 tokens. On MPS, torch.compile is less mature than CUDA — try it but don't be surprised if it falls back to eager. """ if not os.path.exists(ckpt_path): # Silent fallback for HF layout (only model.pt = the SFT'd one) fallback = os.path.join(DEFAULT_CKPT_DIR, 'model.pt') if ckpt_path != fallback and os.path.exists(fallback): ckpt_path = fallback else: sys.exit(f"error: no checkpoint at {ckpt_path}") if not os.path.exists(tokenizer_path): sys.exit(f"error: no tokenizer at {tokenizer_path}") tokenizer = Tokenizer.from_file(tokenizer_path) vocab_size = tokenizer.get_vocab_size() ckpt = torch.load(ckpt_path, map_location=device) config = dict(ckpt['config']) config['vocab_size'] = ((vocab_size + 63) // 64) * 64 model = GPT.from_config(config).to(device) state = {k.removeprefix('_orig_mod.'): v for k, v in ckpt['model'].items()} missing, unexpected = model.load_state_dict(state, strict=False) if missing: print(f"warn: missing keys: {missing}", file=sys.stderr) if unexpected: print(f"warn: unexpected keys: {unexpected}", file=sys.stderr) model.eval() specials = {} for name in SPECIAL_NAMES: tid = tokenizer.token_to_id(name) if tid is None: sys.exit(f"error: tokenizer is missing special token {name} — was it retrained without the SFT vocab?") specials[name] = tid if compile_step: print("compiling model.step (hot inner-loop path)...", file=sys.stderr) if device == 'mps': print(" note: torch.compile on MPS may fall back to eager. CUDA gets the full speedup.", file=sys.stderr) compile_kwargs = {'dynamic': False} # reduce-overhead enables CUDA graphs — big win on Ampere+ where kernel-launch # overhead dominates the small per-token forward. No-op on MPS/CPU. if device == 'cuda' and torch.cuda.get_device_capability() >= (8, 0): compile_kwargs['mode'] = 'reduce-overhead' model.step = torch.compile(model.step, **compile_kwargs) return model, tokenizer, specials, ckpt, ckpt_path @torch.no_grad() def _generate_one_turn(model, tokenizer, conversation_tokens, specials, max_new_tokens, temperature, top_k, top_p, repetition_penalty, stream): """Run prefill over the full conversation, then sample tokens until <|assistant_end|> (or <|bos|>, or max_new_tokens). Returns the response token list.""" bos = specials["<|bos|>"] asst_end = specials["<|assistant_end|>"] ctx = torch.tensor([conversation_tokens], device=device) # (1, T) V = model.lm_head.weight.size(0) # padded vocab seen_mask = None if repetition_penalty != 1.0: seen_mask = torch.zeros(1, V, dtype=torch.bool, device=device) seen_mask.scatter_(1, ctx, True) # Prefill the full conversation in one pass logits, states = model.forward_with_states(ctx) next_logits = logits[:, -1, :].clone() response = [] for _ in range(max_new_tokens): next_tok = _sample_token_batch(next_logits, temperature, top_k, top_p, repetition_penalty, seen_mask) # (1,) tid = next_tok.item() if tid == asst_end or tid == bos: break response.append(tid) if stream: # Per-token decode is fine for ASCII English (the ClimbMix corpus dominates that). piece = tokenizer.decode([tid]) sys.stdout.write(piece) sys.stdout.flush() if seen_mask is not None: seen_mask.scatter_(1, next_tok.unsqueeze(1), True) step_logits, states = model.step(next_tok.view(1, 1), states) next_logits = step_logits[:, 0, :] if stream: print() return response def main(): p = argparse.ArgumentParser(description=__doc__.splitlines()[0]) p.add_argument('-p', '--prompt', type=str, default='', help='single-turn prompt — exits after one response. Empty = interactive REPL.') p.add_argument('-t', '--temperature', type=float, default=0.6, help='softmax temperature; 0 = greedy. nanochat-CLI default 0.6 — ' 'tighter than infer.py because chat should be focused.') p.add_argument('-k', '--top-k', type=int, default=50, help='top-k sampling; 0 disables. Default 50 (nanochat-CLI).') p.add_argument('--top-p', type=float, default=1.0, help='nucleus sampling threshold; 1.0 disables. Try 0.9 for varied responses.') p.add_argument('-r', '--repetition-penalty', type=float, default=1.15, help='CTRL-style repetition penalty (default 1.15) — keeps chat responses ' 'from looping. Set to 1.0 for raw sampling.') p.add_argument('-n', '--max-tokens', type=int, default=256, help='max tokens per assistant response') p.add_argument('--ckpt', type=str, default=DEFAULT_CKPT, help='checkpoint path (default $CHECKPOINT_DIR/model.pt)') p.add_argument('--tokenizer', type=str, default=DEFAULT_TOKENIZER) p.add_argument('--no-history', action='store_true', help='reset conversation history before each turn (model has no memory)') p.add_argument('--seed', type=int, default=None) p.add_argument('--no-stream', action='store_true', help='print full response at end instead of token-by-token') p.add_argument('--compile', action='store_true', help='torch.compile the inner step() path. First generation pays ' '~10–30s warmup; subsequent generations are 2–5× faster on CUDA. ' 'Best for long REPL sessions; skip for one-shots.') args = p.parse_args() if args.seed is not None: torch.manual_seed(args.seed) print(f"device: {device}", file=sys.stderr) model, tokenizer, specials, ckpt, used_ckpt = _load(args.ckpt, args.tokenizer, compile_step=args.compile) step = ckpt.get('step', '?') best = ckpt.get('best_loss') n_params = sum(t.numel() for t in model.parameters()) best_str = f" best_loss={best:.4f}" if isinstance(best, float) else "" print(f"loaded {used_ckpt} step={step}{best_str} params={n_params:,}", file=sys.stderr) bos = specials["<|bos|>"] user_s = specials["<|user_start|>"] user_e = specials["<|user_end|>"] asst_s = specials["<|assistant_start|>"] asst_e = specials["<|assistant_end|>"] print() print("Mnemo — chat mode") print("-" * 50) print(f"sampling: T={args.temperature} top_k={args.top_k} top_p={args.top_p} rep_penalty={args.repetition_penalty}") if not args.prompt: print("commands: 'quit' / 'exit' to end, 'clear' to reset history") print("-" * 50) conversation_tokens = [bos] while True: if args.prompt: user_input = args.prompt else: try: user_input = input("\nUser: ").strip() except (EOFError, KeyboardInterrupt): print("\nGoodbye!") break if user_input.lower() in ('quit', 'exit'): print("Goodbye!") break if user_input.lower() == 'clear': conversation_tokens = [bos] print("Conversation cleared.") continue if not user_input: continue if args.no_history: conversation_tokens = [bos] # Append user turn conversation_tokens.append(user_s) conversation_tokens.extend(tokenizer.encode(user_input).ids) conversation_tokens.append(user_e) # Open assistant turn (the model continues from here) conversation_tokens.append(asst_s) if not args.no_stream: sys.stdout.write("\nAssistant: ") sys.stdout.flush() t0 = time.time() response = _generate_one_turn( model, tokenizer, conversation_tokens, specials, max_new_tokens=args.max_tokens, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, repetition_penalty=args.repetition_penalty, stream=not args.no_stream, ) elapsed = time.time() - t0 if args.no_stream: print(f"\nAssistant: {tokenizer.decode(response)}") n_resp = len(response) print(f" [{n_resp} tok in {elapsed:.1f}s = {n_resp/max(elapsed, 1e-9):.1f} tok/s]", file=sys.stderr) # Close assistant turn in the history (so the next prefill sees a complete turn) conversation_tokens.extend(response) conversation_tokens.append(asst_e) if args.prompt: break if __name__ == '__main__': main()