#!/usr/bin/env python3 """ n_flex.py — Flexible Attention Mechanisms Constraint: Must support AR (causal), SAT (block), and NAR (bidirectional) Testing: 1. Linear Attention - O(n) instead of O(n²) 2. Cosine Attention - Different similarity metric 3. Differential Attention - Noise cancellation (Microsoft 2024) 4. Local + Global - Sparse hybrid 5. Multi-Query Attention (MQA) - Inference efficient 6. Grouped Query Attention (GQA) - Between MHA and MQA 7. Retention - RetNet style (recurrent + parallel) 8. Gated Linear Attention - Recent efficient attention 9. ReLU Attention - Simpler activation 10. Sigmoid Attention - Bounded attention """ from __future__ import annotations import argparse, math, time import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, Literal DEV = torch.device("cuda" if torch.cuda.is_available() else "cpu") torch.backends.cuda.matmul.allow_tf32 = True VOCAB = 128256 # ═══════════════════════════════════════════════════════════════ # Masking utilities for AR/SAT/NAR # ═══════════════════════════════════════════════════════════════ def get_mask(n: int, mode: str = "ar", block_size: int = 2): """ AR (autoregressive): causal, see only past SAT (semi-autoregressive): see within block + all past blocks NAR (non-autoregressive): bidirectional, see everything """ if mode == "nar": return None # No mask elif mode == "ar": return torch.triu(torch.full((n, n), float("-inf"), device=DEV), 1) elif mode == "sat": # Block-wise: can see within same block and all previous blocks idx = torch.arange(n, device=DEV) block_idx = idx // block_size # Allow if same block OR target block is earlier mask = torch.where( (block_idx.unsqueeze(0) <= block_idx.unsqueeze(1)), torch.tensor(0.0, device=DEV), torch.tensor(float("-inf"), device=DEV) ) return mask else: raise ValueError(f"Unknown mode: {mode}") def alibi_bias(n_heads: int, n_tokens: int): def slopes(n): start = 2 ** (-2 ** -(math.log2(n) - 3)) return [start * (start ** i) for i in range(n)] if n_heads > 0 and math.log2(n_heads).is_integer(): s = slopes(n_heads) else: closest = 2 ** math.floor(math.log2(max(1, n_heads))) s = slopes(closest)[:n_heads] s = torch.tensor(s, device=DEV).view(1, n_heads, 1, 1) i = torch.arange(n_tokens, device=DEV).view(1, 1, n_tokens, 1) j = torch.arange(n_tokens, device=DEV).view(1, 1, 1, n_tokens) return -s * (j - i).clamp_min(0).float() # ═══════════════════════════════════════════════════════════════ # 1. STANDARD (baseline) # ═══════════════════════════════════════════════════════════════ class StandardAttention(nn.Module): """Standard multi-head attention - O(n²)""" def __init__(self, d: int, h: int): super().__init__() self.h, self.dk = h, d // h self.qkv = nn.Linear(d, 3 * d, bias=False) self.proj = nn.Linear(d, d, bias=False) def forward(self, x, mask=None): B, N, _ = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk) att = att + alibi_bias(self.h, N) if mask is not None: att = att + mask.unsqueeze(0).unsqueeze(0) z = (att.softmax(-1) @ v).transpose(1, 2).reshape(B, N, -1) return self.proj(z) # ═══════════════════════════════════════════════════════════════ # 2. LINEAR ATTENTION - O(n) via kernel trick # ═══════════════════════════════════════════════════════════════ class LinearAttention(nn.Module): """ Linear attention: O(n) instead of O(n²) Uses feature map φ(x) so that φ(q)φ(k)^T ≈ softmax(qk^T) Key insight: (QK^T)V = Q(K^TV) - compute K^TV first for O(n) Works with AR/SAT/NAR via cumsum tricks for causal """ def __init__(self, d: int, h: int, feature_map: str = "elu"): super().__init__() self.h, self.dk = h, d // h self.qkv = nn.Linear(d, 3 * d, bias=False) self.proj = nn.Linear(d, d, bias=False) self.feature_map = feature_map self.eps = 1e-6 def _phi(self, x): """Feature map for linear attention""" if self.feature_map == "elu": return F.elu(x) + 1 elif self.feature_map == "relu": return F.relu(x) elif self.feature_map == "softmax": return F.softmax(x, dim=-1) else: # identity return x def forward(self, x, mask=None): B, N, _ = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] # (B, H, N, dk) # Apply feature map q = self._phi(q) k = self._phi(k) if mask is None: # NAR: Full bidirectional - O(n) via associativity # (Q @ K^T) @ V = Q @ (K^T @ V) kv = torch.einsum('bhnd,bhnv->bhdv', k, v) # (B, H, dk, dv) out = torch.einsum('bhnd,bhdv->bhnv', q, kv) # (B, H, N, dv) # Normalize k_sum = k.sum(dim=2, keepdim=True) # (B, H, 1, dk) normalizer = torch.einsum('bhnd,bhkd->bhnk', q, k_sum).clamp(min=self.eps) out = out / normalizer else: # AR/SAT: Causal via cumulative sum # This is still O(n) but needs sequential computation kv_cumsum = torch.cumsum(torch.einsum('bhnd,bhnv->bhndv', k, v), dim=2) k_cumsum = torch.cumsum(k, dim=2) out = torch.einsum('bhnd,bhndv->bhnv', q, kv_cumsum) normalizer = torch.einsum('bhnd,bhnd->bhn', q, k_cumsum).unsqueeze(-1).clamp(min=self.eps) out = out / normalizer return self.proj(out.transpose(1, 2).reshape(B, N, -1)) # ═══════════════════════════════════════════════════════════════ # 3. COSINE ATTENTION - Different similarity metric # ═══════════════════════════════════════════════════════════════ class CosineAttention(nn.Module): """ Use cosine similarity instead of dot product. More stable, bounded [-1, 1] before scaling. """ def __init__(self, d: int, h: int, temp: float = 10.0): super().__init__() self.h, self.dk = h, d // h self.qkv = nn.Linear(d, 3 * d, bias=False) self.proj = nn.Linear(d, d, bias=False) self.temp = nn.Parameter(torch.tensor(temp)) def forward(self, x, mask=None): B, N, _ = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] # Normalize for cosine similarity q = F.normalize(q, dim=-1) k = F.normalize(k, dim=-1) att = self.temp * (q @ k.transpose(-1, -2)) # Cosine sim scaled by temp if mask is not None: att = att + mask.unsqueeze(0).unsqueeze(0) z = (att.softmax(-1) @ v).transpose(1, 2).reshape(B, N, -1) return self.proj(z) # ═══════════════════════════════════════════════════════════════ # 4. DIFFERENTIAL ATTENTION - Noise cancellation # ═══════════════════════════════════════════════════════════════ class DifferentialAttention(nn.Module): """ From Microsoft's "Differential Transformer" (2024) Compute two attention patterns and subtract: Attn = softmax(Q1 K1^T) - λ * softmax(Q2 K2^T) Cancels noise, improves signal. """ def __init__(self, d: int, h: int): super().__init__() self.h, self.dk = h, d // h # Two sets of Q, K projections self.q1 = nn.Linear(d, d, bias=False) self.k1 = nn.Linear(d, d, bias=False) self.q2 = nn.Linear(d, d, bias=False) self.k2 = nn.Linear(d, d, bias=False) self.v = nn.Linear(d, d, bias=False) # Learnable lambda for subtraction weight self.lambda_param = nn.Parameter(torch.tensor(0.5)) self.proj = nn.Linear(d, d, bias=False) def forward(self, x, mask=None): B, N, _ = x.shape q1 = self.q1(x).view(B, N, self.h, self.dk).transpose(1, 2) k1 = self.k1(x).view(B, N, self.h, self.dk).transpose(1, 2) q2 = self.q2(x).view(B, N, self.h, self.dk).transpose(1, 2) k2 = self.k2(x).view(B, N, self.h, self.dk).transpose(1, 2) v = self.v(x).view(B, N, self.h, self.dk).transpose(1, 2) scale = math.sqrt(self.dk) # First attention att1 = (q1 @ k1.transpose(-1, -2)) / scale if mask is not None: att1 = att1 + mask.unsqueeze(0).unsqueeze(0) att1 = att1.softmax(-1) # Second attention att2 = (q2 @ k2.transpose(-1, -2)) / scale if mask is not None: att2 = att2 + mask.unsqueeze(0).unsqueeze(0) att2 = att2.softmax(-1) # Differential: subtract weighted second from first lam = torch.sigmoid(self.lambda_param) att = att1 - lam * att2 # ReLU to ensure non-negative (optional, can remove) att = F.relu(att) att = att / (att.sum(dim=-1, keepdim=True) + 1e-6) z = (att @ v).transpose(1, 2).reshape(B, N, -1) return self.proj(z) # ═══════════════════════════════════════════════════════════════ # 5. MULTI-QUERY ATTENTION (MQA) - Inference efficient # ═══════════════════════════════════════════════════════════════ class MultiQueryAttention(nn.Module): """ MQA: Multiple query heads, single K/V head. Massive inference speedup (smaller KV cache). Same training cost as standard. """ def __init__(self, d: int, h: int): super().__init__() self.h, self.dk = h, d // h # H query heads, but only 1 K and 1 V head self.q = nn.Linear(d, d, bias=False) # H heads self.k = nn.Linear(d, self.dk, bias=False) # 1 head self.v = nn.Linear(d, self.dk, bias=False) # 1 head self.proj = nn.Linear(d, d, bias=False) def forward(self, x, mask=None): B, N, _ = x.shape q = self.q(x).view(B, N, self.h, self.dk).transpose(1, 2) # (B, H, N, dk) k = self.k(x).view(B, N, 1, self.dk).transpose(1, 2) # (B, 1, N, dk) v = self.v(x).view(B, N, 1, self.dk).transpose(1, 2) # (B, 1, N, dk) # K, V broadcast across heads att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk) att = att + alibi_bias(self.h, N) if mask is not None: att = att + mask.unsqueeze(0).unsqueeze(0) z = (att.softmax(-1) @ v).transpose(1, 2).reshape(B, N, -1) return self.proj(z) # ═══════════════════════════════════════════════════════════════ # 6. GROUPED QUERY ATTENTION (GQA) - Between MHA and MQA # ═══════════════════════════════════════════════════════════════ class GroupedQueryAttention(nn.Module): """ GQA: Groups of query heads share K/V heads. Llama 2 uses this. Balance between quality and inference speed. """ def __init__(self, d: int, h: int, num_kv_heads: int = 2): super().__init__() self.h = h self.num_kv_heads = num_kv_heads self.dk = d // h self.heads_per_group = h // num_kv_heads self.q = nn.Linear(d, d, bias=False) self.k = nn.Linear(d, num_kv_heads * self.dk, bias=False) self.v = nn.Linear(d, num_kv_heads * self.dk, bias=False) self.proj = nn.Linear(d, d, bias=False) def forward(self, x, mask=None): B, N, _ = x.shape q = self.q(x).view(B, N, self.h, self.dk).transpose(1, 2) k = self.k(x).view(B, N, self.num_kv_heads, self.dk).transpose(1, 2) v = self.v(x).view(B, N, self.num_kv_heads, self.dk).transpose(1, 2) # Repeat K, V for each group k = k.repeat_interleave(self.heads_per_group, dim=1) v = v.repeat_interleave(self.heads_per_group, dim=1) att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk) att = att + alibi_bias(self.h, N) if mask is not None: att = att + mask.unsqueeze(0).unsqueeze(0) z = (att.softmax(-1) @ v).transpose(1, 2).reshape(B, N, -1) return self.proj(z) # ═══════════════════════════════════════════════════════════════ # 7. RETENTION - RetNet style # ═══════════════════════════════════════════════════════════════ class RetentionAttention(nn.Module): """ From RetNet: Retentive Network Parallel mode (training): Like linear attention Recurrent mode (inference): O(1) per step Key: exponential decay instead of softmax """ def __init__(self, d: int, h: int, gamma: float = 0.9): super().__init__() self.h, self.dk = h, d // h self.qkv = nn.Linear(d, 3 * d, bias=False) self.proj = nn.Linear(d, d, bias=False) # Per-head decay rates self.gamma = nn.Parameter(torch.ones(h) * gamma) def forward(self, x, mask=None): B, N, _ = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] # Build decay matrix D[i,j] = gamma^(i-j) for i >= j gamma = torch.sigmoid(self.gamma).view(1, self.h, 1, 1) positions = torch.arange(N, device=x.device).float() decay = gamma ** (positions.unsqueeze(0) - positions.unsqueeze(1)).clamp(min=0) # Apply causal mask via decay (future positions get 0) causal = torch.tril(torch.ones(N, N, device=x.device)) decay = decay * causal.unsqueeze(0).unsqueeze(0) # If SAT/NAR mask provided, incorporate it if mask is not None: mask_binary = (mask == 0).float().unsqueeze(0).unsqueeze(0) decay = decay * mask_binary # Retention = (Q @ K^T) * D @ V att = (q @ k.transpose(-1, -2)) * decay # Normalize per row att = att / (att.sum(dim=-1, keepdim=True) + 1e-6) z = (att @ v).transpose(1, 2).reshape(B, N, -1) return self.proj(z) # ═══════════════════════════════════════════════════════════════ # 8. GATED LINEAR ATTENTION # ═══════════════════════════════════════════════════════════════ class GatedLinearAttention(nn.Module): """ Linear attention with gating for better gradient flow. From "Gated Linear Attention Transformers" (2024) """ def __init__(self, d: int, h: int): super().__init__() self.h, self.dk = h, d // h self.qkv = nn.Linear(d, 3 * d, bias=False) self.gate = nn.Linear(d, d) self.proj = nn.Linear(d, d, bias=False) self.eps = 1e-6 def forward(self, x, mask=None): B, N, _ = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] # Feature map (ELU + 1 for positivity) q = F.elu(q) + 1 k = F.elu(k) + 1 if mask is None: # Bidirectional kv = torch.einsum('bhnd,bhnv->bhdv', k, v) out = torch.einsum('bhnd,bhdv->bhnv', q, kv) normalizer = torch.einsum('bhnd,bhd->bhn', q, k.sum(dim=2)).unsqueeze(-1).clamp(min=self.eps) else: # Causal kv_cumsum = torch.cumsum(torch.einsum('bhnd,bhnv->bhndv', k, v), dim=2) k_cumsum = torch.cumsum(k, dim=2) out = torch.einsum('bhnd,bhndv->bhnv', q, kv_cumsum) normalizer = torch.einsum('bhnd,bhnd->bhn', q, k_cumsum).unsqueeze(-1).clamp(min=self.eps) out = out / normalizer out = out.transpose(1, 2).reshape(B, N, -1) # Gating gate = torch.sigmoid(self.gate(x)) out = out * gate return self.proj(out) # ═══════════════════════════════════════════════════════════════ # 9. RELU ATTENTION - Simpler activation # ═══════════════════════════════════════════════════════════════ class ReLUAttention(nn.Module): """ Replace softmax with ReLU + normalization. Simpler, faster, sometimes works as well. From "ReLU Attention" papers. """ def __init__(self, d: int, h: int): super().__init__() self.h, self.dk = h, d // h self.qkv = nn.Linear(d, 3 * d, bias=False) self.proj = nn.Linear(d, d, bias=False) def forward(self, x, mask=None): B, N, _ = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk) att = att + alibi_bias(self.h, N) if mask is not None: att = att + mask.unsqueeze(0).unsqueeze(0) # ReLU instead of softmax att = F.relu(att) att = att / (att.sum(dim=-1, keepdim=True) + 1e-6) z = (att @ v).transpose(1, 2).reshape(B, N, -1) return self.proj(z) # ═══════════════════════════════════════════════════════════════ # 10. SIGMOID ATTENTION - Bounded # ═══════════════════════════════════════════════════════════════ class SigmoidAttention(nn.Module): """ Sigmoid attention: each position independently decides attention weight. Not normalized to sum to 1 - allows variable "total attention". """ def __init__(self, d: int, h: int): super().__init__() self.h, self.dk = h, d // h self.qkv = nn.Linear(d, 3 * d, bias=False) self.proj = nn.Linear(d, d, bias=False) self.bias = nn.Parameter(torch.zeros(h, 1, 1)) def forward(self, x, mask=None): B, N, _ = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk) + self.bias if mask is not None: att = att + mask.unsqueeze(0).unsqueeze(0) # Sigmoid instead of softmax - each weight independent att = torch.sigmoid(att) # Optional: mask out future for AR if mask is not None: att = att * (mask == 0).float().unsqueeze(0).unsqueeze(0) z = (att @ v).transpose(1, 2).reshape(B, N, -1) return self.proj(z) # ═══════════════════════════════════════════════════════════════ # Block and Model # ═══════════════════════════════════════════════════════════════ ATTN_REGISTRY = { "standard": StandardAttention, "linear": LinearAttention, "cosine": CosineAttention, "differential": DifferentialAttention, "mqa": MultiQueryAttention, "gqa": GroupedQueryAttention, "retention": RetentionAttention, "gated_linear": GatedLinearAttention, "relu": ReLUAttention, "sigmoid": SigmoidAttention, } class Block(nn.Module): def __init__(self, d: int, h: int, attn_type: str = "standard"): super().__init__() self.ln1, self.ln2 = nn.LayerNorm(d), nn.LayerNorm(d) self.attn = ATTN_REGISTRY[attn_type](d, h) self.ff = nn.Sequential(nn.Linear(d, 4*d), nn.GELU(), nn.Linear(4*d, d)) def forward(self, x, mask=None): x = x + self.attn(self.ln1(x), mask) return x + self.ff(self.ln2(x)) class FlexModel(nn.Module): def __init__(self, d: int, layers: int, h: int, attn_type: str = "standard"): super().__init__() self.emb = nn.Embedding(VOCAB, d) self.blocks = nn.ModuleList([Block(d, h, attn_type) for _ in range(layers)]) self.ln = nn.LayerNorm(d) self.head = nn.Linear(d, VOCAB, bias=False) self.head.weight = self.emb.weight def forward(self, x, mask=None): x = self.emb(x) for b in self.blocks: x = b(x, mask) return self.head(self.ln(x)) def count_params(self): return sum(p.numel() for p in self.parameters()) # ═══════════════════════════════════════════════════════════════ # Training with AR/SAT/NAR modes # ═══════════════════════════════════════════════════════════════ def train(attn_type: str, mode: str, d: int, layers: int, h: int, batch: int, seq: int, steps: int, block_size: int = 4): print(f"\n{'='*60}") print(f"ATTENTION: {attn_type.upper()} | MODE: {mode.upper()}") print(f"{'='*60}") model = FlexModel(d, layers, h, attn_type).to(DEV) print(f"Parameters: {model.count_params():,}") opt = torch.optim.AdamW(model.parameters(), lr=1e-4) losses, times = [], [] for step in range(steps): ids = torch.randint(0, VOCAB, (batch, seq), device=DEV) if mode == "ar": # Standard AR: predict next token target = ids[:, 1:] input_ids = ids[:, :-1] mask = get_mask(seq - 1, "ar") elif mode == "sat": # SAT: predict within blocks target = ids[:, 1:] input_ids = ids[:, :-1] mask = get_mask(seq - 1, "sat", block_size) else: # nar # NAR: predict all from [MASK] or noisy input target = ids # Add noise to input for NAR (simple version) noise_mask = torch.rand(batch, seq, device=DEV) < 0.15 input_ids = ids.clone() input_ids[noise_mask] = torch.randint(0, VOCAB, (noise_mask.sum().item(),), device=DEV) mask = get_mask(seq, "nar") start = time.time() opt.zero_grad() try: logits = model(input_ids, mask) loss = F.cross_entropy(logits.view(-1, VOCAB), target.reshape(-1)) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) opt.step() except Exception as e: print(f"Step {step} failed: {e}") return None elapsed = time.time() - start losses.append(loss.item()) times.append(elapsed) if step % 20 == 0 or step == steps - 1: tok_s = batch * seq / elapsed print(f"Step {step:3d} | Loss {loss.item():.4f} | {tok_s:.0f} tok/s") avg_loss = sum(losses[-20:]) / min(20, len(losses)) avg_toks = batch * seq / (sum(times[-20:]) / min(20, len(times))) return {"attn": attn_type, "mode": mode, "loss": avg_loss, "tok_s": avg_toks} def main(): parser = argparse.ArgumentParser() parser.add_argument("--d", type=int, default=256) parser.add_argument("--layers", type=int, default=4) parser.add_argument("--heads", type=int, default=8) parser.add_argument("--batch", type=int, default=16) parser.add_argument("--seq", type=int, default=128) parser.add_argument("--steps", type=int, default=100) parser.add_argument("--mode", type=str, default="ar", choices=["ar", "sat", "nar", "all"]) parser.add_argument("--types", type=str, default="all") args = parser.parse_args() print(f"Device: {DEV}") if torch.cuda.is_available(): print(f"GPU: {torch.cuda.get_device_name()}") if args.types == "all": types = list(ATTN_REGISTRY.keys()) else: types = [t.strip() for t in args.types.split(",")] modes = ["ar", "sat", "nar"] if args.mode == "all" else [args.mode] results = [] for mode in modes: for attn_type in types: r = train(attn_type, mode, args.d, args.layers, args.heads, args.batch, args.seq, args.steps) if r: results.append(r) torch.cuda.empty_cache() # Summary print(f"\n{'='*60}") print("SUMMARY") print(f"{'='*60}") for mode in modes: print(f"\n--- MODE: {mode.upper()} ---") mode_results = [r for r in results if r['mode'] == mode] baseline = next((r for r in mode_results if r['attn'] == 'standard'), None) for r in sorted(mode_results, key=lambda x: x['loss']): rel = "" if baseline and r['attn'] != 'standard': loss_diff = (baseline['loss'] - r['loss']) / baseline['loss'] * 100 speed_ratio = r['tok_s'] / baseline['tok_s'] rel = f" | vs std: {loss_diff:+.1f}%, {speed_ratio:.2f}x" print(f"{r['attn']:15s} | Loss {r['loss']:.4f} | {r['tok_s']:6.0f} tok/s{rel}") if __name__ == "__main__": main()