File size: 742 Bytes
021b089 45b9b60 021b089 45b9b60 021b089 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
import torch
import torch.nn.functional as F
from kernels.benchmark import Benchmark
class ReluBenchmark(Benchmark):
seed: int = 42
def setup(self):
self.x = torch.randn(1024, 1024, device=self.device, dtype=torch.float32)
self.out = torch.empty_like(self.x)
def benchmark_base(self):
self.out = self.kernel.relu(self.x)
def verify_base(self) -> torch.Tensor:
return F.relu(self.x)
def setup_large(self):
self.x = torch.randn(4096, 4096, device=self.device, dtype=torch.float32)
self.out = torch.empty_like(self.x)
def benchmark_large(self):
self.out = self.kernel.relu(self.x)
def verify_large(self) -> torch.Tensor:
return F.relu(self.x)
|