File size: 742 Bytes
021b089
 
 
 
 
 
 
 
 
 
45b9b60
021b089
 
 
 
 
 
 
 
 
45b9b60
021b089
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
import torch.nn.functional as F

from kernels.benchmark import Benchmark


class ReluBenchmark(Benchmark):
    seed: int = 42

    def setup(self):
        self.x = torch.randn(1024, 1024, device=self.device, dtype=torch.float32)
        self.out = torch.empty_like(self.x)

    def benchmark_base(self):
        self.out = self.kernel.relu(self.x)

    def verify_base(self) -> torch.Tensor:
        return F.relu(self.x)

    def setup_large(self):
        self.x = torch.randn(4096, 4096, device=self.device, dtype=torch.float32)
        self.out = torch.empty_like(self.x)

    def benchmark_large(self):
        self.out = self.kernel.relu(self.x)

    def verify_large(self) -> torch.Tensor:
        return F.relu(self.x)