File size: 11,676 Bytes
1d176f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
"""Chat with the SFT'd microgpt model. Adapted from karpathy/nanochat's chat_cli.py.



Builds chat-format prompts using our 5 special tokens:

    <|bos|> <|user_start|> ...user text... <|user_end|>

            <|assistant_start|> ...assistant text... <|assistant_end|>



Conversation history is maintained across turns (each new user message is appended

to the running token stream). Generation stops on <|assistant_end|>.



Usage:

    python3 chat_cli.py                                  # interactive REPL

    python3 chat_cli.py -p "Who are you?"                # one-shot

    python3 chat_cli.py --ckpt model.pt                  # explicit checkpoint path

    python3 chat_cli.py -t 0.6 -k 50                     # nanochat-CLI defaults

    python3 chat_cli.py --no-history                     # reset per turn (no memory)



Defaults follow nanochat-CLI (T=0.6, top-k=50). Lower temperature than infer.py's

default 0.8 because chat is meant to be focused, not exploratory.



REPL commands:

    quit / exit     end the session

    clear           reset conversation history

"""
import argparse
import os
import sys
import time

import torch
from tokenizers import Tokenizer

from model import GPT
from infer import _sample_token_batch  # reuse our sampler (rep penalty + nucleus + top-k)


DEFAULT_CKPT_DIR = os.environ.get('CHECKPOINT_DIR', '.')
# Prefer model_sft.pt (local dev convention with both base + SFT present), fall
# back to model.pt (HF convention β€” only the SFT'd model is uploaded as model.pt).
DEFAULT_CKPT = os.path.join(DEFAULT_CKPT_DIR, 'model_sft.pt')
DEFAULT_TOKENIZER = 'tokenizer.json'

SPECIAL_NAMES = [
    "<|bos|>",
    "<|user_start|>", "<|user_end|>",
    "<|assistant_start|>", "<|assistant_end|>",
]

device = 'cuda' if torch.cuda.is_available() else ('mps' if torch.backends.mps.is_available() else 'cpu')


def _load(ckpt_path, tokenizer_path, compile_step=False):
    """Resolve checkpoint (with fallback), load model + tokenizer + special token IDs.



    If `compile_step=True`, torch.compile the inner-loop `model.step` (B=1, T=1 fixed

    shapes β€” perfect for compile). We deliberately don't compile `forward_with_states`

    because chat conversations grow each turn β†’ dynamic prompt length β†’ recompile

    every turn. Prefill stays eager.



    Compile warmup takes 10–30s on first generation; pays off for any conversation

    long enough to generate >~50 tokens. On MPS, torch.compile is less mature than

    CUDA β€” try it but don't be surprised if it falls back to eager.

    """
    if not os.path.exists(ckpt_path):
        # Silent fallback for HF layout (only model.pt = the SFT'd one)
        fallback = os.path.join(DEFAULT_CKPT_DIR, 'model.pt')
        if ckpt_path != fallback and os.path.exists(fallback):
            ckpt_path = fallback
        else:
            sys.exit(f"error: no checkpoint at {ckpt_path}")
    if not os.path.exists(tokenizer_path):
        sys.exit(f"error: no tokenizer at {tokenizer_path}")

    tokenizer = Tokenizer.from_file(tokenizer_path)
    vocab_size = tokenizer.get_vocab_size()

    ckpt = torch.load(ckpt_path, map_location=device)
    config = dict(ckpt['config'])
    config['vocab_size'] = ((vocab_size + 63) // 64) * 64
    model = GPT.from_config(config).to(device)
    state = {k.removeprefix('_orig_mod.'): v for k, v in ckpt['model'].items()}
    missing, unexpected = model.load_state_dict(state, strict=False)
    if missing:    print(f"warn: missing keys:    {missing}",    file=sys.stderr)
    if unexpected: print(f"warn: unexpected keys: {unexpected}", file=sys.stderr)
    model.eval()

    specials = {}
    for name in SPECIAL_NAMES:
        tid = tokenizer.token_to_id(name)
        if tid is None:
            sys.exit(f"error: tokenizer is missing special token {name} β€” was it retrained without the SFT vocab?")
        specials[name] = tid

    if compile_step:
        print("compiling model.step (hot inner-loop path)...", file=sys.stderr)
        if device == 'mps':
            print("  note: torch.compile on MPS may fall back to eager. CUDA gets the full speedup.",
                  file=sys.stderr)
        compile_kwargs = {'dynamic': False}
        # reduce-overhead enables CUDA graphs β€” big win on Ampere+ where kernel-launch
        # overhead dominates the small per-token forward. No-op on MPS/CPU.
        if device == 'cuda' and torch.cuda.get_device_capability() >= (8, 0):
            compile_kwargs['mode'] = 'reduce-overhead'
        model.step = torch.compile(model.step, **compile_kwargs)

    return model, tokenizer, specials, ckpt, ckpt_path


@torch.no_grad()
def _generate_one_turn(model, tokenizer, conversation_tokens, specials,

                       max_new_tokens, temperature, top_k, top_p, repetition_penalty,

                       stream):
    """Run prefill over the full conversation, then sample tokens until

    <|assistant_end|> (or <|bos|>, or max_new_tokens). Returns the response token list."""
    bos = specials["<|bos|>"]
    asst_end = specials["<|assistant_end|>"]

    ctx = torch.tensor([conversation_tokens], device=device)              # (1, T)
    V = model.lm_head.weight.size(0)  # padded vocab

    seen_mask = None
    if repetition_penalty != 1.0:
        seen_mask = torch.zeros(1, V, dtype=torch.bool, device=device)
        seen_mask.scatter_(1, ctx, True)

    # Prefill the full conversation in one pass
    logits, states = model.forward_with_states(ctx)
    next_logits = logits[:, -1, :].clone()

    response = []
    for _ in range(max_new_tokens):
        next_tok = _sample_token_batch(next_logits, temperature, top_k, top_p,
                                       repetition_penalty, seen_mask)        # (1,)
        tid = next_tok.item()
        if tid == asst_end or tid == bos:
            break
        response.append(tid)
        if stream:
            # Per-token decode is fine for ASCII English (the ClimbMix corpus dominates that).
            piece = tokenizer.decode([tid])
            sys.stdout.write(piece)
            sys.stdout.flush()
        if seen_mask is not None:
            seen_mask.scatter_(1, next_tok.unsqueeze(1), True)
        step_logits, states = model.step(next_tok.view(1, 1), states)
        next_logits = step_logits[:, 0, :]

    if stream:
        print()
    return response


def main():
    p = argparse.ArgumentParser(description=__doc__.splitlines()[0])
    p.add_argument('-p', '--prompt', type=str, default='',
                   help='single-turn prompt β€” exits after one response. Empty = interactive REPL.')
    p.add_argument('-t', '--temperature', type=float, default=0.6,
                   help='softmax temperature; 0 = greedy. nanochat-CLI default 0.6 β€” '
                        'tighter than infer.py because chat should be focused.')
    p.add_argument('-k', '--top-k', type=int, default=50,
                   help='top-k sampling; 0 disables. Default 50 (nanochat-CLI).')
    p.add_argument('--top-p', type=float, default=1.0,
                   help='nucleus sampling threshold; 1.0 disables. Try 0.9 for varied responses.')
    p.add_argument('-r', '--repetition-penalty', type=float, default=1.15,
                   help='CTRL-style repetition penalty (default 1.15) β€” keeps chat responses '
                        'from looping. Set to 1.0 for raw sampling.')
    p.add_argument('-n', '--max-tokens', type=int, default=256,
                   help='max tokens per assistant response')
    p.add_argument('--ckpt', type=str, default=DEFAULT_CKPT,
                   help='checkpoint path (default $CHECKPOINT_DIR/model.pt)')
    p.add_argument('--tokenizer', type=str, default=DEFAULT_TOKENIZER)
    p.add_argument('--no-history', action='store_true',
                   help='reset conversation history before each turn (model has no memory)')
    p.add_argument('--seed', type=int, default=None)
    p.add_argument('--no-stream', action='store_true',
                   help='print full response at end instead of token-by-token')
    p.add_argument('--compile', action='store_true',
                   help='torch.compile the inner step() path. First generation pays '
                        '~10–30s warmup; subsequent generations are 2–5Γ— faster on CUDA. '
                        'Best for long REPL sessions; skip for one-shots.')
    args = p.parse_args()

    if args.seed is not None:
        torch.manual_seed(args.seed)

    print(f"device: {device}", file=sys.stderr)
    model, tokenizer, specials, ckpt, used_ckpt = _load(args.ckpt, args.tokenizer,
                                                        compile_step=args.compile)
    step = ckpt.get('step', '?')
    best = ckpt.get('best_loss')
    n_params = sum(t.numel() for t in model.parameters())
    best_str = f"  best_loss={best:.4f}" if isinstance(best, float) else ""
    print(f"loaded {used_ckpt}  step={step}{best_str}  params={n_params:,}", file=sys.stderr)

    bos = specials["<|bos|>"]
    user_s = specials["<|user_start|>"]
    user_e = specials["<|user_end|>"]
    asst_s = specials["<|assistant_start|>"]
    asst_e = specials["<|assistant_end|>"]

    print()
    print("Mnemo β€” chat mode")
    print("-" * 50)
    print(f"sampling: T={args.temperature}  top_k={args.top_k}  top_p={args.top_p}  rep_penalty={args.repetition_penalty}")
    if not args.prompt:
        print("commands: 'quit' / 'exit' to end, 'clear' to reset history")
    print("-" * 50)

    conversation_tokens = [bos]

    while True:
        if args.prompt:
            user_input = args.prompt
        else:
            try:
                user_input = input("\nUser: ").strip()
            except (EOFError, KeyboardInterrupt):
                print("\nGoodbye!")
                break

        if user_input.lower() in ('quit', 'exit'):
            print("Goodbye!")
            break
        if user_input.lower() == 'clear':
            conversation_tokens = [bos]
            print("Conversation cleared.")
            continue
        if not user_input:
            continue

        if args.no_history:
            conversation_tokens = [bos]

        # Append user turn
        conversation_tokens.append(user_s)
        conversation_tokens.extend(tokenizer.encode(user_input).ids)
        conversation_tokens.append(user_e)
        # Open assistant turn (the model continues from here)
        conversation_tokens.append(asst_s)

        if not args.no_stream:
            sys.stdout.write("\nAssistant: ")
            sys.stdout.flush()
        t0 = time.time()
        response = _generate_one_turn(
            model, tokenizer, conversation_tokens, specials,
            max_new_tokens=args.max_tokens,
            temperature=args.temperature,
            top_k=args.top_k,
            top_p=args.top_p,
            repetition_penalty=args.repetition_penalty,
            stream=not args.no_stream,
        )
        elapsed = time.time() - t0

        if args.no_stream:
            print(f"\nAssistant: {tokenizer.decode(response)}")
        n_resp = len(response)
        print(f"  [{n_resp} tok in {elapsed:.1f}s = {n_resp/max(elapsed, 1e-9):.1f} tok/s]",
              file=sys.stderr)

        # Close assistant turn in the history (so the next prefill sees a complete turn)
        conversation_tokens.extend(response)
        conversation_tokens.append(asst_e)

        if args.prompt:
            break


if __name__ == '__main__':
    main()