danieldk HF Staff commited on
Commit
c958d25
·
verified ·
1 Parent(s): bc5b225

Benchmarks uploaded using `kernels`.

Browse files
Files changed (1) hide show
  1. benchmarks/benchmark.py +128 -0
benchmarks/benchmark.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from kernels.benchmark import Benchmark
4
+
5
+
6
+ def mm_to_sparse_reference(
7
+ dense_A: torch.Tensor,
8
+ dense_B: torch.Tensor,
9
+ indices: torch.Tensor,
10
+ ) -> torch.Tensor:
11
+ batch_size = dense_A.size(0)
12
+ A_num_block = dense_A.size(1)
13
+ B_num_block = dense_B.size(1)
14
+ dim = dense_A.size(2)
15
+ num_block = indices.size(1)
16
+
17
+ # Output: (batch_size, num_block, 32, 32)
18
+ sparse_C = torch.zeros(
19
+ batch_size, num_block, 32, 32, device=dense_A.device, dtype=dense_A.dtype
20
+ )
21
+
22
+ for b in range(batch_size):
23
+ for blk in range(num_block):
24
+ AB_idx = indices[b, blk].item()
25
+ A_idx = AB_idx // B_num_block
26
+ B_idx = AB_idx % B_num_block
27
+
28
+ A_block = dense_A[b, A_idx] # (dim, 32)
29
+ B_block = dense_B[b, B_idx] # (dim, 32)
30
+
31
+ # Kernel computes C = B.T @ A: (32, dim) @ (dim, 32) = (32, 32)
32
+ sparse_C[b, blk] = B_block.T @ A_block
33
+
34
+ return sparse_C
35
+
36
+
37
+ class MRABenchmark(Benchmark):
38
+ seed: int = 42
39
+
40
+ def setup(self):
41
+ # Config matching the kernel's expected format
42
+ batch_size = 2
43
+ num_heads = 8
44
+ head_dim = 64
45
+ block_size = 32 # Fixed by kernel
46
+
47
+ A_num_block = 4
48
+ B_num_block = 4
49
+ total_blocks = A_num_block * B_num_block
50
+ indices_per_block = 4 # Must be divisible by 4
51
+
52
+ self.batch_heads = batch_size * num_heads
53
+
54
+ # dense_A: [batch_size, A_num_block, dim, 32]
55
+ self.dense_a = torch.randn(
56
+ self.batch_heads,
57
+ A_num_block,
58
+ head_dim,
59
+ block_size,
60
+ device=self.device,
61
+ dtype=torch.float32,
62
+ )
63
+ # dense_B: [batch_size, B_num_block, dim, 32]
64
+ self.dense_b = torch.randn(
65
+ self.batch_heads,
66
+ B_num_block,
67
+ head_dim,
68
+ block_size,
69
+ device=self.device,
70
+ dtype=torch.float32,
71
+ )
72
+ # indices: [batch_size, num_block]
73
+ self.indices = torch.randint(
74
+ 0,
75
+ total_blocks,
76
+ (self.batch_heads, indices_per_block),
77
+ device=self.device,
78
+ dtype=torch.int32,
79
+ )
80
+
81
+ def benchmark_base(self):
82
+ self.out = self.kernel.mm_to_sparse(self.dense_a, self.dense_b, self.indices)
83
+
84
+ def verify_base(self) -> torch.Tensor:
85
+ return mm_to_sparse_reference(self.dense_a, self.dense_b, self.indices)
86
+
87
+ def setup_large(self):
88
+ batch_size = 4
89
+ num_heads = 8
90
+ head_dim = 64
91
+ block_size = 32
92
+
93
+ A_num_block = 8
94
+ B_num_block = 8
95
+ total_blocks = A_num_block * B_num_block
96
+ indices_per_block = 8 # Must be divisible by 4
97
+
98
+ self.batch_heads = batch_size * num_heads
99
+
100
+ self.dense_a = torch.randn(
101
+ self.batch_heads,
102
+ A_num_block,
103
+ head_dim,
104
+ block_size,
105
+ device=self.device,
106
+ dtype=torch.float32,
107
+ )
108
+ self.dense_b = torch.randn(
109
+ self.batch_heads,
110
+ B_num_block,
111
+ head_dim,
112
+ block_size,
113
+ device=self.device,
114
+ dtype=torch.float32,
115
+ )
116
+ self.indices = torch.randint(
117
+ 0,
118
+ total_blocks,
119
+ (self.batch_heads, indices_per_block),
120
+ device=self.device,
121
+ dtype=torch.int32,
122
+ )
123
+
124
+ def benchmark_large(self):
125
+ self.out = self.kernel.mm_to_sparse(self.dense_a, self.dense_b, self.indices)
126
+
127
+ def verify_large(self) -> torch.Tensor:
128
+ return mm_to_sparse_reference(self.dense_a, self.dense_b, self.indices)