drbh HF Staff commited on
Commit
0341981
·
verified ·
1 Parent(s): 5c6642e

Update benchmark.py

Browse files
Files changed (1) hide show
  1. benchmark.py +3 -64
benchmark.py CHANGED
@@ -1,66 +1,5 @@
1
- import torch
2
- from kernels.benchmark import Benchmark
3
 
4
 
5
- def setup_silu_tensors(self, num_tokens: int, hidden_dim: int, dtype=torch.float16):
6
- self.x = torch.randn(num_tokens, 2 * hidden_dim, device="cuda", dtype=dtype)
7
- self.out = torch.empty(num_tokens, hidden_dim, device="cuda", dtype=dtype)
8
-
9
-
10
- def verify_silu(self):
11
- d = self.x.shape[-1] // 2
12
- ref = torch.nn.functional.silu(self.x[..., :d]) * self.x[..., d:]
13
- return torch.allclose(self.out, ref, atol=1e-3, rtol=1e-3)
14
-
15
-
16
- class SiluWorkloads(Benchmark):
17
- kernel_id = "kernels-community/activation"
18
- seed = 42
19
- x: torch.Tensor # kernel specific input var
20
- out: torch.Tensor # kernel specific output var
21
-
22
- # Workload 1
23
- def setup_small(self):
24
- setup_silu_tensors(self, num_tokens=32, hidden_dim=256)
25
-
26
- def benchmark_small(self):
27
- self.kernel.silu_and_mul(self.out, self.x) # type: ignore
28
-
29
- def verify_small(self):
30
- return verify_silu(self)
31
-
32
- # Workload 2
33
- def setup_medium(self):
34
- setup_silu_tensors(self, num_tokens=1024, hidden_dim=2048)
35
-
36
- def benchmark_medium(self):
37
- self.kernel.silu_and_mul(self.out, self.x) # type: ignore
38
-
39
- def verify_medium(self):
40
- return verify_silu(self)
41
-
42
-
43
- class SiluWorkloads2(Benchmark):
44
- kernel_id = "kernels-community/activation"
45
- seed = 42
46
- x: torch.Tensor # kernel specific input var
47
- out: torch.Tensor # kernel specific output var
48
-
49
- # Workload 1
50
- def setup_small(self):
51
- setup_silu_tensors(self, num_tokens=32, hidden_dim=256)
52
-
53
- def benchmark_small(self):
54
- self.kernel.silu_and_mul(self.out, self.x) # type: ignore
55
-
56
- def verify_small(self):
57
- return verify_silu(self)
58
-
59
- # Workload 2
60
- def setup_medium(self):
61
- setup_silu_tensors(self, num_tokens=1024, hidden_dim=2048)
62
-
63
- def benchmark_medium(self):
64
- self.kernel.silu_and_mul(self.out, self.x) # type: ignore
65
-
66
- # Note: show case without a verify
 
1
+ from kernels.benchmarks import SiluAndMulBenchmark
 
2
 
3
 
4
+ class SiluWorkloads(SiluAndMulBenchmark):
5
+ kernel_id = "kernels-community/activation