drbh HF Staff commited on
Commit
45b9b60
·
verified ·
1 Parent(s): f3b474c

Update benchmarks/benchmark.py

Browse files
Files changed (1) hide show
  1. benchmarks/benchmark.py +2 -12
benchmarks/benchmark.py CHANGED
@@ -4,20 +4,11 @@ import torch.nn.functional as F
4
  from kernels.benchmark import Benchmark
5
 
6
 
7
- def get_device():
8
- if torch.cuda.is_available():
9
- return "cuda"
10
- elif torch.backends.mps.is_available():
11
- return "mps"
12
- return "cpu"
13
-
14
-
15
  class ReluBenchmark(Benchmark):
16
  seed: int = 42
17
 
18
  def setup(self):
19
- device = get_device()
20
- self.x = torch.randn(1024, 1024, device=device, dtype=torch.float32)
21
  self.out = torch.empty_like(self.x)
22
 
23
  def benchmark_base(self):
@@ -27,8 +18,7 @@ class ReluBenchmark(Benchmark):
27
  return F.relu(self.x)
28
 
29
  def setup_large(self):
30
- device = get_device()
31
- self.x = torch.randn(4096, 4096, device=device, dtype=torch.float32)
32
  self.out = torch.empty_like(self.x)
33
 
34
  def benchmark_large(self):
 
4
  from kernels.benchmark import Benchmark
5
 
6
 
 
 
 
 
 
 
 
 
7
  class ReluBenchmark(Benchmark):
8
  seed: int = 42
9
 
10
  def setup(self):
11
+ self.x = torch.randn(1024, 1024, device=self.device, dtype=torch.float32)
 
12
  self.out = torch.empty_like(self.x)
13
 
14
  def benchmark_base(self):
 
18
  return F.relu(self.x)
19
 
20
  def setup_large(self):
21
+ self.x = torch.randn(4096, 4096, device=self.device, dtype=torch.float32)
 
22
  self.out = torch.empty_like(self.x)
23
 
24
  def benchmark_large(self):