Spaces:
Sleeping
Sleeping
| """Baseline Transformer — standard uniform-width architecture.""" | |
| from __future__ import annotations | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from src.fog.config import FOGConfig | |
| class BaselineAttention(nn.Module): | |
| def __init__(self, d_model: int, n_heads: int): | |
| super().__init__() | |
| assert d_model % n_heads == 0 | |
| self.n_heads = n_heads | |
| self.head_dim = d_model // n_heads | |
| self.qkv = nn.Linear(d_model, 3 * d_model) | |
| self.out = nn.Linear(d_model, d_model) | |
| def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor: | |
| b, t, _ = x.shape | |
| qkv = self.qkv(x).view(b, t, 3, self.n_heads, self.head_dim) | |
| q, k, v = qkv.unbind(dim=2) # each [B, T, H, D] | |
| q = q.transpose(1, 2) # [B, H, T, D] | |
| k = k.transpose(1, 2) | |
| v = v.transpose(1, 2) | |
| scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) | |
| if mask is not None: | |
| scores = scores.masked_fill(mask == 0, float("-inf")) | |
| attn = F.softmax(scores, dim=-1) | |
| y = torch.matmul(attn, v) | |
| y = y.transpose(1, 2).contiguous().view(b, t, -1) | |
| return self.out(y) | |
| class BaselineFFN(nn.Module): | |
| def __init__(self, d_model: int, d_ff: int): | |
| super().__init__() | |
| self.up = nn.Linear(d_model, d_ff) | |
| self.down = nn.Linear(d_ff, d_model) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.down(F.silu(self.up(x))) | |
| class BaselineBlock(nn.Module): | |
| def __init__(self, d_model: int, d_ff: int, n_heads: int, dropout: float): | |
| super().__init__() | |
| self.ln1 = nn.LayerNorm(d_model) | |
| self.ln2 = nn.LayerNorm(d_model) | |
| self.attn = BaselineAttention(d_model, n_heads) | |
| self.ffn = BaselineFFN(d_model, d_ff) | |
| self.drop = nn.Dropout(dropout) | |
| def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor: | |
| x = x + self.drop(self.attn(self.ln1(x), mask)) | |
| x = x + self.drop(self.ffn(self.ln2(x))) | |
| return x | |
| class BaselineTransformer(nn.Module): | |
| def __init__(self, cfg: FOGConfig): | |
| super().__init__() | |
| self.cfg = cfg | |
| self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model) | |
| self.pos_emb = nn.Embedding(cfg.max_seq_len, cfg.d_model) | |
| self.blocks = nn.ModuleList([ | |
| BaselineBlock(cfg.d_model, cfg.d_ff, cfg.n_heads, cfg.dropout) | |
| for _ in range(cfg.n_layers) | |
| ]) | |
| self.ln_f = nn.LayerNorm(cfg.d_model) | |
| self.head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False) | |
| self.tok_emb.weight = self.head.weight # weight tying | |
| # Pre-compute causal mask once | |
| self.register_buffer( | |
| "_causal_mask", | |
| torch.tril(torch.ones(cfg.max_seq_len, cfg.max_seq_len, dtype=torch.bool)).unsqueeze(0).unsqueeze(0), | |
| persistent=False, | |
| ) | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor, | |
| targets: torch.Tensor | None = None, | |
| loss_mask: torch.Tensor | None = None, | |
| ) -> dict[str, torch.Tensor]: | |
| b, t = input_ids.shape | |
| pos = torch.arange(t, device=input_ids.device).unsqueeze(0) | |
| x = self.tok_emb(input_ids) + self.pos_emb(pos) | |
| mask = self._causal_mask[:, :, :t, :t] | |
| for block in self.blocks: | |
| x = block(x, mask) | |
| x = self.ln_f(x) | |
| logits = self.head(x) | |
| loss = None | |
| if targets is not None: | |
| if loss_mask is not None: | |
| # only compute loss on target positions (after SEP) | |
| flat_logits = logits.view(-1, logits.size(-1)) | |
| flat_targets = targets.view(-1) | |
| flat_mask = loss_mask.view(-1).bool() | |
| if flat_mask.any(): | |
| loss = F.cross_entropy(flat_logits[flat_mask], flat_targets[flat_mask]) | |
| else: | |
| loss = torch.tensor(0.0, device=logits.device) | |
| else: | |
| loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) | |
| return {"logits": logits, "loss": loss} | |