| """Chat with the SFT'd microgpt model. Adapted from karpathy/nanochat's chat_cli.py.
|
|
|
| Builds chat-format prompts using our 5 special tokens:
|
| <|bos|> <|user_start|> ...user text... <|user_end|>
|
| <|assistant_start|> ...assistant text... <|assistant_end|>
|
|
|
| Conversation history is maintained across turns (each new user message is appended
|
| to the running token stream). Generation stops on <|assistant_end|>.
|
|
|
| Usage:
|
| python3 chat_cli.py # interactive REPL
|
| python3 chat_cli.py -p "Who are you?" # one-shot
|
| python3 chat_cli.py --ckpt model.pt # explicit checkpoint path
|
| python3 chat_cli.py -t 0.6 -k 50 # nanochat-CLI defaults
|
| python3 chat_cli.py --no-history # reset per turn (no memory)
|
|
|
| Defaults follow nanochat-CLI (T=0.6, top-k=50). Lower temperature than infer.py's
|
| default 0.8 because chat is meant to be focused, not exploratory.
|
|
|
| REPL commands:
|
| quit / exit end the session
|
| clear reset conversation history
|
| """
|
| import argparse
|
| import os
|
| import sys
|
| import time
|
|
|
| import torch
|
| from tokenizers import Tokenizer
|
|
|
| from model import GPT
|
| from infer import _sample_token_batch
|
|
|
|
|
| DEFAULT_CKPT_DIR = os.environ.get('CHECKPOINT_DIR', '.')
|
|
|
|
|
| DEFAULT_CKPT = os.path.join(DEFAULT_CKPT_DIR, 'model_sft.pt')
|
| DEFAULT_TOKENIZER = 'tokenizer.json'
|
|
|
| SPECIAL_NAMES = [
|
| "<|bos|>",
|
| "<|user_start|>", "<|user_end|>",
|
| "<|assistant_start|>", "<|assistant_end|>",
|
| ]
|
|
|
| device = 'cuda' if torch.cuda.is_available() else ('mps' if torch.backends.mps.is_available() else 'cpu')
|
|
|
|
|
| def _load(ckpt_path, tokenizer_path, compile_step=False):
|
| """Resolve checkpoint (with fallback), load model + tokenizer + special token IDs.
|
|
|
| If `compile_step=True`, torch.compile the inner-loop `model.step` (B=1, T=1 fixed
|
| shapes — perfect for compile). We deliberately don't compile `forward_with_states`
|
| because chat conversations grow each turn → dynamic prompt length → recompile
|
| every turn. Prefill stays eager.
|
|
|
| Compile warmup takes 10–30s on first generation; pays off for any conversation
|
| long enough to generate >~50 tokens. On MPS, torch.compile is less mature than
|
| CUDA — try it but don't be surprised if it falls back to eager.
|
| """
|
| if not os.path.exists(ckpt_path):
|
|
|
| fallback = os.path.join(DEFAULT_CKPT_DIR, 'model.pt')
|
| if ckpt_path != fallback and os.path.exists(fallback):
|
| ckpt_path = fallback
|
| else:
|
| sys.exit(f"error: no checkpoint at {ckpt_path}")
|
| if not os.path.exists(tokenizer_path):
|
| sys.exit(f"error: no tokenizer at {tokenizer_path}")
|
|
|
| tokenizer = Tokenizer.from_file(tokenizer_path)
|
| vocab_size = tokenizer.get_vocab_size()
|
|
|
| ckpt = torch.load(ckpt_path, map_location=device)
|
| config = dict(ckpt['config'])
|
| config['vocab_size'] = ((vocab_size + 63) // 64) * 64
|
| model = GPT.from_config(config).to(device)
|
| state = {k.removeprefix('_orig_mod.'): v for k, v in ckpt['model'].items()}
|
| missing, unexpected = model.load_state_dict(state, strict=False)
|
| if missing: print(f"warn: missing keys: {missing}", file=sys.stderr)
|
| if unexpected: print(f"warn: unexpected keys: {unexpected}", file=sys.stderr)
|
| model.eval()
|
|
|
| specials = {}
|
| for name in SPECIAL_NAMES:
|
| tid = tokenizer.token_to_id(name)
|
| if tid is None:
|
| sys.exit(f"error: tokenizer is missing special token {name} — was it retrained without the SFT vocab?")
|
| specials[name] = tid
|
|
|
| if compile_step:
|
| print("compiling model.step (hot inner-loop path)...", file=sys.stderr)
|
| if device == 'mps':
|
| print(" note: torch.compile on MPS may fall back to eager. CUDA gets the full speedup.",
|
| file=sys.stderr)
|
| compile_kwargs = {'dynamic': False}
|
|
|
|
|
| if device == 'cuda' and torch.cuda.get_device_capability() >= (8, 0):
|
| compile_kwargs['mode'] = 'reduce-overhead'
|
| model.step = torch.compile(model.step, **compile_kwargs)
|
|
|
| return model, tokenizer, specials, ckpt, ckpt_path
|
|
|
|
|
| @torch.no_grad()
|
| def _generate_one_turn(model, tokenizer, conversation_tokens, specials,
|
| max_new_tokens, temperature, top_k, top_p, repetition_penalty,
|
| stream):
|
| """Run prefill over the full conversation, then sample tokens until
|
| <|assistant_end|> (or <|bos|>, or max_new_tokens). Returns the response token list."""
|
| bos = specials["<|bos|>"]
|
| asst_end = specials["<|assistant_end|>"]
|
|
|
| ctx = torch.tensor([conversation_tokens], device=device)
|
| V = model.lm_head.weight.size(0)
|
|
|
| seen_mask = None
|
| if repetition_penalty != 1.0:
|
| seen_mask = torch.zeros(1, V, dtype=torch.bool, device=device)
|
| seen_mask.scatter_(1, ctx, True)
|
|
|
|
|
| logits, states = model.forward_with_states(ctx)
|
| next_logits = logits[:, -1, :].clone()
|
|
|
| response = []
|
| for _ in range(max_new_tokens):
|
| next_tok = _sample_token_batch(next_logits, temperature, top_k, top_p,
|
| repetition_penalty, seen_mask)
|
| tid = next_tok.item()
|
| if tid == asst_end or tid == bos:
|
| break
|
| response.append(tid)
|
| if stream:
|
|
|
| piece = tokenizer.decode([tid])
|
| sys.stdout.write(piece)
|
| sys.stdout.flush()
|
| if seen_mask is not None:
|
| seen_mask.scatter_(1, next_tok.unsqueeze(1), True)
|
| step_logits, states = model.step(next_tok.view(1, 1), states)
|
| next_logits = step_logits[:, 0, :]
|
|
|
| if stream:
|
| print()
|
| return response
|
|
|
|
|
| def main():
|
| p = argparse.ArgumentParser(description=__doc__.splitlines()[0])
|
| p.add_argument('-p', '--prompt', type=str, default='',
|
| help='single-turn prompt — exits after one response. Empty = interactive REPL.')
|
| p.add_argument('-t', '--temperature', type=float, default=0.6,
|
| help='softmax temperature; 0 = greedy. nanochat-CLI default 0.6 — '
|
| 'tighter than infer.py because chat should be focused.')
|
| p.add_argument('-k', '--top-k', type=int, default=50,
|
| help='top-k sampling; 0 disables. Default 50 (nanochat-CLI).')
|
| p.add_argument('--top-p', type=float, default=1.0,
|
| help='nucleus sampling threshold; 1.0 disables. Try 0.9 for varied responses.')
|
| p.add_argument('-r', '--repetition-penalty', type=float, default=1.15,
|
| help='CTRL-style repetition penalty (default 1.15) — keeps chat responses '
|
| 'from looping. Set to 1.0 for raw sampling.')
|
| p.add_argument('-n', '--max-tokens', type=int, default=256,
|
| help='max tokens per assistant response')
|
| p.add_argument('--ckpt', type=str, default=DEFAULT_CKPT,
|
| help='checkpoint path (default $CHECKPOINT_DIR/model.pt)')
|
| p.add_argument('--tokenizer', type=str, default=DEFAULT_TOKENIZER)
|
| p.add_argument('--no-history', action='store_true',
|
| help='reset conversation history before each turn (model has no memory)')
|
| p.add_argument('--seed', type=int, default=None)
|
| p.add_argument('--no-stream', action='store_true',
|
| help='print full response at end instead of token-by-token')
|
| p.add_argument('--compile', action='store_true',
|
| help='torch.compile the inner step() path. First generation pays '
|
| '~10–30s warmup; subsequent generations are 2–5× faster on CUDA. '
|
| 'Best for long REPL sessions; skip for one-shots.')
|
| args = p.parse_args()
|
|
|
| if args.seed is not None:
|
| torch.manual_seed(args.seed)
|
|
|
| print(f"device: {device}", file=sys.stderr)
|
| model, tokenizer, specials, ckpt, used_ckpt = _load(args.ckpt, args.tokenizer,
|
| compile_step=args.compile)
|
| step = ckpt.get('step', '?')
|
| best = ckpt.get('best_loss')
|
| n_params = sum(t.numel() for t in model.parameters())
|
| best_str = f" best_loss={best:.4f}" if isinstance(best, float) else ""
|
| print(f"loaded {used_ckpt} step={step}{best_str} params={n_params:,}", file=sys.stderr)
|
|
|
| bos = specials["<|bos|>"]
|
| user_s = specials["<|user_start|>"]
|
| user_e = specials["<|user_end|>"]
|
| asst_s = specials["<|assistant_start|>"]
|
| asst_e = specials["<|assistant_end|>"]
|
|
|
| print()
|
| print("Mnemo — chat mode")
|
| print("-" * 50)
|
| print(f"sampling: T={args.temperature} top_k={args.top_k} top_p={args.top_p} rep_penalty={args.repetition_penalty}")
|
| if not args.prompt:
|
| print("commands: 'quit' / 'exit' to end, 'clear' to reset history")
|
| print("-" * 50)
|
|
|
| conversation_tokens = [bos]
|
|
|
| while True:
|
| if args.prompt:
|
| user_input = args.prompt
|
| else:
|
| try:
|
| user_input = input("\nUser: ").strip()
|
| except (EOFError, KeyboardInterrupt):
|
| print("\nGoodbye!")
|
| break
|
|
|
| if user_input.lower() in ('quit', 'exit'):
|
| print("Goodbye!")
|
| break
|
| if user_input.lower() == 'clear':
|
| conversation_tokens = [bos]
|
| print("Conversation cleared.")
|
| continue
|
| if not user_input:
|
| continue
|
|
|
| if args.no_history:
|
| conversation_tokens = [bos]
|
|
|
|
|
| conversation_tokens.append(user_s)
|
| conversation_tokens.extend(tokenizer.encode(user_input).ids)
|
| conversation_tokens.append(user_e)
|
|
|
| conversation_tokens.append(asst_s)
|
|
|
| if not args.no_stream:
|
| sys.stdout.write("\nAssistant: ")
|
| sys.stdout.flush()
|
| t0 = time.time()
|
| response = _generate_one_turn(
|
| model, tokenizer, conversation_tokens, specials,
|
| max_new_tokens=args.max_tokens,
|
| temperature=args.temperature,
|
| top_k=args.top_k,
|
| top_p=args.top_p,
|
| repetition_penalty=args.repetition_penalty,
|
| stream=not args.no_stream,
|
| )
|
| elapsed = time.time() - t0
|
|
|
| if args.no_stream:
|
| print(f"\nAssistant: {tokenizer.decode(response)}")
|
| n_resp = len(response)
|
| print(f" [{n_resp} tok in {elapsed:.1f}s = {n_resp/max(elapsed, 1e-9):.1f} tok/s]",
|
| file=sys.stderr)
|
|
|
|
|
| conversation_tokens.extend(response)
|
| conversation_tokens.append(asst_e)
|
|
|
| if args.prompt:
|
| break
|
|
|
|
|
| if __name__ == '__main__':
|
| main()
|
|
|