drbh HF Staff commited on
Commit
f3b474c
·
verified ·
1 Parent(s): 021b089

Update benchmarks/benchmark.py

Browse files
Files changed (1) hide show
  1. benchmarks/benchmark.py +12 -2
benchmarks/benchmark.py CHANGED
@@ -4,11 +4,20 @@ import torch.nn.functional as F
4
  from kernels.benchmark import Benchmark
5
 
6
 
 
 
 
 
 
 
 
 
7
  class ReluBenchmark(Benchmark):
8
  seed: int = 42
9
 
10
  def setup(self):
11
- self.x = torch.randn(1024, 1024, device="cuda", dtype=torch.float32)
 
12
  self.out = torch.empty_like(self.x)
13
 
14
  def benchmark_base(self):
@@ -18,7 +27,8 @@ class ReluBenchmark(Benchmark):
18
  return F.relu(self.x)
19
 
20
  def setup_large(self):
21
- self.x = torch.randn(4096, 4096, device="cuda", dtype=torch.float32)
 
22
  self.out = torch.empty_like(self.x)
23
 
24
  def benchmark_large(self):
 
4
  from kernels.benchmark import Benchmark
5
 
6
 
7
+ def get_device():
8
+ if torch.cuda.is_available():
9
+ return "cuda"
10
+ elif torch.backends.mps.is_available():
11
+ return "mps"
12
+ return "cpu"
13
+
14
+
15
  class ReluBenchmark(Benchmark):
16
  seed: int = 42
17
 
18
  def setup(self):
19
+ device = get_device()
20
+ self.x = torch.randn(1024, 1024, device=device, dtype=torch.float32)
21
  self.out = torch.empty_like(self.x)
22
 
23
  def benchmark_base(self):
 
27
  return F.relu(self.x)
28
 
29
  def setup_large(self):
30
+ device = get_device()
31
+ self.x = torch.randn(4096, 4096, device=device, dtype=torch.float32)
32
  self.out = torch.empty_like(self.x)
33
 
34
  def benchmark_large(self):