| """ |
| Inference helper for Nano-SLM. |
| |
| Wraps the model + tokenizer into a clean `generate()` function suitable for |
| demos, notebooks, or a Gradio interface. |
| |
| Usage: |
| from inference import NanoSLMInference |
| slm = NanoSLMInference("out/ckpt.pt", "data/tokenizer.json") |
| text = slm.generate("Once upon a time", max_new_tokens=200, temperature=0.8) |
| print(text) |
| """ |
| import torch |
| import torch.nn.functional as F |
| from tokenizers import Tokenizer |
| from model import NanoSLM |
|
|
|
|
| |
| DEFAULT_CFG = dict( |
| vocab_size=4096, d_model=128, n_heads=4, n_layers=4, |
| ffn_dim=512, ctx_len=256, dropout=0.0, |
| ) |
|
|
|
|
| class NanoSLMInference: |
| def __init__(self, ckpt_path, tokenizer_path, device=None, cfg=None): |
| if device is None: |
| if torch.backends.mps.is_available(): |
| device = "mps" |
| elif torch.cuda.is_available(): |
| device = "cuda" |
| else: |
| device = "cpu" |
| self.device = device |
|
|
| self.tokenizer = Tokenizer.from_file(tokenizer_path) |
|
|
| cfg = cfg or DEFAULT_CFG |
| self.model = NanoSLM(**cfg) |
| ckpt = torch.load(ckpt_path, map_location=device) |
| |
| state = ckpt["model"] if isinstance(ckpt, dict) and "model" in ckpt else ckpt |
| self.model.load_state_dict(state) |
| self.model.to(device).eval() |
| self.ctx_len = cfg["ctx_len"] |
|
|
| @torch.no_grad() |
| def generate( |
| self, |
| prompt: str, |
| max_new_tokens: int = 200, |
| temperature: float = 0.8, |
| top_k: int | None = 40, |
| top_p: float | None = None, |
| seed: int | None = None, |
| ) -> str: |
| """Generate continuation for a prompt. |
| |
| Args: |
| prompt: input text |
| max_new_tokens: how many tokens to generate |
| temperature: 0 = greedy, 1.0 = no scaling, >1 = more random |
| top_k: keep only the k highest-prob tokens (None = no filter) |
| top_p: nucleus — keep smallest set with cumulative prob >= p |
| seed: for reproducibility |
| """ |
| if seed is not None: |
| torch.manual_seed(seed) |
|
|
| ids = self.tokenizer.encode(prompt).ids |
| x = torch.tensor([ids], dtype=torch.long, device=self.device) |
|
|
| for _ in range(max_new_tokens): |
| |
| x_cond = x[:, -self.ctx_len:] |
| logits, _ = self.model(x_cond) |
| |
| logits = logits[:, -1, :] |
|
|
| if temperature == 0.0: |
| |
| next_tok = logits.argmax(dim=-1, keepdim=True) |
| else: |
| logits = logits / temperature |
|
|
| if top_k is not None: |
| v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
| logits[logits < v[:, [-1]]] = -float("inf") |
|
|
| if top_p is not None: |
| sorted_logits, sorted_idx = torch.sort(logits, descending=True) |
| cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
| |
| mask = cum_probs > top_p |
| |
| mask[..., 1:] = mask[..., :-1].clone() |
| mask[..., 0] = False |
| sorted_logits[mask] = -float("inf") |
| |
| logits = torch.zeros_like(logits).scatter_(1, sorted_idx, sorted_logits) |
|
|
| probs = F.softmax(logits, dim=-1) |
| next_tok = torch.multinomial(probs, num_samples=1) |
|
|
| x = torch.cat([x, next_tok], dim=1) |
|
|
| return self.tokenizer.decode(x[0].tolist()) |
|
|
|
|
| if __name__ == "__main__": |
| |
| slm = NanoSLMInference("out/ckpt.pt", "data/tokenizer.json") |
| print(slm.generate("Once upon a time", max_new_tokens=100, temperature=0.8, top_k=40)) |
|
|