File size: 3,618 Bytes
c958d25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import torch

from kernels.benchmark import Benchmark


def mm_to_sparse_reference(
    dense_A: torch.Tensor,
    dense_B: torch.Tensor,
    indices: torch.Tensor,
) -> torch.Tensor:
    batch_size = dense_A.size(0)
    A_num_block = dense_A.size(1)
    B_num_block = dense_B.size(1)
    dim = dense_A.size(2)
    num_block = indices.size(1)

    # Output: (batch_size, num_block, 32, 32)
    sparse_C = torch.zeros(
        batch_size, num_block, 32, 32, device=dense_A.device, dtype=dense_A.dtype
    )

    for b in range(batch_size):
        for blk in range(num_block):
            AB_idx = indices[b, blk].item()
            A_idx = AB_idx // B_num_block
            B_idx = AB_idx % B_num_block

            A_block = dense_A[b, A_idx]  # (dim, 32)
            B_block = dense_B[b, B_idx]  # (dim, 32)

            # Kernel computes C = B.T @ A: (32, dim) @ (dim, 32) = (32, 32)
            sparse_C[b, blk] = B_block.T @ A_block

    return sparse_C


class MRABenchmark(Benchmark):
    seed: int = 42

    def setup(self):
        # Config matching the kernel's expected format
        batch_size = 2
        num_heads = 8
        head_dim = 64
        block_size = 32  # Fixed by kernel

        A_num_block = 4
        B_num_block = 4
        total_blocks = A_num_block * B_num_block
        indices_per_block = 4  # Must be divisible by 4

        self.batch_heads = batch_size * num_heads

        # dense_A: [batch_size, A_num_block, dim, 32]
        self.dense_a = torch.randn(
            self.batch_heads,
            A_num_block,
            head_dim,
            block_size,
            device=self.device,
            dtype=torch.float32,
        )
        # dense_B: [batch_size, B_num_block, dim, 32]
        self.dense_b = torch.randn(
            self.batch_heads,
            B_num_block,
            head_dim,
            block_size,
            device=self.device,
            dtype=torch.float32,
        )
        # indices: [batch_size, num_block]
        self.indices = torch.randint(
            0,
            total_blocks,
            (self.batch_heads, indices_per_block),
            device=self.device,
            dtype=torch.int32,
        )

    def benchmark_base(self):
        self.out = self.kernel.mm_to_sparse(self.dense_a, self.dense_b, self.indices)

    def verify_base(self) -> torch.Tensor:
        return mm_to_sparse_reference(self.dense_a, self.dense_b, self.indices)

    def setup_large(self):
        batch_size = 4
        num_heads = 8
        head_dim = 64
        block_size = 32

        A_num_block = 8
        B_num_block = 8
        total_blocks = A_num_block * B_num_block
        indices_per_block = 8  # Must be divisible by 4

        self.batch_heads = batch_size * num_heads

        self.dense_a = torch.randn(
            self.batch_heads,
            A_num_block,
            head_dim,
            block_size,
            device=self.device,
            dtype=torch.float32,
        )
        self.dense_b = torch.randn(
            self.batch_heads,
            B_num_block,
            head_dim,
            block_size,
            device=self.device,
            dtype=torch.float32,
        )
        self.indices = torch.randint(
            0,
            total_blocks,
            (self.batch_heads, indices_per_block),
            device=self.device,
            dtype=torch.int32,
        )

    def benchmark_large(self):
        self.out = self.kernel.mm_to_sparse(self.dense_a, self.dense_b, self.indices)

    def verify_large(self) -> torch.Tensor:
        return mm_to_sparse_reference(self.dense_a, self.dense_b, self.indices)