#!/usr/bin/env python3 """ n_heavy2.py — Extended Heavy Attention Experiments Testing mechanisms that use MORE compute than standard attention Approaches: 1. Multi-Hop: Explicit k-step reasoning chains 2. Slot Attention: Competitive binding (from object-centric learning) 3. Edge-Compute: Full pairwise MLP, not just weighted sum 4. Memory-Aug: External memory bank with read/write 5. Recurrent Depth: Same block applied k times (Universal Transformer) """ from __future__ import annotations import argparse, math, time import torch import torch.nn as nn import torch.nn.functional as F DEV = torch.device("cuda" if torch.cuda.is_available() else "cpu") torch.backends.cuda.matmul.allow_tf32 = True try: torch.set_float32_matmul_precision("high") except: pass VOCAB = 128256 EOS = 128001 # ─────────────────────────── ALiBi ─────────────────────────── def _alibi_slopes(n_heads: int): def pow2slopes(n): start = 2 ** (-2 ** -(math.log2(n) - 3)) return [start * (start ** i) for i in range(n)] if math.log2(n_heads).is_integer(): vals = pow2slopes(n_heads) else: closest = 2 ** math.floor(math.log2(n_heads)) vals = pow2slopes(closest) extra = pow2slopes(2 * closest) vals += extra[0::2][:n_heads - closest] return torch.tensor(vals, device=DEV).view(1, n_heads, 1, 1) def alibi_bias(n_heads: int, n_tokens: int): i = torch.arange(n_tokens, device=DEV).view(1, 1, n_tokens, 1) j = torch.arange(n_tokens, device=DEV).view(1, 1, 1, n_tokens) dist = (j - i).clamp_min(0).float() slopes = _alibi_slopes(n_heads) return -slopes * dist def causal_mask(n): return torch.triu(torch.full((1, 1, n, n), float("-inf"), device=DEV), 1) # ═══════════════════════════════════════════════════════════════ # BASELINE: Standard Attention # ═══════════════════════════════════════════════════════════════ class StandardAttention(nn.Module): def __init__(self, d: int, h: int): super().__init__() assert d % h == 0 self.h, self.dk = h, d // h self.qkv = nn.Linear(d, 3 * d, bias=False) self.proj = nn.Linear(d, d, bias=False) def forward(self, x, mask=None): B, N, _ = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk) att = att + alibi_bias(self.h, N) if mask is not None: att = att + mask z = (att.softmax(-1) @ v).transpose(1, 2).reshape(B, N, -1) return self.proj(z) # ═══════════════════════════════════════════════════════════════ # HEAVY 1: Multi-Hop Attention # Each "hop" attends to previous hop's output # Simulates multi-step reasoning chains # ═══════════════════════════════════════════════════════════════ class MultiHopAttention(nn.Module): """ K explicit reasoning hops. Each hop: 1. Attend to current state 2. Update state with attended info 3. Next hop attends to updated state O(k * n²) - linear in hops, quadratic in sequence """ def __init__(self, d: int, h: int, num_hops: int = 3): super().__init__() self.h, self.dk = h, d // h self.num_hops = num_hops # Separate Q projection per hop (K,V shared) self.q_projs = nn.ModuleList([nn.Linear(d, d, bias=False) for _ in range(num_hops)]) self.kv = nn.Linear(d, 2 * d, bias=False) self.proj = nn.Linear(d, d, bias=False) # Hop mixing: combine info from all hops self.hop_gate = nn.Linear(d * num_hops, d) def forward(self, x, mask=None): B, N, D = x.shape # Compute K, V once (shared across hops) kv = self.kv(x).reshape(B, N, 2, self.h, self.dk).permute(2, 0, 3, 1, 4) k, v = kv[0], kv[1] bias = alibi_bias(self.h, N) hop_outputs = [] state = x for hop in range(self.num_hops): # Query from current state q = self.q_projs[hop](state).reshape(B, N, self.h, self.dk).transpose(1, 2) att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk) att = att + bias if mask is not None: att = att + mask hop_out = (att.softmax(-1) @ v).transpose(1, 2).reshape(B, N, -1) hop_outputs.append(hop_out) # Update state for next hop state = state + hop_out # Combine all hops combined = torch.cat(hop_outputs, dim=-1) return self.proj(self.hop_gate(combined)) # ═══════════════════════════════════════════════════════════════ # HEAVY 2: Slot Attention # From "Object-Centric Learning with Slot Attention" # Slots compete to bind to input positions # ═══════════════════════════════════════════════════════════════ class SlotAttention(nn.Module): """ Competitive binding: K slots compete for N positions. Unlike standard attention (N queries), we have K << N slots. Each slot iteratively refines what it attends to. Then we project slots back to sequence. O(iterations * K * N) where K = num_slots """ def __init__(self, d: int, num_slots: int = 8, num_iters: int = 3): super().__init__() self.num_slots = num_slots self.num_iters = num_iters self.d = d # Learnable slot initializations self.slots_mu = nn.Parameter(torch.randn(1, num_slots, d) * 0.02) self.slots_sigma = nn.Parameter(torch.ones(1, num_slots, d) * 0.02) # Attention self.to_q = nn.Linear(d, d, bias=False) self.to_k = nn.Linear(d, d, bias=False) self.to_v = nn.Linear(d, d, bias=False) # Slot update GRU self.gru = nn.GRUCell(d, d) self.mlp = nn.Sequential( nn.Linear(d, d * 2), nn.ReLU(), nn.Linear(d * 2, d) ) self.ln1 = nn.LayerNorm(d) self.ln2 = nn.LayerNorm(d) # Project slots back to sequence self.slot_to_seq = nn.Linear(d, d) def forward(self, x, mask=None): B, N, D = x.shape # Initialize slots with noise slots = self.slots_mu + self.slots_sigma * torch.randn(B, self.num_slots, D, device=x.device) # Pre-compute keys and values k = self.to_k(x) # (B, N, D) v = self.to_v(x) # (B, N, D) for _ in range(self.num_iters): slots_prev = slots slots = self.ln1(slots) # Slot attention: slots query, inputs are keys/values q = self.to_q(slots) # (B, K, D) # Attention: (B, K, D) @ (B, D, N) -> (B, K, N) attn = torch.einsum('bkd,bnd->bkn', q, k) / math.sqrt(D) # Softmax over SLOTS (competition) not positions attn = F.softmax(attn, dim=1) # Slots compete for each position # Weighted sum of values updates = torch.einsum('bkn,bnd->bkd', attn, v) # (B, K, D) # GRU update slots = self.gru( updates.reshape(B * self.num_slots, D), slots_prev.reshape(B * self.num_slots, D) ).reshape(B, self.num_slots, D) # MLP residual slots = slots + self.mlp(self.ln2(slots)) # Project slots back to sequence length # Use attention from slots to positions q_out = self.to_q(x) # (B, N, D) k_slots = self.to_k(slots) # (B, K, D) attn_out = torch.einsum('bnd,bkd->bnk', q_out, k_slots) / math.sqrt(D) attn_out = F.softmax(attn_out, dim=-1) # (B, N, K) output = torch.einsum('bnk,bkd->bnd', attn_out, slots) return self.slot_to_seq(output) # ═══════════════════════════════════════════════════════════════ # HEAVY 3: Edge-Compute Attention # Instead of weighted sum, compute MLP on each (query, key) pair # ═══════════════════════════════════════════════════════════════ class EdgeComputeAttention(nn.Module): """ Standard attention: output = softmax(QK^T) @ V This is just a weighted sum - no computation on relationships. Edge-Compute: For each (i,j) pair, run MLP([q_i; k_j; v_j]) Then aggregate. Much heavier but captures richer interactions. O(n² * mlp_cost) - quadratic with multiplicative MLP factor Note: Only practical for short sequences! """ def __init__(self, d: int, h: int, max_seq: int = 128): super().__init__() self.h, self.dk = h, d // h self.max_seq = max_seq self.qkv = nn.Linear(d, 3 * d, bias=False) # Edge MLP: processes each (q_i, k_j, v_j) triple self.edge_mlp = nn.Sequential( nn.Linear(3 * self.dk, 2 * self.dk), nn.ReLU(), nn.Linear(2 * self.dk, self.dk) ) # Attention for aggregation self.score_mlp = nn.Sequential( nn.Linear(2 * self.dk, self.dk), nn.ReLU(), nn.Linear(self.dk, 1) ) self.proj = nn.Linear(d, d, bias=False) def forward(self, x, mask=None): B, N, D = x.shape # For long sequences, fall back to standard if N > self.max_seq: return self._standard_forward(x, mask) qkv = self.qkv(x).reshape(B, N, 3, self.h, self.dk) q, k, v = qkv[:,:,0], qkv[:,:,1], qkv[:,:,2] # Each: (B, N, H, dk) outputs = [] for head in range(self.h): q_h = q[:, :, head, :] # (B, N, dk) k_h = k[:, :, head, :] v_h = v[:, :, head, :] # Expand for pairwise: (B, N, 1, dk) and (B, 1, N, dk) q_exp = q_h.unsqueeze(2).expand(-1, -1, N, -1) # (B, N, N, dk) k_exp = k_h.unsqueeze(1).expand(-1, N, -1, -1) # (B, N, N, dk) v_exp = v_h.unsqueeze(1).expand(-1, N, -1, -1) # (B, N, N, dk) # Concatenate for edge MLP edge_input = torch.cat([q_exp, k_exp, v_exp], dim=-1) # (B, N, N, 3*dk) # Compute edge features edge_features = self.edge_mlp(edge_input) # (B, N, N, dk) # Compute attention scores score_input = torch.cat([q_exp, k_exp], dim=-1) # (B, N, N, 2*dk) scores = self.score_mlp(score_input).squeeze(-1) # (B, N, N) # Apply causal mask if mask is not None: scores = scores + mask.squeeze(1) # Aggregate weights = F.softmax(scores, dim=-1) # (B, N, N) head_out = (weights.unsqueeze(-1) * edge_features).sum(dim=2) # (B, N, dk) outputs.append(head_out) out = torch.cat(outputs, dim=-1) # (B, N, D) return self.proj(out) def _standard_forward(self, x, mask=None): B, N, _ = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk) att = att + alibi_bias(self.h, N) if mask is not None: att = att + mask z = (att.softmax(-1) @ v).transpose(1, 2).reshape(B, N, -1) return self.proj(z) # ═══════════════════════════════════════════════════════════════ # HEAVY 4: Memory-Augmented Attention # External memory bank with read/write operations # ═══════════════════════════════════════════════════════════════ class MemoryAugmentedAttention(nn.Module): """ Maintain external memory bank M of size (mem_size, d). Each forward: 1. Read from memory using attention 2. Standard self-attention augmented with memory content 3. Write updated info back to memory O(n² + n*mem_size) - adds memory interaction cost """ def __init__(self, d: int, h: int, mem_size: int = 64): super().__init__() self.h, self.dk = h, d // h self.mem_size = mem_size # Persistent memory (learned) self.memory = nn.Parameter(torch.randn(1, mem_size, d) * 0.02) # Standard attention self.qkv = nn.Linear(d, 3 * d, bias=False) self.proj = nn.Linear(d, d, bias=False) # Memory read/write self.mem_q = nn.Linear(d, d, bias=False) self.mem_k = nn.Linear(d, d, bias=False) self.mem_v = nn.Linear(d, d, bias=False) # Write gate self.write_gate = nn.Sequential( nn.Linear(d * 2, d), nn.Sigmoid() ) # Combine self-attention and memory self.combine = nn.Linear(d * 2, d) def forward(self, x, mask=None): B, N, D = x.shape # Expand memory for batch mem = self.memory.expand(B, -1, -1) # (B, mem_size, D) # 1. Read from memory q_mem = self.mem_q(x) # (B, N, D) k_mem = self.mem_k(mem) # (B, mem_size, D) v_mem = self.mem_v(mem) # (B, mem_size, D) mem_attn = torch.einsum('bnd,bmd->bnm', q_mem, k_mem) / math.sqrt(D) mem_attn = F.softmax(mem_attn, dim=-1) mem_read = torch.einsum('bnm,bmd->bnd', mem_attn, v_mem) # (B, N, D) # 2. Standard self-attention qkv = self.qkv(x).reshape(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk) att = att + alibi_bias(self.h, N) if mask is not None: att = att + mask self_out = (att.softmax(-1) @ v).transpose(1, 2).reshape(B, N, -1) # 3. Combine self-attention and memory read combined = self.combine(torch.cat([self_out, mem_read], dim=-1)) return self.proj(combined) # ═══════════════════════════════════════════════════════════════ # HEAVY 5: Recurrent Depth (Universal Transformer) # Same block applied k times with position-in-depth encoding # ═══════════════════════════════════════════════════════════════ class RecurrentDepthAttention(nn.Module): """ Instead of L different layers, use 1 layer L times. Add depth embedding so model knows which iteration it's on. O(k * n²) where k = num_recurrences Key insight: Weight sharing + depth embedding = potentially more efficient use of parameters for complex reasoning. """ def __init__(self, d: int, h: int, num_recur: int = 4): super().__init__() self.h, self.dk = h, d // h self.num_recur = num_recur self.qkv = nn.Linear(d, 3 * d, bias=False) self.proj = nn.Linear(d, d, bias=False) # Depth embedding self.depth_emb = nn.Embedding(num_recur, d) # Transition function between recurrences self.transition = nn.Sequential( nn.LayerNorm(d), nn.Linear(d, d * 2), nn.GELU(), nn.Linear(d * 2, d) ) def forward(self, x, mask=None): B, N, D = x.shape bias = alibi_bias(self.h, N) for r in range(self.num_recur): # Add depth embedding x_r = x + self.depth_emb.weight[r].unsqueeze(0).unsqueeze(0) # Self-attention qkv = self.qkv(x_r).reshape(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk) att = att + bias if mask is not None: att = att + mask attn_out = (att.softmax(-1) @ v).transpose(1, 2).reshape(B, N, -1) attn_out = self.proj(attn_out) # Residual + transition x = x + attn_out x = x + self.transition(x) return x - x.detach() + x.detach() # Gradient trick for stability # ═══════════════════════════════════════════════════════════════ # Block and Model wrappers # ═══════════════════════════════════════════════════════════════ class Block(nn.Module): def __init__(self, d: int, h: int, attn_type: str = "standard", **kwargs): super().__init__() self.ln1 = nn.LayerNorm(d) self.ln2 = nn.LayerNorm(d) if attn_type == "standard": self.attn = StandardAttention(d, h) elif attn_type == "multihop": self.attn = MultiHopAttention(d, h, num_hops=kwargs.get('num_hops', 3)) elif attn_type == "slot": self.attn = SlotAttention(d, num_slots=kwargs.get('num_slots', 8)) elif attn_type == "edge": self.attn = EdgeComputeAttention(d, h) elif attn_type == "memory": self.attn = MemoryAugmentedAttention(d, h, mem_size=kwargs.get('mem_size', 64)) elif attn_type == "recurrent": self.attn = RecurrentDepthAttention(d, h, num_recur=kwargs.get('num_recur', 4)) else: raise ValueError(f"Unknown attn_type: {attn_type}") self.ff = nn.Sequential( nn.Linear(d, 4 * d), nn.GELU(), nn.Linear(4 * d, d) ) def forward(self, x, mask=None): x = x + self.attn(self.ln1(x), mask) x = x + self.ff(self.ln2(x)) return x class HeavyModel(nn.Module): def __init__(self, d: int, layers: int, h: int, attn_type: str = "standard", **kwargs): super().__init__() self.emb = nn.Embedding(VOCAB, d) self.blocks = nn.ModuleList([Block(d, h, attn_type, **kwargs) for _ in range(layers)]) self.ln = nn.LayerNorm(d) self.head = nn.Linear(d, VOCAB, bias=False) self.head.weight = self.emb.weight # Tie weights def forward(self, x, mask=None): x = self.emb(x) for blk in self.blocks: x = blk(x, mask) return self.head(self.ln(x)) def count_params(self): return sum(p.numel() for p in self.parameters()) # ═══════════════════════════════════════════════════════════════ # Experiment Runner # ═══════════════════════════════════════════════════════════════ def run_experiment(attn_type: str, d: int, layers: int, heads: int, batch: int, seq: int, steps: int, **kwargs): print(f"\n{'='*60}") print(f"ATTENTION TYPE: {attn_type.upper()}") print(f"Config: d={d}, layers={layers}, heads={heads}") print(f"{'='*60}") model = HeavyModel(d, layers, heads, attn_type, **kwargs).to(DEV) print(f"Parameters: {model.count_params():,}") optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) mask = causal_mask(seq - 1) losses, times = [], [] for step in range(steps): ids = torch.randint(0, VOCAB, (batch, seq), device=DEV) target = ids[:, 1:] input_ids = ids[:, :-1] start = time.time() optimizer.zero_grad() logits = model(input_ids, mask) loss = F.cross_entropy(logits.view(-1, VOCAB), target.reshape(-1)) loss.backward() optimizer.step() elapsed = time.time() - start losses.append(loss.item()) times.append(elapsed) tok_s = (batch * seq) / elapsed if step % 10 == 0 or step == steps - 1: print(f"Step {step:3d} | Loss: {loss.item():.4f} | {tok_s:.0f} tok/s | {elapsed*1000:.0f}ms") avg_loss = sum(losses[-20:]) / min(20, len(losses)) avg_time = sum(times[-20:]) / min(20, len(times)) avg_toks = (batch * seq) / avg_time return { "type": attn_type, "loss": avg_loss, "tok_s": avg_toks, "params": model.count_params() } def main(): parser = argparse.ArgumentParser() parser.add_argument("--d", type=int, default=256) parser.add_argument("--layers", type=int, default=4) parser.add_argument("--heads", type=int, default=8) parser.add_argument("--batch", type=int, default=16) parser.add_argument("--seq", type=int, default=128) parser.add_argument("--steps", type=int, default=100) parser.add_argument("--types", type=str, default="all", help="Comma-separated: standard,multihop,slot,edge,memory,recurrent") args = parser.parse_args() print(f"Device: {DEV}") if torch.cuda.is_available(): print(f"GPU: {torch.cuda.get_device_name()}") print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB") if args.types == "all": types = ["standard", "multihop", "slot", "edge", "memory", "recurrent"] else: types = [t.strip() for t in args.types.split(",")] results = [] for t in types: try: r = run_experiment(t, args.d, args.layers, args.heads, args.batch, args.seq, args.steps) results.append(r) except Exception as e: print(f"ERROR in {t}: {e}") import traceback traceback.print_exc() # Summary print(f"\n{'='*60}") print("SUMMARY") print(f"{'='*60}") baseline = next((r for r in results if r['type'] == 'standard'), None) for r in results: rel = "" if baseline and r['type'] != 'standard': loss_diff = (baseline['loss'] - r['loss']) / baseline['loss'] * 100 speed_ratio = r['tok_s'] / baseline['tok_s'] rel = f" | vs baseline: {loss_diff:+.1f}% loss, {speed_ratio:.2f}x speed" print(f"{r['type']:12s} | Loss: {r['loss']:.4f} | {r['tok_s']:6.0f} tok/s | {r['params']:,} params{rel}") if __name__ == "__main__": main()