abpt / src /fog /model_baseline.py
Search
sync: FOG micro+medium configs, stress tasks, fast pipeline
6ef010e
"""Baseline Transformer — standard uniform-width architecture."""
from __future__ import annotations
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from src.fog.config import FOGConfig
class BaselineAttention(nn.Module):
def __init__(self, d_model: int, n_heads: int):
super().__init__()
assert d_model % n_heads == 0
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.qkv = nn.Linear(d_model, 3 * d_model)
self.out = nn.Linear(d_model, d_model)
def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
b, t, _ = x.shape
qkv = self.qkv(x).view(b, t, 3, self.n_heads, self.head_dim)
q, k, v = qkv.unbind(dim=2) # each [B, T, H, D]
q = q.transpose(1, 2) # [B, H, T, D]
k = k.transpose(1, 2)
v = v.transpose(1, 2)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
if mask is not None:
scores = scores.masked_fill(mask == 0, float("-inf"))
attn = F.softmax(scores, dim=-1)
y = torch.matmul(attn, v)
y = y.transpose(1, 2).contiguous().view(b, t, -1)
return self.out(y)
class BaselineFFN(nn.Module):
def __init__(self, d_model: int, d_ff: int):
super().__init__()
self.up = nn.Linear(d_model, d_ff)
self.down = nn.Linear(d_ff, d_model)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.down(F.silu(self.up(x)))
class BaselineBlock(nn.Module):
def __init__(self, d_model: int, d_ff: int, n_heads: int, dropout: float):
super().__init__()
self.ln1 = nn.LayerNorm(d_model)
self.ln2 = nn.LayerNorm(d_model)
self.attn = BaselineAttention(d_model, n_heads)
self.ffn = BaselineFFN(d_model, d_ff)
self.drop = nn.Dropout(dropout)
def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
x = x + self.drop(self.attn(self.ln1(x), mask))
x = x + self.drop(self.ffn(self.ln2(x)))
return x
class BaselineTransformer(nn.Module):
def __init__(self, cfg: FOGConfig):
super().__init__()
self.cfg = cfg
self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model)
self.pos_emb = nn.Embedding(cfg.max_seq_len, cfg.d_model)
self.blocks = nn.ModuleList([
BaselineBlock(cfg.d_model, cfg.d_ff, cfg.n_heads, cfg.dropout)
for _ in range(cfg.n_layers)
])
self.ln_f = nn.LayerNorm(cfg.d_model)
self.head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
self.tok_emb.weight = self.head.weight # weight tying
# Pre-compute causal mask once
self.register_buffer(
"_causal_mask",
torch.tril(torch.ones(cfg.max_seq_len, cfg.max_seq_len, dtype=torch.bool)).unsqueeze(0).unsqueeze(0),
persistent=False,
)
def forward(
self,
input_ids: torch.Tensor,
targets: torch.Tensor | None = None,
loss_mask: torch.Tensor | None = None,
) -> dict[str, torch.Tensor]:
b, t = input_ids.shape
pos = torch.arange(t, device=input_ids.device).unsqueeze(0)
x = self.tok_emb(input_ids) + self.pos_emb(pos)
mask = self._causal_mask[:, :, :t, :t]
for block in self.blocks:
x = block(x, mask)
x = self.ln_f(x)
logits = self.head(x)
loss = None
if targets is not None:
if loss_mask is not None:
# only compute loss on target positions (after SEP)
flat_logits = logits.view(-1, logits.size(-1))
flat_targets = targets.view(-1)
flat_mask = loss_mask.view(-1).bool()
if flat_mask.any():
loss = F.cross_entropy(flat_logits[flat_mask], flat_targets[flat_mask])
else:
loss = torch.tensor(0.0, device=logits.device)
else:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
return {"logits": logits, "loss": loss}