File size: 3,411 Bytes
74d778e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 | import torch
from kernels.benchmark import Benchmark
def apply_rotary_reference(
x1: torch.Tensor, x2: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, conj: bool
) -> tuple[torch.Tensor, torch.Tensor]:
if not conj:
out1 = x1 * cos - x2 * sin
out2 = x1 * sin + x2 * cos
else:
out1 = x1 * cos + x2 * sin
out2 = -x1 * sin + x2 * cos
return out1, out2
class RotaryBenchmark(Benchmark):
seed: int = 42
def setup(self):
batch_size = 2
seqlen = 128
num_heads = 8
head_dim = 64
rotary_dim = 32
# Query tensor split into rotary parts
self.x1 = torch.randn(
batch_size,
seqlen,
num_heads,
rotary_dim,
device=self.device,
dtype=torch.float32,
)
self.x2 = torch.randn(
batch_size,
seqlen,
num_heads,
rotary_dim,
device=self.device,
dtype=torch.float32,
)
# Rotary position embeddings
self.cos = torch.randn(
seqlen, 1, rotary_dim, device=self.device, dtype=torch.float32
)
self.sin = torch.randn(
seqlen, 1, rotary_dim, device=self.device, dtype=torch.float32
)
# Output tensors (in-place, so clone inputs)
self.out1 = self.x1.clone()
self.out2 = self.x2.clone()
def benchmark_base(self):
# Reset outputs to input values for in-place operation
self.out1.copy_(self.x1)
self.out2.copy_(self.x2)
self.kernel.apply_rotary(
self.out1, self.out2, self.cos, self.sin, self.out1, self.out2, False
)
def verify_base(self) -> torch.Tensor:
ref_out1, ref_out2 = apply_rotary_reference(
self.x1, self.x2, self.cos, self.sin, False
)
# Concatenate for comparison (benchmark compares self.out with returned tensor)
self.out = torch.cat([self.out1, self.out2], dim=-1)
return torch.cat([ref_out1, ref_out2], dim=-1)
def setup_large(self):
batch_size = 8
seqlen = 512
num_heads = 32
rotary_dim = 64
self.x1 = torch.randn(
batch_size,
seqlen,
num_heads,
rotary_dim,
device=self.device,
dtype=torch.float32,
)
self.x2 = torch.randn(
batch_size,
seqlen,
num_heads,
rotary_dim,
device=self.device,
dtype=torch.float32,
)
self.cos = torch.randn(
seqlen, 1, rotary_dim, device=self.device, dtype=torch.float32
)
self.sin = torch.randn(
seqlen, 1, rotary_dim, device=self.device, dtype=torch.float32
)
self.out1 = self.x1.clone()
self.out2 = self.x2.clone()
def benchmark_large(self):
self.out1.copy_(self.x1)
self.out2.copy_(self.x2)
self.kernel.apply_rotary(
self.out1, self.out2, self.cos, self.sin, self.out1, self.out2, False
)
def verify_large(self) -> torch.Tensor:
ref_out1, ref_out2 = apply_rotary_reference(
self.x1, self.x2, self.cos, self.sin, False
)
self.out = torch.cat([self.out1, self.out2], dim=-1)
return torch.cat([ref_out1, ref_out2], dim=-1)
|