| | |
| | """ |
| | n_heavy2.py β Extended Heavy Attention Experiments |
| | Testing mechanisms that use MORE compute than standard attention |
| | |
| | Approaches: |
| | 1. Multi-Hop: Explicit k-step reasoning chains |
| | 2. Slot Attention: Competitive binding (from object-centric learning) |
| | 3. Edge-Compute: Full pairwise MLP, not just weighted sum |
| | 4. Memory-Aug: External memory bank with read/write |
| | 5. Recurrent Depth: Same block applied k times (Universal Transformer) |
| | """ |
| |
|
| | from __future__ import annotations |
| | import argparse, math, time |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | DEV = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | torch.backends.cuda.matmul.allow_tf32 = True |
| | try: |
| | torch.set_float32_matmul_precision("high") |
| | except: |
| | pass |
| |
|
| | VOCAB = 128256 |
| | EOS = 128001 |
| |
|
| | |
| | def _alibi_slopes(n_heads: int): |
| | def pow2slopes(n): |
| | start = 2 ** (-2 ** -(math.log2(n) - 3)) |
| | return [start * (start ** i) for i in range(n)] |
| | if math.log2(n_heads).is_integer(): |
| | vals = pow2slopes(n_heads) |
| | else: |
| | closest = 2 ** math.floor(math.log2(n_heads)) |
| | vals = pow2slopes(closest) |
| | extra = pow2slopes(2 * closest) |
| | vals += extra[0::2][:n_heads - closest] |
| | return torch.tensor(vals, device=DEV).view(1, n_heads, 1, 1) |
| |
|
| | def alibi_bias(n_heads: int, n_tokens: int): |
| | i = torch.arange(n_tokens, device=DEV).view(1, 1, n_tokens, 1) |
| | j = torch.arange(n_tokens, device=DEV).view(1, 1, 1, n_tokens) |
| | dist = (j - i).clamp_min(0).float() |
| | slopes = _alibi_slopes(n_heads) |
| | return -slopes * dist |
| |
|
| | def causal_mask(n): |
| | return torch.triu(torch.full((1, 1, n, n), float("-inf"), device=DEV), 1) |
| |
|
| |
|
| | |
| | |
| | |
| | class StandardAttention(nn.Module): |
| | def __init__(self, d: int, h: int): |
| | super().__init__() |
| | assert d % h == 0 |
| | self.h, self.dk = h, d // h |
| | self.qkv = nn.Linear(d, 3 * d, bias=False) |
| | self.proj = nn.Linear(d, d, bias=False) |
| |
|
| | def forward(self, x, mask=None): |
| | B, N, _ = x.shape |
| | qkv = self.qkv(x).reshape(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4) |
| | q, k, v = qkv[0], qkv[1], qkv[2] |
| | att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk) |
| | att = att + alibi_bias(self.h, N) |
| | if mask is not None: |
| | att = att + mask |
| | z = (att.softmax(-1) @ v).transpose(1, 2).reshape(B, N, -1) |
| | return self.proj(z) |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | class MultiHopAttention(nn.Module): |
| | """ |
| | K explicit reasoning hops. Each hop: |
| | 1. Attend to current state |
| | 2. Update state with attended info |
| | 3. Next hop attends to updated state |
| | |
| | O(k * nΒ²) - linear in hops, quadratic in sequence |
| | """ |
| | def __init__(self, d: int, h: int, num_hops: int = 3): |
| | super().__init__() |
| | self.h, self.dk = h, d // h |
| | self.num_hops = num_hops |
| | |
| | |
| | self.q_projs = nn.ModuleList([nn.Linear(d, d, bias=False) for _ in range(num_hops)]) |
| | self.kv = nn.Linear(d, 2 * d, bias=False) |
| | self.proj = nn.Linear(d, d, bias=False) |
| | |
| | |
| | self.hop_gate = nn.Linear(d * num_hops, d) |
| |
|
| | def forward(self, x, mask=None): |
| | B, N, D = x.shape |
| | |
| | |
| | kv = self.kv(x).reshape(B, N, 2, self.h, self.dk).permute(2, 0, 3, 1, 4) |
| | k, v = kv[0], kv[1] |
| | |
| | bias = alibi_bias(self.h, N) |
| | hop_outputs = [] |
| | state = x |
| | |
| | for hop in range(self.num_hops): |
| | |
| | q = self.q_projs[hop](state).reshape(B, N, self.h, self.dk).transpose(1, 2) |
| | |
| | att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk) |
| | att = att + bias |
| | if mask is not None: |
| | att = att + mask |
| | |
| | hop_out = (att.softmax(-1) @ v).transpose(1, 2).reshape(B, N, -1) |
| | hop_outputs.append(hop_out) |
| | |
| | |
| | state = state + hop_out |
| | |
| | |
| | combined = torch.cat(hop_outputs, dim=-1) |
| | return self.proj(self.hop_gate(combined)) |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | class SlotAttention(nn.Module): |
| | """ |
| | Competitive binding: K slots compete for N positions. |
| | Unlike standard attention (N queries), we have K << N slots. |
| | |
| | Each slot iteratively refines what it attends to. |
| | Then we project slots back to sequence. |
| | |
| | O(iterations * K * N) where K = num_slots |
| | """ |
| | def __init__(self, d: int, num_slots: int = 8, num_iters: int = 3): |
| | super().__init__() |
| | self.num_slots = num_slots |
| | self.num_iters = num_iters |
| | self.d = d |
| | |
| | |
| | self.slots_mu = nn.Parameter(torch.randn(1, num_slots, d) * 0.02) |
| | self.slots_sigma = nn.Parameter(torch.ones(1, num_slots, d) * 0.02) |
| | |
| | |
| | self.to_q = nn.Linear(d, d, bias=False) |
| | self.to_k = nn.Linear(d, d, bias=False) |
| | self.to_v = nn.Linear(d, d, bias=False) |
| | |
| | |
| | self.gru = nn.GRUCell(d, d) |
| | self.mlp = nn.Sequential( |
| | nn.Linear(d, d * 2), |
| | nn.ReLU(), |
| | nn.Linear(d * 2, d) |
| | ) |
| | self.ln1 = nn.LayerNorm(d) |
| | self.ln2 = nn.LayerNorm(d) |
| | |
| | |
| | self.slot_to_seq = nn.Linear(d, d) |
| |
|
| | def forward(self, x, mask=None): |
| | B, N, D = x.shape |
| | |
| | |
| | slots = self.slots_mu + self.slots_sigma * torch.randn(B, self.num_slots, D, device=x.device) |
| | |
| | |
| | k = self.to_k(x) |
| | v = self.to_v(x) |
| | |
| | for _ in range(self.num_iters): |
| | slots_prev = slots |
| | slots = self.ln1(slots) |
| | |
| | |
| | q = self.to_q(slots) |
| | |
| | |
| | attn = torch.einsum('bkd,bnd->bkn', q, k) / math.sqrt(D) |
| | |
| | |
| | attn = F.softmax(attn, dim=1) |
| | |
| | |
| | updates = torch.einsum('bkn,bnd->bkd', attn, v) |
| | |
| | |
| | slots = self.gru( |
| | updates.reshape(B * self.num_slots, D), |
| | slots_prev.reshape(B * self.num_slots, D) |
| | ).reshape(B, self.num_slots, D) |
| | |
| | |
| | slots = slots + self.mlp(self.ln2(slots)) |
| | |
| | |
| | |
| | q_out = self.to_q(x) |
| | k_slots = self.to_k(slots) |
| | |
| | attn_out = torch.einsum('bnd,bkd->bnk', q_out, k_slots) / math.sqrt(D) |
| | attn_out = F.softmax(attn_out, dim=-1) |
| | |
| | output = torch.einsum('bnk,bkd->bnd', attn_out, slots) |
| | return self.slot_to_seq(output) |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | class EdgeComputeAttention(nn.Module): |
| | """ |
| | Standard attention: output = softmax(QK^T) @ V |
| | This is just a weighted sum - no computation on relationships. |
| | |
| | Edge-Compute: For each (i,j) pair, run MLP([q_i; k_j; v_j]) |
| | Then aggregate. Much heavier but captures richer interactions. |
| | |
| | O(nΒ² * mlp_cost) - quadratic with multiplicative MLP factor |
| | |
| | Note: Only practical for short sequences! |
| | """ |
| | def __init__(self, d: int, h: int, max_seq: int = 128): |
| | super().__init__() |
| | self.h, self.dk = h, d // h |
| | self.max_seq = max_seq |
| | |
| | self.qkv = nn.Linear(d, 3 * d, bias=False) |
| | |
| | |
| | self.edge_mlp = nn.Sequential( |
| | nn.Linear(3 * self.dk, 2 * self.dk), |
| | nn.ReLU(), |
| | nn.Linear(2 * self.dk, self.dk) |
| | ) |
| | |
| | |
| | self.score_mlp = nn.Sequential( |
| | nn.Linear(2 * self.dk, self.dk), |
| | nn.ReLU(), |
| | nn.Linear(self.dk, 1) |
| | ) |
| | |
| | self.proj = nn.Linear(d, d, bias=False) |
| |
|
| | def forward(self, x, mask=None): |
| | B, N, D = x.shape |
| | |
| | |
| | if N > self.max_seq: |
| | return self._standard_forward(x, mask) |
| | |
| | qkv = self.qkv(x).reshape(B, N, 3, self.h, self.dk) |
| | q, k, v = qkv[:,:,0], qkv[:,:,1], qkv[:,:,2] |
| | |
| | outputs = [] |
| | for head in range(self.h): |
| | q_h = q[:, :, head, :] |
| | k_h = k[:, :, head, :] |
| | v_h = v[:, :, head, :] |
| | |
| | |
| | q_exp = q_h.unsqueeze(2).expand(-1, -1, N, -1) |
| | k_exp = k_h.unsqueeze(1).expand(-1, N, -1, -1) |
| | v_exp = v_h.unsqueeze(1).expand(-1, N, -1, -1) |
| | |
| | |
| | edge_input = torch.cat([q_exp, k_exp, v_exp], dim=-1) |
| | |
| | |
| | edge_features = self.edge_mlp(edge_input) |
| | |
| | |
| | score_input = torch.cat([q_exp, k_exp], dim=-1) |
| | scores = self.score_mlp(score_input).squeeze(-1) |
| | |
| | |
| | if mask is not None: |
| | scores = scores + mask.squeeze(1) |
| | |
| | |
| | weights = F.softmax(scores, dim=-1) |
| | head_out = (weights.unsqueeze(-1) * edge_features).sum(dim=2) |
| | outputs.append(head_out) |
| | |
| | out = torch.cat(outputs, dim=-1) |
| | return self.proj(out) |
| | |
| | def _standard_forward(self, x, mask=None): |
| | B, N, _ = x.shape |
| | qkv = self.qkv(x).reshape(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4) |
| | q, k, v = qkv[0], qkv[1], qkv[2] |
| | att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk) |
| | att = att + alibi_bias(self.h, N) |
| | if mask is not None: |
| | att = att + mask |
| | z = (att.softmax(-1) @ v).transpose(1, 2).reshape(B, N, -1) |
| | return self.proj(z) |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | class MemoryAugmentedAttention(nn.Module): |
| | """ |
| | Maintain external memory bank M of size (mem_size, d). |
| | Each forward: |
| | 1. Read from memory using attention |
| | 2. Standard self-attention augmented with memory content |
| | 3. Write updated info back to memory |
| | |
| | O(nΒ² + n*mem_size) - adds memory interaction cost |
| | """ |
| | def __init__(self, d: int, h: int, mem_size: int = 64): |
| | super().__init__() |
| | self.h, self.dk = h, d // h |
| | self.mem_size = mem_size |
| | |
| | |
| | self.memory = nn.Parameter(torch.randn(1, mem_size, d) * 0.02) |
| | |
| | |
| | self.qkv = nn.Linear(d, 3 * d, bias=False) |
| | self.proj = nn.Linear(d, d, bias=False) |
| | |
| | |
| | self.mem_q = nn.Linear(d, d, bias=False) |
| | self.mem_k = nn.Linear(d, d, bias=False) |
| | self.mem_v = nn.Linear(d, d, bias=False) |
| | |
| | |
| | self.write_gate = nn.Sequential( |
| | nn.Linear(d * 2, d), |
| | nn.Sigmoid() |
| | ) |
| | |
| | |
| | self.combine = nn.Linear(d * 2, d) |
| |
|
| | def forward(self, x, mask=None): |
| | B, N, D = x.shape |
| | |
| | |
| | mem = self.memory.expand(B, -1, -1) |
| | |
| | |
| | q_mem = self.mem_q(x) |
| | k_mem = self.mem_k(mem) |
| | v_mem = self.mem_v(mem) |
| | |
| | mem_attn = torch.einsum('bnd,bmd->bnm', q_mem, k_mem) / math.sqrt(D) |
| | mem_attn = F.softmax(mem_attn, dim=-1) |
| | mem_read = torch.einsum('bnm,bmd->bnd', mem_attn, v_mem) |
| | |
| | |
| | qkv = self.qkv(x).reshape(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4) |
| | q, k, v = qkv[0], qkv[1], qkv[2] |
| | |
| | att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk) |
| | att = att + alibi_bias(self.h, N) |
| | if mask is not None: |
| | att = att + mask |
| | self_out = (att.softmax(-1) @ v).transpose(1, 2).reshape(B, N, -1) |
| | |
| | |
| | combined = self.combine(torch.cat([self_out, mem_read], dim=-1)) |
| | |
| | return self.proj(combined) |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | class RecurrentDepthAttention(nn.Module): |
| | """ |
| | Instead of L different layers, use 1 layer L times. |
| | Add depth embedding so model knows which iteration it's on. |
| | |
| | O(k * nΒ²) where k = num_recurrences |
| | |
| | Key insight: Weight sharing + depth embedding = potentially more |
| | efficient use of parameters for complex reasoning. |
| | """ |
| | def __init__(self, d: int, h: int, num_recur: int = 4): |
| | super().__init__() |
| | self.h, self.dk = h, d // h |
| | self.num_recur = num_recur |
| | |
| | self.qkv = nn.Linear(d, 3 * d, bias=False) |
| | self.proj = nn.Linear(d, d, bias=False) |
| | |
| | |
| | self.depth_emb = nn.Embedding(num_recur, d) |
| | |
| | |
| | self.transition = nn.Sequential( |
| | nn.LayerNorm(d), |
| | nn.Linear(d, d * 2), |
| | nn.GELU(), |
| | nn.Linear(d * 2, d) |
| | ) |
| |
|
| | def forward(self, x, mask=None): |
| | B, N, D = x.shape |
| | bias = alibi_bias(self.h, N) |
| | |
| | for r in range(self.num_recur): |
| | |
| | x_r = x + self.depth_emb.weight[r].unsqueeze(0).unsqueeze(0) |
| | |
| | |
| | qkv = self.qkv(x_r).reshape(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4) |
| | q, k, v = qkv[0], qkv[1], qkv[2] |
| | |
| | att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk) |
| | att = att + bias |
| | if mask is not None: |
| | att = att + mask |
| | |
| | attn_out = (att.softmax(-1) @ v).transpose(1, 2).reshape(B, N, -1) |
| | attn_out = self.proj(attn_out) |
| | |
| | |
| | x = x + attn_out |
| | x = x + self.transition(x) |
| | |
| | return x - x.detach() + x.detach() |
| |
|
| |
|
| | |
| | |
| | |
| | class Block(nn.Module): |
| | def __init__(self, d: int, h: int, attn_type: str = "standard", **kwargs): |
| | super().__init__() |
| | self.ln1 = nn.LayerNorm(d) |
| | self.ln2 = nn.LayerNorm(d) |
| | |
| | if attn_type == "standard": |
| | self.attn = StandardAttention(d, h) |
| | elif attn_type == "multihop": |
| | self.attn = MultiHopAttention(d, h, num_hops=kwargs.get('num_hops', 3)) |
| | elif attn_type == "slot": |
| | self.attn = SlotAttention(d, num_slots=kwargs.get('num_slots', 8)) |
| | elif attn_type == "edge": |
| | self.attn = EdgeComputeAttention(d, h) |
| | elif attn_type == "memory": |
| | self.attn = MemoryAugmentedAttention(d, h, mem_size=kwargs.get('mem_size', 64)) |
| | elif attn_type == "recurrent": |
| | self.attn = RecurrentDepthAttention(d, h, num_recur=kwargs.get('num_recur', 4)) |
| | else: |
| | raise ValueError(f"Unknown attn_type: {attn_type}") |
| | |
| | self.ff = nn.Sequential( |
| | nn.Linear(d, 4 * d), |
| | nn.GELU(), |
| | nn.Linear(4 * d, d) |
| | ) |
| |
|
| | def forward(self, x, mask=None): |
| | x = x + self.attn(self.ln1(x), mask) |
| | x = x + self.ff(self.ln2(x)) |
| | return x |
| |
|
| |
|
| | class HeavyModel(nn.Module): |
| | def __init__(self, d: int, layers: int, h: int, attn_type: str = "standard", **kwargs): |
| | super().__init__() |
| | self.emb = nn.Embedding(VOCAB, d) |
| | self.blocks = nn.ModuleList([Block(d, h, attn_type, **kwargs) for _ in range(layers)]) |
| | self.ln = nn.LayerNorm(d) |
| | self.head = nn.Linear(d, VOCAB, bias=False) |
| | self.head.weight = self.emb.weight |
| | |
| | def forward(self, x, mask=None): |
| | x = self.emb(x) |
| | for blk in self.blocks: |
| | x = blk(x, mask) |
| | return self.head(self.ln(x)) |
| | |
| | def count_params(self): |
| | return sum(p.numel() for p in self.parameters()) |
| |
|
| |
|
| | |
| | |
| | |
| | def run_experiment(attn_type: str, d: int, layers: int, heads: int, |
| | batch: int, seq: int, steps: int, **kwargs): |
| | print(f"\n{'='*60}") |
| | print(f"ATTENTION TYPE: {attn_type.upper()}") |
| | print(f"Config: d={d}, layers={layers}, heads={heads}") |
| | print(f"{'='*60}") |
| | |
| | model = HeavyModel(d, layers, heads, attn_type, **kwargs).to(DEV) |
| | print(f"Parameters: {model.count_params():,}") |
| | |
| | optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) |
| | mask = causal_mask(seq - 1) |
| | |
| | losses, times = [], [] |
| | |
| | for step in range(steps): |
| | ids = torch.randint(0, VOCAB, (batch, seq), device=DEV) |
| | target = ids[:, 1:] |
| | input_ids = ids[:, :-1] |
| | |
| | start = time.time() |
| | optimizer.zero_grad() |
| | logits = model(input_ids, mask) |
| | loss = F.cross_entropy(logits.view(-1, VOCAB), target.reshape(-1)) |
| | loss.backward() |
| | optimizer.step() |
| | elapsed = time.time() - start |
| | |
| | losses.append(loss.item()) |
| | times.append(elapsed) |
| | tok_s = (batch * seq) / elapsed |
| | |
| | if step % 10 == 0 or step == steps - 1: |
| | print(f"Step {step:3d} | Loss: {loss.item():.4f} | {tok_s:.0f} tok/s | {elapsed*1000:.0f}ms") |
| | |
| | avg_loss = sum(losses[-20:]) / min(20, len(losses)) |
| | avg_time = sum(times[-20:]) / min(20, len(times)) |
| | avg_toks = (batch * seq) / avg_time |
| | |
| | return { |
| | "type": attn_type, |
| | "loss": avg_loss, |
| | "tok_s": avg_toks, |
| | "params": model.count_params() |
| | } |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--d", type=int, default=256) |
| | parser.add_argument("--layers", type=int, default=4) |
| | parser.add_argument("--heads", type=int, default=8) |
| | parser.add_argument("--batch", type=int, default=16) |
| | parser.add_argument("--seq", type=int, default=128) |
| | parser.add_argument("--steps", type=int, default=100) |
| | parser.add_argument("--types", type=str, default="all", |
| | help="Comma-separated: standard,multihop,slot,edge,memory,recurrent") |
| | args = parser.parse_args() |
| | |
| | print(f"Device: {DEV}") |
| | if torch.cuda.is_available(): |
| | print(f"GPU: {torch.cuda.get_device_name()}") |
| | print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB") |
| | |
| | if args.types == "all": |
| | types = ["standard", "multihop", "slot", "edge", "memory", "recurrent"] |
| | else: |
| | types = [t.strip() for t in args.types.split(",")] |
| | |
| | results = [] |
| | for t in types: |
| | try: |
| | r = run_experiment(t, args.d, args.layers, args.heads, |
| | args.batch, args.seq, args.steps) |
| | results.append(r) |
| | except Exception as e: |
| | print(f"ERROR in {t}: {e}") |
| | import traceback |
| | traceback.print_exc() |
| | |
| | |
| | print(f"\n{'='*60}") |
| | print("SUMMARY") |
| | print(f"{'='*60}") |
| | baseline = next((r for r in results if r['type'] == 'standard'), None) |
| | |
| | for r in results: |
| | rel = "" |
| | if baseline and r['type'] != 'standard': |
| | loss_diff = (baseline['loss'] - r['loss']) / baseline['loss'] * 100 |
| | speed_ratio = r['tok_s'] / baseline['tok_s'] |
| | rel = f" | vs baseline: {loss_diff:+.1f}% loss, {speed_ratio:.2f}x speed" |
| | print(f"{r['type']:12s} | Loss: {r['loss']:.4f} | {r['tok_s']:6.0f} tok/s | {r['params']:,} params{rel}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|