| """ |
| Baseline attention mechanisms for comparison with H4 Polytopic Attention. |
| |
| Implements standard softmax attention and linear attention (Katharopoulos et al. 2020) |
| with the SAME model wrapper (embeddings, FFN, LM head) so the only variable is attention. |
| |
| Usage: |
| model = BaselineLanguageModel(vocab_size=128, d_model=128, n_heads=8, |
| n_layers=4, d_value=16, d_ffn=512, |
| attention_type='softmax') # or 'linear' |
| """ |
|
|
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| import sys |
| import os |
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
|
|
| from utils.phi_positional import PhiPositionalEncoding |
| from bitlinear import BitLinear |
|
|
|
|
| |
| |
| |
|
|
| class SoftmaxAttention(nn.Module): |
| """Standard multi-head scaled dot-product attention with causal mask.""" |
|
|
| def __init__(self, d_model, n_heads, d_value, dropout=0.0, use_bitlinear=False): |
| super().__init__() |
| self.d_model = d_model |
| self.n_heads = n_heads |
| self.d_head = d_model // n_heads |
| self.d_value = d_value |
| self.scale = 1.0 / math.sqrt(self.d_head) |
|
|
| Linear = BitLinear if use_bitlinear else nn.Linear |
|
|
| self.W_q = Linear(d_model, self.d_head * n_heads, bias=False) |
| self.W_k = Linear(d_model, self.d_head * n_heads, bias=False) |
| self.W_v = Linear(d_model, d_value * n_heads, bias=False) |
| self.W_out = Linear(d_value * n_heads, d_model, bias=False) |
| self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity() |
|
|
| def forward(self, x, **kwargs): |
| B, T, D = x.shape |
| H = self.n_heads |
|
|
| Q = self.W_q(x).view(B, T, H, self.d_head).transpose(1, 2) |
| K = self.W_k(x).view(B, T, H, self.d_head).transpose(1, 2) |
| V = self.W_v(x).view(B, T, H, self.d_value).transpose(1, 2) |
|
|
| |
| scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale |
|
|
| |
| causal_mask = torch.triu( |
| torch.ones(T, T, device=x.device, dtype=torch.bool), diagonal=1 |
| ) |
| scores.masked_fill_(causal_mask.unsqueeze(0).unsqueeze(0), float('-inf')) |
|
|
| attn = F.softmax(scores, dim=-1) |
| attn = self.dropout(attn) |
|
|
| out = torch.matmul(attn, V) |
| out = out.transpose(1, 2).contiguous().view(B, T, H * self.d_value) |
| return self.W_out(out) |
|
|
|
|
| |
| |
| |
|
|
| def elu_feature_map(x): |
| """ELU+1 feature map for linear attention: phi(x) = elu(x) + 1.""" |
| return F.elu(x) + 1.0 |
|
|
|
|
| class LinearAttention(nn.Module): |
| """ |
| Linear attention: O(T) causal attention via kernel trick. |
| |
| Instead of softmax(QK^T)V, computes phi(Q) @ (phi(K)^T @ V) |
| where phi is the ELU+1 feature map. |
| |
| For causal attention, uses cumulative sum formulation: |
| S_t = sum_{i<=t} phi(K_i)^T V_i (running state) |
| z_t = sum_{i<=t} phi(K_i) (running normalizer) |
| output_t = (phi(Q_t) @ S_t) / (phi(Q_t) @ z_t) |
| """ |
|
|
| def __init__(self, d_model, n_heads, d_value, dropout=0.0, use_bitlinear=False): |
| super().__init__() |
| self.d_model = d_model |
| self.n_heads = n_heads |
| self.d_head = d_model // n_heads |
| self.d_value = d_value |
|
|
| Linear = BitLinear if use_bitlinear else nn.Linear |
|
|
| self.W_q = Linear(d_model, self.d_head * n_heads, bias=False) |
| self.W_k = Linear(d_model, self.d_head * n_heads, bias=False) |
| self.W_v = Linear(d_model, d_value * n_heads, bias=False) |
| self.W_out = Linear(d_value * n_heads, d_model, bias=False) |
| self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity() |
|
|
| def forward(self, x, **kwargs): |
| B, T, D = x.shape |
| H = self.n_heads |
|
|
| Q = self.W_q(x).view(B, T, H, self.d_head) |
| K = self.W_k(x).view(B, T, H, self.d_head) |
| V = self.W_v(x).view(B, T, H, self.d_value) |
|
|
| |
| Q = elu_feature_map(Q) |
| K = elu_feature_map(K) |
|
|
| |
| |
| |
| |
|
|
| |
| Q = Q.permute(0, 2, 1, 3) |
| K = K.permute(0, 2, 1, 3) |
| V = V.permute(0, 2, 1, 3) |
|
|
| |
| KV = torch.einsum('bhti,bhtj->bhtij', K, V) |
| S = torch.cumsum(KV, dim=2) |
|
|
| |
| z = torch.cumsum(K, dim=2) |
|
|
| |
| |
| num = torch.einsum('bhti,bhtij->bhtj', Q, S) |
| |
| den = torch.einsum('bhti,bhti->bht', Q, z).unsqueeze(-1).clamp(min=1e-6) |
|
|
| out = num / den |
| out = out.permute(0, 2, 1, 3).contiguous().view(B, T, H * self.d_value) |
| out = self.dropout(out) |
| return self.W_out(out) |
|
|
|
|
| |
| |
| |
|
|
| class SoftmaxTransformerBlock(nn.Module): |
| """Standard pre-norm transformer block with softmax attention.""" |
|
|
| def __init__(self, d_model, n_heads, d_value, d_ffn=None, dropout=0.0, |
| use_bitlinear=False): |
| super().__init__() |
| if d_ffn is None: |
| d_ffn = d_model * 4 |
| Linear = BitLinear if use_bitlinear else nn.Linear |
|
|
| self.ln1 = nn.LayerNorm(d_model) |
| self.attn = SoftmaxAttention(d_model, n_heads, d_value, dropout, use_bitlinear) |
| self.ln2 = nn.LayerNorm(d_model) |
| self.ffn = nn.Sequential( |
| Linear(d_model, d_ffn, bias=False), |
| nn.GELU(), |
| Linear(d_ffn, d_model, bias=False), |
| nn.Dropout(dropout) if dropout > 0 else nn.Identity(), |
| ) |
|
|
| def forward(self, x, **kwargs): |
| x = x + self.attn(self.ln1(x)) |
| x = x + self.ffn(self.ln2(x)) |
| return x |
|
|
|
|
| class LinearTransformerBlock(nn.Module): |
| """Pre-norm transformer block with linear attention (Katharopoulos et al. 2020).""" |
|
|
| def __init__(self, d_model, n_heads, d_value, d_ffn=None, dropout=0.0, |
| use_bitlinear=False): |
| super().__init__() |
| if d_ffn is None: |
| d_ffn = d_model * 4 |
| Linear = BitLinear if use_bitlinear else nn.Linear |
|
|
| self.ln1 = nn.LayerNorm(d_model) |
| self.attn = LinearAttention(d_model, n_heads, d_value, dropout, use_bitlinear) |
| self.ln2 = nn.LayerNorm(d_model) |
| self.ffn = nn.Sequential( |
| Linear(d_model, d_ffn, bias=False), |
| nn.GELU(), |
| Linear(d_ffn, d_model, bias=False), |
| nn.Dropout(dropout) if dropout > 0 else nn.Identity(), |
| ) |
|
|
| def forward(self, x, **kwargs): |
| x = x + self.attn(self.ln1(x)) |
| x = x + self.ffn(self.ln2(x)) |
| return x |
|
|
|
|
| |
| |
| |
|
|
| class BaselineLanguageModel(nn.Module): |
| """ |
| Language model with swappable attention mechanism. |
| |
| Same architecture as H4LanguageModel (same embeddings, FFN, LM head) |
| but with standard softmax or linear attention instead of H4 geometric attention. |
| This ensures the only variable in comparisons is the attention mechanism. |
| |
| Args: |
| vocab_size: vocabulary size |
| d_model: model dimension |
| n_heads: number of attention heads |
| n_layers: number of transformer blocks |
| d_value: value dimension per head |
| d_ffn: FFN hidden dimension (default: 4 * d_model) |
| max_seq_len: max sequence length for positional encoding |
| dropout: dropout rate |
| attention_type: 'softmax' or 'linear' |
| use_bitlinear: if True, use ternary weights |
| """ |
|
|
| def __init__( |
| self, |
| vocab_size, |
| d_model=128, |
| n_heads=8, |
| n_layers=4, |
| d_value=16, |
| d_ffn=None, |
| max_seq_len=512, |
| dropout=0.0, |
| attention_type='softmax', |
| use_bitlinear=False, |
| ): |
| super().__init__() |
| self.vocab_size = vocab_size |
| self.d_model = d_model |
| self.n_layers = n_layers |
| self.attention_type = attention_type |
|
|
| if d_ffn is None: |
| d_ffn = d_model * 4 |
|
|
| |
| self.token_emb = nn.Embedding(vocab_size, d_model) |
| self.emb_scale = math.sqrt(d_model) |
|
|
| |
| self.pos_enc = PhiPositionalEncoding(d_model, max_cached=max_seq_len) |
|
|
| self.emb_dropout = nn.Dropout(dropout) |
|
|
| |
| if attention_type == 'softmax': |
| BlockClass = SoftmaxTransformerBlock |
| elif attention_type == 'linear': |
| BlockClass = LinearTransformerBlock |
| else: |
| raise ValueError(f"Unknown attention_type: {attention_type}") |
|
|
| self.blocks = nn.ModuleList([ |
| BlockClass( |
| d_model=d_model, |
| n_heads=n_heads, |
| d_value=d_value, |
| d_ffn=d_ffn, |
| dropout=dropout, |
| use_bitlinear=use_bitlinear, |
| ) |
| for _ in range(n_layers) |
| ]) |
|
|
| |
| self.ln_f = nn.LayerNorm(d_model) |
|
|
| |
| self.lm_head = nn.Linear(d_model, vocab_size, bias=False) |
| self.lm_head.weight = self.token_emb.weight |
|
|
| self._init_weights() |
|
|
| def _init_weights(self): |
| """Initialize weights following GPT-2 conventions.""" |
| for module in self.modules(): |
| if isinstance(module, BitLinear): |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| elif isinstance(module, nn.Linear): |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| if module.bias is not None: |
| torch.nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.Embedding): |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
|
| def forward(self, input_ids, **kwargs): |
| """ |
| Args: |
| input_ids: (batch, seq_len) token indices |
| Returns: |
| logits: (batch, seq_len, vocab_size) |
| """ |
| B, T = input_ids.shape |
|
|
| tok_emb = self.token_emb(input_ids) * self.emb_scale |
| pos_emb = self.pos_enc(T).unsqueeze(0).to(tok_emb.device) |
| x = self.emb_dropout(tok_emb + pos_emb) |
|
|
| for block in self.blocks: |
| x = block(x) |
|
|
| x = self.ln_f(x) |
| logits = self.lm_head(x) |
| return logits |
|
|
| def count_params(self): |
| """Count trainable and frozen parameters.""" |
| trainable = sum(p.numel() for p in self.parameters() if p.requires_grad) |
| frozen = sum(p.numel() for p in self.parameters() if not p.requires_grad) |
| buffers = sum(b.numel() for b in self.buffers()) |
| return { |
| 'trainable': trainable, |
| 'frozen': frozen, |
| 'buffers': buffers, |
| 'total': trainable + frozen, |
| } |
|
|
| @torch.no_grad() |
| def generate(self, input_ids, max_new_tokens=100, temperature=1.0, top_k_sample=0): |
| """Autoregressive generation.""" |
| for _ in range(max_new_tokens): |
| logits = self.forward(input_ids) |
| logits = logits[:, -1, :] / temperature |
|
|
| if top_k_sample > 0: |
| v, _ = torch.topk(logits, min(top_k_sample, logits.size(-1))) |
| logits[logits < v[:, [-1]]] = float('-inf') |
|
|
| probs = F.softmax(logits, dim=-1) |
| next_id = torch.multinomial(probs, num_samples=1) |
| input_ids = torch.cat([input_ids, next_id], dim=1) |
|
|
| return input_ids |
|
|