File size: 11,676 Bytes
1d176f6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 | """Chat with the SFT'd microgpt model. Adapted from karpathy/nanochat's chat_cli.py.
Builds chat-format prompts using our 5 special tokens:
<|bos|> <|user_start|> ...user text... <|user_end|>
<|assistant_start|> ...assistant text... <|assistant_end|>
Conversation history is maintained across turns (each new user message is appended
to the running token stream). Generation stops on <|assistant_end|>.
Usage:
python3 chat_cli.py # interactive REPL
python3 chat_cli.py -p "Who are you?" # one-shot
python3 chat_cli.py --ckpt model.pt # explicit checkpoint path
python3 chat_cli.py -t 0.6 -k 50 # nanochat-CLI defaults
python3 chat_cli.py --no-history # reset per turn (no memory)
Defaults follow nanochat-CLI (T=0.6, top-k=50). Lower temperature than infer.py's
default 0.8 because chat is meant to be focused, not exploratory.
REPL commands:
quit / exit end the session
clear reset conversation history
"""
import argparse
import os
import sys
import time
import torch
from tokenizers import Tokenizer
from model import GPT
from infer import _sample_token_batch # reuse our sampler (rep penalty + nucleus + top-k)
DEFAULT_CKPT_DIR = os.environ.get('CHECKPOINT_DIR', '.')
# Prefer model_sft.pt (local dev convention with both base + SFT present), fall
# back to model.pt (HF convention β only the SFT'd model is uploaded as model.pt).
DEFAULT_CKPT = os.path.join(DEFAULT_CKPT_DIR, 'model_sft.pt')
DEFAULT_TOKENIZER = 'tokenizer.json'
SPECIAL_NAMES = [
"<|bos|>",
"<|user_start|>", "<|user_end|>",
"<|assistant_start|>", "<|assistant_end|>",
]
device = 'cuda' if torch.cuda.is_available() else ('mps' if torch.backends.mps.is_available() else 'cpu')
def _load(ckpt_path, tokenizer_path, compile_step=False):
"""Resolve checkpoint (with fallback), load model + tokenizer + special token IDs.
If `compile_step=True`, torch.compile the inner-loop `model.step` (B=1, T=1 fixed
shapes β perfect for compile). We deliberately don't compile `forward_with_states`
because chat conversations grow each turn β dynamic prompt length β recompile
every turn. Prefill stays eager.
Compile warmup takes 10β30s on first generation; pays off for any conversation
long enough to generate >~50 tokens. On MPS, torch.compile is less mature than
CUDA β try it but don't be surprised if it falls back to eager.
"""
if not os.path.exists(ckpt_path):
# Silent fallback for HF layout (only model.pt = the SFT'd one)
fallback = os.path.join(DEFAULT_CKPT_DIR, 'model.pt')
if ckpt_path != fallback and os.path.exists(fallback):
ckpt_path = fallback
else:
sys.exit(f"error: no checkpoint at {ckpt_path}")
if not os.path.exists(tokenizer_path):
sys.exit(f"error: no tokenizer at {tokenizer_path}")
tokenizer = Tokenizer.from_file(tokenizer_path)
vocab_size = tokenizer.get_vocab_size()
ckpt = torch.load(ckpt_path, map_location=device)
config = dict(ckpt['config'])
config['vocab_size'] = ((vocab_size + 63) // 64) * 64
model = GPT.from_config(config).to(device)
state = {k.removeprefix('_orig_mod.'): v for k, v in ckpt['model'].items()}
missing, unexpected = model.load_state_dict(state, strict=False)
if missing: print(f"warn: missing keys: {missing}", file=sys.stderr)
if unexpected: print(f"warn: unexpected keys: {unexpected}", file=sys.stderr)
model.eval()
specials = {}
for name in SPECIAL_NAMES:
tid = tokenizer.token_to_id(name)
if tid is None:
sys.exit(f"error: tokenizer is missing special token {name} β was it retrained without the SFT vocab?")
specials[name] = tid
if compile_step:
print("compiling model.step (hot inner-loop path)...", file=sys.stderr)
if device == 'mps':
print(" note: torch.compile on MPS may fall back to eager. CUDA gets the full speedup.",
file=sys.stderr)
compile_kwargs = {'dynamic': False}
# reduce-overhead enables CUDA graphs β big win on Ampere+ where kernel-launch
# overhead dominates the small per-token forward. No-op on MPS/CPU.
if device == 'cuda' and torch.cuda.get_device_capability() >= (8, 0):
compile_kwargs['mode'] = 'reduce-overhead'
model.step = torch.compile(model.step, **compile_kwargs)
return model, tokenizer, specials, ckpt, ckpt_path
@torch.no_grad()
def _generate_one_turn(model, tokenizer, conversation_tokens, specials,
max_new_tokens, temperature, top_k, top_p, repetition_penalty,
stream):
"""Run prefill over the full conversation, then sample tokens until
<|assistant_end|> (or <|bos|>, or max_new_tokens). Returns the response token list."""
bos = specials["<|bos|>"]
asst_end = specials["<|assistant_end|>"]
ctx = torch.tensor([conversation_tokens], device=device) # (1, T)
V = model.lm_head.weight.size(0) # padded vocab
seen_mask = None
if repetition_penalty != 1.0:
seen_mask = torch.zeros(1, V, dtype=torch.bool, device=device)
seen_mask.scatter_(1, ctx, True)
# Prefill the full conversation in one pass
logits, states = model.forward_with_states(ctx)
next_logits = logits[:, -1, :].clone()
response = []
for _ in range(max_new_tokens):
next_tok = _sample_token_batch(next_logits, temperature, top_k, top_p,
repetition_penalty, seen_mask) # (1,)
tid = next_tok.item()
if tid == asst_end or tid == bos:
break
response.append(tid)
if stream:
# Per-token decode is fine for ASCII English (the ClimbMix corpus dominates that).
piece = tokenizer.decode([tid])
sys.stdout.write(piece)
sys.stdout.flush()
if seen_mask is not None:
seen_mask.scatter_(1, next_tok.unsqueeze(1), True)
step_logits, states = model.step(next_tok.view(1, 1), states)
next_logits = step_logits[:, 0, :]
if stream:
print()
return response
def main():
p = argparse.ArgumentParser(description=__doc__.splitlines()[0])
p.add_argument('-p', '--prompt', type=str, default='',
help='single-turn prompt β exits after one response. Empty = interactive REPL.')
p.add_argument('-t', '--temperature', type=float, default=0.6,
help='softmax temperature; 0 = greedy. nanochat-CLI default 0.6 β '
'tighter than infer.py because chat should be focused.')
p.add_argument('-k', '--top-k', type=int, default=50,
help='top-k sampling; 0 disables. Default 50 (nanochat-CLI).')
p.add_argument('--top-p', type=float, default=1.0,
help='nucleus sampling threshold; 1.0 disables. Try 0.9 for varied responses.')
p.add_argument('-r', '--repetition-penalty', type=float, default=1.15,
help='CTRL-style repetition penalty (default 1.15) β keeps chat responses '
'from looping. Set to 1.0 for raw sampling.')
p.add_argument('-n', '--max-tokens', type=int, default=256,
help='max tokens per assistant response')
p.add_argument('--ckpt', type=str, default=DEFAULT_CKPT,
help='checkpoint path (default $CHECKPOINT_DIR/model.pt)')
p.add_argument('--tokenizer', type=str, default=DEFAULT_TOKENIZER)
p.add_argument('--no-history', action='store_true',
help='reset conversation history before each turn (model has no memory)')
p.add_argument('--seed', type=int, default=None)
p.add_argument('--no-stream', action='store_true',
help='print full response at end instead of token-by-token')
p.add_argument('--compile', action='store_true',
help='torch.compile the inner step() path. First generation pays '
'~10β30s warmup; subsequent generations are 2β5Γ faster on CUDA. '
'Best for long REPL sessions; skip for one-shots.')
args = p.parse_args()
if args.seed is not None:
torch.manual_seed(args.seed)
print(f"device: {device}", file=sys.stderr)
model, tokenizer, specials, ckpt, used_ckpt = _load(args.ckpt, args.tokenizer,
compile_step=args.compile)
step = ckpt.get('step', '?')
best = ckpt.get('best_loss')
n_params = sum(t.numel() for t in model.parameters())
best_str = f" best_loss={best:.4f}" if isinstance(best, float) else ""
print(f"loaded {used_ckpt} step={step}{best_str} params={n_params:,}", file=sys.stderr)
bos = specials["<|bos|>"]
user_s = specials["<|user_start|>"]
user_e = specials["<|user_end|>"]
asst_s = specials["<|assistant_start|>"]
asst_e = specials["<|assistant_end|>"]
print()
print("Mnemo β chat mode")
print("-" * 50)
print(f"sampling: T={args.temperature} top_k={args.top_k} top_p={args.top_p} rep_penalty={args.repetition_penalty}")
if not args.prompt:
print("commands: 'quit' / 'exit' to end, 'clear' to reset history")
print("-" * 50)
conversation_tokens = [bos]
while True:
if args.prompt:
user_input = args.prompt
else:
try:
user_input = input("\nUser: ").strip()
except (EOFError, KeyboardInterrupt):
print("\nGoodbye!")
break
if user_input.lower() in ('quit', 'exit'):
print("Goodbye!")
break
if user_input.lower() == 'clear':
conversation_tokens = [bos]
print("Conversation cleared.")
continue
if not user_input:
continue
if args.no_history:
conversation_tokens = [bos]
# Append user turn
conversation_tokens.append(user_s)
conversation_tokens.extend(tokenizer.encode(user_input).ids)
conversation_tokens.append(user_e)
# Open assistant turn (the model continues from here)
conversation_tokens.append(asst_s)
if not args.no_stream:
sys.stdout.write("\nAssistant: ")
sys.stdout.flush()
t0 = time.time()
response = _generate_one_turn(
model, tokenizer, conversation_tokens, specials,
max_new_tokens=args.max_tokens,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
repetition_penalty=args.repetition_penalty,
stream=not args.no_stream,
)
elapsed = time.time() - t0
if args.no_stream:
print(f"\nAssistant: {tokenizer.decode(response)}")
n_resp = len(response)
print(f" [{n_resp} tok in {elapsed:.1f}s = {n_resp/max(elapsed, 1e-9):.1f} tok/s]",
file=sys.stderr)
# Close assistant turn in the history (so the next prefill sees a complete turn)
conversation_tokens.extend(response)
conversation_tokens.append(asst_e)
if args.prompt:
break
if __name__ == '__main__':
main()
|