| |
| """ |
| n_heavy.py β Iterative Refinement Transformer Experiment |
| Heavier-than-standard-attention: tokens get reprocessed based on uncertainty |
| |
| Key idea: Instead of single-pass attention, run multiple iterations |
| where "hard" tokens (high uncertainty) get recomputed while "easy" tokens halt. |
| |
| This is O(nΒ² Γ k) where k = average iterations, vs standard O(nΒ²). |
| """ |
|
|
| from __future__ import annotations |
| import argparse, json, math, pathlib, random, time, os, sys |
| from contextlib import nullcontext |
| from typing import Dict, Any, List, Optional, Tuple |
| from datetime import datetime, timezone |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| |
| DEV = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| torch.backends.cuda.matmul.allow_tf32 = True |
| try: |
| torch.set_float32_matmul_precision("high") |
| except: |
| pass |
|
|
| VOCAB = 128256 |
| EOS = 128001 |
|
|
| |
| def _alibi_slopes(n_heads: int): |
| def pow2slopes(n): |
| start = 2 ** (-2 ** -(math.log2(n) - 3)) |
| ratio = start |
| return [start * (ratio ** i) for i in range(n)] |
| if math.log2(n_heads).is_integer(): vals = pow2slopes(n_heads) |
| else: |
| closest = 2 ** math.floor(math.log2(n_heads)) |
| vals = pow2slopes(closest) |
| extra = pow2slopes(2 * closest) |
| vals += extra[0::2][: n_heads - closest] |
| return torch.tensor(vals, device=DEV).view(1, n_heads, 1, 1) |
|
|
| def alibi_bias(n_heads: int, n_tokens: int): |
| i = torch.arange(n_tokens, device=DEV).view(1, 1, n_tokens, 1) |
| j = torch.arange(n_tokens, device=DEV).view(1, 1, 1, n_tokens) |
| dist = (j - i).clamp_min(0) |
| return -_alibi_slopes(n_heads) * dist |
|
|
| |
| class StandardAttention(nn.Module): |
| """Baseline: single-pass multi-head attention""" |
| def __init__(self, d: int, h: int): |
| super().__init__() |
| assert d % h == 0 |
| self.h, self.dk = h, d // h |
| self.qkv = nn.Linear(d, 3 * d, bias=False) |
| self.proj = nn.Linear(d, d, bias=False) |
| self.drop = nn.Dropout(0.1) |
|
|
| def forward(self, x, mask=None): |
| B, N, _ = x.shape |
| qkv = self.qkv(x).reshape(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4) |
| q, k, v = qkv[0], qkv[1], qkv[2] |
| |
| att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk) |
| att = att + alibi_bias(self.h, N) |
| if mask is not None: |
| att = att + mask |
| |
| z = (att.softmax(-1) @ v).transpose(1, 2).reshape(B, N, -1) |
| return self.drop(self.proj(z)) |
|
|
|
|
| |
| class IterativeAttention(nn.Module): |
| """ |
| Heavier-than-standard: iteratively refine representations. |
| |
| Each token has a "halting probability" - once it exceeds threshold, |
| that token stops updating. Hard tokens keep getting reprocessed. |
| |
| Inspired by Universal Transformers + PonderNet. |
| """ |
| def __init__(self, d: int, h: int, max_iters: int = 5, halt_threshold: float = 0.9): |
| super().__init__() |
| assert d % h == 0 |
| self.h, self.dk = h, d // h |
| self.max_iters = max_iters |
| self.halt_threshold = halt_threshold |
| |
| |
| self.qkv = nn.Linear(d, 3 * d, bias=False) |
| self.proj = nn.Linear(d, d, bias=False) |
| self.drop = nn.Dropout(0.1) |
| |
| |
| self.halt_pred = nn.Sequential( |
| nn.Linear(d, d // 4), |
| nn.ReLU(), |
| nn.Linear(d // 4, 1), |
| nn.Sigmoid() |
| ) |
| |
| |
| self.iter_emb = nn.Embedding(max_iters, d) |
|
|
| def forward(self, x, mask=None): |
| B, N, D = x.shape |
| |
| |
| halted = torch.zeros(B, N, 1, device=x.device, dtype=torch.bool) |
| cumulative_halt = torch.zeros(B, N, 1, device=x.device) |
| |
| |
| output = torch.zeros_like(x) |
| remainder = torch.ones(B, N, 1, device=x.device) |
| |
| total_compute = 0 |
| |
| for i in range(self.max_iters): |
| |
| x_iter = x + self.iter_emb.weight[i].unsqueeze(0).unsqueeze(0) |
| |
| |
| qkv = self.qkv(x_iter).reshape(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4) |
| q, k, v = qkv[0], qkv[1], qkv[2] |
| |
| att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk) |
| att = att + alibi_bias(self.h, N) |
| if mask is not None: |
| att = att + mask |
| |
| z = (att.softmax(-1) @ v).transpose(1, 2).reshape(B, N, -1) |
| delta = self.drop(self.proj(z)) |
| |
| |
| halt_prob = self.halt_pred(x + delta) |
| |
| |
| new_cumulative = cumulative_halt + halt_prob * (~halted).float() |
| |
| |
| should_halt = (new_cumulative >= self.halt_threshold) & (~halted) |
| |
| |
| contrib_weight = torch.where( |
| should_halt, |
| remainder, |
| torch.where(halted, torch.zeros_like(halt_prob), halt_prob) |
| ) |
| |
| |
| output = output + contrib_weight * (x + delta) |
| |
| |
| remainder = remainder - contrib_weight |
| |
| |
| halted = halted | should_halt |
| cumulative_halt = new_cumulative |
| |
| |
| x = torch.where(halted.expand_as(x), x, x + delta) |
| |
| |
| total_compute += (~halted).float().sum().item() |
| |
| |
| if halted.all(): |
| break |
| |
| |
| output = output + remainder * x |
| |
| |
| self._last_iters = i + 1 |
| self._last_compute_ratio = total_compute / (B * N * self.max_iters) |
| |
| return output |
|
|
|
|
| |
| class TripletAttention(nn.Module): |
| """ |
| O(nΒ³) attention: model 3-way interactions. |
| "How does token A relate to B in context of C?" |
| |
| This is VERY heavy - use small sequences only. |
| """ |
| def __init__(self, d: int, h: int, max_triplet_n: int = 64): |
| super().__init__() |
| self.h, self.dk = h, d // h |
| self.max_triplet_n = max_triplet_n |
| |
| |
| self.qkv = nn.Linear(d, 3 * d, bias=False) |
| |
| |
| self.triplet_score = nn.Sequential( |
| nn.Linear(3 * d // h, d // h), |
| nn.ReLU(), |
| nn.Linear(d // h, 1) |
| ) |
| |
| self.proj = nn.Linear(d, d, bias=False) |
| self.drop = nn.Dropout(0.1) |
|
|
| def forward(self, x, mask=None): |
| B, N, D = x.shape |
| |
| |
| if N > self.max_triplet_n: |
| return self._standard_forward(x, mask) |
| |
| qkv = self.qkv(x).reshape(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4) |
| q, k, v = qkv[0], qkv[1], qkv[2] |
| |
| |
| pairwise = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk) |
| |
| |
| |
| triplet_mod = torch.zeros_like(pairwise) |
| |
| for c in range(N): |
| |
| |
| k_c = k[:, :, c:c+1, :].expand(-1, -1, N, -1) |
| |
| |
| q_exp = q.unsqueeze(3) |
| k_exp = k.unsqueeze(2) |
| k_c_exp = k_c.unsqueeze(3) |
| |
| |
| triplet_input = torch.cat([ |
| q_exp.expand(-1, -1, -1, N, -1), |
| k_exp.expand(-1, -1, N, -1, -1), |
| k_c_exp.expand(-1, -1, -1, N, -1) |
| ], dim=-1) |
| |
| |
| mod = self.triplet_score(triplet_input).squeeze(-1) |
| triplet_mod = triplet_mod + mod |
| |
| |
| triplet_mod = triplet_mod / N |
| att = pairwise + 0.1 * triplet_mod |
| |
| att = att + alibi_bias(self.h, N) |
| if mask is not None: |
| att = att + mask |
| |
| z = (att.softmax(-1) @ v).transpose(1, 2).reshape(B, N, -1) |
| return self.drop(self.proj(z)) |
| |
| def _standard_forward(self, x, mask=None): |
| B, N, _ = x.shape |
| qkv = self.qkv(x).reshape(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4) |
| q, k, v = qkv[0], qkv[1], qkv[2] |
| att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk) |
| att = att + alibi_bias(self.h, N) |
| if mask is not None: |
| att = att + mask |
| z = (att.softmax(-1) @ v).transpose(1, 2).reshape(B, N, -1) |
| return self.drop(self.proj(z)) |
|
|
|
|
| |
| class StandardBlock(nn.Module): |
| def __init__(self, d: int, h: int): |
| super().__init__() |
| self.ln1, self.ln2 = nn.LayerNorm(d), nn.LayerNorm(d) |
| self.attn = StandardAttention(d, h) |
| self.ff = nn.Sequential(nn.Linear(d, 4 * d), nn.ReLU(), nn.Linear(4 * d, d)) |
|
|
| def forward(self, x, mask=None): |
| x = x + self.attn(self.ln1(x), mask) |
| return x + self.ff(self.ln2(x)) |
|
|
|
|
| class IterativeBlock(nn.Module): |
| def __init__(self, d: int, h: int, max_iters: int = 5): |
| super().__init__() |
| self.ln1, self.ln2 = nn.LayerNorm(d), nn.LayerNorm(d) |
| self.attn = IterativeAttention(d, h, max_iters=max_iters) |
| self.ff = nn.Sequential(nn.Linear(d, 4 * d), nn.ReLU(), nn.Linear(4 * d, d)) |
|
|
| def forward(self, x, mask=None): |
| x = x + self.attn(self.ln1(x), mask) |
| return x + self.ff(self.ln2(x)) |
|
|
|
|
| class TripletBlock(nn.Module): |
| def __init__(self, d: int, h: int): |
| super().__init__() |
| self.ln1, self.ln2 = nn.LayerNorm(d), nn.LayerNorm(d) |
| self.attn = TripletAttention(d, h) |
| self.ff = nn.Sequential(nn.Linear(d, 4 * d), nn.ReLU(), nn.Linear(4 * d, d)) |
|
|
| def forward(self, x, mask=None): |
| x = x + self.attn(self.ln1(x), mask) |
| return x + self.ff(self.ln2(x)) |
|
|
|
|
| |
| class HeavyTransformer(nn.Module): |
| def __init__(self, d: int, layers: int, heads: int, mode: str = "standard"): |
| super().__init__() |
| self.emb = nn.Embedding(VOCAB, d) |
| |
| if mode == "standard": |
| self.blocks = nn.ModuleList([StandardBlock(d, heads) for _ in range(layers)]) |
| elif mode == "iterative": |
| self.blocks = nn.ModuleList([IterativeBlock(d, heads) for _ in range(layers)]) |
| elif mode == "triplet": |
| self.blocks = nn.ModuleList([TripletBlock(d, heads) for _ in range(layers)]) |
| else: |
| raise ValueError(f"Unknown mode: {mode}") |
| |
| self.ln = nn.LayerNorm(d) |
| self.head = nn.Linear(d, VOCAB) |
| self.mode = mode |
| |
| |
| self.head.weight = self.emb.weight |
|
|
| def forward(self, ids, mask=None): |
| x = self.emb(ids) |
| for blk in self.blocks: |
| x = blk(x, mask) |
| return self.head(self.ln(x)) |
|
|
| def count_params(self): |
| return sum(p.numel() for p in self.parameters()) |
|
|
|
|
| |
| def causal_mask(n): |
| return torch.triu(torch.full((1, 1, n, n), float("-inf"), device=DEV), 1) |
|
|
|
|
| def run_experiment(mode: str, d: int, layers: int, heads: int, |
| batch_size: int, seq_len: int, num_steps: int): |
| """Run training steps and measure loss + throughput""" |
| print(f"\n{'='*60}") |
| print(f"MODE: {mode.upper()}") |
| print(f"Config: d={d}, layers={layers}, heads={heads}") |
| print(f"{'='*60}") |
| |
| model = HeavyTransformer(d, layers, heads, mode=mode).to(DEV) |
| print(f"Parameters: {model.count_params():,}") |
| |
| optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) |
| |
| losses = [] |
| times = [] |
| |
| for step in range(num_steps): |
| |
| ids = torch.randint(0, VOCAB, (batch_size, seq_len), device=DEV) |
| target = ids[:, 1:] |
| input_ids = ids[:, :-1] |
| mask = causal_mask(seq_len - 1) |
| |
| start = time.time() |
| |
| optimizer.zero_grad() |
| logits = model(input_ids, mask) |
| loss = F.cross_entropy(logits.view(-1, VOCAB), target.reshape(-1)) |
| loss.backward() |
| optimizer.step() |
| |
| elapsed = time.time() - start |
| times.append(elapsed) |
| losses.append(loss.item()) |
| |
| tok_per_sec = (batch_size * seq_len) / elapsed |
| |
| if step % 10 == 0 or step == num_steps - 1: |
| print(f"Step {step:3d} | Loss: {loss.item():.4f} | {tok_per_sec:.0f} tok/s | {elapsed*1000:.0f}ms") |
| |
| |
| if mode == "iterative" and hasattr(model.blocks[0].attn, '_last_iters'): |
| if step % 20 == 0: |
| avg_iters = model.blocks[0].attn._last_iters |
| compute_ratio = model.blocks[0].attn._last_compute_ratio |
| print(f" ββ Avg iters: {avg_iters}, Compute ratio: {compute_ratio:.2%}") |
| |
| avg_loss = sum(losses[-20:]) / min(20, len(losses)) |
| avg_time = sum(times[-20:]) / min(20, len(times)) |
| avg_toks = (batch_size * seq_len) / avg_time |
| |
| return { |
| "mode": mode, |
| "final_loss": losses[-1], |
| "avg_loss": avg_loss, |
| "avg_tok_per_sec": avg_toks, |
| "params": model.count_params() |
| } |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Heavy Attention Experiment") |
| parser.add_argument("--d", type=int, default=256, help="Model dimension") |
| parser.add_argument("--layers", type=int, default=4, help="Number of layers") |
| parser.add_argument("--heads", type=int, default=8, help="Number of heads") |
| parser.add_argument("--batch", type=int, default=8, help="Batch size") |
| parser.add_argument("--seq", type=int, default=128, help="Sequence length") |
| parser.add_argument("--steps", type=int, default=100, help="Training steps") |
| parser.add_argument("--mode", type=str, default="all", |
| choices=["standard", "iterative", "triplet", "all"]) |
| args = parser.parse_args() |
| |
| print(f"Device: {DEV}") |
| print(f"CUDA available: {torch.cuda.is_available()}") |
| if torch.cuda.is_available(): |
| print(f"GPU: {torch.cuda.get_device_name()}") |
| print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB") |
| |
| results = [] |
| |
| modes = ["standard", "iterative", "triplet"] if args.mode == "all" else [args.mode] |
| |
| for mode in modes: |
| try: |
| result = run_experiment( |
| mode=mode, |
| d=args.d, |
| layers=args.layers, |
| heads=args.heads, |
| batch_size=args.batch, |
| seq_len=args.seq, |
| num_steps=args.steps |
| ) |
| results.append(result) |
| except Exception as e: |
| print(f"ERROR in {mode}: {e}") |
| import traceback |
| traceback.print_exc() |
| |
| |
| print(f"\n{'='*60}") |
| print("SUMMARY") |
| print(f"{'='*60}") |
| for r in results: |
| print(f"{r['mode']:12s} | Loss: {r['avg_loss']:.4f} | {r['avg_tok_per_sec']:6.0f} tok/s | {r['params']:,} params") |
| |
| |
| if len(results) >= 2: |
| baseline = next((r for r in results if r['mode'] == 'standard'), results[0]) |
| print(f"\n{'='*60}") |
| print("RELATIVE TO STANDARD:") |
| print(f"{'='*60}") |
| for r in results: |
| if r['mode'] != 'standard': |
| loss_diff = (baseline['avg_loss'] - r['avg_loss']) / baseline['avg_loss'] * 100 |
| speed_ratio = r['avg_tok_per_sec'] / baseline['avg_tok_per_sec'] |
| print(f"{r['mode']:12s} | Loss: {loss_diff:+.1f}% | Speed: {speed_ratio:.2f}x") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|