""" Nano-SLM: a tiny decoder-only transformer (~1M params). Architecture is intentionally minimal so every line is readable. Mirrors the standard GPT recipe: token + position embeddings, N stacked (causal self-attention -> MLP) blocks with pre-LayerNorm and residuals, final LayerNorm, and a tied LM head. """ import math import torch import torch.nn as nn import torch.nn.functional as F class CausalSelfAttention(nn.Module): """Multi-head causal self-attention. Uses fused QKV and PyTorch's SDPA.""" def __init__(self, d_model, n_heads, dropout=0.1): super().__init__() assert d_model % n_heads == 0 self.n_heads = n_heads self.head_dim = d_model // n_heads # one big linear that produces Q, K, V at once self.qkv = nn.Linear(d_model, 3 * d_model, bias=False) self.proj = nn.Linear(d_model, d_model, bias=False) self.attn_dropout_p = dropout self.resid_dropout = nn.Dropout(dropout) def forward(self, x): B, T, C = x.shape q, k, v = self.qkv(x).split(C, dim=-1) # reshape to (B, n_heads, T, head_dim) q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) # Flash/SDPA: causal mask + scaling handled internally y = F.scaled_dot_product_attention( q, k, v, is_causal=True, dropout_p=self.attn_dropout_p if self.training else 0.0, ) y = y.transpose(1, 2).contiguous().view(B, T, C) return self.resid_dropout(self.proj(y)) class MLP(nn.Module): """Position-wise feed-forward (GELU).""" def __init__(self, d_model, ffn_dim, dropout=0.1): super().__init__() self.fc1 = nn.Linear(d_model, ffn_dim, bias=False) self.fc2 = nn.Linear(ffn_dim, d_model, bias=False) self.dropout = nn.Dropout(dropout) def forward(self, x): return self.dropout(self.fc2(F.gelu(self.fc1(x)))) class Block(nn.Module): """Pre-LN transformer block: x = x + attn(LN(x)); x = x + mlp(LN(x)).""" def __init__(self, d_model, n_heads, ffn_dim, dropout=0.1): super().__init__() self.ln1 = nn.LayerNorm(d_model) self.attn = CausalSelfAttention(d_model, n_heads, dropout) self.ln2 = nn.LayerNorm(d_model) self.mlp = MLP(d_model, ffn_dim, dropout) def forward(self, x): x = x + self.attn(self.ln1(x)) x = x + self.mlp(self.ln2(x)) return x class NanoSLM(nn.Module): def __init__( self, vocab_size=4096, d_model=128, n_heads=4, n_layers=4, ffn_dim=512, ctx_len=256, dropout=0.1, ): super().__init__() self.ctx_len = ctx_len self.tok_emb = nn.Embedding(vocab_size, d_model) self.pos_emb = nn.Embedding(ctx_len, d_model) self.drop = nn.Dropout(dropout) self.blocks = nn.ModuleList( [Block(d_model, n_heads, ffn_dim, dropout) for _ in range(n_layers)] ) self.ln_f = nn.LayerNorm(d_model) self.head = nn.Linear(d_model, vocab_size, bias=False) # weight tying: input embedding and output projection share weights. # saves a lot of params at small vocab sizes and usually helps quality. self.head.weight = self.tok_emb.weight self.apply(self._init_weights) # scaled init for residual projections (GPT-2 trick) for name, p in self.named_parameters(): if name.endswith("proj.weight") or name.endswith("fc2.weight"): nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * n_layers)) def _init_weights(self, m): if isinstance(m, nn.Linear): nn.init.normal_(m.weight, mean=0.0, std=0.02) if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.Embedding): nn.init.normal_(m.weight, mean=0.0, std=0.02) def num_params(self, non_embedding=False): n = sum(p.numel() for p in self.parameters()) if non_embedding: n -= self.tok_emb.weight.numel() n -= self.pos_emb.weight.numel() return n def forward(self, idx, targets=None): B, T = idx.shape assert T <= self.ctx_len, f"sequence length {T} > ctx_len {self.ctx_len}" pos = torch.arange(T, device=idx.device) x = self.drop(self.tok_emb(idx) + self.pos_emb(pos)) for block in self.blocks: x = block(x) x = self.ln_f(x) logits = self.head(x) loss = None if targets is not None: loss = F.cross_entropy( logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-100, ) return logits, loss @torch.no_grad() def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): """Autoregressive sampling. Slow on purpose: no KV cache (a great upgrade later).""" for _ in range(max_new_tokens): idx_cond = idx[:, -self.ctx_len:] logits, _ = self(idx_cond) logits = logits[:, -1, :] / temperature if top_k is not None: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = -float("inf") probs = F.softmax(logits, dim=-1) next_tok = torch.multinomial(probs, num_samples=1) idx = torch.cat([idx, next_tok], dim=1) return idx