| | import torch |
| |
|
| | from kernels.benchmark import Benchmark |
| |
|
| |
|
| | def mm_to_sparse_reference( |
| | dense_A: torch.Tensor, |
| | dense_B: torch.Tensor, |
| | indices: torch.Tensor, |
| | ) -> torch.Tensor: |
| | batch_size = dense_A.size(0) |
| | A_num_block = dense_A.size(1) |
| | B_num_block = dense_B.size(1) |
| | dim = dense_A.size(2) |
| | num_block = indices.size(1) |
| |
|
| | |
| | sparse_C = torch.zeros( |
| | batch_size, num_block, 32, 32, device=dense_A.device, dtype=dense_A.dtype |
| | ) |
| |
|
| | for b in range(batch_size): |
| | for blk in range(num_block): |
| | AB_idx = indices[b, blk].item() |
| | A_idx = AB_idx // B_num_block |
| | B_idx = AB_idx % B_num_block |
| |
|
| | A_block = dense_A[b, A_idx] |
| | B_block = dense_B[b, B_idx] |
| |
|
| | |
| | sparse_C[b, blk] = B_block.T @ A_block |
| |
|
| | return sparse_C |
| |
|
| |
|
| | class MRABenchmark(Benchmark): |
| | seed: int = 42 |
| |
|
| | def setup(self): |
| | |
| | batch_size = 2 |
| | num_heads = 8 |
| | head_dim = 64 |
| | block_size = 32 |
| |
|
| | A_num_block = 4 |
| | B_num_block = 4 |
| | total_blocks = A_num_block * B_num_block |
| | indices_per_block = 4 |
| |
|
| | self.batch_heads = batch_size * num_heads |
| |
|
| | |
| | self.dense_a = torch.randn( |
| | self.batch_heads, |
| | A_num_block, |
| | head_dim, |
| | block_size, |
| | device=self.device, |
| | dtype=torch.float32, |
| | ) |
| | |
| | self.dense_b = torch.randn( |
| | self.batch_heads, |
| | B_num_block, |
| | head_dim, |
| | block_size, |
| | device=self.device, |
| | dtype=torch.float32, |
| | ) |
| | |
| | self.indices = torch.randint( |
| | 0, |
| | total_blocks, |
| | (self.batch_heads, indices_per_block), |
| | device=self.device, |
| | dtype=torch.int32, |
| | ) |
| |
|
| | def benchmark_base(self): |
| | self.out = self.kernel.mm_to_sparse(self.dense_a, self.dense_b, self.indices) |
| |
|
| | def verify_base(self) -> torch.Tensor: |
| | return mm_to_sparse_reference(self.dense_a, self.dense_b, self.indices) |
| |
|
| | def setup_large(self): |
| | batch_size = 4 |
| | num_heads = 8 |
| | head_dim = 64 |
| | block_size = 32 |
| |
|
| | A_num_block = 8 |
| | B_num_block = 8 |
| | total_blocks = A_num_block * B_num_block |
| | indices_per_block = 8 |
| |
|
| | self.batch_heads = batch_size * num_heads |
| |
|
| | self.dense_a = torch.randn( |
| | self.batch_heads, |
| | A_num_block, |
| | head_dim, |
| | block_size, |
| | device=self.device, |
| | dtype=torch.float32, |
| | ) |
| | self.dense_b = torch.randn( |
| | self.batch_heads, |
| | B_num_block, |
| | head_dim, |
| | block_size, |
| | device=self.device, |
| | dtype=torch.float32, |
| | ) |
| | self.indices = torch.randint( |
| | 0, |
| | total_blocks, |
| | (self.batch_heads, indices_per_block), |
| | device=self.device, |
| | dtype=torch.int32, |
| | ) |
| |
|
| | def benchmark_large(self): |
| | self.out = self.kernel.mm_to_sparse(self.dense_a, self.dense_b, self.indices) |
| |
|
| | def verify_large(self) -> torch.Tensor: |
| | return mm_to_sparse_reference(self.dense_a, self.dense_b, self.indices) |
| |
|