| """ |
| GuppyLM — a tiny fish brain. |
| |
| Vanilla transformer: multi-head attention, ReLU FFN, LayerNorm, learned positional embeddings. |
| No GQA, no SwiGLU, no parallel residual, no RoPE. As simple as it gets. |
| """ |
|
|
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from config import GuppyConfig |
|
|
|
|
| class Attention(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.n_heads = config.n_heads |
| self.head_dim = config.d_model // config.n_heads |
|
|
| self.qkv = nn.Linear(config.d_model, 3 * config.d_model) |
| self.out = nn.Linear(config.d_model, config.d_model) |
| self.dropout = nn.Dropout(config.dropout) |
|
|
| def forward(self, x, mask=None): |
| B, T, C = x.shape |
| qkv = self.qkv(x).reshape(B, T, 3, self.n_heads, self.head_dim).permute(2, 0, 3, 1, 4) |
| q, k, v = qkv[0], qkv[1], qkv[2] |
|
|
| attn = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim) |
| if mask is not None: |
| attn = attn.masked_fill(mask == 0, float("-inf")) |
| attn = self.dropout(F.softmax(attn, dim=-1)) |
| return self.out((attn @ v).transpose(1, 2).contiguous().view(B, T, C)) |
|
|
|
|
| class FFN(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.up = nn.Linear(config.d_model, config.ffn_hidden) |
| self.down = nn.Linear(config.ffn_hidden, config.d_model) |
| self.dropout = nn.Dropout(config.dropout) |
|
|
| def forward(self, x): |
| return self.dropout(self.down(F.relu(self.up(x)))) |
|
|
|
|
| class Block(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.norm1 = nn.LayerNorm(config.d_model) |
| self.attn = Attention(config) |
| self.norm2 = nn.LayerNorm(config.d_model) |
| self.ffn = FFN(config) |
|
|
| def forward(self, x, mask=None): |
| x = x + self.attn(self.norm1(x), mask) |
| x = x + self.ffn(self.norm2(x)) |
| return x |
|
|
|
|
| class GuppyLM(nn.Module): |
| def __init__(self, config: GuppyConfig): |
| super().__init__() |
| self.config = config |
|
|
| self.tok_emb = nn.Embedding(config.vocab_size, config.d_model) |
| self.pos_emb = nn.Embedding(config.max_seq_len, config.d_model) |
| self.drop = nn.Dropout(config.dropout) |
| self.blocks = nn.ModuleList([Block(config) for _ in range(config.n_layers)]) |
| self.norm = nn.LayerNorm(config.d_model) |
| self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) |
| self.lm_head.weight = self.tok_emb.weight |
|
|
| self.apply(self._init_weights) |
|
|
| def _init_weights(self, m): |
| if isinstance(m, nn.Linear): |
| nn.init.normal_(m.weight, mean=0.0, std=0.02) |
| if m.bias is not None: |
| nn.init.zeros_(m.bias) |
| elif isinstance(m, nn.Embedding): |
| nn.init.normal_(m.weight, mean=0.0, std=0.02) |
|
|
| def forward(self, idx, targets=None): |
| B, T = idx.shape |
| pos = torch.arange(T, device=idx.device) |
| x = self.drop(self.tok_emb(idx) + self.pos_emb(pos)) |
| mask = torch.tril(torch.ones(T, T, device=idx.device)).unsqueeze(0).unsqueeze(0) |
|
|
| for block in self.blocks: |
| x = block(x, mask) |
|
|
| logits = self.lm_head(self.norm(x)) |
|
|
| loss = None |
| if targets is not None: |
| loss = F.cross_entropy( |
| logits.view(-1, self.config.vocab_size), |
| targets.view(-1), |
| ignore_index=0, |
| ) |
|
|
| return logits, loss |
|
|
| @torch.no_grad() |
| def generate(self, idx, max_new_tokens=64, temperature=0.7, top_k=50, **kwargs): |
| self.eval() |
| for _ in range(max_new_tokens): |
| idx_cond = idx[:, -self.config.max_seq_len:] |
| logits, _ = self(idx_cond) |
| logits = logits[:, -1, :] / temperature |
| if top_k > 0: |
| v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
| logits[logits < v[:, [-1]]] = float("-inf") |
| probs = F.softmax(logits, dim=-1) |
| next_id = torch.multinomial(probs, num_samples=1) |
| idx = torch.cat([idx, next_id], dim=1) |
| if next_id.item() == self.config.eos_id: |
| break |
| return idx, [] |
|
|
| def param_count(self): |
| total = sum(p.numel() for p in self.parameters()) |
| return total, 0 |
|
|
| def param_summary(self): |
| total, _ = self.param_count() |
| return f"GuppyLM: {total:,} params ({total/1e6:.1f}M)" |
|
|