File size: 3,152 Bytes
15f2bcb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 | import torch
import torch.nn.functional as F
from kernels.benchmark import Benchmark
class CausalConv1dBenchmark(Benchmark):
seed: int = 42
def setup(self):
batch_size, dim, seqlen, width = 2, 64, 128, 4
self.x = torch.randn(
batch_size, dim, seqlen, device=self.device, dtype=torch.float16
)
self.weight = torch.randn(dim, width, device=self.device, dtype=torch.float32)
self.bias = torch.randn(dim, device=self.device, dtype=torch.float32)
self.out = torch.empty(
batch_size, dim, seqlen, device=self.device, dtype=torch.float16
)
self.dim = dim
self.width = width
self.seqlen = seqlen
def benchmark_base(self):
self.out = self.kernel.causal_conv1d_fn(self.x, self.weight, self.bias)
def verify_base(self) -> torch.Tensor:
x_fp32 = self.x.to(self.weight.dtype)
out = F.conv1d(
x_fp32,
self.weight.unsqueeze(1),
self.bias,
padding=self.width - 1,
groups=self.dim,
)
return out[..., : self.seqlen].to(self.x.dtype)
def setup_large(self):
batch_size, dim, seqlen, width = 8, 256, 512, 4
self.x = torch.randn(
batch_size, dim, seqlen, device=self.device, dtype=torch.float16
)
self.weight = torch.randn(dim, width, device=self.device, dtype=torch.float32)
self.bias = torch.randn(dim, device=self.device, dtype=torch.float32)
self.out = torch.empty(
batch_size, dim, seqlen, device=self.device, dtype=torch.float16
)
self.dim = dim
self.width = width
self.seqlen = seqlen
def benchmark_large(self):
self.out = self.kernel.causal_conv1d_fn(self.x, self.weight, self.bias)
def verify_large(self) -> torch.Tensor:
x_fp32 = self.x.to(self.weight.dtype)
out = F.conv1d(
x_fp32,
self.weight.unsqueeze(1),
self.bias,
padding=self.width - 1,
groups=self.dim,
)
return out[..., : self.seqlen].to(self.x.dtype)
def setup_xlarge(self):
batch_size, dim, seqlen, width = 16, 512, 1024, 4
self.x = torch.randn(
batch_size, dim, seqlen, device=self.device, dtype=torch.float16
)
self.weight = torch.randn(dim, width, device=self.device, dtype=torch.float32)
self.bias = torch.randn(dim, device=self.device, dtype=torch.float32)
self.out = torch.empty(
batch_size, dim, seqlen, device=self.device, dtype=torch.float16
)
self.dim = dim
self.width = width
self.seqlen = seqlen
def benchmark_xlarge(self):
self.out = self.kernel.causal_conv1d_fn(self.x, self.weight, self.bias)
def verify_xlarge(self) -> torch.Tensor:
x_fp32 = self.x.to(self.weight.dtype)
out = F.conv1d(
x_fp32,
self.weight.unsqueeze(1),
self.bias,
padding=self.width - 1,
groups=self.dim,
)
return out[..., : self.seqlen].to(self.x.dtype)
|