AGILLM-3-large / experiments /infer_bench.py
OpenTransformer's picture
Add experiments/infer_bench.py
ec068a9 verified
#!/usr/bin/env python3
"""
Inference benchmark - measure actual generation speed
MQA/GQA should shine here due to smaller KV cache
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import math
DEV = torch.device("cuda" if torch.cuda.is_available() else "cpu")
VOCAB = 128256
def alibi_bias(n_heads, n_tokens):
def slopes(n):
start = 2 ** (-2 ** -(math.log2(n) - 3))
return [start * (start ** i) for i in range(n)]
s = slopes(n_heads) if n_heads > 0 and math.log2(n_heads).is_integer() else slopes(2 ** math.floor(math.log2(max(1, n_heads))))[:n_heads]
s = torch.tensor(s, device=DEV).view(1, n_heads, 1, 1)
i = torch.arange(n_tokens, device=DEV).view(1, 1, n_tokens, 1)
j = torch.arange(n_tokens, device=DEV).view(1, 1, 1, n_tokens)
return -s * (j - i).clamp_min(0).float()
class StandardAttn(nn.Module):
def __init__(self, d, h):
super().__init__()
self.h, self.dk = h, d // h
self.qkv = nn.Linear(d, 3*d, bias=False)
self.proj = nn.Linear(d, d, bias=False)
def forward(self, x, kv_cache=None):
B, N, _ = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
if kv_cache is not None:
k_cache, v_cache = kv_cache
k = torch.cat([k_cache, k], dim=2)
v = torch.cat([v_cache, v], dim=2)
new_cache = (k, v)
seq_len = k.shape[2]
att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
# Causal mask for last position only
mask = torch.zeros(1, 1, N, seq_len, device=x.device)
mask[:, :, :, seq_len:] = float('-inf')
att = att + mask
z = (att.softmax(-1) @ v).transpose(1, 2).reshape(B, N, -1)
return self.proj(z), new_cache
def cache_size(self, seq_len, batch):
# K and V each: (batch, heads, seq, dk)
return 2 * batch * self.h * seq_len * self.dk
class MQAAttn(nn.Module):
def __init__(self, d, h):
super().__init__()
self.h, self.dk = h, d // h
self.q = nn.Linear(d, d, bias=False)
self.k = nn.Linear(d, self.dk, bias=False) # 1 head
self.v = nn.Linear(d, self.dk, bias=False) # 1 head
self.proj = nn.Linear(d, d, bias=False)
def forward(self, x, kv_cache=None):
B, N, _ = x.shape
q = self.q(x).view(B, N, self.h, self.dk).transpose(1, 2)
k = self.k(x).view(B, N, 1, self.dk).transpose(1, 2)
v = self.v(x).view(B, N, 1, self.dk).transpose(1, 2)
if kv_cache is not None:
k_cache, v_cache = kv_cache
k = torch.cat([k_cache, k], dim=2)
v = torch.cat([v_cache, v], dim=2)
new_cache = (k, v)
seq_len = k.shape[2]
att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
mask = torch.zeros(1, 1, N, seq_len, device=x.device)
mask[:, :, :, seq_len:] = float('-inf')
att = att + mask
z = (att.softmax(-1) @ v).transpose(1, 2).reshape(B, N, -1)
return self.proj(z), new_cache
def cache_size(self, seq_len, batch):
# Only 1 K and 1 V head!
return 2 * batch * 1 * seq_len * self.dk
class GQAAttn(nn.Module):
def __init__(self, d, h, num_kv_heads=2):
super().__init__()
self.h, self.dk = h, d // h
self.num_kv_heads = num_kv_heads
self.heads_per_group = h // num_kv_heads
self.q = nn.Linear(d, d, bias=False)
self.k = nn.Linear(d, num_kv_heads * self.dk, bias=False)
self.v = nn.Linear(d, num_kv_heads * self.dk, bias=False)
self.proj = nn.Linear(d, d, bias=False)
def forward(self, x, kv_cache=None):
B, N, _ = x.shape
q = self.q(x).view(B, N, self.h, self.dk).transpose(1, 2)
k = self.k(x).view(B, N, self.num_kv_heads, self.dk).transpose(1, 2)
v = self.v(x).view(B, N, self.num_kv_heads, self.dk).transpose(1, 2)
if kv_cache is not None:
k_cache, v_cache = kv_cache
k = torch.cat([k_cache, k], dim=2)
v = torch.cat([v_cache, v], dim=2)
new_cache = (k, v)
k_exp = k.repeat_interleave(self.heads_per_group, dim=1)
v_exp = v.repeat_interleave(self.heads_per_group, dim=1)
seq_len = k.shape[2]
att = (q @ k_exp.transpose(-1, -2)) / math.sqrt(self.dk)
mask = torch.zeros(1, 1, N, seq_len, device=x.device)
mask[:, :, :, seq_len:] = float('-inf')
att = att + mask
z = (att.softmax(-1) @ v_exp).transpose(1, 2).reshape(B, N, -1)
return self.proj(z), new_cache
def cache_size(self, seq_len, batch):
return 2 * batch * self.num_kv_heads * seq_len * self.dk
class Block(nn.Module):
def __init__(self, d, h, attn_type="standard"):
super().__init__()
self.ln1, self.ln2 = nn.LayerNorm(d), nn.LayerNorm(d)
if attn_type == "standard":
self.attn = StandardAttn(d, h)
elif attn_type == "mqa":
self.attn = MQAAttn(d, h)
elif attn_type == "gqa":
self.attn = GQAAttn(d, h, num_kv_heads=2)
self.ff = nn.Sequential(nn.Linear(d, 4*d), nn.GELU(), nn.Linear(4*d, d))
def forward(self, x, kv_cache=None):
attn_out, new_cache = self.attn(self.ln1(x), kv_cache)
x = x + attn_out
x = x + self.ff(self.ln2(x))
return x, new_cache
class Model(nn.Module):
def __init__(self, d, layers, h, attn_type="standard"):
super().__init__()
self.emb = nn.Embedding(VOCAB, d)
self.blocks = nn.ModuleList([Block(d, h, attn_type) for _ in range(layers)])
self.ln = nn.LayerNorm(d)
self.head = nn.Linear(d, VOCAB, bias=False)
self.head.weight = self.emb.weight
self.d, self.layers_n = d, layers
def forward(self, x, kv_caches=None):
x = self.emb(x)
new_caches = []
for i, b in enumerate(self.blocks):
cache = kv_caches[i] if kv_caches else None
x, new_cache = b(x, cache)
new_caches.append(new_cache)
return self.head(self.ln(x)), new_caches
@torch.no_grad()
def benchmark_generation(attn_type, d, layers, h, batch, prompt_len, gen_len):
model = Model(d, layers, h, attn_type).to(DEV).eval()
# Prefill
prompt = torch.randint(0, VOCAB, (batch, prompt_len), device=DEV)
torch.cuda.synchronize()
start = time.time()
logits, kv_caches = model(prompt)
next_tok = logits[:, -1:].argmax(-1)
torch.cuda.synchronize()
prefill_time = time.time() - start
# Generation
torch.cuda.synchronize()
start = time.time()
for _ in range(gen_len):
logits, kv_caches = model(next_tok, kv_caches)
next_tok = logits[:, -1:].argmax(-1)
torch.cuda.synchronize()
gen_time = time.time() - start
# Calculate cache size
cache_size = sum(
b.attn.cache_size(prompt_len + gen_len, batch)
for b in model.blocks
) * 4 / (1024**2) # MB (float32)
tok_per_sec = gen_len * batch / gen_time
return {
"type": attn_type,
"prefill_ms": prefill_time * 1000,
"gen_tok_s": tok_per_sec,
"cache_mb": cache_size,
"gen_time": gen_time
}
def main():
print(f"Device: {DEV}")
if torch.cuda.is_available():
print(f"GPU: {torch.cuda.get_device_name()}")
d, layers, h = 512, 8, 8
configs = [
(1, 128, 128), # Small batch, short
(1, 128, 512), # Small batch, long gen
(8, 128, 128), # Medium batch
(16, 64, 64), # Large batch, short
]
for batch, prompt_len, gen_len in configs:
print(f"\n{'='*60}")
print(f"Batch={batch}, Prompt={prompt_len}, Gen={gen_len}")
print(f"{'='*60}")
results = []
for attn_type in ["standard", "mqa", "gqa"]:
try:
r = benchmark_generation(attn_type, d, layers, h, batch, prompt_len, gen_len)
results.append(r)
print(f"{attn_type:10s} | Prefill {r['prefill_ms']:6.1f}ms | Gen {r['gen_tok_s']:6.0f} tok/s | Cache {r['cache_mb']:5.1f}MB")
except Exception as e:
print(f"{attn_type:10s} | FAILED: {e}")
torch.cuda.empty_cache()
if len(results) >= 2:
std = next((r for r in results if r['type'] == 'standard'), None)
for r in results:
if r['type'] != 'standard' and std:
speedup = r['gen_tok_s'] / std['gen_tok_s']
cache_ratio = r['cache_mb'] / std['cache_mb']
print(f" → {r['type']} vs standard: {speedup:.2f}x gen speed, {cache_ratio:.2f}x cache")
if __name__ == "__main__":
main()