| """ |
| Nano-SLM: a tiny decoder-only transformer (~1M params). |
| |
| Architecture is intentionally minimal so every line is readable. |
| Mirrors the standard GPT recipe: token + position embeddings, N stacked |
| (causal self-attention -> MLP) blocks with pre-LayerNorm and residuals, |
| final LayerNorm, and a tied LM head. |
| """ |
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| class CausalSelfAttention(nn.Module): |
| """Multi-head causal self-attention. Uses fused QKV and PyTorch's SDPA.""" |
|
|
| def __init__(self, d_model, n_heads, dropout=0.1): |
| super().__init__() |
| assert d_model % n_heads == 0 |
| self.n_heads = n_heads |
| self.head_dim = d_model // n_heads |
| |
| self.qkv = nn.Linear(d_model, 3 * d_model, bias=False) |
| self.proj = nn.Linear(d_model, d_model, bias=False) |
| self.attn_dropout_p = dropout |
| self.resid_dropout = nn.Dropout(dropout) |
|
|
| def forward(self, x): |
| B, T, C = x.shape |
| q, k, v = self.qkv(x).split(C, dim=-1) |
| |
| q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) |
| k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) |
| v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) |
| |
| y = F.scaled_dot_product_attention( |
| q, k, v, |
| is_causal=True, |
| dropout_p=self.attn_dropout_p if self.training else 0.0, |
| ) |
| y = y.transpose(1, 2).contiguous().view(B, T, C) |
| return self.resid_dropout(self.proj(y)) |
|
|
|
|
| class MLP(nn.Module): |
| """Position-wise feed-forward (GELU).""" |
|
|
| def __init__(self, d_model, ffn_dim, dropout=0.1): |
| super().__init__() |
| self.fc1 = nn.Linear(d_model, ffn_dim, bias=False) |
| self.fc2 = nn.Linear(ffn_dim, d_model, bias=False) |
| self.dropout = nn.Dropout(dropout) |
|
|
| def forward(self, x): |
| return self.dropout(self.fc2(F.gelu(self.fc1(x)))) |
|
|
|
|
| class Block(nn.Module): |
| """Pre-LN transformer block: x = x + attn(LN(x)); x = x + mlp(LN(x)).""" |
|
|
| def __init__(self, d_model, n_heads, ffn_dim, dropout=0.1): |
| super().__init__() |
| self.ln1 = nn.LayerNorm(d_model) |
| self.attn = CausalSelfAttention(d_model, n_heads, dropout) |
| self.ln2 = nn.LayerNorm(d_model) |
| self.mlp = MLP(d_model, ffn_dim, dropout) |
|
|
| def forward(self, x): |
| x = x + self.attn(self.ln1(x)) |
| x = x + self.mlp(self.ln2(x)) |
| return x |
|
|
|
|
| class NanoSLM(nn.Module): |
| def __init__( |
| self, |
| vocab_size=4096, |
| d_model=128, |
| n_heads=4, |
| n_layers=4, |
| ffn_dim=512, |
| ctx_len=256, |
| dropout=0.1, |
| ): |
| super().__init__() |
| self.ctx_len = ctx_len |
| self.tok_emb = nn.Embedding(vocab_size, d_model) |
| self.pos_emb = nn.Embedding(ctx_len, d_model) |
| self.drop = nn.Dropout(dropout) |
| self.blocks = nn.ModuleList( |
| [Block(d_model, n_heads, ffn_dim, dropout) for _ in range(n_layers)] |
| ) |
| self.ln_f = nn.LayerNorm(d_model) |
| self.head = nn.Linear(d_model, vocab_size, bias=False) |
| |
| |
| self.head.weight = self.tok_emb.weight |
|
|
| self.apply(self._init_weights) |
| |
| for name, p in self.named_parameters(): |
| if name.endswith("proj.weight") or name.endswith("fc2.weight"): |
| nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * n_layers)) |
|
|
| def _init_weights(self, m): |
| if isinstance(m, nn.Linear): |
| nn.init.normal_(m.weight, mean=0.0, std=0.02) |
| if m.bias is not None: |
| nn.init.zeros_(m.bias) |
| elif isinstance(m, nn.Embedding): |
| nn.init.normal_(m.weight, mean=0.0, std=0.02) |
|
|
| def num_params(self, non_embedding=False): |
| n = sum(p.numel() for p in self.parameters()) |
| if non_embedding: |
| n -= self.tok_emb.weight.numel() |
| n -= self.pos_emb.weight.numel() |
| return n |
|
|
| def forward(self, idx, targets=None): |
| B, T = idx.shape |
| assert T <= self.ctx_len, f"sequence length {T} > ctx_len {self.ctx_len}" |
| pos = torch.arange(T, device=idx.device) |
| x = self.drop(self.tok_emb(idx) + self.pos_emb(pos)) |
| for block in self.blocks: |
| x = block(x) |
| x = self.ln_f(x) |
| logits = self.head(x) |
| loss = None |
| if targets is not None: |
| loss = F.cross_entropy( |
| logits.view(-1, logits.size(-1)), |
| targets.view(-1), |
| ignore_index=-100, |
| ) |
| return logits, loss |
|
|
| @torch.no_grad() |
| def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): |
| """Autoregressive sampling. Slow on purpose: no KV cache (a great upgrade later).""" |
| for _ in range(max_new_tokens): |
| idx_cond = idx[:, -self.ctx_len:] |
| logits, _ = self(idx_cond) |
| logits = logits[:, -1, :] / temperature |
| if top_k is not None: |
| v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
| logits[logits < v[:, [-1]]] = -float("inf") |
| probs = F.softmax(logits, dim=-1) |
| next_tok = torch.multinomial(probs, num_samples=1) |
| idx = torch.cat([idx, next_tok], dim=1) |
| return idx |
|
|