| | |
| | """ |
| | Inference benchmark - measure actual generation speed |
| | MQA/GQA should shine here due to smaller KV cache |
| | """ |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import time |
| | import math |
| |
|
| | DEV = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | VOCAB = 128256 |
| |
|
| | def alibi_bias(n_heads, n_tokens): |
| | def slopes(n): |
| | start = 2 ** (-2 ** -(math.log2(n) - 3)) |
| | return [start * (start ** i) for i in range(n)] |
| | s = slopes(n_heads) if n_heads > 0 and math.log2(n_heads).is_integer() else slopes(2 ** math.floor(math.log2(max(1, n_heads))))[:n_heads] |
| | s = torch.tensor(s, device=DEV).view(1, n_heads, 1, 1) |
| | i = torch.arange(n_tokens, device=DEV).view(1, 1, n_tokens, 1) |
| | j = torch.arange(n_tokens, device=DEV).view(1, 1, 1, n_tokens) |
| | return -s * (j - i).clamp_min(0).float() |
| |
|
| |
|
| | class StandardAttn(nn.Module): |
| | def __init__(self, d, h): |
| | super().__init__() |
| | self.h, self.dk = h, d // h |
| | self.qkv = nn.Linear(d, 3*d, bias=False) |
| | self.proj = nn.Linear(d, d, bias=False) |
| | |
| | def forward(self, x, kv_cache=None): |
| | B, N, _ = x.shape |
| | qkv = self.qkv(x).reshape(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4) |
| | q, k, v = qkv[0], qkv[1], qkv[2] |
| | |
| | if kv_cache is not None: |
| | k_cache, v_cache = kv_cache |
| | k = torch.cat([k_cache, k], dim=2) |
| | v = torch.cat([v_cache, v], dim=2) |
| | |
| | new_cache = (k, v) |
| | seq_len = k.shape[2] |
| | |
| | att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk) |
| | |
| | mask = torch.zeros(1, 1, N, seq_len, device=x.device) |
| | mask[:, :, :, seq_len:] = float('-inf') |
| | att = att + mask |
| | |
| | z = (att.softmax(-1) @ v).transpose(1, 2).reshape(B, N, -1) |
| | return self.proj(z), new_cache |
| | |
| | def cache_size(self, seq_len, batch): |
| | |
| | return 2 * batch * self.h * seq_len * self.dk |
| |
|
| |
|
| | class MQAAttn(nn.Module): |
| | def __init__(self, d, h): |
| | super().__init__() |
| | self.h, self.dk = h, d // h |
| | self.q = nn.Linear(d, d, bias=False) |
| | self.k = nn.Linear(d, self.dk, bias=False) |
| | self.v = nn.Linear(d, self.dk, bias=False) |
| | self.proj = nn.Linear(d, d, bias=False) |
| | |
| | def forward(self, x, kv_cache=None): |
| | B, N, _ = x.shape |
| | q = self.q(x).view(B, N, self.h, self.dk).transpose(1, 2) |
| | k = self.k(x).view(B, N, 1, self.dk).transpose(1, 2) |
| | v = self.v(x).view(B, N, 1, self.dk).transpose(1, 2) |
| | |
| | if kv_cache is not None: |
| | k_cache, v_cache = kv_cache |
| | k = torch.cat([k_cache, k], dim=2) |
| | v = torch.cat([v_cache, v], dim=2) |
| | |
| | new_cache = (k, v) |
| | seq_len = k.shape[2] |
| | |
| | att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk) |
| | mask = torch.zeros(1, 1, N, seq_len, device=x.device) |
| | mask[:, :, :, seq_len:] = float('-inf') |
| | att = att + mask |
| | |
| | z = (att.softmax(-1) @ v).transpose(1, 2).reshape(B, N, -1) |
| | return self.proj(z), new_cache |
| | |
| | def cache_size(self, seq_len, batch): |
| | |
| | return 2 * batch * 1 * seq_len * self.dk |
| |
|
| |
|
| | class GQAAttn(nn.Module): |
| | def __init__(self, d, h, num_kv_heads=2): |
| | super().__init__() |
| | self.h, self.dk = h, d // h |
| | self.num_kv_heads = num_kv_heads |
| | self.heads_per_group = h // num_kv_heads |
| | self.q = nn.Linear(d, d, bias=False) |
| | self.k = nn.Linear(d, num_kv_heads * self.dk, bias=False) |
| | self.v = nn.Linear(d, num_kv_heads * self.dk, bias=False) |
| | self.proj = nn.Linear(d, d, bias=False) |
| | |
| | def forward(self, x, kv_cache=None): |
| | B, N, _ = x.shape |
| | q = self.q(x).view(B, N, self.h, self.dk).transpose(1, 2) |
| | k = self.k(x).view(B, N, self.num_kv_heads, self.dk).transpose(1, 2) |
| | v = self.v(x).view(B, N, self.num_kv_heads, self.dk).transpose(1, 2) |
| | |
| | if kv_cache is not None: |
| | k_cache, v_cache = kv_cache |
| | k = torch.cat([k_cache, k], dim=2) |
| | v = torch.cat([v_cache, v], dim=2) |
| | |
| | new_cache = (k, v) |
| | |
| | k_exp = k.repeat_interleave(self.heads_per_group, dim=1) |
| | v_exp = v.repeat_interleave(self.heads_per_group, dim=1) |
| | |
| | seq_len = k.shape[2] |
| | att = (q @ k_exp.transpose(-1, -2)) / math.sqrt(self.dk) |
| | mask = torch.zeros(1, 1, N, seq_len, device=x.device) |
| | mask[:, :, :, seq_len:] = float('-inf') |
| | att = att + mask |
| | |
| | z = (att.softmax(-1) @ v_exp).transpose(1, 2).reshape(B, N, -1) |
| | return self.proj(z), new_cache |
| | |
| | def cache_size(self, seq_len, batch): |
| | return 2 * batch * self.num_kv_heads * seq_len * self.dk |
| |
|
| |
|
| | class Block(nn.Module): |
| | def __init__(self, d, h, attn_type="standard"): |
| | super().__init__() |
| | self.ln1, self.ln2 = nn.LayerNorm(d), nn.LayerNorm(d) |
| | if attn_type == "standard": |
| | self.attn = StandardAttn(d, h) |
| | elif attn_type == "mqa": |
| | self.attn = MQAAttn(d, h) |
| | elif attn_type == "gqa": |
| | self.attn = GQAAttn(d, h, num_kv_heads=2) |
| | self.ff = nn.Sequential(nn.Linear(d, 4*d), nn.GELU(), nn.Linear(4*d, d)) |
| | |
| | def forward(self, x, kv_cache=None): |
| | attn_out, new_cache = self.attn(self.ln1(x), kv_cache) |
| | x = x + attn_out |
| | x = x + self.ff(self.ln2(x)) |
| | return x, new_cache |
| |
|
| |
|
| | class Model(nn.Module): |
| | def __init__(self, d, layers, h, attn_type="standard"): |
| | super().__init__() |
| | self.emb = nn.Embedding(VOCAB, d) |
| | self.blocks = nn.ModuleList([Block(d, h, attn_type) for _ in range(layers)]) |
| | self.ln = nn.LayerNorm(d) |
| | self.head = nn.Linear(d, VOCAB, bias=False) |
| | self.head.weight = self.emb.weight |
| | self.d, self.layers_n = d, layers |
| | |
| | def forward(self, x, kv_caches=None): |
| | x = self.emb(x) |
| | new_caches = [] |
| | for i, b in enumerate(self.blocks): |
| | cache = kv_caches[i] if kv_caches else None |
| | x, new_cache = b(x, cache) |
| | new_caches.append(new_cache) |
| | return self.head(self.ln(x)), new_caches |
| |
|
| |
|
| | @torch.no_grad() |
| | def benchmark_generation(attn_type, d, layers, h, batch, prompt_len, gen_len): |
| | model = Model(d, layers, h, attn_type).to(DEV).eval() |
| | |
| | |
| | prompt = torch.randint(0, VOCAB, (batch, prompt_len), device=DEV) |
| | |
| | torch.cuda.synchronize() |
| | start = time.time() |
| | |
| | logits, kv_caches = model(prompt) |
| | next_tok = logits[:, -1:].argmax(-1) |
| | |
| | torch.cuda.synchronize() |
| | prefill_time = time.time() - start |
| | |
| | |
| | torch.cuda.synchronize() |
| | start = time.time() |
| | |
| | for _ in range(gen_len): |
| | logits, kv_caches = model(next_tok, kv_caches) |
| | next_tok = logits[:, -1:].argmax(-1) |
| | |
| | torch.cuda.synchronize() |
| | gen_time = time.time() - start |
| | |
| | |
| | cache_size = sum( |
| | b.attn.cache_size(prompt_len + gen_len, batch) |
| | for b in model.blocks |
| | ) * 4 / (1024**2) |
| | |
| | tok_per_sec = gen_len * batch / gen_time |
| | |
| | return { |
| | "type": attn_type, |
| | "prefill_ms": prefill_time * 1000, |
| | "gen_tok_s": tok_per_sec, |
| | "cache_mb": cache_size, |
| | "gen_time": gen_time |
| | } |
| |
|
| |
|
| | def main(): |
| | print(f"Device: {DEV}") |
| | if torch.cuda.is_available(): |
| | print(f"GPU: {torch.cuda.get_device_name()}") |
| | |
| | d, layers, h = 512, 8, 8 |
| | |
| | configs = [ |
| | (1, 128, 128), |
| | (1, 128, 512), |
| | (8, 128, 128), |
| | (16, 64, 64), |
| | ] |
| | |
| | for batch, prompt_len, gen_len in configs: |
| | print(f"\n{'='*60}") |
| | print(f"Batch={batch}, Prompt={prompt_len}, Gen={gen_len}") |
| | print(f"{'='*60}") |
| | |
| | results = [] |
| | for attn_type in ["standard", "mqa", "gqa"]: |
| | try: |
| | r = benchmark_generation(attn_type, d, layers, h, batch, prompt_len, gen_len) |
| | results.append(r) |
| | print(f"{attn_type:10s} | Prefill {r['prefill_ms']:6.1f}ms | Gen {r['gen_tok_s']:6.0f} tok/s | Cache {r['cache_mb']:5.1f}MB") |
| | except Exception as e: |
| | print(f"{attn_type:10s} | FAILED: {e}") |
| | torch.cuda.empty_cache() |
| | |
| | if len(results) >= 2: |
| | std = next((r for r in results if r['type'] == 'standard'), None) |
| | for r in results: |
| | if r['type'] != 'standard' and std: |
| | speedup = r['gen_tok_s'] / std['gen_tok_s'] |
| | cache_ratio = r['cache_mb'] / std['cache_mb'] |
| | print(f" → {r['type']} vs standard: {speedup:.2f}x gen speed, {cache_ratio:.2f}x cache") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|