File size: 3,618 Bytes
c958d25 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 | import torch
from kernels.benchmark import Benchmark
def mm_to_sparse_reference(
dense_A: torch.Tensor,
dense_B: torch.Tensor,
indices: torch.Tensor,
) -> torch.Tensor:
batch_size = dense_A.size(0)
A_num_block = dense_A.size(1)
B_num_block = dense_B.size(1)
dim = dense_A.size(2)
num_block = indices.size(1)
# Output: (batch_size, num_block, 32, 32)
sparse_C = torch.zeros(
batch_size, num_block, 32, 32, device=dense_A.device, dtype=dense_A.dtype
)
for b in range(batch_size):
for blk in range(num_block):
AB_idx = indices[b, blk].item()
A_idx = AB_idx // B_num_block
B_idx = AB_idx % B_num_block
A_block = dense_A[b, A_idx] # (dim, 32)
B_block = dense_B[b, B_idx] # (dim, 32)
# Kernel computes C = B.T @ A: (32, dim) @ (dim, 32) = (32, 32)
sparse_C[b, blk] = B_block.T @ A_block
return sparse_C
class MRABenchmark(Benchmark):
seed: int = 42
def setup(self):
# Config matching the kernel's expected format
batch_size = 2
num_heads = 8
head_dim = 64
block_size = 32 # Fixed by kernel
A_num_block = 4
B_num_block = 4
total_blocks = A_num_block * B_num_block
indices_per_block = 4 # Must be divisible by 4
self.batch_heads = batch_size * num_heads
# dense_A: [batch_size, A_num_block, dim, 32]
self.dense_a = torch.randn(
self.batch_heads,
A_num_block,
head_dim,
block_size,
device=self.device,
dtype=torch.float32,
)
# dense_B: [batch_size, B_num_block, dim, 32]
self.dense_b = torch.randn(
self.batch_heads,
B_num_block,
head_dim,
block_size,
device=self.device,
dtype=torch.float32,
)
# indices: [batch_size, num_block]
self.indices = torch.randint(
0,
total_blocks,
(self.batch_heads, indices_per_block),
device=self.device,
dtype=torch.int32,
)
def benchmark_base(self):
self.out = self.kernel.mm_to_sparse(self.dense_a, self.dense_b, self.indices)
def verify_base(self) -> torch.Tensor:
return mm_to_sparse_reference(self.dense_a, self.dense_b, self.indices)
def setup_large(self):
batch_size = 4
num_heads = 8
head_dim = 64
block_size = 32
A_num_block = 8
B_num_block = 8
total_blocks = A_num_block * B_num_block
indices_per_block = 8 # Must be divisible by 4
self.batch_heads = batch_size * num_heads
self.dense_a = torch.randn(
self.batch_heads,
A_num_block,
head_dim,
block_size,
device=self.device,
dtype=torch.float32,
)
self.dense_b = torch.randn(
self.batch_heads,
B_num_block,
head_dim,
block_size,
device=self.device,
dtype=torch.float32,
)
self.indices = torch.randint(
0,
total_blocks,
(self.batch_heads, indices_per_block),
device=self.device,
dtype=torch.int32,
)
def benchmark_large(self):
self.out = self.kernel.mm_to_sparse(self.dense_a, self.dense_b, self.indices)
def verify_large(self) -> torch.Tensor:
return mm_to_sparse_reference(self.dense_a, self.dense_b, self.indices)
|