""" Inference helper for Nano-SLM. Wraps the model + tokenizer into a clean `generate()` function suitable for demos, notebooks, or a Gradio interface. Usage: from inference import NanoSLMInference slm = NanoSLMInference("out/ckpt.pt", "data/tokenizer.json") text = slm.generate("Once upon a time", max_new_tokens=200, temperature=0.8) print(text) """ import torch import torch.nn.functional as F from tokenizers import Tokenizer from model import NanoSLM # Must match the architecture used during training. DEFAULT_CFG = dict( vocab_size=4096, d_model=128, n_heads=4, n_layers=4, ffn_dim=512, ctx_len=256, dropout=0.0, ) class NanoSLMInference: def __init__(self, ckpt_path, tokenizer_path, device=None, cfg=None): if device is None: if torch.backends.mps.is_available(): device = "mps" elif torch.cuda.is_available(): device = "cuda" else: device = "cpu" self.device = device self.tokenizer = Tokenizer.from_file(tokenizer_path) cfg = cfg or DEFAULT_CFG self.model = NanoSLM(**cfg) ckpt = torch.load(ckpt_path, map_location=device) # support both raw state_dicts and {"model": ...} checkpoints state = ckpt["model"] if isinstance(ckpt, dict) and "model" in ckpt else ckpt self.model.load_state_dict(state) self.model.to(device).eval() self.ctx_len = cfg["ctx_len"] @torch.no_grad() def generate( self, prompt: str, max_new_tokens: int = 200, temperature: float = 0.8, top_k: int | None = 40, top_p: float | None = None, seed: int | None = None, ) -> str: """Generate continuation for a prompt. Args: prompt: input text max_new_tokens: how many tokens to generate temperature: 0 = greedy, 1.0 = no scaling, >1 = more random top_k: keep only the k highest-prob tokens (None = no filter) top_p: nucleus — keep smallest set with cumulative prob >= p seed: for reproducibility """ if seed is not None: torch.manual_seed(seed) ids = self.tokenizer.encode(prompt).ids x = torch.tensor([ids], dtype=torch.long, device=self.device) for _ in range(max_new_tokens): # truncate context if it grows past ctx_len x_cond = x[:, -self.ctx_len:] logits, _ = self.model(x_cond) # we only care about the prediction for the next token logits = logits[:, -1, :] if temperature == 0.0: # greedy: pick the argmax next_tok = logits.argmax(dim=-1, keepdim=True) else: logits = logits / temperature if top_k is not None: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = -float("inf") if top_p is not None: sorted_logits, sorted_idx = torch.sort(logits, descending=True) cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) # mask tokens past the nucleus mask = cum_probs > top_p # shift right so we always keep at least one token mask[..., 1:] = mask[..., :-1].clone() mask[..., 0] = False sorted_logits[mask] = -float("inf") # unsort back to original vocab order logits = torch.zeros_like(logits).scatter_(1, sorted_idx, sorted_logits) probs = F.softmax(logits, dim=-1) next_tok = torch.multinomial(probs, num_samples=1) x = torch.cat([x, next_tok], dim=1) return self.tokenizer.decode(x[0].tolist()) if __name__ == "__main__": # quick self-test slm = NanoSLMInference("out/ckpt.pt", "data/tokenizer.json") print(slm.generate("Once upon a time", max_new_tokens=100, temperature=0.8, top_k=40))