Kernels
File size: 3,411 Bytes
74d778e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch

from kernels.benchmark import Benchmark


def apply_rotary_reference(
    x1: torch.Tensor, x2: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, conj: bool
) -> tuple[torch.Tensor, torch.Tensor]:
    if not conj:
        out1 = x1 * cos - x2 * sin
        out2 = x1 * sin + x2 * cos
    else:
        out1 = x1 * cos + x2 * sin
        out2 = -x1 * sin + x2 * cos
    return out1, out2


class RotaryBenchmark(Benchmark):
    seed: int = 42

    def setup(self):
        batch_size = 2
        seqlen = 128
        num_heads = 8
        head_dim = 64
        rotary_dim = 32

        # Query tensor split into rotary parts
        self.x1 = torch.randn(
            batch_size,
            seqlen,
            num_heads,
            rotary_dim,
            device=self.device,
            dtype=torch.float32,
        )
        self.x2 = torch.randn(
            batch_size,
            seqlen,
            num_heads,
            rotary_dim,
            device=self.device,
            dtype=torch.float32,
        )

        # Rotary position embeddings
        self.cos = torch.randn(
            seqlen, 1, rotary_dim, device=self.device, dtype=torch.float32
        )
        self.sin = torch.randn(
            seqlen, 1, rotary_dim, device=self.device, dtype=torch.float32
        )

        # Output tensors (in-place, so clone inputs)
        self.out1 = self.x1.clone()
        self.out2 = self.x2.clone()

    def benchmark_base(self):
        # Reset outputs to input values for in-place operation
        self.out1.copy_(self.x1)
        self.out2.copy_(self.x2)
        self.kernel.apply_rotary(
            self.out1, self.out2, self.cos, self.sin, self.out1, self.out2, False
        )

    def verify_base(self) -> torch.Tensor:
        ref_out1, ref_out2 = apply_rotary_reference(
            self.x1, self.x2, self.cos, self.sin, False
        )
        # Concatenate for comparison (benchmark compares self.out with returned tensor)
        self.out = torch.cat([self.out1, self.out2], dim=-1)
        return torch.cat([ref_out1, ref_out2], dim=-1)

    def setup_large(self):
        batch_size = 8
        seqlen = 512
        num_heads = 32
        rotary_dim = 64

        self.x1 = torch.randn(
            batch_size,
            seqlen,
            num_heads,
            rotary_dim,
            device=self.device,
            dtype=torch.float32,
        )
        self.x2 = torch.randn(
            batch_size,
            seqlen,
            num_heads,
            rotary_dim,
            device=self.device,
            dtype=torch.float32,
        )

        self.cos = torch.randn(
            seqlen, 1, rotary_dim, device=self.device, dtype=torch.float32
        )
        self.sin = torch.randn(
            seqlen, 1, rotary_dim, device=self.device, dtype=torch.float32
        )

        self.out1 = self.x1.clone()
        self.out2 = self.x2.clone()

    def benchmark_large(self):
        self.out1.copy_(self.x1)
        self.out2.copy_(self.x2)
        self.kernel.apply_rotary(
            self.out1, self.out2, self.cos, self.sin, self.out1, self.out2, False
        )

    def verify_large(self) -> torch.Tensor:
        ref_out1, ref_out2 = apply_rotary_reference(
            self.x1, self.x2, self.cos, self.sin, False
        )
        self.out = torch.cat([self.out1, self.out2], dim=-1)
        return torch.cat([ref_out1, ref_out2], dim=-1)