drbh HF Staff commited on
Commit
c623dea
·
verified ·
1 Parent(s): d7ef8e6

Upload benchmark.py

Browse files
Files changed (1) hide show
  1. benchmarks/benchmark.py +322 -0
benchmarks/benchmark.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+
4
+ from kernels.benchmark import Benchmark
5
+
6
+
7
+ def _cdiv(a, b):
8
+ return (a + b - 1) // b
9
+
10
+
11
+ def _extract_output(result):
12
+ if isinstance(result, tuple):
13
+ return result[0]
14
+ return result
15
+
16
+
17
+ def _reference_mla_decode(q, blocked_k, block_table, cache_seqlens, head_dim_v, causal=False):
18
+ b, s_q, h_q, d = q.size()
19
+ block_size = blocked_k.size(1)
20
+ h_kv = blocked_k.size(2)
21
+
22
+ out = torch.empty(b, s_q, h_q, head_dim_v, dtype=torch.float32, device=q.device)
23
+
24
+ for i in range(b):
25
+ cur_len = int(cache_seqlens[i].item())
26
+ num_blocks = _cdiv(cur_len, block_size)
27
+ cur_blocks = block_table[i][:num_blocks]
28
+ kv = blocked_k[cur_blocks].reshape(-1, h_kv, d)[:cur_len]
29
+
30
+ query = q[i].transpose(0, 1).float() # [h_q, s_q, d]
31
+ key_val = kv.transpose(0, 1).float() # [h_kv, s_k, d]
32
+
33
+ if h_kv != h_q:
34
+ key_val = key_val.repeat_interleave(h_q // h_kv, dim=0)
35
+
36
+ attn = query @ key_val.transpose(-2, -1) / math.sqrt(d)
37
+
38
+ s_k = key_val.size(1)
39
+ if causal and s_q > 1:
40
+ mask = torch.ones(s_q, s_k, dtype=torch.bool, device=q.device).tril(
41
+ diagonal=s_k - s_q
42
+ )
43
+ attn.masked_fill_(~mask, float("-inf"))
44
+
45
+ attn = torch.softmax(attn, dim=-1)
46
+ output = attn @ key_val[..., :head_dim_v]
47
+ out[i] = output.transpose(0, 1)
48
+
49
+ return out.to(q.dtype)
50
+
51
+
52
+ def _varlen_reference_attention(q, k, v, cu_seqlens_q, cu_seqlens_k, causal=False):
53
+ batch_size = cu_seqlens_q.shape[0] - 1
54
+ total_tokens_q = q.shape[0]
55
+ num_heads = q.shape[1]
56
+ head_dim_v = v.shape[2]
57
+ scale = q.shape[-1] ** (-0.5)
58
+
59
+ out = torch.zeros(
60
+ (total_tokens_q, num_heads, head_dim_v), device=q.device, dtype=q.dtype
61
+ )
62
+
63
+ for b in range(batch_size):
64
+ start_q, end_q = cu_seqlens_q[b], cu_seqlens_q[b + 1]
65
+ start_k, end_k = cu_seqlens_k[b], cu_seqlens_k[b + 1]
66
+
67
+ q_b = q[start_q:end_q].transpose(0, 1).float() # [H, seq_q, D_qk]
68
+ k_b = k[start_k:end_k].transpose(0, 1).float() # [H, seq_k, D_qk]
69
+ v_b = v[start_k:end_k].transpose(0, 1).float() # [H, seq_k, D_v]
70
+
71
+ attn = q_b @ k_b.transpose(-2, -1) * scale
72
+
73
+ if causal:
74
+ seq_q, seq_k = q_b.size(1), k_b.size(1)
75
+ mask = torch.ones(seq_q, seq_k, dtype=torch.bool, device=q.device).tril(
76
+ diagonal=seq_k - seq_q
77
+ )
78
+ attn.masked_fill_(~mask, float("-inf"))
79
+
80
+ attn = torch.softmax(attn, dim=-1)
81
+ result = attn @ v_b # [H, seq_q, D_v]
82
+ out[start_q:end_q] = result.transpose(0, 1).to(q.dtype)
83
+
84
+ return out
85
+
86
+
87
+ # MLA decode constants (DeepSeek V3 architecture)
88
+ _HEAD_DIM = 576 # Q/K head dimension
89
+ _HEAD_DIM_V = 512 # V head dimension
90
+ _NUM_HEADS_K = 1 # MLA uses single KV head
91
+ _PAGE_BLOCK_SIZE = 64 # Page block size
92
+
93
+
94
+ def _setup_mla_decode(bench, batch_size, seq_k, num_heads_q):
95
+ max_num_blocks = _cdiv(seq_k, _PAGE_BLOCK_SIZE)
96
+ total_blocks = batch_size * max_num_blocks
97
+
98
+ bench.q = (
99
+ torch.randn(
100
+ batch_size, 1, num_heads_q, _HEAD_DIM, device="cuda", dtype=torch.bfloat16
101
+ )
102
+ / 10
103
+ )
104
+ bench.blocked_k = (
105
+ torch.randn(
106
+ total_blocks,
107
+ _PAGE_BLOCK_SIZE,
108
+ _NUM_HEADS_K,
109
+ _HEAD_DIM,
110
+ device="cuda",
111
+ dtype=torch.bfloat16,
112
+ )
113
+ / 10
114
+ )
115
+ bench.block_table = torch.arange(
116
+ total_blocks, device="cuda", dtype=torch.int32
117
+ ).view(batch_size, max_num_blocks)
118
+ bench.cache_seqlens = torch.full(
119
+ (batch_size,), seq_k, device="cuda", dtype=torch.int32
120
+ )
121
+ bench.tile_scheduler_metadata, _ = bench.kernel.get_mla_metadata()
122
+ bench.out = torch.empty(
123
+ batch_size, 1, num_heads_q, _HEAD_DIM_V, device="cuda", dtype=torch.bfloat16
124
+ )
125
+
126
+
127
+ def _run_mla_decode(bench, causal=False):
128
+ out, lse = bench.kernel.flash_mla_with_kvcache(
129
+ q=bench.q,
130
+ k_cache=bench.blocked_k,
131
+ block_table=bench.block_table,
132
+ cache_seqlens=bench.cache_seqlens,
133
+ head_dim_v=_HEAD_DIM_V,
134
+ tile_scheduler_metadata=bench.tile_scheduler_metadata,
135
+ causal=causal,
136
+ )
137
+ bench.out = out
138
+
139
+
140
+ def _verify_mla_decode(bench, causal=False):
141
+ return _reference_mla_decode(
142
+ bench.q,
143
+ bench.blocked_k,
144
+ bench.block_table,
145
+ bench.cache_seqlens,
146
+ _HEAD_DIM_V,
147
+ causal=causal,
148
+ )
149
+
150
+
151
+ class FlashMLABenchmark(Benchmark):
152
+ seed: int = 42
153
+
154
+ # Workload: small (B=2, S_k=256, H_q=64)
155
+ def setup_small(self):
156
+ _setup_mla_decode(self, batch_size=2, seq_k=256, num_heads_q=64)
157
+
158
+ def benchmark_small(self):
159
+ _run_mla_decode(self, causal=False)
160
+
161
+ def verify_small(self) -> torch.Tensor:
162
+ return _verify_mla_decode(self, causal=False)
163
+
164
+ # Workload: medium (B=4, S_k=1024, H_q=64)
165
+ def setup_medium(self):
166
+ _setup_mla_decode(self, batch_size=4, seq_k=1024, num_heads_q=64)
167
+
168
+ def benchmark_medium(self):
169
+ _run_mla_decode(self, causal=False)
170
+
171
+ def verify_medium(self) -> torch.Tensor:
172
+ return _verify_mla_decode(self, causal=False)
173
+
174
+ # Workload: large (B=8, S_k=4096, H_q=128)
175
+ def setup_large(self):
176
+ _setup_mla_decode(self, batch_size=8, seq_k=4096, num_heads_q=128)
177
+
178
+ def benchmark_large(self):
179
+ _run_mla_decode(self, causal=False)
180
+
181
+ def verify_large(self) -> torch.Tensor:
182
+ return _verify_mla_decode(self, causal=False)
183
+
184
+
185
+ class FlashMLACausalBenchmark(Benchmark):
186
+ seed: int = 42
187
+
188
+ # Workload: small (B=2, S_k=256, H_q=64)
189
+ def setup_small(self):
190
+ _setup_mla_decode(self, batch_size=2, seq_k=256, num_heads_q=64)
191
+
192
+ def benchmark_small(self):
193
+ _run_mla_decode(self, causal=True)
194
+
195
+ def verify_small(self) -> torch.Tensor:
196
+ return _verify_mla_decode(self, causal=True)
197
+
198
+ # Workload: medium (B=4, S_k=1024, H_q=64)
199
+ def setup_medium(self):
200
+ _setup_mla_decode(self, batch_size=4, seq_k=1024, num_heads_q=64)
201
+
202
+ def benchmark_medium(self):
203
+ _run_mla_decode(self, causal=True)
204
+
205
+ def verify_medium(self) -> torch.Tensor:
206
+ return _verify_mla_decode(self, causal=True)
207
+
208
+ # Workload: large (B=8, S_k=4096, H_q=128)
209
+ def setup_large(self):
210
+ _setup_mla_decode(self, batch_size=8, seq_k=4096, num_heads_q=128)
211
+
212
+ def benchmark_large(self):
213
+ _run_mla_decode(self, causal=True)
214
+
215
+ def verify_large(self) -> torch.Tensor:
216
+ return _verify_mla_decode(self, causal=True)
217
+
218
+
219
+ class FlashMLAVarlenBenchmark(Benchmark):
220
+ seed: int = 42
221
+
222
+ # Workload: small (3 sequences, max_seqlen=64)
223
+ def setup_small(self):
224
+ H, D = 8, 64
225
+ seqlens = [32, 48, 64]
226
+ total = sum(seqlens)
227
+ self.q = torch.randn(total, H, D, device="cuda", dtype=torch.bfloat16)
228
+ self.k = torch.randn(total, H, D, device="cuda", dtype=torch.bfloat16)
229
+ self.v = torch.randn(total, H, D, device="cuda", dtype=torch.bfloat16)
230
+ self.cu_seqlens = torch.tensor(
231
+ [0] + list(torch.cumsum(torch.tensor(seqlens), 0)),
232
+ device="cuda",
233
+ dtype=torch.int32,
234
+ )
235
+ self.max_seqlen = max(seqlens)
236
+ self.out = torch.empty(total, H, D, device="cuda", dtype=torch.bfloat16)
237
+
238
+ def benchmark_small(self):
239
+ self.out = _extract_output(
240
+ self.kernel.flash_attn_varlen_func(
241
+ self.q,
242
+ self.k,
243
+ self.v,
244
+ self.cu_seqlens,
245
+ self.cu_seqlens,
246
+ self.max_seqlen,
247
+ self.max_seqlen,
248
+ )
249
+ )
250
+
251
+ def verify_small(self) -> torch.Tensor:
252
+ return _varlen_reference_attention(
253
+ self.q, self.k, self.v, self.cu_seqlens, self.cu_seqlens, causal=False
254
+ )
255
+
256
+ # Workload: medium (5 sequences, max_seqlen=256)
257
+ def setup_medium(self):
258
+ H, D = 16, 64
259
+ seqlens = [128, 192, 256, 200, 150]
260
+ total = sum(seqlens)
261
+ self.q = torch.randn(total, H, D, device="cuda", dtype=torch.bfloat16)
262
+ self.k = torch.randn(total, H, D, device="cuda", dtype=torch.bfloat16)
263
+ self.v = torch.randn(total, H, D, device="cuda", dtype=torch.bfloat16)
264
+ self.cu_seqlens = torch.tensor(
265
+ [0] + list(torch.cumsum(torch.tensor(seqlens), 0)),
266
+ device="cuda",
267
+ dtype=torch.int32,
268
+ )
269
+ self.max_seqlen = max(seqlens)
270
+ self.out = torch.empty(total, H, D, device="cuda", dtype=torch.bfloat16)
271
+
272
+ def benchmark_medium(self):
273
+ self.out = _extract_output(
274
+ self.kernel.flash_attn_varlen_func(
275
+ self.q,
276
+ self.k,
277
+ self.v,
278
+ self.cu_seqlens,
279
+ self.cu_seqlens,
280
+ self.max_seqlen,
281
+ self.max_seqlen,
282
+ )
283
+ )
284
+
285
+ def verify_medium(self) -> torch.Tensor:
286
+ return _varlen_reference_attention(
287
+ self.q, self.k, self.v, self.cu_seqlens, self.cu_seqlens, causal=False
288
+ )
289
+
290
+ # Workload: large (8 sequences, max_seqlen=512)
291
+ def setup_large(self):
292
+ H, D = 32, 128
293
+ seqlens = [256, 384, 512, 448, 320, 480, 400, 512]
294
+ total = sum(seqlens)
295
+ self.q = torch.randn(total, H, D, device="cuda", dtype=torch.bfloat16)
296
+ self.k = torch.randn(total, H, D, device="cuda", dtype=torch.bfloat16)
297
+ self.v = torch.randn(total, H, D, device="cuda", dtype=torch.bfloat16)
298
+ self.cu_seqlens = torch.tensor(
299
+ [0] + list(torch.cumsum(torch.tensor(seqlens), 0)),
300
+ device="cuda",
301
+ dtype=torch.int32,
302
+ )
303
+ self.max_seqlen = max(seqlens)
304
+ self.out = torch.empty(total, H, D, device="cuda", dtype=torch.bfloat16)
305
+
306
+ def benchmark_large(self):
307
+ self.out = _extract_output(
308
+ self.kernel.flash_attn_varlen_func(
309
+ self.q,
310
+ self.k,
311
+ self.v,
312
+ self.cu_seqlens,
313
+ self.cu_seqlens,
314
+ self.max_seqlen,
315
+ self.max_seqlen,
316
+ )
317
+ )
318
+
319
+ def verify_large(self) -> torch.Tensor:
320
+ return _varlen_reference_attention(
321
+ self.q, self.k, self.v, self.cu_seqlens, self.cu_seqlens, causal=False
322
+ )