microgpt / inference.py
brettleehari's picture
Initial microGPT upload
14c107a verified
"""
Inference helper for Nano-SLM.
Wraps the model + tokenizer into a clean `generate()` function suitable for
demos, notebooks, or a Gradio interface.
Usage:
from inference import NanoSLMInference
slm = NanoSLMInference("out/ckpt.pt", "data/tokenizer.json")
text = slm.generate("Once upon a time", max_new_tokens=200, temperature=0.8)
print(text)
"""
import torch
import torch.nn.functional as F
from tokenizers import Tokenizer
from model import NanoSLM
# Must match the architecture used during training.
DEFAULT_CFG = dict(
vocab_size=4096, d_model=128, n_heads=4, n_layers=4,
ffn_dim=512, ctx_len=256, dropout=0.0,
)
class NanoSLMInference:
def __init__(self, ckpt_path, tokenizer_path, device=None, cfg=None):
if device is None:
if torch.backends.mps.is_available():
device = "mps"
elif torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
self.device = device
self.tokenizer = Tokenizer.from_file(tokenizer_path)
cfg = cfg or DEFAULT_CFG
self.model = NanoSLM(**cfg)
ckpt = torch.load(ckpt_path, map_location=device)
# support both raw state_dicts and {"model": ...} checkpoints
state = ckpt["model"] if isinstance(ckpt, dict) and "model" in ckpt else ckpt
self.model.load_state_dict(state)
self.model.to(device).eval()
self.ctx_len = cfg["ctx_len"]
@torch.no_grad()
def generate(
self,
prompt: str,
max_new_tokens: int = 200,
temperature: float = 0.8,
top_k: int | None = 40,
top_p: float | None = None,
seed: int | None = None,
) -> str:
"""Generate continuation for a prompt.
Args:
prompt: input text
max_new_tokens: how many tokens to generate
temperature: 0 = greedy, 1.0 = no scaling, >1 = more random
top_k: keep only the k highest-prob tokens (None = no filter)
top_p: nucleus — keep smallest set with cumulative prob >= p
seed: for reproducibility
"""
if seed is not None:
torch.manual_seed(seed)
ids = self.tokenizer.encode(prompt).ids
x = torch.tensor([ids], dtype=torch.long, device=self.device)
for _ in range(max_new_tokens):
# truncate context if it grows past ctx_len
x_cond = x[:, -self.ctx_len:]
logits, _ = self.model(x_cond)
# we only care about the prediction for the next token
logits = logits[:, -1, :]
if temperature == 0.0:
# greedy: pick the argmax
next_tok = logits.argmax(dim=-1, keepdim=True)
else:
logits = logits / temperature
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float("inf")
if top_p is not None:
sorted_logits, sorted_idx = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# mask tokens past the nucleus
mask = cum_probs > top_p
# shift right so we always keep at least one token
mask[..., 1:] = mask[..., :-1].clone()
mask[..., 0] = False
sorted_logits[mask] = -float("inf")
# unsort back to original vocab order
logits = torch.zeros_like(logits).scatter_(1, sorted_idx, sorted_logits)
probs = F.softmax(logits, dim=-1)
next_tok = torch.multinomial(probs, num_samples=1)
x = torch.cat([x, next_tok], dim=1)
return self.tokenizer.decode(x[0].tolist())
if __name__ == "__main__":
# quick self-test
slm = NanoSLMInference("out/ckpt.pt", "data/tokenizer.json")
print(slm.generate("Once upon a time", max_new_tokens=100, temperature=0.8, top_k=40))