OpenTransformer's picture
Add experiments/n_heavy2.py
2db758d verified
#!/usr/bin/env python3
"""
n_heavy2.py β€” Extended Heavy Attention Experiments
Testing mechanisms that use MORE compute than standard attention
Approaches:
1. Multi-Hop: Explicit k-step reasoning chains
2. Slot Attention: Competitive binding (from object-centric learning)
3. Edge-Compute: Full pairwise MLP, not just weighted sum
4. Memory-Aug: External memory bank with read/write
5. Recurrent Depth: Same block applied k times (Universal Transformer)
"""
from __future__ import annotations
import argparse, math, time
import torch
import torch.nn as nn
import torch.nn.functional as F
DEV = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cuda.matmul.allow_tf32 = True
try:
torch.set_float32_matmul_precision("high")
except:
pass
VOCAB = 128256
EOS = 128001
# ─────────────────────────── ALiBi ───────────────────────────
def _alibi_slopes(n_heads: int):
def pow2slopes(n):
start = 2 ** (-2 ** -(math.log2(n) - 3))
return [start * (start ** i) for i in range(n)]
if math.log2(n_heads).is_integer():
vals = pow2slopes(n_heads)
else:
closest = 2 ** math.floor(math.log2(n_heads))
vals = pow2slopes(closest)
extra = pow2slopes(2 * closest)
vals += extra[0::2][:n_heads - closest]
return torch.tensor(vals, device=DEV).view(1, n_heads, 1, 1)
def alibi_bias(n_heads: int, n_tokens: int):
i = torch.arange(n_tokens, device=DEV).view(1, 1, n_tokens, 1)
j = torch.arange(n_tokens, device=DEV).view(1, 1, 1, n_tokens)
dist = (j - i).clamp_min(0).float()
slopes = _alibi_slopes(n_heads)
return -slopes * dist
def causal_mask(n):
return torch.triu(torch.full((1, 1, n, n), float("-inf"), device=DEV), 1)
# ═══════════════════════════════════════════════════════════════
# BASELINE: Standard Attention
# ═══════════════════════════════════════════════════════════════
class StandardAttention(nn.Module):
def __init__(self, d: int, h: int):
super().__init__()
assert d % h == 0
self.h, self.dk = h, d // h
self.qkv = nn.Linear(d, 3 * d, bias=False)
self.proj = nn.Linear(d, d, bias=False)
def forward(self, x, mask=None):
B, N, _ = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
att = att + alibi_bias(self.h, N)
if mask is not None:
att = att + mask
z = (att.softmax(-1) @ v).transpose(1, 2).reshape(B, N, -1)
return self.proj(z)
# ═══════════════════════════════════════════════════════════════
# HEAVY 1: Multi-Hop Attention
# Each "hop" attends to previous hop's output
# Simulates multi-step reasoning chains
# ═══════════════════════════════════════════════════════════════
class MultiHopAttention(nn.Module):
"""
K explicit reasoning hops. Each hop:
1. Attend to current state
2. Update state with attended info
3. Next hop attends to updated state
O(k * nΒ²) - linear in hops, quadratic in sequence
"""
def __init__(self, d: int, h: int, num_hops: int = 3):
super().__init__()
self.h, self.dk = h, d // h
self.num_hops = num_hops
# Separate Q projection per hop (K,V shared)
self.q_projs = nn.ModuleList([nn.Linear(d, d, bias=False) for _ in range(num_hops)])
self.kv = nn.Linear(d, 2 * d, bias=False)
self.proj = nn.Linear(d, d, bias=False)
# Hop mixing: combine info from all hops
self.hop_gate = nn.Linear(d * num_hops, d)
def forward(self, x, mask=None):
B, N, D = x.shape
# Compute K, V once (shared across hops)
kv = self.kv(x).reshape(B, N, 2, self.h, self.dk).permute(2, 0, 3, 1, 4)
k, v = kv[0], kv[1]
bias = alibi_bias(self.h, N)
hop_outputs = []
state = x
for hop in range(self.num_hops):
# Query from current state
q = self.q_projs[hop](state).reshape(B, N, self.h, self.dk).transpose(1, 2)
att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
att = att + bias
if mask is not None:
att = att + mask
hop_out = (att.softmax(-1) @ v).transpose(1, 2).reshape(B, N, -1)
hop_outputs.append(hop_out)
# Update state for next hop
state = state + hop_out
# Combine all hops
combined = torch.cat(hop_outputs, dim=-1)
return self.proj(self.hop_gate(combined))
# ═══════════════════════════════════════════════════════════════
# HEAVY 2: Slot Attention
# From "Object-Centric Learning with Slot Attention"
# Slots compete to bind to input positions
# ═══════════════════════════════════════════════════════════════
class SlotAttention(nn.Module):
"""
Competitive binding: K slots compete for N positions.
Unlike standard attention (N queries), we have K << N slots.
Each slot iteratively refines what it attends to.
Then we project slots back to sequence.
O(iterations * K * N) where K = num_slots
"""
def __init__(self, d: int, num_slots: int = 8, num_iters: int = 3):
super().__init__()
self.num_slots = num_slots
self.num_iters = num_iters
self.d = d
# Learnable slot initializations
self.slots_mu = nn.Parameter(torch.randn(1, num_slots, d) * 0.02)
self.slots_sigma = nn.Parameter(torch.ones(1, num_slots, d) * 0.02)
# Attention
self.to_q = nn.Linear(d, d, bias=False)
self.to_k = nn.Linear(d, d, bias=False)
self.to_v = nn.Linear(d, d, bias=False)
# Slot update GRU
self.gru = nn.GRUCell(d, d)
self.mlp = nn.Sequential(
nn.Linear(d, d * 2),
nn.ReLU(),
nn.Linear(d * 2, d)
)
self.ln1 = nn.LayerNorm(d)
self.ln2 = nn.LayerNorm(d)
# Project slots back to sequence
self.slot_to_seq = nn.Linear(d, d)
def forward(self, x, mask=None):
B, N, D = x.shape
# Initialize slots with noise
slots = self.slots_mu + self.slots_sigma * torch.randn(B, self.num_slots, D, device=x.device)
# Pre-compute keys and values
k = self.to_k(x) # (B, N, D)
v = self.to_v(x) # (B, N, D)
for _ in range(self.num_iters):
slots_prev = slots
slots = self.ln1(slots)
# Slot attention: slots query, inputs are keys/values
q = self.to_q(slots) # (B, K, D)
# Attention: (B, K, D) @ (B, D, N) -> (B, K, N)
attn = torch.einsum('bkd,bnd->bkn', q, k) / math.sqrt(D)
# Softmax over SLOTS (competition) not positions
attn = F.softmax(attn, dim=1) # Slots compete for each position
# Weighted sum of values
updates = torch.einsum('bkn,bnd->bkd', attn, v) # (B, K, D)
# GRU update
slots = self.gru(
updates.reshape(B * self.num_slots, D),
slots_prev.reshape(B * self.num_slots, D)
).reshape(B, self.num_slots, D)
# MLP residual
slots = slots + self.mlp(self.ln2(slots))
# Project slots back to sequence length
# Use attention from slots to positions
q_out = self.to_q(x) # (B, N, D)
k_slots = self.to_k(slots) # (B, K, D)
attn_out = torch.einsum('bnd,bkd->bnk', q_out, k_slots) / math.sqrt(D)
attn_out = F.softmax(attn_out, dim=-1) # (B, N, K)
output = torch.einsum('bnk,bkd->bnd', attn_out, slots)
return self.slot_to_seq(output)
# ═══════════════════════════════════════════════════════════════
# HEAVY 3: Edge-Compute Attention
# Instead of weighted sum, compute MLP on each (query, key) pair
# ═══════════════════════════════════════════════════════════════
class EdgeComputeAttention(nn.Module):
"""
Standard attention: output = softmax(QK^T) @ V
This is just a weighted sum - no computation on relationships.
Edge-Compute: For each (i,j) pair, run MLP([q_i; k_j; v_j])
Then aggregate. Much heavier but captures richer interactions.
O(nΒ² * mlp_cost) - quadratic with multiplicative MLP factor
Note: Only practical for short sequences!
"""
def __init__(self, d: int, h: int, max_seq: int = 128):
super().__init__()
self.h, self.dk = h, d // h
self.max_seq = max_seq
self.qkv = nn.Linear(d, 3 * d, bias=False)
# Edge MLP: processes each (q_i, k_j, v_j) triple
self.edge_mlp = nn.Sequential(
nn.Linear(3 * self.dk, 2 * self.dk),
nn.ReLU(),
nn.Linear(2 * self.dk, self.dk)
)
# Attention for aggregation
self.score_mlp = nn.Sequential(
nn.Linear(2 * self.dk, self.dk),
nn.ReLU(),
nn.Linear(self.dk, 1)
)
self.proj = nn.Linear(d, d, bias=False)
def forward(self, x, mask=None):
B, N, D = x.shape
# For long sequences, fall back to standard
if N > self.max_seq:
return self._standard_forward(x, mask)
qkv = self.qkv(x).reshape(B, N, 3, self.h, self.dk)
q, k, v = qkv[:,:,0], qkv[:,:,1], qkv[:,:,2] # Each: (B, N, H, dk)
outputs = []
for head in range(self.h):
q_h = q[:, :, head, :] # (B, N, dk)
k_h = k[:, :, head, :]
v_h = v[:, :, head, :]
# Expand for pairwise: (B, N, 1, dk) and (B, 1, N, dk)
q_exp = q_h.unsqueeze(2).expand(-1, -1, N, -1) # (B, N, N, dk)
k_exp = k_h.unsqueeze(1).expand(-1, N, -1, -1) # (B, N, N, dk)
v_exp = v_h.unsqueeze(1).expand(-1, N, -1, -1) # (B, N, N, dk)
# Concatenate for edge MLP
edge_input = torch.cat([q_exp, k_exp, v_exp], dim=-1) # (B, N, N, 3*dk)
# Compute edge features
edge_features = self.edge_mlp(edge_input) # (B, N, N, dk)
# Compute attention scores
score_input = torch.cat([q_exp, k_exp], dim=-1) # (B, N, N, 2*dk)
scores = self.score_mlp(score_input).squeeze(-1) # (B, N, N)
# Apply causal mask
if mask is not None:
scores = scores + mask.squeeze(1)
# Aggregate
weights = F.softmax(scores, dim=-1) # (B, N, N)
head_out = (weights.unsqueeze(-1) * edge_features).sum(dim=2) # (B, N, dk)
outputs.append(head_out)
out = torch.cat(outputs, dim=-1) # (B, N, D)
return self.proj(out)
def _standard_forward(self, x, mask=None):
B, N, _ = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
att = att + alibi_bias(self.h, N)
if mask is not None:
att = att + mask
z = (att.softmax(-1) @ v).transpose(1, 2).reshape(B, N, -1)
return self.proj(z)
# ═══════════════════════════════════════════════════════════════
# HEAVY 4: Memory-Augmented Attention
# External memory bank with read/write operations
# ═══════════════════════════════════════════════════════════════
class MemoryAugmentedAttention(nn.Module):
"""
Maintain external memory bank M of size (mem_size, d).
Each forward:
1. Read from memory using attention
2. Standard self-attention augmented with memory content
3. Write updated info back to memory
O(nΒ² + n*mem_size) - adds memory interaction cost
"""
def __init__(self, d: int, h: int, mem_size: int = 64):
super().__init__()
self.h, self.dk = h, d // h
self.mem_size = mem_size
# Persistent memory (learned)
self.memory = nn.Parameter(torch.randn(1, mem_size, d) * 0.02)
# Standard attention
self.qkv = nn.Linear(d, 3 * d, bias=False)
self.proj = nn.Linear(d, d, bias=False)
# Memory read/write
self.mem_q = nn.Linear(d, d, bias=False)
self.mem_k = nn.Linear(d, d, bias=False)
self.mem_v = nn.Linear(d, d, bias=False)
# Write gate
self.write_gate = nn.Sequential(
nn.Linear(d * 2, d),
nn.Sigmoid()
)
# Combine self-attention and memory
self.combine = nn.Linear(d * 2, d)
def forward(self, x, mask=None):
B, N, D = x.shape
# Expand memory for batch
mem = self.memory.expand(B, -1, -1) # (B, mem_size, D)
# 1. Read from memory
q_mem = self.mem_q(x) # (B, N, D)
k_mem = self.mem_k(mem) # (B, mem_size, D)
v_mem = self.mem_v(mem) # (B, mem_size, D)
mem_attn = torch.einsum('bnd,bmd->bnm', q_mem, k_mem) / math.sqrt(D)
mem_attn = F.softmax(mem_attn, dim=-1)
mem_read = torch.einsum('bnm,bmd->bnd', mem_attn, v_mem) # (B, N, D)
# 2. Standard self-attention
qkv = self.qkv(x).reshape(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
att = att + alibi_bias(self.h, N)
if mask is not None:
att = att + mask
self_out = (att.softmax(-1) @ v).transpose(1, 2).reshape(B, N, -1)
# 3. Combine self-attention and memory read
combined = self.combine(torch.cat([self_out, mem_read], dim=-1))
return self.proj(combined)
# ═══════════════════════════════════════════════════════════════
# HEAVY 5: Recurrent Depth (Universal Transformer)
# Same block applied k times with position-in-depth encoding
# ═══════════════════════════════════════════════════════════════
class RecurrentDepthAttention(nn.Module):
"""
Instead of L different layers, use 1 layer L times.
Add depth embedding so model knows which iteration it's on.
O(k * nΒ²) where k = num_recurrences
Key insight: Weight sharing + depth embedding = potentially more
efficient use of parameters for complex reasoning.
"""
def __init__(self, d: int, h: int, num_recur: int = 4):
super().__init__()
self.h, self.dk = h, d // h
self.num_recur = num_recur
self.qkv = nn.Linear(d, 3 * d, bias=False)
self.proj = nn.Linear(d, d, bias=False)
# Depth embedding
self.depth_emb = nn.Embedding(num_recur, d)
# Transition function between recurrences
self.transition = nn.Sequential(
nn.LayerNorm(d),
nn.Linear(d, d * 2),
nn.GELU(),
nn.Linear(d * 2, d)
)
def forward(self, x, mask=None):
B, N, D = x.shape
bias = alibi_bias(self.h, N)
for r in range(self.num_recur):
# Add depth embedding
x_r = x + self.depth_emb.weight[r].unsqueeze(0).unsqueeze(0)
# Self-attention
qkv = self.qkv(x_r).reshape(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
att = att + bias
if mask is not None:
att = att + mask
attn_out = (att.softmax(-1) @ v).transpose(1, 2).reshape(B, N, -1)
attn_out = self.proj(attn_out)
# Residual + transition
x = x + attn_out
x = x + self.transition(x)
return x - x.detach() + x.detach() # Gradient trick for stability
# ═══════════════════════════════════════════════════════════════
# Block and Model wrappers
# ═══════════════════════════════════════════════════════════════
class Block(nn.Module):
def __init__(self, d: int, h: int, attn_type: str = "standard", **kwargs):
super().__init__()
self.ln1 = nn.LayerNorm(d)
self.ln2 = nn.LayerNorm(d)
if attn_type == "standard":
self.attn = StandardAttention(d, h)
elif attn_type == "multihop":
self.attn = MultiHopAttention(d, h, num_hops=kwargs.get('num_hops', 3))
elif attn_type == "slot":
self.attn = SlotAttention(d, num_slots=kwargs.get('num_slots', 8))
elif attn_type == "edge":
self.attn = EdgeComputeAttention(d, h)
elif attn_type == "memory":
self.attn = MemoryAugmentedAttention(d, h, mem_size=kwargs.get('mem_size', 64))
elif attn_type == "recurrent":
self.attn = RecurrentDepthAttention(d, h, num_recur=kwargs.get('num_recur', 4))
else:
raise ValueError(f"Unknown attn_type: {attn_type}")
self.ff = nn.Sequential(
nn.Linear(d, 4 * d),
nn.GELU(),
nn.Linear(4 * d, d)
)
def forward(self, x, mask=None):
x = x + self.attn(self.ln1(x), mask)
x = x + self.ff(self.ln2(x))
return x
class HeavyModel(nn.Module):
def __init__(self, d: int, layers: int, h: int, attn_type: str = "standard", **kwargs):
super().__init__()
self.emb = nn.Embedding(VOCAB, d)
self.blocks = nn.ModuleList([Block(d, h, attn_type, **kwargs) for _ in range(layers)])
self.ln = nn.LayerNorm(d)
self.head = nn.Linear(d, VOCAB, bias=False)
self.head.weight = self.emb.weight # Tie weights
def forward(self, x, mask=None):
x = self.emb(x)
for blk in self.blocks:
x = blk(x, mask)
return self.head(self.ln(x))
def count_params(self):
return sum(p.numel() for p in self.parameters())
# ═══════════════════════════════════════════════════════════════
# Experiment Runner
# ═══════════════════════════════════════════════════════════════
def run_experiment(attn_type: str, d: int, layers: int, heads: int,
batch: int, seq: int, steps: int, **kwargs):
print(f"\n{'='*60}")
print(f"ATTENTION TYPE: {attn_type.upper()}")
print(f"Config: d={d}, layers={layers}, heads={heads}")
print(f"{'='*60}")
model = HeavyModel(d, layers, heads, attn_type, **kwargs).to(DEV)
print(f"Parameters: {model.count_params():,}")
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
mask = causal_mask(seq - 1)
losses, times = [], []
for step in range(steps):
ids = torch.randint(0, VOCAB, (batch, seq), device=DEV)
target = ids[:, 1:]
input_ids = ids[:, :-1]
start = time.time()
optimizer.zero_grad()
logits = model(input_ids, mask)
loss = F.cross_entropy(logits.view(-1, VOCAB), target.reshape(-1))
loss.backward()
optimizer.step()
elapsed = time.time() - start
losses.append(loss.item())
times.append(elapsed)
tok_s = (batch * seq) / elapsed
if step % 10 == 0 or step == steps - 1:
print(f"Step {step:3d} | Loss: {loss.item():.4f} | {tok_s:.0f} tok/s | {elapsed*1000:.0f}ms")
avg_loss = sum(losses[-20:]) / min(20, len(losses))
avg_time = sum(times[-20:]) / min(20, len(times))
avg_toks = (batch * seq) / avg_time
return {
"type": attn_type,
"loss": avg_loss,
"tok_s": avg_toks,
"params": model.count_params()
}
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--d", type=int, default=256)
parser.add_argument("--layers", type=int, default=4)
parser.add_argument("--heads", type=int, default=8)
parser.add_argument("--batch", type=int, default=16)
parser.add_argument("--seq", type=int, default=128)
parser.add_argument("--steps", type=int, default=100)
parser.add_argument("--types", type=str, default="all",
help="Comma-separated: standard,multihop,slot,edge,memory,recurrent")
args = parser.parse_args()
print(f"Device: {DEV}")
if torch.cuda.is_available():
print(f"GPU: {torch.cuda.get_device_name()}")
print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
if args.types == "all":
types = ["standard", "multihop", "slot", "edge", "memory", "recurrent"]
else:
types = [t.strip() for t in args.types.split(",")]
results = []
for t in types:
try:
r = run_experiment(t, args.d, args.layers, args.heads,
args.batch, args.seq, args.steps)
results.append(r)
except Exception as e:
print(f"ERROR in {t}: {e}")
import traceback
traceback.print_exc()
# Summary
print(f"\n{'='*60}")
print("SUMMARY")
print(f"{'='*60}")
baseline = next((r for r in results if r['type'] == 'standard'), None)
for r in results:
rel = ""
if baseline and r['type'] != 'standard':
loss_diff = (baseline['loss'] - r['loss']) / baseline['loss'] * 100
speed_ratio = r['tok_s'] / baseline['tok_s']
rel = f" | vs baseline: {loss_diff:+.1f}% loss, {speed_ratio:.2f}x speed"
print(f"{r['type']:12s} | Loss: {r['loss']:.4f} | {r['tok_s']:6.0f} tok/s | {r['params']:,} params{rel}")
if __name__ == "__main__":
main()