"""Faster FOG variants that keep the same test protocol. Design goals: 1. keep motif-aware geometry, 2. reduce CPU cost through fused projections, 3. use grouped KV heads, 4. replace expensive stage-specific expand-space transforms with cheap low-rank adapters. """ from __future__ import annotations from dataclasses import dataclass import math import torch import torch.nn as nn import torch.nn.functional as F from src.fog.config import FOGConfig from src.fog.model_structured import LayerGeometry, build_layer_geometries def _choose_kv_heads(n_heads: int) -> int: if n_heads % 4 == 0: return max(1, n_heads // 4) if n_heads % 2 == 0: return max(1, n_heads // 2) return 1 class FastGroupedAttention(nn.Module): def __init__( self, d_model: int, d_compare: int, d_memory: int, n_heads: int, kv_heads: int | None = None, ) -> None: super().__init__() assert d_compare % n_heads == 0 assert d_memory % n_heads == 0 self.n_heads = n_heads self.kv_heads = kv_heads or _choose_kv_heads(n_heads) assert n_heads % self.kv_heads == 0 self.kv_repeat = n_heads // self.kv_heads self.compare_head_dim = d_compare // n_heads self.memory_head_dim = d_memory // n_heads self.d_compare = d_compare self.d_memory = d_memory total_out = ( d_compare + self.kv_heads * self.compare_head_dim + self.kv_heads * self.memory_head_dim ) self.in_proj = nn.Linear(d_model, total_out) self.out_proj = nn.Linear(d_memory, d_model) def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor: b, t, _ = x.shape packed = self.in_proj(x) q_end = self.d_compare k_end = q_end + self.kv_heads * self.compare_head_dim q, k, v = packed.split( [self.d_compare, self.kv_heads * self.compare_head_dim, self.kv_heads * self.memory_head_dim], dim=-1, ) q = q.view(b, t, self.n_heads, self.compare_head_dim).transpose(1, 2) k = k.view(b, t, self.kv_heads, self.compare_head_dim).transpose(1, 2) v = v.view(b, t, self.kv_heads, self.memory_head_dim).transpose(1, 2) if self.kv_repeat > 1: k = k.repeat_interleave(self.kv_repeat, dim=1) v = v.repeat_interleave(self.kv_repeat, dim=1) scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.compare_head_dim) if mask is not None: scores = scores.masked_fill(mask == 0, float("-inf")) attn = F.softmax(scores, dim=-1) y = torch.matmul(attn, v) y = y.transpose(1, 2).contiguous().view(b, t, self.d_memory) return self.out_proj(y) class FastMotifFFN(nn.Module): def __init__(self, d_model: int, d_expand: int, d_gate: int, dropout: float) -> None: super().__init__() self.fused_in = nn.Linear(d_model, d_expand + d_gate) self.gate_up = nn.Linear(d_gate, d_expand) self.compress = nn.Linear(d_expand, d_model) self.drop = nn.Dropout(dropout) def forward(self, x: torch.Tensor) -> torch.Tensor: packed = self.fused_in(x) expanded, gate_seed = packed.split([self.compress.in_features, self.gate_up.in_features], dim=-1) expanded = F.silu(expanded) gate = torch.sigmoid(self.gate_up(F.silu(gate_seed))) h = self.drop(expanded * gate) return self.compress(h) class FastMotifBlock(nn.Module): def __init__( self, d_model: int, d_compare: int, d_memory: int, d_expand: int, d_gate: int, n_heads: int, dropout: float, ) -> None: super().__init__() self.ln1 = nn.LayerNorm(d_model) self.ln2 = nn.LayerNorm(d_model) self.attn = FastGroupedAttention( d_model=d_model, d_compare=d_compare, d_memory=d_memory, n_heads=n_heads, kv_heads=n_heads, ) self.ffn = FastMotifFFN(d_model, d_expand, d_gate, dropout) self.drop = nn.Dropout(dropout) def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor: x = x + self.drop(self.attn(self.ln1(x), mask)) x = x + self.drop(self.ffn(self.ln2(x))) return x class FastMotifTransformer(nn.Module): def __init__(self, cfg: FOGConfig) -> None: super().__init__() self.cfg = cfg self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model) self.pos_emb = nn.Embedding(cfg.max_seq_len, cfg.d_model) self.blocks = nn.ModuleList( [ FastMotifBlock( d_model=cfg.d_model, d_compare=cfg.d_compare, d_memory=cfg.d_memory, d_expand=cfg.d_expand, d_gate=cfg.d_gate, n_heads=cfg.n_heads, dropout=cfg.dropout, ) for _ in range(cfg.n_layers) ] ) self.ln_f = nn.LayerNorm(cfg.d_model) self.head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False) self.tok_emb.weight = self.head.weight self.drop = nn.Dropout(cfg.dropout) self.register_buffer( "_causal_mask", torch.tril(torch.ones(cfg.max_seq_len, cfg.max_seq_len, dtype=torch.bool)).unsqueeze(0).unsqueeze(0), persistent=False, ) def forward( self, input_ids: torch.Tensor, targets: torch.Tensor | None = None, loss_mask: torch.Tensor | None = None, ) -> dict[str, torch.Tensor]: b, t = input_ids.shape pos = torch.arange(t, device=input_ids.device).unsqueeze(0) x = self.drop(self.tok_emb(input_ids) + self.pos_emb(pos)) mask = self._causal_mask[:, :, :t, :t] for block in self.blocks: x = block(x, mask) x = self.ln_f(x) logits = self.head(x) loss = None if targets is not None: if loss_mask is not None: flat_logits = logits.view(-1, logits.size(-1)) flat_targets = targets.view(-1) flat_mask = loss_mask.view(-1).bool() loss = F.cross_entropy(flat_logits[flat_mask], flat_targets[flat_mask]) if flat_mask.any() else torch.tensor(0.0, device=logits.device) else: loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) return {"logits": logits, "loss": loss} class FastStructuredFFN(nn.Module): def __init__(self, d_model: int, geometry: LayerGeometry, dropout: float) -> None: super().__init__() self.stage = geometry.stage self.d_expand = geometry.d_expand self.d_gate = geometry.d_gate self.fused_in = nn.Linear(d_model, geometry.d_expand + geometry.d_gate) self.gate_up = nn.Linear(geometry.d_gate, geometry.d_expand) self.compress = nn.Linear(geometry.d_expand, d_model) self.drop = nn.Dropout(dropout) if self.stage in ("middle", "late"): self.stage_adapter = nn.Linear(geometry.d_gate, geometry.d_expand) self.stage_scale = nn.Parameter(torch.tensor(0.10 if self.stage == "middle" else 0.08)) else: self.stage_adapter = None self.stage_scale = None def forward(self, x: torch.Tensor) -> torch.Tensor: packed = self.fused_in(x) expanded, gate_seed = packed.split([self.d_expand, self.d_gate], dim=-1) expanded = F.silu(expanded) gate_hidden = F.silu(gate_seed) gate = torch.sigmoid(self.gate_up(gate_hidden)) h = expanded * gate if self.stage_adapter is not None and self.stage_scale is not None: h = h + self.stage_scale * torch.tanh(self.stage_adapter(gate_hidden)) h = self.drop(h) return self.compress(h) class FastStructuredBlock(nn.Module): def __init__(self, d_model: int, n_heads: int, geometry: LayerGeometry, dropout: float) -> None: super().__init__() self.geometry = geometry self.ln1 = nn.LayerNorm(d_model) self.ln2 = nn.LayerNorm(d_model) self.attn = FastGroupedAttention( d_model=d_model, d_compare=geometry.d_compare, d_memory=geometry.d_memory, n_heads=n_heads, kv_heads=_choose_kv_heads(n_heads), ) self.ffn = FastStructuredFFN(d_model=d_model, geometry=geometry, dropout=dropout) self.drop = nn.Dropout(dropout) self.attn_scale = nn.Parameter(torch.tensor(float(geometry.residual_scale))) self.ffn_scale = nn.Parameter(torch.tensor(float(geometry.residual_scale))) def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor: x = x + self.attn_scale * self.drop(self.attn(self.ln1(x), mask)) x = x + self.ffn_scale * self.drop(self.ffn(self.ln2(x))) return x class FastStructuredMotifTransformer(nn.Module): def __init__(self, cfg: FOGConfig) -> None: super().__init__() self.cfg = cfg self.layer_geometries = build_layer_geometries(cfg) self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model) self.pos_emb = nn.Embedding(cfg.max_seq_len, cfg.d_model) self.drop = nn.Dropout(cfg.dropout) self.blocks = nn.ModuleList( [ FastStructuredBlock( d_model=cfg.d_model, n_heads=cfg.n_heads, geometry=geometry, dropout=cfg.dropout, ) for geometry in self.layer_geometries ] ) self.ln_f = nn.LayerNorm(cfg.d_model) self.head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False) self.tok_emb.weight = self.head.weight self.register_buffer( "_causal_mask", torch.tril(torch.ones(cfg.max_seq_len, cfg.max_seq_len, dtype=torch.bool)).unsqueeze(0).unsqueeze(0), persistent=False, ) def forward( self, input_ids: torch.Tensor, targets: torch.Tensor | None = None, loss_mask: torch.Tensor | None = None, ) -> dict[str, torch.Tensor | list[dict[str, int | str]]]: b, t = input_ids.shape pos = torch.arange(t, device=input_ids.device).unsqueeze(0) x = self.drop(self.tok_emb(input_ids) + self.pos_emb(pos)) mask = self._causal_mask[:, :, :t, :t] for block in self.blocks: x = block(x, mask) x = self.ln_f(x) logits = self.head(x) loss = None if targets is not None: if loss_mask is not None: flat_logits = logits.view(-1, logits.size(-1)) flat_targets = targets.view(-1) flat_mask = loss_mask.view(-1).bool() loss = F.cross_entropy(flat_logits[flat_mask], flat_targets[flat_mask]) if flat_mask.any() else torch.tensor(0.0, device=logits.device) else: loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) geometry_summary = [ { "stage": g.stage, "d_compare": g.d_compare, "d_memory": g.d_memory, "d_expand": g.d_expand, "d_gate": g.d_gate, } for g in self.layer_geometries ] return {"logits": logits, "loss": loss, "geometry": geometry_summary}