Kernels
File size: 3,152 Bytes
15f2bcb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import torch
import torch.nn.functional as F

from kernels.benchmark import Benchmark


class CausalConv1dBenchmark(Benchmark):
    seed: int = 42

    def setup(self):
        batch_size, dim, seqlen, width = 2, 64, 128, 4
        self.x = torch.randn(
            batch_size, dim, seqlen, device=self.device, dtype=torch.float16
        )
        self.weight = torch.randn(dim, width, device=self.device, dtype=torch.float32)
        self.bias = torch.randn(dim, device=self.device, dtype=torch.float32)
        self.out = torch.empty(
            batch_size, dim, seqlen, device=self.device, dtype=torch.float16
        )
        self.dim = dim
        self.width = width
        self.seqlen = seqlen

    def benchmark_base(self):
        self.out = self.kernel.causal_conv1d_fn(self.x, self.weight, self.bias)

    def verify_base(self) -> torch.Tensor:
        x_fp32 = self.x.to(self.weight.dtype)
        out = F.conv1d(
            x_fp32,
            self.weight.unsqueeze(1),
            self.bias,
            padding=self.width - 1,
            groups=self.dim,
        )
        return out[..., : self.seqlen].to(self.x.dtype)

    def setup_large(self):
        batch_size, dim, seqlen, width = 8, 256, 512, 4
        self.x = torch.randn(
            batch_size, dim, seqlen, device=self.device, dtype=torch.float16
        )
        self.weight = torch.randn(dim, width, device=self.device, dtype=torch.float32)
        self.bias = torch.randn(dim, device=self.device, dtype=torch.float32)
        self.out = torch.empty(
            batch_size, dim, seqlen, device=self.device, dtype=torch.float16
        )
        self.dim = dim
        self.width = width
        self.seqlen = seqlen

    def benchmark_large(self):
        self.out = self.kernel.causal_conv1d_fn(self.x, self.weight, self.bias)

    def verify_large(self) -> torch.Tensor:
        x_fp32 = self.x.to(self.weight.dtype)
        out = F.conv1d(
            x_fp32,
            self.weight.unsqueeze(1),
            self.bias,
            padding=self.width - 1,
            groups=self.dim,
        )
        return out[..., : self.seqlen].to(self.x.dtype)

    def setup_xlarge(self):
        batch_size, dim, seqlen, width = 16, 512, 1024, 4
        self.x = torch.randn(
            batch_size, dim, seqlen, device=self.device, dtype=torch.float16
        )
        self.weight = torch.randn(dim, width, device=self.device, dtype=torch.float32)
        self.bias = torch.randn(dim, device=self.device, dtype=torch.float32)
        self.out = torch.empty(
            batch_size, dim, seqlen, device=self.device, dtype=torch.float16
        )
        self.dim = dim
        self.width = width
        self.seqlen = seqlen

    def benchmark_xlarge(self):
        self.out = self.kernel.causal_conv1d_fn(self.x, self.weight, self.bias)

    def verify_xlarge(self) -> torch.Tensor:
        x_fp32 = self.x.to(self.weight.dtype)
        out = F.conv1d(
            x_fp32,
            self.weight.unsqueeze(1),
            self.bias,
            padding=self.width - 1,
            groups=self.dim,
        )
        return out[..., : self.seqlen].to(self.x.dtype)