#!/usr/bin/env python3 """ Inference benchmark - measure actual generation speed MQA/GQA should shine here due to smaller KV cache """ import torch import torch.nn as nn import torch.nn.functional as F import time import math DEV = torch.device("cuda" if torch.cuda.is_available() else "cpu") VOCAB = 128256 def alibi_bias(n_heads, n_tokens): def slopes(n): start = 2 ** (-2 ** -(math.log2(n) - 3)) return [start * (start ** i) for i in range(n)] s = slopes(n_heads) if n_heads > 0 and math.log2(n_heads).is_integer() else slopes(2 ** math.floor(math.log2(max(1, n_heads))))[:n_heads] s = torch.tensor(s, device=DEV).view(1, n_heads, 1, 1) i = torch.arange(n_tokens, device=DEV).view(1, 1, n_tokens, 1) j = torch.arange(n_tokens, device=DEV).view(1, 1, 1, n_tokens) return -s * (j - i).clamp_min(0).float() class StandardAttn(nn.Module): def __init__(self, d, h): super().__init__() self.h, self.dk = h, d // h self.qkv = nn.Linear(d, 3*d, bias=False) self.proj = nn.Linear(d, d, bias=False) def forward(self, x, kv_cache=None): B, N, _ = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] if kv_cache is not None: k_cache, v_cache = kv_cache k = torch.cat([k_cache, k], dim=2) v = torch.cat([v_cache, v], dim=2) new_cache = (k, v) seq_len = k.shape[2] att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk) # Causal mask for last position only mask = torch.zeros(1, 1, N, seq_len, device=x.device) mask[:, :, :, seq_len:] = float('-inf') att = att + mask z = (att.softmax(-1) @ v).transpose(1, 2).reshape(B, N, -1) return self.proj(z), new_cache def cache_size(self, seq_len, batch): # K and V each: (batch, heads, seq, dk) return 2 * batch * self.h * seq_len * self.dk class MQAAttn(nn.Module): def __init__(self, d, h): super().__init__() self.h, self.dk = h, d // h self.q = nn.Linear(d, d, bias=False) self.k = nn.Linear(d, self.dk, bias=False) # 1 head self.v = nn.Linear(d, self.dk, bias=False) # 1 head self.proj = nn.Linear(d, d, bias=False) def forward(self, x, kv_cache=None): B, N, _ = x.shape q = self.q(x).view(B, N, self.h, self.dk).transpose(1, 2) k = self.k(x).view(B, N, 1, self.dk).transpose(1, 2) v = self.v(x).view(B, N, 1, self.dk).transpose(1, 2) if kv_cache is not None: k_cache, v_cache = kv_cache k = torch.cat([k_cache, k], dim=2) v = torch.cat([v_cache, v], dim=2) new_cache = (k, v) seq_len = k.shape[2] att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk) mask = torch.zeros(1, 1, N, seq_len, device=x.device) mask[:, :, :, seq_len:] = float('-inf') att = att + mask z = (att.softmax(-1) @ v).transpose(1, 2).reshape(B, N, -1) return self.proj(z), new_cache def cache_size(self, seq_len, batch): # Only 1 K and 1 V head! return 2 * batch * 1 * seq_len * self.dk class GQAAttn(nn.Module): def __init__(self, d, h, num_kv_heads=2): super().__init__() self.h, self.dk = h, d // h self.num_kv_heads = num_kv_heads self.heads_per_group = h // num_kv_heads self.q = nn.Linear(d, d, bias=False) self.k = nn.Linear(d, num_kv_heads * self.dk, bias=False) self.v = nn.Linear(d, num_kv_heads * self.dk, bias=False) self.proj = nn.Linear(d, d, bias=False) def forward(self, x, kv_cache=None): B, N, _ = x.shape q = self.q(x).view(B, N, self.h, self.dk).transpose(1, 2) k = self.k(x).view(B, N, self.num_kv_heads, self.dk).transpose(1, 2) v = self.v(x).view(B, N, self.num_kv_heads, self.dk).transpose(1, 2) if kv_cache is not None: k_cache, v_cache = kv_cache k = torch.cat([k_cache, k], dim=2) v = torch.cat([v_cache, v], dim=2) new_cache = (k, v) k_exp = k.repeat_interleave(self.heads_per_group, dim=1) v_exp = v.repeat_interleave(self.heads_per_group, dim=1) seq_len = k.shape[2] att = (q @ k_exp.transpose(-1, -2)) / math.sqrt(self.dk) mask = torch.zeros(1, 1, N, seq_len, device=x.device) mask[:, :, :, seq_len:] = float('-inf') att = att + mask z = (att.softmax(-1) @ v_exp).transpose(1, 2).reshape(B, N, -1) return self.proj(z), new_cache def cache_size(self, seq_len, batch): return 2 * batch * self.num_kv_heads * seq_len * self.dk class Block(nn.Module): def __init__(self, d, h, attn_type="standard"): super().__init__() self.ln1, self.ln2 = nn.LayerNorm(d), nn.LayerNorm(d) if attn_type == "standard": self.attn = StandardAttn(d, h) elif attn_type == "mqa": self.attn = MQAAttn(d, h) elif attn_type == "gqa": self.attn = GQAAttn(d, h, num_kv_heads=2) self.ff = nn.Sequential(nn.Linear(d, 4*d), nn.GELU(), nn.Linear(4*d, d)) def forward(self, x, kv_cache=None): attn_out, new_cache = self.attn(self.ln1(x), kv_cache) x = x + attn_out x = x + self.ff(self.ln2(x)) return x, new_cache class Model(nn.Module): def __init__(self, d, layers, h, attn_type="standard"): super().__init__() self.emb = nn.Embedding(VOCAB, d) self.blocks = nn.ModuleList([Block(d, h, attn_type) for _ in range(layers)]) self.ln = nn.LayerNorm(d) self.head = nn.Linear(d, VOCAB, bias=False) self.head.weight = self.emb.weight self.d, self.layers_n = d, layers def forward(self, x, kv_caches=None): x = self.emb(x) new_caches = [] for i, b in enumerate(self.blocks): cache = kv_caches[i] if kv_caches else None x, new_cache = b(x, cache) new_caches.append(new_cache) return self.head(self.ln(x)), new_caches @torch.no_grad() def benchmark_generation(attn_type, d, layers, h, batch, prompt_len, gen_len): model = Model(d, layers, h, attn_type).to(DEV).eval() # Prefill prompt = torch.randint(0, VOCAB, (batch, prompt_len), device=DEV) torch.cuda.synchronize() start = time.time() logits, kv_caches = model(prompt) next_tok = logits[:, -1:].argmax(-1) torch.cuda.synchronize() prefill_time = time.time() - start # Generation torch.cuda.synchronize() start = time.time() for _ in range(gen_len): logits, kv_caches = model(next_tok, kv_caches) next_tok = logits[:, -1:].argmax(-1) torch.cuda.synchronize() gen_time = time.time() - start # Calculate cache size cache_size = sum( b.attn.cache_size(prompt_len + gen_len, batch) for b in model.blocks ) * 4 / (1024**2) # MB (float32) tok_per_sec = gen_len * batch / gen_time return { "type": attn_type, "prefill_ms": prefill_time * 1000, "gen_tok_s": tok_per_sec, "cache_mb": cache_size, "gen_time": gen_time } def main(): print(f"Device: {DEV}") if torch.cuda.is_available(): print(f"GPU: {torch.cuda.get_device_name()}") d, layers, h = 512, 8, 8 configs = [ (1, 128, 128), # Small batch, short (1, 128, 512), # Small batch, long gen (8, 128, 128), # Medium batch (16, 64, 64), # Large batch, short ] for batch, prompt_len, gen_len in configs: print(f"\n{'='*60}") print(f"Batch={batch}, Prompt={prompt_len}, Gen={gen_len}") print(f"{'='*60}") results = [] for attn_type in ["standard", "mqa", "gqa"]: try: r = benchmark_generation(attn_type, d, layers, h, batch, prompt_len, gen_len) results.append(r) print(f"{attn_type:10s} | Prefill {r['prefill_ms']:6.1f}ms | Gen {r['gen_tok_s']:6.0f} tok/s | Cache {r['cache_mb']:5.1f}MB") except Exception as e: print(f"{attn_type:10s} | FAILED: {e}") torch.cuda.empty_cache() if len(results) >= 2: std = next((r for r in results if r['type'] == 'standard'), None) for r in results: if r['type'] != 'standard' and std: speedup = r['gen_tok_s'] / std['gen_tok_s'] cache_ratio = r['cache_mb'] / std['cache_mb'] print(f" → {r['type']} vs standard: {speedup:.2f}x gen speed, {cache_ratio:.2f}x cache") if __name__ == "__main__": main()