OpenTransformer commited on
Commit
ec068a9
·
verified ·
1 Parent(s): 764896d

Add experiments/infer_bench.py

Browse files
Files changed (1) hide show
  1. experiments/infer_bench.py +259 -0
experiments/infer_bench.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Inference benchmark - measure actual generation speed
4
+ MQA/GQA should shine here due to smaller KV cache
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import time
11
+ import math
12
+
13
+ DEV = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ VOCAB = 128256
15
+
16
+ def alibi_bias(n_heads, n_tokens):
17
+ def slopes(n):
18
+ start = 2 ** (-2 ** -(math.log2(n) - 3))
19
+ return [start * (start ** i) for i in range(n)]
20
+ s = slopes(n_heads) if n_heads > 0 and math.log2(n_heads).is_integer() else slopes(2 ** math.floor(math.log2(max(1, n_heads))))[:n_heads]
21
+ s = torch.tensor(s, device=DEV).view(1, n_heads, 1, 1)
22
+ i = torch.arange(n_tokens, device=DEV).view(1, 1, n_tokens, 1)
23
+ j = torch.arange(n_tokens, device=DEV).view(1, 1, 1, n_tokens)
24
+ return -s * (j - i).clamp_min(0).float()
25
+
26
+
27
+ class StandardAttn(nn.Module):
28
+ def __init__(self, d, h):
29
+ super().__init__()
30
+ self.h, self.dk = h, d // h
31
+ self.qkv = nn.Linear(d, 3*d, bias=False)
32
+ self.proj = nn.Linear(d, d, bias=False)
33
+
34
+ def forward(self, x, kv_cache=None):
35
+ B, N, _ = x.shape
36
+ qkv = self.qkv(x).reshape(B, N, 3, self.h, self.dk).permute(2, 0, 3, 1, 4)
37
+ q, k, v = qkv[0], qkv[1], qkv[2]
38
+
39
+ if kv_cache is not None:
40
+ k_cache, v_cache = kv_cache
41
+ k = torch.cat([k_cache, k], dim=2)
42
+ v = torch.cat([v_cache, v], dim=2)
43
+
44
+ new_cache = (k, v)
45
+ seq_len = k.shape[2]
46
+
47
+ att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
48
+ # Causal mask for last position only
49
+ mask = torch.zeros(1, 1, N, seq_len, device=x.device)
50
+ mask[:, :, :, seq_len:] = float('-inf')
51
+ att = att + mask
52
+
53
+ z = (att.softmax(-1) @ v).transpose(1, 2).reshape(B, N, -1)
54
+ return self.proj(z), new_cache
55
+
56
+ def cache_size(self, seq_len, batch):
57
+ # K and V each: (batch, heads, seq, dk)
58
+ return 2 * batch * self.h * seq_len * self.dk
59
+
60
+
61
+ class MQAAttn(nn.Module):
62
+ def __init__(self, d, h):
63
+ super().__init__()
64
+ self.h, self.dk = h, d // h
65
+ self.q = nn.Linear(d, d, bias=False)
66
+ self.k = nn.Linear(d, self.dk, bias=False) # 1 head
67
+ self.v = nn.Linear(d, self.dk, bias=False) # 1 head
68
+ self.proj = nn.Linear(d, d, bias=False)
69
+
70
+ def forward(self, x, kv_cache=None):
71
+ B, N, _ = x.shape
72
+ q = self.q(x).view(B, N, self.h, self.dk).transpose(1, 2)
73
+ k = self.k(x).view(B, N, 1, self.dk).transpose(1, 2)
74
+ v = self.v(x).view(B, N, 1, self.dk).transpose(1, 2)
75
+
76
+ if kv_cache is not None:
77
+ k_cache, v_cache = kv_cache
78
+ k = torch.cat([k_cache, k], dim=2)
79
+ v = torch.cat([v_cache, v], dim=2)
80
+
81
+ new_cache = (k, v)
82
+ seq_len = k.shape[2]
83
+
84
+ att = (q @ k.transpose(-1, -2)) / math.sqrt(self.dk)
85
+ mask = torch.zeros(1, 1, N, seq_len, device=x.device)
86
+ mask[:, :, :, seq_len:] = float('-inf')
87
+ att = att + mask
88
+
89
+ z = (att.softmax(-1) @ v).transpose(1, 2).reshape(B, N, -1)
90
+ return self.proj(z), new_cache
91
+
92
+ def cache_size(self, seq_len, batch):
93
+ # Only 1 K and 1 V head!
94
+ return 2 * batch * 1 * seq_len * self.dk
95
+
96
+
97
+ class GQAAttn(nn.Module):
98
+ def __init__(self, d, h, num_kv_heads=2):
99
+ super().__init__()
100
+ self.h, self.dk = h, d // h
101
+ self.num_kv_heads = num_kv_heads
102
+ self.heads_per_group = h // num_kv_heads
103
+ self.q = nn.Linear(d, d, bias=False)
104
+ self.k = nn.Linear(d, num_kv_heads * self.dk, bias=False)
105
+ self.v = nn.Linear(d, num_kv_heads * self.dk, bias=False)
106
+ self.proj = nn.Linear(d, d, bias=False)
107
+
108
+ def forward(self, x, kv_cache=None):
109
+ B, N, _ = x.shape
110
+ q = self.q(x).view(B, N, self.h, self.dk).transpose(1, 2)
111
+ k = self.k(x).view(B, N, self.num_kv_heads, self.dk).transpose(1, 2)
112
+ v = self.v(x).view(B, N, self.num_kv_heads, self.dk).transpose(1, 2)
113
+
114
+ if kv_cache is not None:
115
+ k_cache, v_cache = kv_cache
116
+ k = torch.cat([k_cache, k], dim=2)
117
+ v = torch.cat([v_cache, v], dim=2)
118
+
119
+ new_cache = (k, v)
120
+
121
+ k_exp = k.repeat_interleave(self.heads_per_group, dim=1)
122
+ v_exp = v.repeat_interleave(self.heads_per_group, dim=1)
123
+
124
+ seq_len = k.shape[2]
125
+ att = (q @ k_exp.transpose(-1, -2)) / math.sqrt(self.dk)
126
+ mask = torch.zeros(1, 1, N, seq_len, device=x.device)
127
+ mask[:, :, :, seq_len:] = float('-inf')
128
+ att = att + mask
129
+
130
+ z = (att.softmax(-1) @ v_exp).transpose(1, 2).reshape(B, N, -1)
131
+ return self.proj(z), new_cache
132
+
133
+ def cache_size(self, seq_len, batch):
134
+ return 2 * batch * self.num_kv_heads * seq_len * self.dk
135
+
136
+
137
+ class Block(nn.Module):
138
+ def __init__(self, d, h, attn_type="standard"):
139
+ super().__init__()
140
+ self.ln1, self.ln2 = nn.LayerNorm(d), nn.LayerNorm(d)
141
+ if attn_type == "standard":
142
+ self.attn = StandardAttn(d, h)
143
+ elif attn_type == "mqa":
144
+ self.attn = MQAAttn(d, h)
145
+ elif attn_type == "gqa":
146
+ self.attn = GQAAttn(d, h, num_kv_heads=2)
147
+ self.ff = nn.Sequential(nn.Linear(d, 4*d), nn.GELU(), nn.Linear(4*d, d))
148
+
149
+ def forward(self, x, kv_cache=None):
150
+ attn_out, new_cache = self.attn(self.ln1(x), kv_cache)
151
+ x = x + attn_out
152
+ x = x + self.ff(self.ln2(x))
153
+ return x, new_cache
154
+
155
+
156
+ class Model(nn.Module):
157
+ def __init__(self, d, layers, h, attn_type="standard"):
158
+ super().__init__()
159
+ self.emb = nn.Embedding(VOCAB, d)
160
+ self.blocks = nn.ModuleList([Block(d, h, attn_type) for _ in range(layers)])
161
+ self.ln = nn.LayerNorm(d)
162
+ self.head = nn.Linear(d, VOCAB, bias=False)
163
+ self.head.weight = self.emb.weight
164
+ self.d, self.layers_n = d, layers
165
+
166
+ def forward(self, x, kv_caches=None):
167
+ x = self.emb(x)
168
+ new_caches = []
169
+ for i, b in enumerate(self.blocks):
170
+ cache = kv_caches[i] if kv_caches else None
171
+ x, new_cache = b(x, cache)
172
+ new_caches.append(new_cache)
173
+ return self.head(self.ln(x)), new_caches
174
+
175
+
176
+ @torch.no_grad()
177
+ def benchmark_generation(attn_type, d, layers, h, batch, prompt_len, gen_len):
178
+ model = Model(d, layers, h, attn_type).to(DEV).eval()
179
+
180
+ # Prefill
181
+ prompt = torch.randint(0, VOCAB, (batch, prompt_len), device=DEV)
182
+
183
+ torch.cuda.synchronize()
184
+ start = time.time()
185
+
186
+ logits, kv_caches = model(prompt)
187
+ next_tok = logits[:, -1:].argmax(-1)
188
+
189
+ torch.cuda.synchronize()
190
+ prefill_time = time.time() - start
191
+
192
+ # Generation
193
+ torch.cuda.synchronize()
194
+ start = time.time()
195
+
196
+ for _ in range(gen_len):
197
+ logits, kv_caches = model(next_tok, kv_caches)
198
+ next_tok = logits[:, -1:].argmax(-1)
199
+
200
+ torch.cuda.synchronize()
201
+ gen_time = time.time() - start
202
+
203
+ # Calculate cache size
204
+ cache_size = sum(
205
+ b.attn.cache_size(prompt_len + gen_len, batch)
206
+ for b in model.blocks
207
+ ) * 4 / (1024**2) # MB (float32)
208
+
209
+ tok_per_sec = gen_len * batch / gen_time
210
+
211
+ return {
212
+ "type": attn_type,
213
+ "prefill_ms": prefill_time * 1000,
214
+ "gen_tok_s": tok_per_sec,
215
+ "cache_mb": cache_size,
216
+ "gen_time": gen_time
217
+ }
218
+
219
+
220
+ def main():
221
+ print(f"Device: {DEV}")
222
+ if torch.cuda.is_available():
223
+ print(f"GPU: {torch.cuda.get_device_name()}")
224
+
225
+ d, layers, h = 512, 8, 8
226
+
227
+ configs = [
228
+ (1, 128, 128), # Small batch, short
229
+ (1, 128, 512), # Small batch, long gen
230
+ (8, 128, 128), # Medium batch
231
+ (16, 64, 64), # Large batch, short
232
+ ]
233
+
234
+ for batch, prompt_len, gen_len in configs:
235
+ print(f"\n{'='*60}")
236
+ print(f"Batch={batch}, Prompt={prompt_len}, Gen={gen_len}")
237
+ print(f"{'='*60}")
238
+
239
+ results = []
240
+ for attn_type in ["standard", "mqa", "gqa"]:
241
+ try:
242
+ r = benchmark_generation(attn_type, d, layers, h, batch, prompt_len, gen_len)
243
+ results.append(r)
244
+ print(f"{attn_type:10s} | Prefill {r['prefill_ms']:6.1f}ms | Gen {r['gen_tok_s']:6.0f} tok/s | Cache {r['cache_mb']:5.1f}MB")
245
+ except Exception as e:
246
+ print(f"{attn_type:10s} | FAILED: {e}")
247
+ torch.cuda.empty_cache()
248
+
249
+ if len(results) >= 2:
250
+ std = next((r for r in results if r['type'] == 'standard'), None)
251
+ for r in results:
252
+ if r['type'] != 'standard' and std:
253
+ speedup = r['gen_tok_s'] / std['gen_tok_s']
254
+ cache_ratio = r['cache_mb'] / std['cache_mb']
255
+ print(f" → {r['type']} vs standard: {speedup:.2f}x gen speed, {cache_ratio:.2f}x cache")
256
+
257
+
258
+ if __name__ == "__main__":
259
+ main()