| """Sample from a trained Stream Mixer GPT checkpoint. |
| |
| Default sampling: temperature=0.8, top-k=50 (empirically tuned on the v4 |
| checkpoint — see log.md post-completion section). Override with -t / -k. |
| Use -t 0 for deterministic greedy decoding (factual probes / debugging). |
| |
| Usage: |
| python3 infer.py |
| python3 infer.py -p "In this article, we will explore" # default stochastic (0.8, k=50) |
| python3 infer.py -p "Photosynthesis is the process by which" # factual recall, encyclopedic continuation |
| python3 infer.py -p "The chemical symbol of gold is" -t 0 # greedy — for testing factual probes |
| python3 infer.py -p "Once upon a time," -t 0.9 -n 200 # higher temperature for narrative diversity |
| python3 infer.py -s 3 # 3 samples (batched in parallel) |
| python3 infer.py --ckpt model.pt --seed 42 |
| python3 infer.py --compile # torch.compile (slow first iter, faster after) |
| python3 infer.py --compile --repl # amortize compile across many prompts |
| |
| The model architecture is imported from model.py (shared with microgpt.py). |
| The checkpoint carries its own config so this script needs no hardcoded shapes. |
| """ |
| import argparse |
| import os |
| import sys |
| import time |
|
|
| import torch |
| import torch.nn.functional as F |
| from tokenizers import Tokenizer |
|
|
| from model import GPT, COMPUTE_DTYPE |
|
|
| |
| FALLBACK_CONFIG = dict(n_embd=128, n_layer=4, n_streams=16, stream_dim=64) |
|
|
| |
| |
| DEFAULT_CKPT = os.path.join(os.environ.get('CHECKPOINT_DIR', '.'), 'model.pt') |
| DEFAULT_TOKENIZER = 'tokenizer.json' |
|
|
| device = 'cuda' if torch.cuda.is_available() else ('mps' if torch.backends.mps.is_available() else 'cpu') |
|
|
|
|
| def _sync(): |
| if device == 'mps': |
| torch.mps.synchronize() |
| elif device == 'cuda': |
| torch.cuda.synchronize() |
|
|
|
|
| def _sample_token_batch(logits, temperature, top_k, top_p, repetition_penalty, seen_mask): |
| """logits: (N, V). seen_mask: (N, V) bool — True at positions of tokens |
| already emitted in each sample (prompt + generated so far). Returns (N,) long |
| tensor of sampled token IDs. |
| |
| Applies, in order: |
| 1. Repetition penalty (CTRL-style): for tokens marked in seen_mask, divide |
| positive logits / multiply negative logits by `repetition_penalty`. |
| penalty > 1 suppresses repeats; penalty == 1.0 is a no-op. Applied BEFORE |
| temperature so the suppression is in raw-logit space. |
| 2. Temperature scaling (skipped when temperature==0 → argmax greedy). |
| 3+4. Fused top-k / top-p filtering. With top-k set, top-p only needs to look |
| at the k candidates instead of sorting the full 32k vocab — big speedup. |
| |
| Honors the --temperature 0 greedy contract — but rep penalty still applies |
| before argmax, so greedy + penalty > 1 is a valid "deterministic anti-loop" |
| decoding mode. |
| """ |
| |
| |
| if repetition_penalty != 1.0 and seen_mask is not None: |
| penalized = torch.where(logits > 0, |
| logits / repetition_penalty, |
| logits * repetition_penalty) |
| logits = torch.where(seen_mask, penalized, logits) |
|
|
| if temperature == 0: |
| return logits.argmax(dim=-1) |
| logits = logits / temperature |
|
|
| |
| |
| |
| |
| if top_k > 0: |
| k = min(top_k, logits.size(-1)) |
| topk_vals, topk_idx = logits.topk(k, dim=-1, sorted=True) |
| if top_p < 1.0: |
| |
| topk_probs = F.softmax(topk_vals, dim=-1) |
| cumsum = topk_probs.cumsum(dim=-1) |
| mask = cumsum > top_p |
| mask[:, 1:] = mask[:, :-1].clone() |
| mask[:, 0] = False |
| topk_vals = topk_vals.masked_fill(mask, float('-inf')) |
| |
| logits = torch.full_like(logits, float('-inf')).scatter_(1, topk_idx, topk_vals) |
| elif top_p < 1.0: |
| |
| sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) |
| sorted_probs = F.softmax(sorted_logits, dim=-1) |
| cumsum = sorted_probs.cumsum(dim=-1) |
| sorted_mask = cumsum > top_p |
| sorted_mask[:, 1:] = sorted_mask[:, :-1].clone() |
| sorted_mask[:, 0] = False |
| mask = torch.zeros_like(sorted_mask).scatter(1, sorted_indices, sorted_mask) |
| logits = logits.masked_fill(mask, float('-inf')) |
|
|
| probs = F.softmax(logits, dim=-1) |
| return torch.multinomial(probs, 1).squeeze(-1) |
|
|
|
|
| @torch.no_grad() |
| def generate(model, tokenizer, BOS, prompt, n_samples=1, |
| max_new_tokens=256, temperature=0.8, top_k=50, top_p=1.0, |
| repetition_penalty=1.0, stream=True): |
| """Batched generation. All N samples advance one token per MPS/CUDA step. |
| |
| Returns list of N decoded strings. Streams to stdout only when n_samples==1 |
| (multi-sample streaming would interleave illegibly). |
| |
| For long generations (>500 tokens) where the model loops, raise |
| repetition_penalty to 1.1–1.2 or lower top_p to ~0.9. |
| """ |
| prompt_ids = [BOS] + tokenizer.encode(prompt).ids |
| ctx = torch.tensor([prompt_ids] * n_samples, device=device) |
| |
| |
| |
| |
| V = None |
| seen_mask = None |
|
|
| t0 = time.time() |
| logits, states = model.forward_with_states(ctx) |
| next_logits = logits[:, -1, :].clone() |
| |
| V = next_logits.size(-1) |
| if repetition_penalty != 1.0: |
| seen_mask = torch.zeros(n_samples, V, dtype=torch.bool, device=device) |
| seen_mask.scatter_(1, ctx, True) |
| _sync() |
| t_prefill = time.time() |
|
|
| can_stream = stream and n_samples == 1 |
| if can_stream: |
| print(prompt, end='', flush=True) |
|
|
| generated = [[] for _ in range(n_samples)] |
| done = [False] * n_samples |
|
|
| for _ in range(max_new_tokens): |
| next_toks = _sample_token_batch(next_logits, temperature, top_k, top_p, |
| repetition_penalty, seen_mask) |
| next_toks_list = next_toks.tolist() |
| any_alive = False |
| for i in range(n_samples): |
| if done[i]: |
| continue |
| tid = next_toks_list[i] |
| if tid == BOS: |
| done[i] = True |
| else: |
| generated[i].append(tid) |
| any_alive = True |
| if not any_alive: |
| break |
| if can_stream: |
| |
| |
| piece = tokenizer.decode([next_toks_list[0]]) |
| sys.stdout.write(piece) |
| sys.stdout.flush() |
| |
| if seen_mask is not None: |
| seen_mask.scatter_(1, next_toks.unsqueeze(1), True) |
| x_in = next_toks.unsqueeze(1) |
| step_logits, states = model.step(x_in, states) |
| next_logits = step_logits[:, 0, :] |
|
|
| _sync() |
| t_end = time.time() |
| if can_stream: |
| print() |
|
|
| total_n = sum(len(g) for g in generated) |
| prefill_ms = (t_prefill - t0) * 1000 |
| decode_s = max(t_end - t_prefill, 1e-9) |
| print(f" [prefill {len(prompt_ids)}×{n_samples} tok in {prefill_ms:.0f}ms | " |
| f"decode {total_n} tok in {decode_s:.2f}s = {total_n/decode_s:.1f} tok/s aggregate]", |
| file=sys.stderr) |
| return [tokenizer.decode(g) for g in generated] |
|
|
|
|
| def load_model(ckpt_path, vocab_size, compile_model=False): |
| ckpt = torch.load(ckpt_path, map_location=device) |
| if 'config' in ckpt: |
| config = dict(ckpt['config']) |
| else: |
| print(f"note: checkpoint has no 'config' (older format); using fallback {FALLBACK_CONFIG}", |
| file=sys.stderr) |
| config = dict(FALLBACK_CONFIG) |
| config['vocab_size'] = vocab_size |
| model = GPT.from_config(config).to(device) |
| |
| state = {k.removeprefix('_orig_mod.'): v for k, v in ckpt['model'].items()} |
| missing, unexpected = model.load_state_dict(state, strict=False) |
| if missing: print(f"warn: missing keys: {missing}", file=sys.stderr) |
| if unexpected: print(f"warn: unexpected keys: {unexpected}", file=sys.stderr) |
| model.eval() |
| if compile_model: |
| print("compiling step/prefill paths...", file=sys.stderr) |
| model.step = torch.compile(model.step, dynamic=False) |
| model.forward_with_states = torch.compile(model.forward_with_states, dynamic=False) |
| return model, ckpt, config |
|
|
|
|
| def main(): |
| p = argparse.ArgumentParser(description="Sample from a trained microgpt Stream Mixer.") |
| p.add_argument('-p', '--prompt', type=str, default="In this article, we will explore") |
| p.add_argument('-n', '--max-tokens', type=int, default=256) |
| p.add_argument('-t', '--temperature', type=float, default=0.8, |
| help='softmax temperature; 0 = greedy argmax. Recommended 0.7–0.9 for prose. ' |
| '(Determined empirically — see v4 log.md post-completion section.)') |
| p.add_argument('-k', '--top-k', type=int, default=50, |
| help='restrict sampling to top-k logits; 0 disables. ' |
| 'Recommended 50 — keeps the long tail of the 32k vocab from injecting weird tokens.') |
| p.add_argument('--top-p', type=float, default=1.0, |
| help='nucleus sampling: keep the smallest token set whose cumulative ' |
| 'probability ≥ top-p. 1.0 disables. Recommended 0.9 for long generations.') |
| p.add_argument('-r', '--repetition-penalty', type=float, default=1.0, |
| help='CTRL-style repetition penalty: tokens already in the context get ' |
| 'their logits divided by this. 1.0 = no penalty. Recommended 1.1–1.2 ' |
| 'for >500-token generations to break out of repetition wells.') |
| p.add_argument('-s', '--num-samples', type=int, default=1) |
| p.add_argument('--ckpt', type=str, default=DEFAULT_CKPT) |
| p.add_argument('--tokenizer', type=str, default=DEFAULT_TOKENIZER) |
| p.add_argument('--seed', type=int, default=None) |
| p.add_argument('--no-stream', action='store_true', |
| help='print full output at end instead of streaming token by token') |
| p.add_argument('--compile', action='store_true', |
| help='torch.compile the step/prefill paths (slower first call, faster after)') |
| p.add_argument('--repl', action='store_true', |
| help='after generating, read more prompts from stdin so --compile is paid once') |
| args = p.parse_args() |
|
|
| if args.seed is not None: |
| torch.manual_seed(args.seed) |
| if not os.path.exists(args.ckpt): |
| sys.exit(f"error: checkpoint not found: {args.ckpt}") |
| if not os.path.exists(args.tokenizer): |
| sys.exit(f"error: tokenizer not found: {args.tokenizer}") |
|
|
| print(f"device: {device}", file=sys.stderr) |
| tokenizer = Tokenizer.from_file(args.tokenizer) |
| BOS = tokenizer.token_to_id("<|bos|>") |
|
|
| model, ckpt, config = load_model(args.ckpt, tokenizer.get_vocab_size(), |
| compile_model=args.compile) |
| step = ckpt.get('step', '?') |
| best = ckpt.get('best_loss') |
| n_params = sum(p.numel() for p in model.parameters()) |
| best_str = f" best_loss={best:.4f}" if isinstance(best, float) else "" |
| print(f"loaded {args.ckpt} step={step}{best_str}", file=sys.stderr) |
| print(f"config: {config}", file=sys.stderr) |
| print(f"params: {n_params:,}\n", file=sys.stderr) |
|
|
| stream = not args.no_stream |
|
|
| def run(prompt): |
| results = generate(model, tokenizer, BOS, prompt, |
| n_samples=args.num_samples, |
| max_new_tokens=args.max_tokens, |
| temperature=args.temperature, |
| top_k=args.top_k, |
| top_p=args.top_p, |
| repetition_penalty=args.repetition_penalty, |
| stream=stream) |
| |
| if not (stream and args.num_samples == 1): |
| for i, out in enumerate(results): |
| if args.num_samples > 1: |
| print(f"--- sample {i+1}/{args.num_samples} ---", file=sys.stderr) |
| print(prompt + out) |
| if args.num_samples > 1: |
| print(file=sys.stderr) |
|
|
| run(args.prompt) |
|
|
| |
| |
| if args.repl: |
| last_prompt = args.prompt |
| print("\n[repl] enter prompt (empty=repeat, exit/quit/Ctrl-D to leave)", file=sys.stderr) |
| while True: |
| try: |
| line = input(">>> ") |
| except (EOFError, KeyboardInterrupt): |
| print(file=sys.stderr) |
| break |
| line = line.strip() |
| if line in ('exit', 'quit'): |
| break |
| prompt = line if line else last_prompt |
| last_prompt = prompt |
| run(prompt) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|