File size: 8,527 Bytes
a608493 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 | import torch
from kernels.benchmark import Benchmark
def ref_masked_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
scale: float,
) -> torch.Tensor:
# query: (q, h, d), key: (k, h, d), value: (k, h, d)
# Transpose to (h, q, d) and (h, k, d) for batched matmul
q = query.transpose(0, 1) # (h, q, d)
k = key.transpose(0, 1) # (h, k, d)
v = value.transpose(0, 1) # (h, k, d)
# Compute attention scores: (h, q, d) @ (h, d, k) -> (h, q, k)
attn_weights = (scale * torch.matmul(q, k.transpose(-1, -2))).float()
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
# Compute output: (h, q, k) @ (h, k, d) -> (h, q, d)
out = torch.matmul(attn_weights, v)
# Transpose back to (q, h, d)
return out.transpose(0, 1)
def ref_paged_attention(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
scale: float,
) -> torch.Tensor:
num_seqs = query.shape[0]
num_heads = query.shape[1]
head_size = query.shape[2]
block_size = value_cache.shape[3]
max_seq_len = int(seq_lens.max().item())
# Create position indices for all sequences up to max_seq_len
positions = torch.arange(max_seq_len, device=query.device)
block_indices = positions // block_size # (max_seq_len,)
block_offsets = positions % block_size # (max_seq_len,)
# Gather block numbers for all sequences: (num_seqs, max_seq_len)
block_numbers = block_tables[:, block_indices.long()]
# Flatten for gathering: (num_seqs * max_seq_len,)
flat_block_numbers = block_numbers.reshape(-1)
flat_offsets = block_offsets.repeat(num_seqs)
# Gather keys: key_cache is (num_blocks, num_heads, head_size // x, block_size, x)
# Index into [block_number, :, :, offset, :] and reshape
keys = key_cache[flat_block_numbers, :, :, flat_offsets, :]
keys = keys.reshape(num_seqs, max_seq_len, num_heads, head_size)
keys = keys.transpose(1, 2) # (num_seqs, num_heads, max_seq_len, head_size)
# Gather values: value_cache is (num_blocks, num_heads, head_size, block_size)
values = value_cache[flat_block_numbers, :, :, flat_offsets]
values = values.reshape(num_seqs, max_seq_len, num_heads, head_size)
values = values.transpose(1, 2) # (num_seqs, num_heads, max_seq_len, head_size)
# Query: (num_seqs, num_heads, head_size) -> (num_seqs, num_heads, 1, head_size)
q = query.unsqueeze(2)
# Compute attention scores: (num_seqs, num_heads, 1, head_size) @ (num_seqs, num_heads, head_size, max_seq_len)
attn_weights = (scale * torch.matmul(q, keys.transpose(-1, -2))).float()
# Create causal mask for variable sequence lengths
# Mask out positions beyond seq_len for each sequence
seq_mask = positions.unsqueeze(0) >= seq_lens.unsqueeze(
1
) # (num_seqs, max_seq_len)
seq_mask = seq_mask.unsqueeze(1).unsqueeze(2) # (num_seqs, 1, 1, max_seq_len)
attn_weights = attn_weights.masked_fill(seq_mask, float("-inf"))
attn_weights = torch.softmax(attn_weights, dim=-1).to(values.dtype)
# Compute output: (num_seqs, num_heads, 1, max_seq_len) @ (num_seqs, num_heads, max_seq_len, head_size)
out = torch.matmul(attn_weights, values)
return out.squeeze(2) # (num_seqs, num_heads, head_size)
class PagedAttentionBenchmark(Benchmark):
seed: int = 42
def setup(self):
num_seqs = 4
num_heads = 8
head_size = 64
block_size = 16
max_seq_len = 128
num_blocks = 64
dtype = torch.float16
self.num_heads = num_heads
self.block_size = block_size
self.max_seq_len = max_seq_len
self.scale = 1.0 / (head_size**0.5)
# Query tensor (current token)
self.query = torch.randn(
num_seqs, num_heads, head_size, device=self.device, dtype=dtype
)
# KV cache with proper layout for the kernel
# x = 16 // element_size, for float16 x = 8
x = 16 // torch.tensor([], dtype=dtype).element_size()
self.key_cache = torch.randn(
num_blocks,
num_heads,
head_size // x,
block_size,
x,
device=self.device,
dtype=dtype,
)
self.value_cache = torch.randn(
num_blocks,
num_heads,
head_size,
block_size,
device=self.device,
dtype=dtype,
)
# Block tables: mapping from sequences to memory blocks
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
self.block_tables = torch.randint(
0,
num_blocks,
(num_seqs, max_num_blocks_per_seq),
device=self.device,
dtype=torch.int32,
)
# Sequence lengths
self.seq_lens = torch.tensor(
[64, 96, 48, 128], device=self.device, dtype=torch.int32
)
# KV scales
self.k_scale = torch.tensor(1.0, dtype=torch.float32, device=self.device)
self.v_scale = torch.tensor(1.0, dtype=torch.float32, device=self.device)
# Output tensor
self.out = torch.empty_like(self.query)
def benchmark_base(self):
self.kernel.paged_attention_v1(
self.out,
self.query,
self.key_cache,
self.value_cache,
num_kv_heads=self.num_heads,
scale=self.scale,
block_tables=self.block_tables,
seq_lens=self.seq_lens,
block_size=self.block_size,
max_seq_len=self.max_seq_len,
alibi_slopes=None,
kv_cache_dtype="auto",
k_scale=self.k_scale,
v_scale=self.v_scale,
)
def verify_base(self) -> torch.Tensor:
return ref_paged_attention(
self.query,
self.key_cache,
self.value_cache,
self.block_tables,
self.seq_lens,
self.scale,
)
def setup_large(self):
num_seqs = 16
num_heads = 32
head_size = 128
block_size = 16
max_seq_len = 512
num_blocks = 256
dtype = torch.float16
self.num_heads = num_heads
self.block_size = block_size
self.max_seq_len = max_seq_len
self.scale = 1.0 / (head_size**0.5)
self.query = torch.randn(
num_seqs, num_heads, head_size, device=self.device, dtype=dtype
)
x = 16 // torch.tensor([], dtype=dtype).element_size()
self.key_cache = torch.randn(
num_blocks,
num_heads,
head_size // x,
block_size,
x,
device=self.device,
dtype=dtype,
)
self.value_cache = torch.randn(
num_blocks,
num_heads,
head_size,
block_size,
device=self.device,
dtype=dtype,
)
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
self.block_tables = torch.randint(
0,
num_blocks,
(num_seqs, max_num_blocks_per_seq),
device=self.device,
dtype=torch.int32,
)
# Variable sequence lengths
self.seq_lens = torch.randint(
64, max_seq_len + 1, (num_seqs,), device=self.device, dtype=torch.int32
)
self.k_scale = torch.tensor(1.0, dtype=torch.float32, device=self.device)
self.v_scale = torch.tensor(1.0, dtype=torch.float32, device=self.device)
self.out = torch.empty_like(self.query)
def benchmark_large(self):
self.kernel.paged_attention_v1(
self.out,
self.query,
self.key_cache,
self.value_cache,
num_kv_heads=self.num_heads,
scale=self.scale,
block_tables=self.block_tables,
seq_lens=self.seq_lens,
block_size=self.block_size,
max_seq_len=self.max_seq_len,
alibi_slopes=None,
kv_cache_dtype="auto",
k_scale=self.k_scale,
v_scale=self.v_scale,
)
def verify_large(self) -> torch.Tensor:
return ref_paged_attention(
self.query,
self.key_cache,
self.value_cache,
self.block_tables,
self.seq_lens,
self.scale,
)
|