Mnemo / infer.py
ecreeth's picture
Upload 2 files
491dfb2 verified
Raw
History Blame Contribute Delete
14.7 kB
"""Sample from a trained Stream Mixer GPT checkpoint.
Default sampling: temperature=0.8, top-k=50 (empirically tuned on the v4
checkpoint — see log.md post-completion section). Override with -t / -k.
Use -t 0 for deterministic greedy decoding (factual probes / debugging).
Usage:
python3 infer.py
python3 infer.py -p "In this article, we will explore" # default stochastic (0.8, k=50)
python3 infer.py -p "Photosynthesis is the process by which" # factual recall, encyclopedic continuation
python3 infer.py -p "The chemical symbol of gold is" -t 0 # greedy — for testing factual probes
python3 infer.py -p "Once upon a time," -t 0.9 -n 200 # higher temperature for narrative diversity
python3 infer.py -s 3 # 3 samples (batched in parallel)
python3 infer.py --ckpt model.pt --seed 42
python3 infer.py --compile # torch.compile (slow first iter, faster after)
python3 infer.py --compile --repl # amortize compile across many prompts
The model architecture is imported from model.py (shared with microgpt.py).
The checkpoint carries its own config so this script needs no hardcoded shapes.
"""
import argparse
import os
import sys
import time
import torch
import torch.nn.functional as F
from tokenizers import Tokenizer
from model import GPT, COMPUTE_DTYPE
# Fallback config for checkpoints saved before microgpt.py started storing it.
FALLBACK_CONFIG = dict(n_embd=128, n_layer=4, n_streams=16, stream_dim=64)
# Default points at $CHECKPOINT_DIR/model.pt if set (matches microgpt.py's training
# output location); falls back to CWD. Override with --ckpt for explicit paths.
DEFAULT_CKPT = os.path.join(os.environ.get('CHECKPOINT_DIR', '.'), 'model.pt')
DEFAULT_TOKENIZER = 'tokenizer.json'
device = 'cuda' if torch.cuda.is_available() else ('mps' if torch.backends.mps.is_available() else 'cpu')
def _sync():
if device == 'mps':
torch.mps.synchronize()
elif device == 'cuda':
torch.cuda.synchronize()
def _sample_token_batch(logits, temperature, top_k, top_p, repetition_penalty, seen_mask):
"""logits: (N, V). seen_mask: (N, V) bool — True at positions of tokens
already emitted in each sample (prompt + generated so far). Returns (N,) long
tensor of sampled token IDs.
Applies, in order:
1. Repetition penalty (CTRL-style): for tokens marked in seen_mask, divide
positive logits / multiply negative logits by `repetition_penalty`.
penalty > 1 suppresses repeats; penalty == 1.0 is a no-op. Applied BEFORE
temperature so the suppression is in raw-logit space.
2. Temperature scaling (skipped when temperature==0 → argmax greedy).
3+4. Fused top-k / top-p filtering. With top-k set, top-p only needs to look
at the k candidates instead of sorting the full 32k vocab — big speedup.
Honors the --temperature 0 greedy contract — but rep penalty still applies
before argmax, so greedy + penalty > 1 is a valid "deterministic anti-loop"
decoding mode.
"""
# 1) Repetition penalty — via fixed-size bool mask (constant cost per step,
# vs. gather/scatter over a growing index list which is slow on MPS).
if repetition_penalty != 1.0 and seen_mask is not None:
penalized = torch.where(logits > 0,
logits / repetition_penalty,
logits * repetition_penalty)
logits = torch.where(seen_mask, penalized, logits)
if temperature == 0:
return logits.argmax(dim=-1)
logits = logits / temperature
# 2+3) Fused top-k + top-p: when both are active, only sort the k candidates
# instead of the full 32k vocab. `topk(..., sorted=True)` already returns
# descending-ordered values, so we get the sort for free as a side effect
# of the top-k pass. Big speedup vs sorting the whole logits tensor.
if top_k > 0:
k = min(top_k, logits.size(-1))
topk_vals, topk_idx = logits.topk(k, dim=-1, sorted=True) # (N, k) desc
if top_p < 1.0:
# Nucleus over just the top-k — sort already done by topk()
topk_probs = F.softmax(topk_vals, dim=-1)
cumsum = topk_probs.cumsum(dim=-1)
mask = cumsum > top_p
mask[:, 1:] = mask[:, :-1].clone()
mask[:, 0] = False
topk_vals = topk_vals.masked_fill(mask, float('-inf'))
# Rebuild logits with only the kept candidates having real values
logits = torch.full_like(logits, float('-inf')).scatter_(1, topk_idx, topk_vals)
elif top_p < 1.0:
# No top-k → must sort the full vocab (slow path; avoid by always passing top_k)
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
sorted_probs = F.softmax(sorted_logits, dim=-1)
cumsum = sorted_probs.cumsum(dim=-1)
sorted_mask = cumsum > top_p
sorted_mask[:, 1:] = sorted_mask[:, :-1].clone()
sorted_mask[:, 0] = False
mask = torch.zeros_like(sorted_mask).scatter(1, sorted_indices, sorted_mask)
logits = logits.masked_fill(mask, float('-inf'))
probs = F.softmax(logits, dim=-1)
return torch.multinomial(probs, 1).squeeze(-1) # (N,)
@torch.no_grad()
def generate(model, tokenizer, BOS, prompt, n_samples=1,
max_new_tokens=256, temperature=0.8, top_k=50, top_p=1.0,
repetition_penalty=1.0, stream=True):
"""Batched generation. All N samples advance one token per MPS/CUDA step.
Returns list of N decoded strings. Streams to stdout only when n_samples==1
(multi-sample streaming would interleave illegibly).
For long generations (>500 tokens) where the model loops, raise
repetition_penalty to 1.1–1.2 or lower top_p to ~0.9.
"""
prompt_ids = [BOS] + tokenizer.encode(prompt).ids
ctx = torch.tensor([prompt_ids] * n_samples, device=device) # (N, T_prompt)
# Repetition-penalty memory: a (N, V) bool mask, True where the corresponding
# token id has already been emitted. Constant cost per step regardless of
# generation length — beats a growing token-list approach on MPS where
# gather/scatter over scattered indices is slow.
V = None # filled in after first forward (we don't know it yet)
seen_mask = None
t0 = time.time()
logits, states = model.forward_with_states(ctx)
next_logits = logits[:, -1, :].clone() # (N, V); clone so compile doesn't see a view
# Now we know V; build the mask and seed it with the prompt tokens.
V = next_logits.size(-1)
if repetition_penalty != 1.0:
seen_mask = torch.zeros(n_samples, V, dtype=torch.bool, device=device)
seen_mask.scatter_(1, ctx, True)
_sync()
t_prefill = time.time()
can_stream = stream and n_samples == 1
if can_stream:
print(prompt, end='', flush=True)
generated = [[] for _ in range(n_samples)]
done = [False] * n_samples
for _ in range(max_new_tokens):
next_toks = _sample_token_batch(next_logits, temperature, top_k, top_p,
repetition_penalty, seen_mask) # (N,)
next_toks_list = next_toks.tolist()
any_alive = False
for i in range(n_samples):
if done[i]:
continue
tid = next_toks_list[i]
if tid == BOS:
done[i] = True
else:
generated[i].append(tid)
any_alive = True
if not any_alive:
break
if can_stream:
# Per-token decode is O(1); avoids the O(N²) full-list re-decode each step.
# ByteLevel BPE pieces decode cleanly for ASCII English, which dominates FineWeb-Edu.
piece = tokenizer.decode([next_toks_list[0]])
sys.stdout.write(piece)
sys.stdout.flush()
# Mark the new token as seen (O(N) scatter — constant per step).
if seen_mask is not None:
seen_mask.scatter_(1, next_toks.unsqueeze(1), True)
x_in = next_toks.unsqueeze(1) # (N, 1)
step_logits, states = model.step(x_in, states)
next_logits = step_logits[:, 0, :]
_sync()
t_end = time.time()
if can_stream:
print()
total_n = sum(len(g) for g in generated)
prefill_ms = (t_prefill - t0) * 1000
decode_s = max(t_end - t_prefill, 1e-9)
print(f" [prefill {len(prompt_ids)}×{n_samples} tok in {prefill_ms:.0f}ms | "
f"decode {total_n} tok in {decode_s:.2f}s = {total_n/decode_s:.1f} tok/s aggregate]",
file=sys.stderr)
return [tokenizer.decode(g) for g in generated]
def load_model(ckpt_path, vocab_size, compile_model=False):
ckpt = torch.load(ckpt_path, map_location=device)
if 'config' in ckpt:
config = dict(ckpt['config'])
else:
print(f"note: checkpoint has no 'config' (older format); using fallback {FALLBACK_CONFIG}",
file=sys.stderr)
config = dict(FALLBACK_CONFIG)
config['vocab_size'] = vocab_size # vocab is set by the tokenizer, not the checkpoint
model = GPT.from_config(config).to(device)
# Strip torch.compile's `_orig_mod.` prefix if present in older checkpoints.
state = {k.removeprefix('_orig_mod.'): v for k, v in ckpt['model'].items()}
missing, unexpected = model.load_state_dict(state, strict=False)
if missing: print(f"warn: missing keys: {missing}", file=sys.stderr)
if unexpected: print(f"warn: unexpected keys: {unexpected}", file=sys.stderr)
model.eval()
if compile_model:
print("compiling step/prefill paths...", file=sys.stderr)
model.step = torch.compile(model.step, dynamic=False)
model.forward_with_states = torch.compile(model.forward_with_states, dynamic=False)
return model, ckpt, config
def main():
p = argparse.ArgumentParser(description="Sample from a trained microgpt Stream Mixer.")
p.add_argument('-p', '--prompt', type=str, default="In this article, we will explore")
p.add_argument('-n', '--max-tokens', type=int, default=256)
p.add_argument('-t', '--temperature', type=float, default=0.8,
help='softmax temperature; 0 = greedy argmax. Recommended 0.7–0.9 for prose. '
'(Determined empirically — see v4 log.md post-completion section.)')
p.add_argument('-k', '--top-k', type=int, default=50,
help='restrict sampling to top-k logits; 0 disables. '
'Recommended 50 — keeps the long tail of the 32k vocab from injecting weird tokens.')
p.add_argument('--top-p', type=float, default=1.0,
help='nucleus sampling: keep the smallest token set whose cumulative '
'probability ≥ top-p. 1.0 disables. Recommended 0.9 for long generations.')
p.add_argument('-r', '--repetition-penalty', type=float, default=1.0,
help='CTRL-style repetition penalty: tokens already in the context get '
'their logits divided by this. 1.0 = no penalty. Recommended 1.1–1.2 '
'for >500-token generations to break out of repetition wells.')
p.add_argument('-s', '--num-samples', type=int, default=1)
p.add_argument('--ckpt', type=str, default=DEFAULT_CKPT)
p.add_argument('--tokenizer', type=str, default=DEFAULT_TOKENIZER)
p.add_argument('--seed', type=int, default=None)
p.add_argument('--no-stream', action='store_true',
help='print full output at end instead of streaming token by token')
p.add_argument('--compile', action='store_true',
help='torch.compile the step/prefill paths (slower first call, faster after)')
p.add_argument('--repl', action='store_true',
help='after generating, read more prompts from stdin so --compile is paid once')
args = p.parse_args()
if args.seed is not None:
torch.manual_seed(args.seed)
if not os.path.exists(args.ckpt):
sys.exit(f"error: checkpoint not found: {args.ckpt}")
if not os.path.exists(args.tokenizer):
sys.exit(f"error: tokenizer not found: {args.tokenizer}")
print(f"device: {device}", file=sys.stderr)
tokenizer = Tokenizer.from_file(args.tokenizer)
BOS = tokenizer.token_to_id("<|bos|>")
model, ckpt, config = load_model(args.ckpt, tokenizer.get_vocab_size(),
compile_model=args.compile)
step = ckpt.get('step', '?')
best = ckpt.get('best_loss')
n_params = sum(p.numel() for p in model.parameters())
best_str = f" best_loss={best:.4f}" if isinstance(best, float) else ""
print(f"loaded {args.ckpt} step={step}{best_str}", file=sys.stderr)
print(f"config: {config}", file=sys.stderr)
print(f"params: {n_params:,}\n", file=sys.stderr)
stream = not args.no_stream
def run(prompt):
results = generate(model, tokenizer, BOS, prompt,
n_samples=args.num_samples,
max_new_tokens=args.max_tokens,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
repetition_penalty=args.repetition_penalty,
stream=stream)
# Single-sample streaming already printed; otherwise print collected results.
if not (stream and args.num_samples == 1):
for i, out in enumerate(results):
if args.num_samples > 1:
print(f"--- sample {i+1}/{args.num_samples} ---", file=sys.stderr)
print(prompt + out)
if args.num_samples > 1:
print(file=sys.stderr)
run(args.prompt)
# REPL: amortize torch.compile warmup over many prompts. Empty input
# reuses the previous prompt; "exit"/"quit"/Ctrl-D leaves the loop.
if args.repl:
last_prompt = args.prompt
print("\n[repl] enter prompt (empty=repeat, exit/quit/Ctrl-D to leave)", file=sys.stderr)
while True:
try:
line = input(">>> ")
except (EOFError, KeyboardInterrupt):
print(file=sys.stderr)
break
line = line.strip()
if line in ('exit', 'quit'):
break
prompt = line if line else last_prompt
last_prompt = prompt
run(prompt)
if __name__ == '__main__':
main()