OpenTransformer's picture
Add experiments/n_heavy.py
2b0bfd4 verified
raw
history blame
18.4 kB
#!/usr/bin/env python3
"""
n_heavy.py β€” Iterative Refinement Transformer Experiment
Heavier-than-standard-attention: tokens get reprocessed based on uncertainty
Key idea: Instead of single-pass attention, run multiple iterations
where "hard" tokens (high uncertainty) get recomputed while "easy" tokens halt.
This is O(nΒ² Γ— k) where k = average iterations, vs standard O(nΒ²).
"""
from __future__ import annotations
import argparse, json, math, pathlib, random, time, os, sys
from contextlib import nullcontext
from typing import Dict, Any, List, Optional, Tuple
from datetime import datetime, timezone
import torch
import torch.nn as nn
import torch.nn.functional as F
# ─────────────────────────── Globals ───────────────────────────
DEV = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cuda.matmul.allow_tf32 = True
try:
torch.set_float32_matmul_precision("high")
except:
pass
VOCAB = 128256 # DeepSeek V3 vocab
EOS = 128001
# ─────────────────────────── ALiBi ───────────────────────────
def _alibi_slopes(n_heads: int):
def pow2slopes(n):
start = 2 ** (-2 ** -(math.log2(n) - 3))
ratio = start
return [start * (ratio ** i) for i in range(n)]
if math.log2(n_heads).is_integer(): vals = pow2slopes(n_heads)
else:
closest = 2 ** math.floor(math.log2(n_heads))
vals = pow2slopes(closest)
extra = pow2slopes(2 * closest)
vals += extra[0::2][: n_heads - closest]
return torch.tensor(vals, device=DEV).view(1, n_heads, 1, 1)
def alibi_bias(n_heads: int, n_tokens: int):
i = torch.arange(n_tokens, device=DEV).view(1, 1, n_tokens, 1)
j = torch.arange(n_tokens, device=DEV).view(1, 1, 1, n_tokens)
dist = (j - i).clamp_min(0)
return -_alibi_slopes(n_heads) * dist
# ─────────────────────────── Standard Attention ───────────────────────────
class StandardAttention(nn.Module):
"""Baseline: single-pass multi-head attention"""
def __init__(self, d: int, h: int):
super().__init__()
assert d % h == 0
self.h, self.dk = h, d // h
self.qkv = nn.Linear(d, 3 * d, bias=False)
self.proj = nn.Linear(d, d, bias=False)
self.drop = nn.Dropout(0.1)
def forward(self, x, mask=None):
B, N, _ = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
att = att + alibi_bias(self.h, N)
if mask is not None:
att = att + mask
z = (att.softmax(-1) @ v).transpose(1, 2).reshape(B, N, -1)
return self.drop(self.proj(z))
# ─────────────────────────── HEAVY: Iterative Refinement Attention ───────────────────────────
class IterativeAttention(nn.Module):
"""
Heavier-than-standard: iteratively refine representations.
Each token has a "halting probability" - once it exceeds threshold,
that token stops updating. Hard tokens keep getting reprocessed.
Inspired by Universal Transformers + PonderNet.
"""
def __init__(self, d: int, h: int, max_iters: int = 5, halt_threshold: float = 0.9):
super().__init__()
assert d % h == 0
self.h, self.dk = h, d // h
self.max_iters = max_iters
self.halt_threshold = halt_threshold
# Shared attention weights across iterations (Universal Transformer style)
self.qkv = nn.Linear(d, 3 * d, bias=False)
self.proj = nn.Linear(d, d, bias=False)
self.drop = nn.Dropout(0.1)
# Halting predictor: per-token probability of "done processing"
self.halt_pred = nn.Sequential(
nn.Linear(d, d // 4),
nn.ReLU(),
nn.Linear(d // 4, 1),
nn.Sigmoid()
)
# Iteration embedding: tell model which iteration we're on
self.iter_emb = nn.Embedding(max_iters, d)
def forward(self, x, mask=None):
B, N, D = x.shape
# Track halting state
halted = torch.zeros(B, N, 1, device=x.device, dtype=torch.bool)
cumulative_halt = torch.zeros(B, N, 1, device=x.device)
# Accumulate outputs weighted by when each token halted
output = torch.zeros_like(x)
remainder = torch.ones(B, N, 1, device=x.device)
total_compute = 0
for i in range(self.max_iters):
# Add iteration embedding
x_iter = x + self.iter_emb.weight[i].unsqueeze(0).unsqueeze(0)
# Standard attention on current state
qkv = self.qkv(x_iter).reshape(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
att = att + alibi_bias(self.h, N)
if mask is not None:
att = att + mask
z = (att.softmax(-1) @ v).transpose(1, 2).reshape(B, N, -1)
delta = self.drop(self.proj(z))
# Compute halting probability for each token
halt_prob = self.halt_pred(x + delta) # p(halt | current state)
# Update cumulative halt probability
new_cumulative = cumulative_halt + halt_prob * (~halted).float()
# Tokens that should halt this iteration
should_halt = (new_cumulative >= self.halt_threshold) & (~halted)
# For halting tokens, use remainder; for already halted, 0; for continuing, halt_prob
contrib_weight = torch.where(
should_halt,
remainder,
torch.where(halted, torch.zeros_like(halt_prob), halt_prob)
)
# Accumulate output
output = output + contrib_weight * (x + delta)
# Update remainder
remainder = remainder - contrib_weight
# Update halted status
halted = halted | should_halt
cumulative_halt = new_cumulative
# Update x for next iteration (only for non-halted)
x = torch.where(halted.expand_as(x), x, x + delta)
# Track compute
total_compute += (~halted).float().sum().item()
# Early exit if all halted
if halted.all():
break
# Final remainder goes to last state
output = output + remainder * x
# Store stats for analysis
self._last_iters = i + 1
self._last_compute_ratio = total_compute / (B * N * self.max_iters)
return output
# ─────────────────────────── HEAVY: Triplet Attention ───────────────────────────
class TripletAttention(nn.Module):
"""
O(nΒ³) attention: model 3-way interactions.
"How does token A relate to B in context of C?"
This is VERY heavy - use small sequences only.
"""
def __init__(self, d: int, h: int, max_triplet_n: int = 64):
super().__init__()
self.h, self.dk = h, d // h
self.max_triplet_n = max_triplet_n
# Standard pairwise attention
self.qkv = nn.Linear(d, 3 * d, bias=False)
# Triplet scoring: takes concatenated (q_i, k_j, k_c) and outputs score modifier
self.triplet_score = nn.Sequential(
nn.Linear(3 * d // h, d // h),
nn.ReLU(),
nn.Linear(d // h, 1)
)
self.proj = nn.Linear(d, d, bias=False)
self.drop = nn.Dropout(0.1)
def forward(self, x, mask=None):
B, N, D = x.shape
# For large N, fall back to standard attention
if N > self.max_triplet_n:
return self._standard_forward(x, mask)
qkv = self.qkv(x).reshape(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # Each: (B, H, N, dk)
# Pairwise scores
pairwise = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk) # (B, H, N, N)
# Triplet modulation: for each (i,j) pair, average influence from all contexts c
# This is O(nΒ³) - compute triplet score for each (i, j, c) triple
triplet_mod = torch.zeros_like(pairwise)
for c in range(N): # Context position
# For each (i,j), compute how context c modifies the attention
# q_i: (B, H, N, dk), k_j: (B, H, N, dk), k_c: (B, H, 1, dk)
k_c = k[:, :, c:c+1, :].expand(-1, -1, N, -1) # (B, H, N, dk)
# Broadcast: q (B,H,N,1,dk), k (B,H,1,N,dk), k_c (B,H,N,1,dk)
q_exp = q.unsqueeze(3) # (B, H, N, 1, dk)
k_exp = k.unsqueeze(2) # (B, H, 1, N, dk)
k_c_exp = k_c.unsqueeze(3) # (B, H, N, 1, dk)
# Concatenate for triplet: (q_i, k_j, k_c)
triplet_input = torch.cat([
q_exp.expand(-1, -1, -1, N, -1),
k_exp.expand(-1, -1, N, -1, -1),
k_c_exp.expand(-1, -1, -1, N, -1)
], dim=-1) # (B, H, N, N, 3*dk)
# Score modification from this context
mod = self.triplet_score(triplet_input).squeeze(-1) # (B, H, N, N)
triplet_mod = triplet_mod + mod
# Average over contexts and combine
triplet_mod = triplet_mod / N
att = pairwise + 0.1 * triplet_mod # Residual triplet contribution
att = att + alibi_bias(self.h, N)
if mask is not None:
att = att + mask
z = (att.softmax(-1) @ v).transpose(1, 2).reshape(B, N, -1)
return self.drop(self.proj(z))
def _standard_forward(self, x, mask=None):
B, N, _ = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
att = att + alibi_bias(self.h, N)
if mask is not None:
att = att + mask
z = (att.softmax(-1) @ v).transpose(1, 2).reshape(B, N, -1)
return self.drop(self.proj(z))
# ─────────────────────────── Block Variants ───────────────────────────
class StandardBlock(nn.Module):
def __init__(self, d: int, h: int):
super().__init__()
self.ln1, self.ln2 = nn.LayerNorm(d), nn.LayerNorm(d)
self.attn = StandardAttention(d, h)
self.ff = nn.Sequential(nn.Linear(d, 4 * d), nn.ReLU(), nn.Linear(4 * d, d))
def forward(self, x, mask=None):
x = x + self.attn(self.ln1(x), mask)
return x + self.ff(self.ln2(x))
class IterativeBlock(nn.Module):
def __init__(self, d: int, h: int, max_iters: int = 5):
super().__init__()
self.ln1, self.ln2 = nn.LayerNorm(d), nn.LayerNorm(d)
self.attn = IterativeAttention(d, h, max_iters=max_iters)
self.ff = nn.Sequential(nn.Linear(d, 4 * d), nn.ReLU(), nn.Linear(4 * d, d))
def forward(self, x, mask=None):
x = x + self.attn(self.ln1(x), mask)
return x + self.ff(self.ln2(x))
class TripletBlock(nn.Module):
def __init__(self, d: int, h: int):
super().__init__()
self.ln1, self.ln2 = nn.LayerNorm(d), nn.LayerNorm(d)
self.attn = TripletAttention(d, h)
self.ff = nn.Sequential(nn.Linear(d, 4 * d), nn.ReLU(), nn.Linear(4 * d, d))
def forward(self, x, mask=None):
x = x + self.attn(self.ln1(x), mask)
return x + self.ff(self.ln2(x))
# ─────────────────────────── Models ───────────────────────────
class HeavyTransformer(nn.Module):
def __init__(self, d: int, layers: int, heads: int, mode: str = "standard"):
super().__init__()
self.emb = nn.Embedding(VOCAB, d)
if mode == "standard":
self.blocks = nn.ModuleList([StandardBlock(d, heads) for _ in range(layers)])
elif mode == "iterative":
self.blocks = nn.ModuleList([IterativeBlock(d, heads) for _ in range(layers)])
elif mode == "triplet":
self.blocks = nn.ModuleList([TripletBlock(d, heads) for _ in range(layers)])
else:
raise ValueError(f"Unknown mode: {mode}")
self.ln = nn.LayerNorm(d)
self.head = nn.Linear(d, VOCAB)
self.mode = mode
# Tie weights
self.head.weight = self.emb.weight
def forward(self, ids, mask=None):
x = self.emb(ids)
for blk in self.blocks:
x = blk(x, mask)
return self.head(self.ln(x))
def count_params(self):
return sum(p.numel() for p in self.parameters())
# ─────────────────────────── Experiment Runner ───────────────────────────
def causal_mask(n):
return torch.triu(torch.full((1, 1, n, n), float("-inf"), device=DEV), 1)
def run_experiment(mode: str, d: int, layers: int, heads: int,
batch_size: int, seq_len: int, num_steps: int):
"""Run training steps and measure loss + throughput"""
print(f"\n{'='*60}")
print(f"MODE: {mode.upper()}")
print(f"Config: d={d}, layers={layers}, heads={heads}")
print(f"{'='*60}")
model = HeavyTransformer(d, layers, heads, mode=mode).to(DEV)
print(f"Parameters: {model.count_params():,}")
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
losses = []
times = []
for step in range(num_steps):
# Random batch
ids = torch.randint(0, VOCAB, (batch_size, seq_len), device=DEV)
target = ids[:, 1:]
input_ids = ids[:, :-1]
mask = causal_mask(seq_len - 1)
start = time.time()
optimizer.zero_grad()
logits = model(input_ids, mask)
loss = F.cross_entropy(logits.view(-1, VOCAB), target.reshape(-1))
loss.backward()
optimizer.step()
elapsed = time.time() - start
times.append(elapsed)
losses.append(loss.item())
tok_per_sec = (batch_size * seq_len) / elapsed
if step % 10 == 0 or step == num_steps - 1:
print(f"Step {step:3d} | Loss: {loss.item():.4f} | {tok_per_sec:.0f} tok/s | {elapsed*1000:.0f}ms")
# For iterative attention, show extra stats
if mode == "iterative" and hasattr(model.blocks[0].attn, '_last_iters'):
if step % 20 == 0:
avg_iters = model.blocks[0].attn._last_iters
compute_ratio = model.blocks[0].attn._last_compute_ratio
print(f" └─ Avg iters: {avg_iters}, Compute ratio: {compute_ratio:.2%}")
avg_loss = sum(losses[-20:]) / min(20, len(losses))
avg_time = sum(times[-20:]) / min(20, len(times))
avg_toks = (batch_size * seq_len) / avg_time
return {
"mode": mode,
"final_loss": losses[-1],
"avg_loss": avg_loss,
"avg_tok_per_sec": avg_toks,
"params": model.count_params()
}
def main():
parser = argparse.ArgumentParser(description="Heavy Attention Experiment")
parser.add_argument("--d", type=int, default=256, help="Model dimension")
parser.add_argument("--layers", type=int, default=4, help="Number of layers")
parser.add_argument("--heads", type=int, default=8, help="Number of heads")
parser.add_argument("--batch", type=int, default=8, help="Batch size")
parser.add_argument("--seq", type=int, default=128, help="Sequence length")
parser.add_argument("--steps", type=int, default=100, help="Training steps")
parser.add_argument("--mode", type=str, default="all",
choices=["standard", "iterative", "triplet", "all"])
args = parser.parse_args()
print(f"Device: {DEV}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"GPU: {torch.cuda.get_device_name()}")
print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
results = []
modes = ["standard", "iterative", "triplet"] if args.mode == "all" else [args.mode]
for mode in modes:
try:
result = run_experiment(
mode=mode,
d=args.d,
layers=args.layers,
heads=args.heads,
batch_size=args.batch,
seq_len=args.seq,
num_steps=args.steps
)
results.append(result)
except Exception as e:
print(f"ERROR in {mode}: {e}")
import traceback
traceback.print_exc()
# Summary
print(f"\n{'='*60}")
print("SUMMARY")
print(f"{'='*60}")
for r in results:
print(f"{r['mode']:12s} | Loss: {r['avg_loss']:.4f} | {r['avg_tok_per_sec']:6.0f} tok/s | {r['params']:,} params")
# Scientific comparison
if len(results) >= 2:
baseline = next((r for r in results if r['mode'] == 'standard'), results[0])
print(f"\n{'='*60}")
print("RELATIVE TO STANDARD:")
print(f"{'='*60}")
for r in results:
if r['mode'] != 'standard':
loss_diff = (baseline['avg_loss'] - r['avg_loss']) / baseline['avg_loss'] * 100
speed_ratio = r['avg_tok_per_sec'] / baseline['avg_tok_per_sec']
print(f"{r['mode']:12s} | Loss: {loss_diff:+.1f}% | Speed: {speed_ratio:.2f}x")
if __name__ == "__main__":
main()