""" GuppyLM — a tiny fish brain. Vanilla transformer: multi-head attention, ReLU FFN, LayerNorm, learned positional embeddings. No GQA, no SwiGLU, no parallel residual, no RoPE. As simple as it gets. """ import math import torch import torch.nn as nn import torch.nn.functional as F from config import GuppyConfig class Attention(nn.Module): def __init__(self, config): super().__init__() self.n_heads = config.n_heads self.head_dim = config.d_model // config.n_heads self.qkv = nn.Linear(config.d_model, 3 * config.d_model) self.out = nn.Linear(config.d_model, config.d_model) self.dropout = nn.Dropout(config.dropout) def forward(self, x, mask=None): B, T, C = x.shape qkv = self.qkv(x).reshape(B, T, 3, self.n_heads, self.head_dim).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] attn = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim) if mask is not None: attn = attn.masked_fill(mask == 0, float("-inf")) attn = self.dropout(F.softmax(attn, dim=-1)) return self.out((attn @ v).transpose(1, 2).contiguous().view(B, T, C)) class FFN(nn.Module): def __init__(self, config): super().__init__() self.up = nn.Linear(config.d_model, config.ffn_hidden) self.down = nn.Linear(config.ffn_hidden, config.d_model) self.dropout = nn.Dropout(config.dropout) def forward(self, x): return self.dropout(self.down(F.relu(self.up(x)))) class Block(nn.Module): def __init__(self, config): super().__init__() self.norm1 = nn.LayerNorm(config.d_model) self.attn = Attention(config) self.norm2 = nn.LayerNorm(config.d_model) self.ffn = FFN(config) def forward(self, x, mask=None): x = x + self.attn(self.norm1(x), mask) x = x + self.ffn(self.norm2(x)) return x class GuppyLM(nn.Module): def __init__(self, config: GuppyConfig): super().__init__() self.config = config self.tok_emb = nn.Embedding(config.vocab_size, config.d_model) self.pos_emb = nn.Embedding(config.max_seq_len, config.d_model) self.drop = nn.Dropout(config.dropout) self.blocks = nn.ModuleList([Block(config) for _ in range(config.n_layers)]) self.norm = nn.LayerNorm(config.d_model) self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) self.lm_head.weight = self.tok_emb.weight # tie weights self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): nn.init.normal_(m.weight, mean=0.0, std=0.02) if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.Embedding): nn.init.normal_(m.weight, mean=0.0, std=0.02) def forward(self, idx, targets=None): B, T = idx.shape pos = torch.arange(T, device=idx.device) x = self.drop(self.tok_emb(idx) + self.pos_emb(pos)) mask = torch.tril(torch.ones(T, T, device=idx.device)).unsqueeze(0).unsqueeze(0) for block in self.blocks: x = block(x, mask) logits = self.lm_head(self.norm(x)) loss = None if targets is not None: loss = F.cross_entropy( logits.view(-1, self.config.vocab_size), targets.view(-1), ignore_index=0, ) return logits, loss @torch.no_grad() def generate(self, idx, max_new_tokens=64, temperature=0.7, top_k=50, **kwargs): self.eval() for _ in range(max_new_tokens): idx_cond = idx[:, -self.config.max_seq_len:] logits, _ = self(idx_cond) logits = logits[:, -1, :] / temperature if top_k > 0: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = float("-inf") probs = F.softmax(logits, dim=-1) next_id = torch.multinomial(probs, num_samples=1) idx = torch.cat([idx, next_id], dim=1) if next_id.item() == self.config.eos_id: break return idx, [] def param_count(self): total = sum(p.numel() for p in self.parameters()) return total, 0 def param_summary(self): total, _ = self.param_count() return f"GuppyLM: {total:,} params ({total/1e6:.1f}M)"