| import torch |
|
|
| from kernels.benchmark import Benchmark |
|
|
|
|
| def apply_rotary_reference( |
| x1: torch.Tensor, x2: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, conj: bool |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| if not conj: |
| out1 = x1 * cos - x2 * sin |
| out2 = x1 * sin + x2 * cos |
| else: |
| out1 = x1 * cos + x2 * sin |
| out2 = -x1 * sin + x2 * cos |
| return out1, out2 |
|
|
|
|
| class RotaryBenchmark(Benchmark): |
| seed: int = 42 |
|
|
| def setup(self): |
| batch_size = 2 |
| seqlen = 128 |
| num_heads = 8 |
| head_dim = 64 |
| rotary_dim = 32 |
|
|
| |
| self.x1 = torch.randn( |
| batch_size, |
| seqlen, |
| num_heads, |
| rotary_dim, |
| device=self.device, |
| dtype=torch.float32, |
| ) |
| self.x2 = torch.randn( |
| batch_size, |
| seqlen, |
| num_heads, |
| rotary_dim, |
| device=self.device, |
| dtype=torch.float32, |
| ) |
|
|
| |
| self.cos = torch.randn( |
| seqlen, 1, rotary_dim, device=self.device, dtype=torch.float32 |
| ) |
| self.sin = torch.randn( |
| seqlen, 1, rotary_dim, device=self.device, dtype=torch.float32 |
| ) |
|
|
| |
| self.out1 = self.x1.clone() |
| self.out2 = self.x2.clone() |
|
|
| def benchmark_base(self): |
| |
| self.out1.copy_(self.x1) |
| self.out2.copy_(self.x2) |
| self.kernel.apply_rotary( |
| self.out1, self.out2, self.cos, self.sin, self.out1, self.out2, False |
| ) |
|
|
| def verify_base(self) -> torch.Tensor: |
| ref_out1, ref_out2 = apply_rotary_reference( |
| self.x1, self.x2, self.cos, self.sin, False |
| ) |
| |
| self.out = torch.cat([self.out1, self.out2], dim=-1) |
| return torch.cat([ref_out1, ref_out2], dim=-1) |
|
|
| def setup_large(self): |
| batch_size = 8 |
| seqlen = 512 |
| num_heads = 32 |
| rotary_dim = 64 |
|
|
| self.x1 = torch.randn( |
| batch_size, |
| seqlen, |
| num_heads, |
| rotary_dim, |
| device=self.device, |
| dtype=torch.float32, |
| ) |
| self.x2 = torch.randn( |
| batch_size, |
| seqlen, |
| num_heads, |
| rotary_dim, |
| device=self.device, |
| dtype=torch.float32, |
| ) |
|
|
| self.cos = torch.randn( |
| seqlen, 1, rotary_dim, device=self.device, dtype=torch.float32 |
| ) |
| self.sin = torch.randn( |
| seqlen, 1, rotary_dim, device=self.device, dtype=torch.float32 |
| ) |
|
|
| self.out1 = self.x1.clone() |
| self.out2 = self.x2.clone() |
|
|
| def benchmark_large(self): |
| self.out1.copy_(self.x1) |
| self.out2.copy_(self.x2) |
| self.kernel.apply_rotary( |
| self.out1, self.out2, self.cos, self.sin, self.out1, self.out2, False |
| ) |
|
|
| def verify_large(self) -> torch.Tensor: |
| ref_out1, ref_out2 = apply_rotary_reference( |
| self.x1, self.x2, self.cos, self.sin, False |
| ) |
| self.out = torch.cat([self.out1, self.out2], dim=-1) |
| return torch.cat([ref_out1, ref_out2], dim=-1) |
|
|