OpenTransformer's picture
Add experiments/n_flex.py
8a88d0a verified
#!/usr/bin/env python3
"""
n_flex.py β€” Flexible Attention Mechanisms
Constraint: Must support AR (causal), SAT (block), and NAR (bidirectional)
Testing:
1. Linear Attention - O(n) instead of O(nΒ²)
2. Cosine Attention - Different similarity metric
3. Differential Attention - Noise cancellation (Microsoft 2024)
4. Local + Global - Sparse hybrid
5. Multi-Query Attention (MQA) - Inference efficient
6. Grouped Query Attention (GQA) - Between MHA and MQA
7. Retention - RetNet style (recurrent + parallel)
8. Gated Linear Attention - Recent efficient attention
9. ReLU Attention - Simpler activation
10. Sigmoid Attention - Bounded attention
"""
from __future__ import annotations
import argparse, math, time
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Literal
DEV = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cuda.matmul.allow_tf32 = True
VOCAB = 128256
# ═══════════════════════════════════════════════════════════════
# Masking utilities for AR/SAT/NAR
# ═══════════════════════════════════════════════════════════════
def get_mask(n: int, mode: str = "ar", block_size: int = 2):
"""
AR (autoregressive): causal, see only past
SAT (semi-autoregressive): see within block + all past blocks
NAR (non-autoregressive): bidirectional, see everything
"""
if mode == "nar":
return None # No mask
elif mode == "ar":
return torch.triu(torch.full((n, n), float("-inf"), device=DEV), 1)
elif mode == "sat":
# Block-wise: can see within same block and all previous blocks
idx = torch.arange(n, device=DEV)
block_idx = idx // block_size
# Allow if same block OR target block is earlier
mask = torch.where(
(block_idx.unsqueeze(0) <= block_idx.unsqueeze(1)),
torch.tensor(0.0, device=DEV),
torch.tensor(float("-inf"), device=DEV)
)
return mask
else:
raise ValueError(f"Unknown mode: {mode}")
def alibi_bias(n_heads: int, n_tokens: int):
def slopes(n):
start = 2 ** (-2 ** -(math.log2(n) - 3))
return [start * (start ** i) for i in range(n)]
if n_heads > 0 and math.log2(n_heads).is_integer():
s = slopes(n_heads)
else:
closest = 2 ** math.floor(math.log2(max(1, n_heads)))
s = slopes(closest)[:n_heads]
s = torch.tensor(s, device=DEV).view(1, n_heads, 1, 1)
i = torch.arange(n_tokens, device=DEV).view(1, 1, n_tokens, 1)
j = torch.arange(n_tokens, device=DEV).view(1, 1, 1, n_tokens)
return -s * (j - i).clamp_min(0).float()
# ═══════════════════════════════════════════════════════════════
# 1. STANDARD (baseline)
# ═══════════════════════════════════════════════════════════════
class StandardAttention(nn.Module):
"""Standard multi-head attention - O(nΒ²)"""
def __init__(self, d: int, h: int):
super().__init__()
self.h, self.dk = h, d // h
self.qkv = nn.Linear(d, 3 * d, bias=False)
self.proj = nn.Linear(d, d, bias=False)
def forward(self, x, mask=None):
B, N, _ = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
att = att + alibi_bias(self.h, N)
if mask is not None:
att = att + mask.unsqueeze(0).unsqueeze(0)
z = (att.softmax(-1) @ v).transpose(1, 2).reshape(B, N, -1)
return self.proj(z)
# ═══════════════════════════════════════════════════════════════
# 2. LINEAR ATTENTION - O(n) via kernel trick
# ═══════════════════════════════════════════════════════════════
class LinearAttention(nn.Module):
"""
Linear attention: O(n) instead of O(nΒ²)
Uses feature map Ο†(x) so that Ο†(q)Ο†(k)^T β‰ˆ softmax(qk^T)
Key insight: (QK^T)V = Q(K^TV) - compute K^TV first for O(n)
Works with AR/SAT/NAR via cumsum tricks for causal
"""
def __init__(self, d: int, h: int, feature_map: str = "elu"):
super().__init__()
self.h, self.dk = h, d // h
self.qkv = nn.Linear(d, 3 * d, bias=False)
self.proj = nn.Linear(d, d, bias=False)
self.feature_map = feature_map
self.eps = 1e-6
def _phi(self, x):
"""Feature map for linear attention"""
if self.feature_map == "elu":
return F.elu(x) + 1
elif self.feature_map == "relu":
return F.relu(x)
elif self.feature_map == "softmax":
return F.softmax(x, dim=-1)
else: # identity
return x
def forward(self, x, mask=None):
B, N, _ = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # (B, H, N, dk)
# Apply feature map
q = self._phi(q)
k = self._phi(k)
if mask is None:
# NAR: Full bidirectional - O(n) via associativity
# (Q @ K^T) @ V = Q @ (K^T @ V)
kv = torch.einsum('bhnd,bhnv->bhdv', k, v) # (B, H, dk, dv)
out = torch.einsum('bhnd,bhdv->bhnv', q, kv) # (B, H, N, dv)
# Normalize
k_sum = k.sum(dim=2, keepdim=True) # (B, H, 1, dk)
normalizer = torch.einsum('bhnd,bhkd->bhnk', q, k_sum).clamp(min=self.eps)
out = out / normalizer
else:
# AR/SAT: Causal via cumulative sum
# This is still O(n) but needs sequential computation
kv_cumsum = torch.cumsum(torch.einsum('bhnd,bhnv->bhndv', k, v), dim=2)
k_cumsum = torch.cumsum(k, dim=2)
out = torch.einsum('bhnd,bhndv->bhnv', q, kv_cumsum)
normalizer = torch.einsum('bhnd,bhnd->bhn', q, k_cumsum).unsqueeze(-1).clamp(min=self.eps)
out = out / normalizer
return self.proj(out.transpose(1, 2).reshape(B, N, -1))
# ═══════════════════════════════════════════════════════════════
# 3. COSINE ATTENTION - Different similarity metric
# ═══════════════════════════════════════════════════════════════
class CosineAttention(nn.Module):
"""
Use cosine similarity instead of dot product.
More stable, bounded [-1, 1] before scaling.
"""
def __init__(self, d: int, h: int, temp: float = 10.0):
super().__init__()
self.h, self.dk = h, d // h
self.qkv = nn.Linear(d, 3 * d, bias=False)
self.proj = nn.Linear(d, d, bias=False)
self.temp = nn.Parameter(torch.tensor(temp))
def forward(self, x, mask=None):
B, N, _ = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
# Normalize for cosine similarity
q = F.normalize(q, dim=-1)
k = F.normalize(k, dim=-1)
att = self.temp * (q @ k.transpose(-1, -2)) # Cosine sim scaled by temp
if mask is not None:
att = att + mask.unsqueeze(0).unsqueeze(0)
z = (att.softmax(-1) @ v).transpose(1, 2).reshape(B, N, -1)
return self.proj(z)
# ═══════════════════════════════════════════════════════════════
# 4. DIFFERENTIAL ATTENTION - Noise cancellation
# ═══════════════════════════════════════════════════════════════
class DifferentialAttention(nn.Module):
"""
From Microsoft's "Differential Transformer" (2024)
Compute two attention patterns and subtract:
Attn = softmax(Q1 K1^T) - Ξ» * softmax(Q2 K2^T)
Cancels noise, improves signal.
"""
def __init__(self, d: int, h: int):
super().__init__()
self.h, self.dk = h, d // h
# Two sets of Q, K projections
self.q1 = nn.Linear(d, d, bias=False)
self.k1 = nn.Linear(d, d, bias=False)
self.q2 = nn.Linear(d, d, bias=False)
self.k2 = nn.Linear(d, d, bias=False)
self.v = nn.Linear(d, d, bias=False)
# Learnable lambda for subtraction weight
self.lambda_param = nn.Parameter(torch.tensor(0.5))
self.proj = nn.Linear(d, d, bias=False)
def forward(self, x, mask=None):
B, N, _ = x.shape
q1 = self.q1(x).view(B, N, self.h, self.dk).transpose(1, 2)
k1 = self.k1(x).view(B, N, self.h, self.dk).transpose(1, 2)
q2 = self.q2(x).view(B, N, self.h, self.dk).transpose(1, 2)
k2 = self.k2(x).view(B, N, self.h, self.dk).transpose(1, 2)
v = self.v(x).view(B, N, self.h, self.dk).transpose(1, 2)
scale = math.sqrt(self.dk)
# First attention
att1 = (q1 @ k1.transpose(-1, -2)) / scale
if mask is not None:
att1 = att1 + mask.unsqueeze(0).unsqueeze(0)
att1 = att1.softmax(-1)
# Second attention
att2 = (q2 @ k2.transpose(-1, -2)) / scale
if mask is not None:
att2 = att2 + mask.unsqueeze(0).unsqueeze(0)
att2 = att2.softmax(-1)
# Differential: subtract weighted second from first
lam = torch.sigmoid(self.lambda_param)
att = att1 - lam * att2
# ReLU to ensure non-negative (optional, can remove)
att = F.relu(att)
att = att / (att.sum(dim=-1, keepdim=True) + 1e-6)
z = (att @ v).transpose(1, 2).reshape(B, N, -1)
return self.proj(z)
# ═══════════════════════════════════════════════════════════════
# 5. MULTI-QUERY ATTENTION (MQA) - Inference efficient
# ═══════════════════════════════════════════════════════════════
class MultiQueryAttention(nn.Module):
"""
MQA: Multiple query heads, single K/V head.
Massive inference speedup (smaller KV cache).
Same training cost as standard.
"""
def __init__(self, d: int, h: int):
super().__init__()
self.h, self.dk = h, d // h
# H query heads, but only 1 K and 1 V head
self.q = nn.Linear(d, d, bias=False) # H heads
self.k = nn.Linear(d, self.dk, bias=False) # 1 head
self.v = nn.Linear(d, self.dk, bias=False) # 1 head
self.proj = nn.Linear(d, d, bias=False)
def forward(self, x, mask=None):
B, N, _ = x.shape
q = self.q(x).view(B, N, self.h, self.dk).transpose(1, 2) # (B, H, N, dk)
k = self.k(x).view(B, N, 1, self.dk).transpose(1, 2) # (B, 1, N, dk)
v = self.v(x).view(B, N, 1, self.dk).transpose(1, 2) # (B, 1, N, dk)
# K, V broadcast across heads
att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
att = att + alibi_bias(self.h, N)
if mask is not None:
att = att + mask.unsqueeze(0).unsqueeze(0)
z = (att.softmax(-1) @ v).transpose(1, 2).reshape(B, N, -1)
return self.proj(z)
# ═══════════════════════════════════════════════════════════════
# 6. GROUPED QUERY ATTENTION (GQA) - Between MHA and MQA
# ═══════════════════════════════════════════════════════════════
class GroupedQueryAttention(nn.Module):
"""
GQA: Groups of query heads share K/V heads.
Llama 2 uses this. Balance between quality and inference speed.
"""
def __init__(self, d: int, h: int, num_kv_heads: int = 2):
super().__init__()
self.h = h
self.num_kv_heads = num_kv_heads
self.dk = d // h
self.heads_per_group = h // num_kv_heads
self.q = nn.Linear(d, d, bias=False)
self.k = nn.Linear(d, num_kv_heads * self.dk, bias=False)
self.v = nn.Linear(d, num_kv_heads * self.dk, bias=False)
self.proj = nn.Linear(d, d, bias=False)
def forward(self, x, mask=None):
B, N, _ = x.shape
q = self.q(x).view(B, N, self.h, self.dk).transpose(1, 2)
k = self.k(x).view(B, N, self.num_kv_heads, self.dk).transpose(1, 2)
v = self.v(x).view(B, N, self.num_kv_heads, self.dk).transpose(1, 2)
# Repeat K, V for each group
k = k.repeat_interleave(self.heads_per_group, dim=1)
v = v.repeat_interleave(self.heads_per_group, dim=1)
att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
att = att + alibi_bias(self.h, N)
if mask is not None:
att = att + mask.unsqueeze(0).unsqueeze(0)
z = (att.softmax(-1) @ v).transpose(1, 2).reshape(B, N, -1)
return self.proj(z)
# ═══════════════════════════════════════════════════════════════
# 7. RETENTION - RetNet style
# ═══════════════════════════════════════════════════════════════
class RetentionAttention(nn.Module):
"""
From RetNet: Retentive Network
Parallel mode (training): Like linear attention
Recurrent mode (inference): O(1) per step
Key: exponential decay instead of softmax
"""
def __init__(self, d: int, h: int, gamma: float = 0.9):
super().__init__()
self.h, self.dk = h, d // h
self.qkv = nn.Linear(d, 3 * d, bias=False)
self.proj = nn.Linear(d, d, bias=False)
# Per-head decay rates
self.gamma = nn.Parameter(torch.ones(h) * gamma)
def forward(self, x, mask=None):
B, N, _ = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
# Build decay matrix D[i,j] = gamma^(i-j) for i >= j
gamma = torch.sigmoid(self.gamma).view(1, self.h, 1, 1)
positions = torch.arange(N, device=x.device).float()
decay = gamma ** (positions.unsqueeze(0) - positions.unsqueeze(1)).clamp(min=0)
# Apply causal mask via decay (future positions get 0)
causal = torch.tril(torch.ones(N, N, device=x.device))
decay = decay * causal.unsqueeze(0).unsqueeze(0)
# If SAT/NAR mask provided, incorporate it
if mask is not None:
mask_binary = (mask == 0).float().unsqueeze(0).unsqueeze(0)
decay = decay * mask_binary
# Retention = (Q @ K^T) * D @ V
att = (q @ k.transpose(-1, -2)) * decay
# Normalize per row
att = att / (att.sum(dim=-1, keepdim=True) + 1e-6)
z = (att @ v).transpose(1, 2).reshape(B, N, -1)
return self.proj(z)
# ═══════════════════════════════════════════════════════════════
# 8. GATED LINEAR ATTENTION
# ═══════════════════════════════════════════════════════════════
class GatedLinearAttention(nn.Module):
"""
Linear attention with gating for better gradient flow.
From "Gated Linear Attention Transformers" (2024)
"""
def __init__(self, d: int, h: int):
super().__init__()
self.h, self.dk = h, d // h
self.qkv = nn.Linear(d, 3 * d, bias=False)
self.gate = nn.Linear(d, d)
self.proj = nn.Linear(d, d, bias=False)
self.eps = 1e-6
def forward(self, x, mask=None):
B, N, _ = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
# Feature map (ELU + 1 for positivity)
q = F.elu(q) + 1
k = F.elu(k) + 1
if mask is None:
# Bidirectional
kv = torch.einsum('bhnd,bhnv->bhdv', k, v)
out = torch.einsum('bhnd,bhdv->bhnv', q, kv)
normalizer = torch.einsum('bhnd,bhd->bhn', q, k.sum(dim=2)).unsqueeze(-1).clamp(min=self.eps)
else:
# Causal
kv_cumsum = torch.cumsum(torch.einsum('bhnd,bhnv->bhndv', k, v), dim=2)
k_cumsum = torch.cumsum(k, dim=2)
out = torch.einsum('bhnd,bhndv->bhnv', q, kv_cumsum)
normalizer = torch.einsum('bhnd,bhnd->bhn', q, k_cumsum).unsqueeze(-1).clamp(min=self.eps)
out = out / normalizer
out = out.transpose(1, 2).reshape(B, N, -1)
# Gating
gate = torch.sigmoid(self.gate(x))
out = out * gate
return self.proj(out)
# ═══════════════════════════════════════════════════════════════
# 9. RELU ATTENTION - Simpler activation
# ═══════════════════════════════════════════════════════════════
class ReLUAttention(nn.Module):
"""
Replace softmax with ReLU + normalization.
Simpler, faster, sometimes works as well.
From "ReLU Attention" papers.
"""
def __init__(self, d: int, h: int):
super().__init__()
self.h, self.dk = h, d // h
self.qkv = nn.Linear(d, 3 * d, bias=False)
self.proj = nn.Linear(d, d, bias=False)
def forward(self, x, mask=None):
B, N, _ = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
att = att + alibi_bias(self.h, N)
if mask is not None:
att = att + mask.unsqueeze(0).unsqueeze(0)
# ReLU instead of softmax
att = F.relu(att)
att = att / (att.sum(dim=-1, keepdim=True) + 1e-6)
z = (att @ v).transpose(1, 2).reshape(B, N, -1)
return self.proj(z)
# ═══════════════════════════════════════════════════════════════
# 10. SIGMOID ATTENTION - Bounded
# ═══════════════════════════════════════════════════════════════
class SigmoidAttention(nn.Module):
"""
Sigmoid attention: each position independently decides attention weight.
Not normalized to sum to 1 - allows variable "total attention".
"""
def __init__(self, d: int, h: int):
super().__init__()
self.h, self.dk = h, d // h
self.qkv = nn.Linear(d, 3 * d, bias=False)
self.proj = nn.Linear(d, d, bias=False)
self.bias = nn.Parameter(torch.zeros(h, 1, 1))
def forward(self, x, mask=None):
B, N, _ = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk) + self.bias
if mask is not None:
att = att + mask.unsqueeze(0).unsqueeze(0)
# Sigmoid instead of softmax - each weight independent
att = torch.sigmoid(att)
# Optional: mask out future for AR
if mask is not None:
att = att * (mask == 0).float().unsqueeze(0).unsqueeze(0)
z = (att @ v).transpose(1, 2).reshape(B, N, -1)
return self.proj(z)
# ═══════════════════════════════════════════════════════════════
# Block and Model
# ═══════════════════════════════════════════════════════════════
ATTN_REGISTRY = {
"standard": StandardAttention,
"linear": LinearAttention,
"cosine": CosineAttention,
"differential": DifferentialAttention,
"mqa": MultiQueryAttention,
"gqa": GroupedQueryAttention,
"retention": RetentionAttention,
"gated_linear": GatedLinearAttention,
"relu": ReLUAttention,
"sigmoid": SigmoidAttention,
}
class Block(nn.Module):
def __init__(self, d: int, h: int, attn_type: str = "standard"):
super().__init__()
self.ln1, self.ln2 = nn.LayerNorm(d), nn.LayerNorm(d)
self.attn = ATTN_REGISTRY[attn_type](d, h)
self.ff = nn.Sequential(nn.Linear(d, 4*d), nn.GELU(), nn.Linear(4*d, d))
def forward(self, x, mask=None):
x = x + self.attn(self.ln1(x), mask)
return x + self.ff(self.ln2(x))
class FlexModel(nn.Module):
def __init__(self, d: int, layers: int, h: int, attn_type: str = "standard"):
super().__init__()
self.emb = nn.Embedding(VOCAB, d)
self.blocks = nn.ModuleList([Block(d, h, attn_type) for _ in range(layers)])
self.ln = nn.LayerNorm(d)
self.head = nn.Linear(d, VOCAB, bias=False)
self.head.weight = self.emb.weight
def forward(self, x, mask=None):
x = self.emb(x)
for b in self.blocks:
x = b(x, mask)
return self.head(self.ln(x))
def count_params(self):
return sum(p.numel() for p in self.parameters())
# ═══════════════════════════════════════════════════════════════
# Training with AR/SAT/NAR modes
# ═══════════════════════════════════════════════════════════════
def train(attn_type: str, mode: str, d: int, layers: int, h: int,
batch: int, seq: int, steps: int, block_size: int = 4):
print(f"\n{'='*60}")
print(f"ATTENTION: {attn_type.upper()} | MODE: {mode.upper()}")
print(f"{'='*60}")
model = FlexModel(d, layers, h, attn_type).to(DEV)
print(f"Parameters: {model.count_params():,}")
opt = torch.optim.AdamW(model.parameters(), lr=1e-4)
losses, times = [], []
for step in range(steps):
ids = torch.randint(0, VOCAB, (batch, seq), device=DEV)
if mode == "ar":
# Standard AR: predict next token
target = ids[:, 1:]
input_ids = ids[:, :-1]
mask = get_mask(seq - 1, "ar")
elif mode == "sat":
# SAT: predict within blocks
target = ids[:, 1:]
input_ids = ids[:, :-1]
mask = get_mask(seq - 1, "sat", block_size)
else: # nar
# NAR: predict all from [MASK] or noisy input
target = ids
# Add noise to input for NAR (simple version)
noise_mask = torch.rand(batch, seq, device=DEV) < 0.15
input_ids = ids.clone()
input_ids[noise_mask] = torch.randint(0, VOCAB, (noise_mask.sum().item(),), device=DEV)
mask = get_mask(seq, "nar")
start = time.time()
opt.zero_grad()
try:
logits = model(input_ids, mask)
loss = F.cross_entropy(logits.view(-1, VOCAB), target.reshape(-1))
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
opt.step()
except Exception as e:
print(f"Step {step} failed: {e}")
return None
elapsed = time.time() - start
losses.append(loss.item())
times.append(elapsed)
if step % 20 == 0 or step == steps - 1:
tok_s = batch * seq / elapsed
print(f"Step {step:3d} | Loss {loss.item():.4f} | {tok_s:.0f} tok/s")
avg_loss = sum(losses[-20:]) / min(20, len(losses))
avg_toks = batch * seq / (sum(times[-20:]) / min(20, len(times)))
return {"attn": attn_type, "mode": mode, "loss": avg_loss, "tok_s": avg_toks}
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--d", type=int, default=256)
parser.add_argument("--layers", type=int, default=4)
parser.add_argument("--heads", type=int, default=8)
parser.add_argument("--batch", type=int, default=16)
parser.add_argument("--seq", type=int, default=128)
parser.add_argument("--steps", type=int, default=100)
parser.add_argument("--mode", type=str, default="ar", choices=["ar", "sat", "nar", "all"])
parser.add_argument("--types", type=str, default="all")
args = parser.parse_args()
print(f"Device: {DEV}")
if torch.cuda.is_available():
print(f"GPU: {torch.cuda.get_device_name()}")
if args.types == "all":
types = list(ATTN_REGISTRY.keys())
else:
types = [t.strip() for t in args.types.split(",")]
modes = ["ar", "sat", "nar"] if args.mode == "all" else [args.mode]
results = []
for mode in modes:
for attn_type in types:
r = train(attn_type, mode, args.d, args.layers, args.heads,
args.batch, args.seq, args.steps)
if r:
results.append(r)
torch.cuda.empty_cache()
# Summary
print(f"\n{'='*60}")
print("SUMMARY")
print(f"{'='*60}")
for mode in modes:
print(f"\n--- MODE: {mode.upper()} ---")
mode_results = [r for r in results if r['mode'] == mode]
baseline = next((r for r in mode_results if r['attn'] == 'standard'), None)
for r in sorted(mode_results, key=lambda x: x['loss']):
rel = ""
if baseline and r['attn'] != 'standard':
loss_diff = (baseline['loss'] - r['loss']) / baseline['loss'] * 100
speed_ratio = r['tok_s'] / baseline['tok_s']
rel = f" | vs std: {loss_diff:+.1f}%, {speed_ratio:.2f}x"
print(f"{r['attn']:15s} | Loss {r['loss']:.4f} | {r['tok_s']:6.0f} tok/s{rel}")
if __name__ == "__main__":
main()