Kernels
danieldk HF Staff commited on
Commit
15f2bcb
·
verified ·
1 Parent(s): 6ce609c

Benchmarks uploaded using `kernels`.

Browse files
Files changed (1) hide show
  1. benchmarks/benchmark.py +92 -0
benchmarks/benchmark.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from kernels.benchmark import Benchmark
5
+
6
+
7
+ class CausalConv1dBenchmark(Benchmark):
8
+ seed: int = 42
9
+
10
+ def setup(self):
11
+ batch_size, dim, seqlen, width = 2, 64, 128, 4
12
+ self.x = torch.randn(
13
+ batch_size, dim, seqlen, device=self.device, dtype=torch.float16
14
+ )
15
+ self.weight = torch.randn(dim, width, device=self.device, dtype=torch.float32)
16
+ self.bias = torch.randn(dim, device=self.device, dtype=torch.float32)
17
+ self.out = torch.empty(
18
+ batch_size, dim, seqlen, device=self.device, dtype=torch.float16
19
+ )
20
+ self.dim = dim
21
+ self.width = width
22
+ self.seqlen = seqlen
23
+
24
+ def benchmark_base(self):
25
+ self.out = self.kernel.causal_conv1d_fn(self.x, self.weight, self.bias)
26
+
27
+ def verify_base(self) -> torch.Tensor:
28
+ x_fp32 = self.x.to(self.weight.dtype)
29
+ out = F.conv1d(
30
+ x_fp32,
31
+ self.weight.unsqueeze(1),
32
+ self.bias,
33
+ padding=self.width - 1,
34
+ groups=self.dim,
35
+ )
36
+ return out[..., : self.seqlen].to(self.x.dtype)
37
+
38
+ def setup_large(self):
39
+ batch_size, dim, seqlen, width = 8, 256, 512, 4
40
+ self.x = torch.randn(
41
+ batch_size, dim, seqlen, device=self.device, dtype=torch.float16
42
+ )
43
+ self.weight = torch.randn(dim, width, device=self.device, dtype=torch.float32)
44
+ self.bias = torch.randn(dim, device=self.device, dtype=torch.float32)
45
+ self.out = torch.empty(
46
+ batch_size, dim, seqlen, device=self.device, dtype=torch.float16
47
+ )
48
+ self.dim = dim
49
+ self.width = width
50
+ self.seqlen = seqlen
51
+
52
+ def benchmark_large(self):
53
+ self.out = self.kernel.causal_conv1d_fn(self.x, self.weight, self.bias)
54
+
55
+ def verify_large(self) -> torch.Tensor:
56
+ x_fp32 = self.x.to(self.weight.dtype)
57
+ out = F.conv1d(
58
+ x_fp32,
59
+ self.weight.unsqueeze(1),
60
+ self.bias,
61
+ padding=self.width - 1,
62
+ groups=self.dim,
63
+ )
64
+ return out[..., : self.seqlen].to(self.x.dtype)
65
+
66
+ def setup_xlarge(self):
67
+ batch_size, dim, seqlen, width = 16, 512, 1024, 4
68
+ self.x = torch.randn(
69
+ batch_size, dim, seqlen, device=self.device, dtype=torch.float16
70
+ )
71
+ self.weight = torch.randn(dim, width, device=self.device, dtype=torch.float32)
72
+ self.bias = torch.randn(dim, device=self.device, dtype=torch.float32)
73
+ self.out = torch.empty(
74
+ batch_size, dim, seqlen, device=self.device, dtype=torch.float16
75
+ )
76
+ self.dim = dim
77
+ self.width = width
78
+ self.seqlen = seqlen
79
+
80
+ def benchmark_xlarge(self):
81
+ self.out = self.kernel.causal_conv1d_fn(self.x, self.weight, self.bias)
82
+
83
+ def verify_xlarge(self) -> torch.Tensor:
84
+ x_fp32 = self.x.to(self.weight.dtype)
85
+ out = F.conv1d(
86
+ x_fp32,
87
+ self.weight.unsqueeze(1),
88
+ self.bias,
89
+ padding=self.width - 1,
90
+ groups=self.dim,
91
+ )
92
+ return out[..., : self.seqlen].to(self.x.dtype)