Kernels
danieldk HF Staff commited on
Commit
a608493
·
verified ·
1 Parent(s): fb85517

Benchmarks uploaded using `kernels`.

Browse files
Files changed (1) hide show
  1. benchmarks/benchmark.py +263 -0
benchmarks/benchmark.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from kernels.benchmark import Benchmark
4
+
5
+
6
+ def ref_masked_attention(
7
+ query: torch.Tensor,
8
+ key: torch.Tensor,
9
+ value: torch.Tensor,
10
+ scale: float,
11
+ ) -> torch.Tensor:
12
+ # query: (q, h, d), key: (k, h, d), value: (k, h, d)
13
+ # Transpose to (h, q, d) and (h, k, d) for batched matmul
14
+ q = query.transpose(0, 1) # (h, q, d)
15
+ k = key.transpose(0, 1) # (h, k, d)
16
+ v = value.transpose(0, 1) # (h, k, d)
17
+
18
+ # Compute attention scores: (h, q, d) @ (h, d, k) -> (h, q, k)
19
+ attn_weights = (scale * torch.matmul(q, k.transpose(-1, -2))).float()
20
+ attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
21
+
22
+ # Compute output: (h, q, k) @ (h, k, d) -> (h, q, d)
23
+ out = torch.matmul(attn_weights, v)
24
+
25
+ # Transpose back to (q, h, d)
26
+ return out.transpose(0, 1)
27
+
28
+
29
+ def ref_paged_attention(
30
+ query: torch.Tensor,
31
+ key_cache: torch.Tensor,
32
+ value_cache: torch.Tensor,
33
+ block_tables: torch.Tensor,
34
+ seq_lens: torch.Tensor,
35
+ scale: float,
36
+ ) -> torch.Tensor:
37
+ num_seqs = query.shape[0]
38
+ num_heads = query.shape[1]
39
+ head_size = query.shape[2]
40
+ block_size = value_cache.shape[3]
41
+ max_seq_len = int(seq_lens.max().item())
42
+
43
+ # Create position indices for all sequences up to max_seq_len
44
+ positions = torch.arange(max_seq_len, device=query.device)
45
+ block_indices = positions // block_size # (max_seq_len,)
46
+ block_offsets = positions % block_size # (max_seq_len,)
47
+
48
+ # Gather block numbers for all sequences: (num_seqs, max_seq_len)
49
+ block_numbers = block_tables[:, block_indices.long()]
50
+
51
+ # Flatten for gathering: (num_seqs * max_seq_len,)
52
+ flat_block_numbers = block_numbers.reshape(-1)
53
+ flat_offsets = block_offsets.repeat(num_seqs)
54
+
55
+ # Gather keys: key_cache is (num_blocks, num_heads, head_size // x, block_size, x)
56
+ # Index into [block_number, :, :, offset, :] and reshape
57
+ keys = key_cache[flat_block_numbers, :, :, flat_offsets, :]
58
+ keys = keys.reshape(num_seqs, max_seq_len, num_heads, head_size)
59
+ keys = keys.transpose(1, 2) # (num_seqs, num_heads, max_seq_len, head_size)
60
+
61
+ # Gather values: value_cache is (num_blocks, num_heads, head_size, block_size)
62
+ values = value_cache[flat_block_numbers, :, :, flat_offsets]
63
+ values = values.reshape(num_seqs, max_seq_len, num_heads, head_size)
64
+ values = values.transpose(1, 2) # (num_seqs, num_heads, max_seq_len, head_size)
65
+
66
+ # Query: (num_seqs, num_heads, head_size) -> (num_seqs, num_heads, 1, head_size)
67
+ q = query.unsqueeze(2)
68
+
69
+ # Compute attention scores: (num_seqs, num_heads, 1, head_size) @ (num_seqs, num_heads, head_size, max_seq_len)
70
+ attn_weights = (scale * torch.matmul(q, keys.transpose(-1, -2))).float()
71
+
72
+ # Create causal mask for variable sequence lengths
73
+ # Mask out positions beyond seq_len for each sequence
74
+ seq_mask = positions.unsqueeze(0) >= seq_lens.unsqueeze(
75
+ 1
76
+ ) # (num_seqs, max_seq_len)
77
+ seq_mask = seq_mask.unsqueeze(1).unsqueeze(2) # (num_seqs, 1, 1, max_seq_len)
78
+ attn_weights = attn_weights.masked_fill(seq_mask, float("-inf"))
79
+
80
+ attn_weights = torch.softmax(attn_weights, dim=-1).to(values.dtype)
81
+
82
+ # Compute output: (num_seqs, num_heads, 1, max_seq_len) @ (num_seqs, num_heads, max_seq_len, head_size)
83
+ out = torch.matmul(attn_weights, values)
84
+
85
+ return out.squeeze(2) # (num_seqs, num_heads, head_size)
86
+
87
+
88
+ class PagedAttentionBenchmark(Benchmark):
89
+ seed: int = 42
90
+
91
+ def setup(self):
92
+ num_seqs = 4
93
+ num_heads = 8
94
+ head_size = 64
95
+ block_size = 16
96
+ max_seq_len = 128
97
+ num_blocks = 64
98
+ dtype = torch.float16
99
+
100
+ self.num_heads = num_heads
101
+ self.block_size = block_size
102
+ self.max_seq_len = max_seq_len
103
+ self.scale = 1.0 / (head_size**0.5)
104
+
105
+ # Query tensor (current token)
106
+ self.query = torch.randn(
107
+ num_seqs, num_heads, head_size, device=self.device, dtype=dtype
108
+ )
109
+
110
+ # KV cache with proper layout for the kernel
111
+ # x = 16 // element_size, for float16 x = 8
112
+ x = 16 // torch.tensor([], dtype=dtype).element_size()
113
+ self.key_cache = torch.randn(
114
+ num_blocks,
115
+ num_heads,
116
+ head_size // x,
117
+ block_size,
118
+ x,
119
+ device=self.device,
120
+ dtype=dtype,
121
+ )
122
+ self.value_cache = torch.randn(
123
+ num_blocks,
124
+ num_heads,
125
+ head_size,
126
+ block_size,
127
+ device=self.device,
128
+ dtype=dtype,
129
+ )
130
+
131
+ # Block tables: mapping from sequences to memory blocks
132
+ max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
133
+ self.block_tables = torch.randint(
134
+ 0,
135
+ num_blocks,
136
+ (num_seqs, max_num_blocks_per_seq),
137
+ device=self.device,
138
+ dtype=torch.int32,
139
+ )
140
+
141
+ # Sequence lengths
142
+ self.seq_lens = torch.tensor(
143
+ [64, 96, 48, 128], device=self.device, dtype=torch.int32
144
+ )
145
+
146
+ # KV scales
147
+ self.k_scale = torch.tensor(1.0, dtype=torch.float32, device=self.device)
148
+ self.v_scale = torch.tensor(1.0, dtype=torch.float32, device=self.device)
149
+
150
+ # Output tensor
151
+ self.out = torch.empty_like(self.query)
152
+
153
+ def benchmark_base(self):
154
+ self.kernel.paged_attention_v1(
155
+ self.out,
156
+ self.query,
157
+ self.key_cache,
158
+ self.value_cache,
159
+ num_kv_heads=self.num_heads,
160
+ scale=self.scale,
161
+ block_tables=self.block_tables,
162
+ seq_lens=self.seq_lens,
163
+ block_size=self.block_size,
164
+ max_seq_len=self.max_seq_len,
165
+ alibi_slopes=None,
166
+ kv_cache_dtype="auto",
167
+ k_scale=self.k_scale,
168
+ v_scale=self.v_scale,
169
+ )
170
+
171
+ def verify_base(self) -> torch.Tensor:
172
+ return ref_paged_attention(
173
+ self.query,
174
+ self.key_cache,
175
+ self.value_cache,
176
+ self.block_tables,
177
+ self.seq_lens,
178
+ self.scale,
179
+ )
180
+
181
+ def setup_large(self):
182
+ num_seqs = 16
183
+ num_heads = 32
184
+ head_size = 128
185
+ block_size = 16
186
+ max_seq_len = 512
187
+ num_blocks = 256
188
+ dtype = torch.float16
189
+
190
+ self.num_heads = num_heads
191
+ self.block_size = block_size
192
+ self.max_seq_len = max_seq_len
193
+ self.scale = 1.0 / (head_size**0.5)
194
+
195
+ self.query = torch.randn(
196
+ num_seqs, num_heads, head_size, device=self.device, dtype=dtype
197
+ )
198
+
199
+ x = 16 // torch.tensor([], dtype=dtype).element_size()
200
+ self.key_cache = torch.randn(
201
+ num_blocks,
202
+ num_heads,
203
+ head_size // x,
204
+ block_size,
205
+ x,
206
+ device=self.device,
207
+ dtype=dtype,
208
+ )
209
+ self.value_cache = torch.randn(
210
+ num_blocks,
211
+ num_heads,
212
+ head_size,
213
+ block_size,
214
+ device=self.device,
215
+ dtype=dtype,
216
+ )
217
+
218
+ max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
219
+ self.block_tables = torch.randint(
220
+ 0,
221
+ num_blocks,
222
+ (num_seqs, max_num_blocks_per_seq),
223
+ device=self.device,
224
+ dtype=torch.int32,
225
+ )
226
+
227
+ # Variable sequence lengths
228
+ self.seq_lens = torch.randint(
229
+ 64, max_seq_len + 1, (num_seqs,), device=self.device, dtype=torch.int32
230
+ )
231
+
232
+ self.k_scale = torch.tensor(1.0, dtype=torch.float32, device=self.device)
233
+ self.v_scale = torch.tensor(1.0, dtype=torch.float32, device=self.device)
234
+
235
+ self.out = torch.empty_like(self.query)
236
+
237
+ def benchmark_large(self):
238
+ self.kernel.paged_attention_v1(
239
+ self.out,
240
+ self.query,
241
+ self.key_cache,
242
+ self.value_cache,
243
+ num_kv_heads=self.num_heads,
244
+ scale=self.scale,
245
+ block_tables=self.block_tables,
246
+ seq_lens=self.seq_lens,
247
+ block_size=self.block_size,
248
+ max_seq_len=self.max_seq_len,
249
+ alibi_slopes=None,
250
+ kv_cache_dtype="auto",
251
+ k_scale=self.k_scale,
252
+ v_scale=self.v_scale,
253
+ )
254
+
255
+ def verify_large(self) -> torch.Tensor:
256
+ return ref_paged_attention(
257
+ self.query,
258
+ self.key_cache,
259
+ self.value_cache,
260
+ self.block_tables,
261
+ self.seq_lens,
262
+ self.scale,
263
+ )