import torch from kernels.benchmark import Benchmark def setup_silu_tensors(self, num_tokens: int, hidden_dim: int, dtype=torch.float16): self.x = torch.randn(num_tokens, 2 * hidden_dim, device="cuda", dtype=dtype) self.out = torch.empty(num_tokens, hidden_dim, device="cuda", dtype=dtype) def verify_silu(self): d = self.x.shape[-1] // 2 ref = torch.nn.functional.silu(self.x[..., :d]) * self.x[..., d:] return torch.allclose(self.out, ref, atol=1e-3, rtol=1e-3) class SiluWorkloads(Benchmark): kernel_id = "kernels-community/activation" seed = 42 x: torch.Tensor # kernel specific input var out: torch.Tensor # kernel specific output var # Workload 1 def setup_small(self): setup_silu_tensors(self, num_tokens=32, hidden_dim=256) def benchmark_small(self): self.kernel.silu_and_mul(self.out, self.x) # type: ignore def verify_small(self): return verify_silu(self) # Workload 2 def setup_medium(self): setup_silu_tensors(self, num_tokens=1024, hidden_dim=2048) def benchmark_medium(self): self.kernel.silu_and_mul(self.out, self.x) # type: ignore def verify_medium(self): return verify_silu(self) class SiluWorkloads2(Benchmark): kernel_id = "kernels-community/activation" seed = 42 x: torch.Tensor # kernel specific input var out: torch.Tensor # kernel specific output var # Workload 1 def setup_small(self): setup_silu_tensors(self, num_tokens=32, hidden_dim=256) def benchmark_small(self): self.kernel.silu_and_mul(self.out, self.x) # type: ignore def verify_small(self): return verify_silu(self) # Workload 2 def setup_medium(self): setup_silu_tensors(self, num_tokens=1024, hidden_dim=2048) def benchmark_medium(self): self.kernel.silu_and_mul(self.out, self.x) # type: ignore # Note: show case without a verify