"""Sample from a trained Stream Mixer GPT checkpoint. Default sampling: temperature=0.8, top-k=50 (empirically tuned on the v4 checkpoint — see log.md post-completion section). Override with -t / -k. Use -t 0 for deterministic greedy decoding (factual probes / debugging). Usage: python3 infer.py python3 infer.py -p "In this article, we will explore" # default stochastic (0.8, k=50) python3 infer.py -p "Photosynthesis is the process by which" # factual recall, encyclopedic continuation python3 infer.py -p "The chemical symbol of gold is" -t 0 # greedy — for testing factual probes python3 infer.py -p "Once upon a time," -t 0.9 -n 200 # higher temperature for narrative diversity python3 infer.py -s 3 # 3 samples (batched in parallel) python3 infer.py --ckpt model.pt --seed 42 python3 infer.py --compile # torch.compile (slow first iter, faster after) python3 infer.py --compile --repl # amortize compile across many prompts The model architecture is imported from model.py (shared with microgpt.py). The checkpoint carries its own config so this script needs no hardcoded shapes. """ import argparse import os import sys import time import torch import torch.nn.functional as F from tokenizers import Tokenizer from model import GPT, COMPUTE_DTYPE # Fallback config for checkpoints saved before microgpt.py started storing it. FALLBACK_CONFIG = dict(n_embd=128, n_layer=4, n_streams=16, stream_dim=64) # Default points at $CHECKPOINT_DIR/model.pt if set (matches microgpt.py's training # output location); falls back to CWD. Override with --ckpt for explicit paths. DEFAULT_CKPT = os.path.join(os.environ.get('CHECKPOINT_DIR', '.'), 'model.pt') DEFAULT_TOKENIZER = 'tokenizer.json' device = 'cuda' if torch.cuda.is_available() else ('mps' if torch.backends.mps.is_available() else 'cpu') def _sync(): if device == 'mps': torch.mps.synchronize() elif device == 'cuda': torch.cuda.synchronize() def _sample_token_batch(logits, temperature, top_k, top_p, repetition_penalty, seen_mask): """logits: (N, V). seen_mask: (N, V) bool — True at positions of tokens already emitted in each sample (prompt + generated so far). Returns (N,) long tensor of sampled token IDs. Applies, in order: 1. Repetition penalty (CTRL-style): for tokens marked in seen_mask, divide positive logits / multiply negative logits by `repetition_penalty`. penalty > 1 suppresses repeats; penalty == 1.0 is a no-op. Applied BEFORE temperature so the suppression is in raw-logit space. 2. Temperature scaling (skipped when temperature==0 → argmax greedy). 3+4. Fused top-k / top-p filtering. With top-k set, top-p only needs to look at the k candidates instead of sorting the full 32k vocab — big speedup. Honors the --temperature 0 greedy contract — but rep penalty still applies before argmax, so greedy + penalty > 1 is a valid "deterministic anti-loop" decoding mode. """ # 1) Repetition penalty — via fixed-size bool mask (constant cost per step, # vs. gather/scatter over a growing index list which is slow on MPS). if repetition_penalty != 1.0 and seen_mask is not None: penalized = torch.where(logits > 0, logits / repetition_penalty, logits * repetition_penalty) logits = torch.where(seen_mask, penalized, logits) if temperature == 0: return logits.argmax(dim=-1) logits = logits / temperature # 2+3) Fused top-k + top-p: when both are active, only sort the k candidates # instead of the full 32k vocab. `topk(..., sorted=True)` already returns # descending-ordered values, so we get the sort for free as a side effect # of the top-k pass. Big speedup vs sorting the whole logits tensor. if top_k > 0: k = min(top_k, logits.size(-1)) topk_vals, topk_idx = logits.topk(k, dim=-1, sorted=True) # (N, k) desc if top_p < 1.0: # Nucleus over just the top-k — sort already done by topk() topk_probs = F.softmax(topk_vals, dim=-1) cumsum = topk_probs.cumsum(dim=-1) mask = cumsum > top_p mask[:, 1:] = mask[:, :-1].clone() mask[:, 0] = False topk_vals = topk_vals.masked_fill(mask, float('-inf')) # Rebuild logits with only the kept candidates having real values logits = torch.full_like(logits, float('-inf')).scatter_(1, topk_idx, topk_vals) elif top_p < 1.0: # No top-k → must sort the full vocab (slow path; avoid by always passing top_k) sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) sorted_probs = F.softmax(sorted_logits, dim=-1) cumsum = sorted_probs.cumsum(dim=-1) sorted_mask = cumsum > top_p sorted_mask[:, 1:] = sorted_mask[:, :-1].clone() sorted_mask[:, 0] = False mask = torch.zeros_like(sorted_mask).scatter(1, sorted_indices, sorted_mask) logits = logits.masked_fill(mask, float('-inf')) probs = F.softmax(logits, dim=-1) return torch.multinomial(probs, 1).squeeze(-1) # (N,) @torch.no_grad() def generate(model, tokenizer, BOS, prompt, n_samples=1, max_new_tokens=256, temperature=0.8, top_k=50, top_p=1.0, repetition_penalty=1.0, stream=True): """Batched generation. All N samples advance one token per MPS/CUDA step. Returns list of N decoded strings. Streams to stdout only when n_samples==1 (multi-sample streaming would interleave illegibly). For long generations (>500 tokens) where the model loops, raise repetition_penalty to 1.1–1.2 or lower top_p to ~0.9. """ prompt_ids = [BOS] + tokenizer.encode(prompt).ids ctx = torch.tensor([prompt_ids] * n_samples, device=device) # (N, T_prompt) # Repetition-penalty memory: a (N, V) bool mask, True where the corresponding # token id has already been emitted. Constant cost per step regardless of # generation length — beats a growing token-list approach on MPS where # gather/scatter over scattered indices is slow. V = None # filled in after first forward (we don't know it yet) seen_mask = None t0 = time.time() logits, states = model.forward_with_states(ctx) next_logits = logits[:, -1, :].clone() # (N, V); clone so compile doesn't see a view # Now we know V; build the mask and seed it with the prompt tokens. V = next_logits.size(-1) if repetition_penalty != 1.0: seen_mask = torch.zeros(n_samples, V, dtype=torch.bool, device=device) seen_mask.scatter_(1, ctx, True) _sync() t_prefill = time.time() can_stream = stream and n_samples == 1 if can_stream: print(prompt, end='', flush=True) generated = [[] for _ in range(n_samples)] done = [False] * n_samples for _ in range(max_new_tokens): next_toks = _sample_token_batch(next_logits, temperature, top_k, top_p, repetition_penalty, seen_mask) # (N,) next_toks_list = next_toks.tolist() any_alive = False for i in range(n_samples): if done[i]: continue tid = next_toks_list[i] if tid == BOS: done[i] = True else: generated[i].append(tid) any_alive = True if not any_alive: break if can_stream: # Per-token decode is O(1); avoids the O(N²) full-list re-decode each step. # ByteLevel BPE pieces decode cleanly for ASCII English, which dominates FineWeb-Edu. piece = tokenizer.decode([next_toks_list[0]]) sys.stdout.write(piece) sys.stdout.flush() # Mark the new token as seen (O(N) scatter — constant per step). if seen_mask is not None: seen_mask.scatter_(1, next_toks.unsqueeze(1), True) x_in = next_toks.unsqueeze(1) # (N, 1) step_logits, states = model.step(x_in, states) next_logits = step_logits[:, 0, :] _sync() t_end = time.time() if can_stream: print() total_n = sum(len(g) for g in generated) prefill_ms = (t_prefill - t0) * 1000 decode_s = max(t_end - t_prefill, 1e-9) print(f" [prefill {len(prompt_ids)}×{n_samples} tok in {prefill_ms:.0f}ms | " f"decode {total_n} tok in {decode_s:.2f}s = {total_n/decode_s:.1f} tok/s aggregate]", file=sys.stderr) return [tokenizer.decode(g) for g in generated] def load_model(ckpt_path, vocab_size, compile_model=False): ckpt = torch.load(ckpt_path, map_location=device) if 'config' in ckpt: config = dict(ckpt['config']) else: print(f"note: checkpoint has no 'config' (older format); using fallback {FALLBACK_CONFIG}", file=sys.stderr) config = dict(FALLBACK_CONFIG) config['vocab_size'] = vocab_size # vocab is set by the tokenizer, not the checkpoint model = GPT.from_config(config).to(device) # Strip torch.compile's `_orig_mod.` prefix if present in older checkpoints. state = {k.removeprefix('_orig_mod.'): v for k, v in ckpt['model'].items()} missing, unexpected = model.load_state_dict(state, strict=False) if missing: print(f"warn: missing keys: {missing}", file=sys.stderr) if unexpected: print(f"warn: unexpected keys: {unexpected}", file=sys.stderr) model.eval() if compile_model: print("compiling step/prefill paths...", file=sys.stderr) model.step = torch.compile(model.step, dynamic=False) model.forward_with_states = torch.compile(model.forward_with_states, dynamic=False) return model, ckpt, config def main(): p = argparse.ArgumentParser(description="Sample from a trained microgpt Stream Mixer.") p.add_argument('-p', '--prompt', type=str, default="In this article, we will explore") p.add_argument('-n', '--max-tokens', type=int, default=256) p.add_argument('-t', '--temperature', type=float, default=0.8, help='softmax temperature; 0 = greedy argmax. Recommended 0.7–0.9 for prose. ' '(Determined empirically — see v4 log.md post-completion section.)') p.add_argument('-k', '--top-k', type=int, default=50, help='restrict sampling to top-k logits; 0 disables. ' 'Recommended 50 — keeps the long tail of the 32k vocab from injecting weird tokens.') p.add_argument('--top-p', type=float, default=1.0, help='nucleus sampling: keep the smallest token set whose cumulative ' 'probability ≥ top-p. 1.0 disables. Recommended 0.9 for long generations.') p.add_argument('-r', '--repetition-penalty', type=float, default=1.0, help='CTRL-style repetition penalty: tokens already in the context get ' 'their logits divided by this. 1.0 = no penalty. Recommended 1.1–1.2 ' 'for >500-token generations to break out of repetition wells.') p.add_argument('-s', '--num-samples', type=int, default=1) p.add_argument('--ckpt', type=str, default=DEFAULT_CKPT) p.add_argument('--tokenizer', type=str, default=DEFAULT_TOKENIZER) p.add_argument('--seed', type=int, default=None) p.add_argument('--no-stream', action='store_true', help='print full output at end instead of streaming token by token') p.add_argument('--compile', action='store_true', help='torch.compile the step/prefill paths (slower first call, faster after)') p.add_argument('--repl', action='store_true', help='after generating, read more prompts from stdin so --compile is paid once') args = p.parse_args() if args.seed is not None: torch.manual_seed(args.seed) if not os.path.exists(args.ckpt): sys.exit(f"error: checkpoint not found: {args.ckpt}") if not os.path.exists(args.tokenizer): sys.exit(f"error: tokenizer not found: {args.tokenizer}") print(f"device: {device}", file=sys.stderr) tokenizer = Tokenizer.from_file(args.tokenizer) BOS = tokenizer.token_to_id("<|bos|>") model, ckpt, config = load_model(args.ckpt, tokenizer.get_vocab_size(), compile_model=args.compile) step = ckpt.get('step', '?') best = ckpt.get('best_loss') n_params = sum(p.numel() for p in model.parameters()) best_str = f" best_loss={best:.4f}" if isinstance(best, float) else "" print(f"loaded {args.ckpt} step={step}{best_str}", file=sys.stderr) print(f"config: {config}", file=sys.stderr) print(f"params: {n_params:,}\n", file=sys.stderr) stream = not args.no_stream def run(prompt): results = generate(model, tokenizer, BOS, prompt, n_samples=args.num_samples, max_new_tokens=args.max_tokens, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, repetition_penalty=args.repetition_penalty, stream=stream) # Single-sample streaming already printed; otherwise print collected results. if not (stream and args.num_samples == 1): for i, out in enumerate(results): if args.num_samples > 1: print(f"--- sample {i+1}/{args.num_samples} ---", file=sys.stderr) print(prompt + out) if args.num_samples > 1: print(file=sys.stderr) run(args.prompt) # REPL: amortize torch.compile warmup over many prompts. Empty input # reuses the previous prompt; "exit"/"quit"/Ctrl-D leaves the loop. if args.repl: last_prompt = args.prompt print("\n[repl] enter prompt (empty=repeat, exit/quit/Ctrl-D to leave)", file=sys.stderr) while True: try: line = input(">>> ") except (EOFError, KeyboardInterrupt): print(file=sys.stderr) break line = line.strip() if line in ('exit', 'quit'): break prompt = line if line else last_prompt last_prompt = prompt run(prompt) if __name__ == '__main__': main()