Kernels
trimul-gpumode / benchmarks /benchmark.py
danieldk's picture
danieldk HF Staff
Benchmarks uploaded using `kernels`.
cd5a62f verified
import torch
from torch import nn, einsum
from kernels.benchmark import Benchmark
class TriMulReference(nn.Module):
"""Reference implementation of Triangle Multiplicative Module."""
def __init__(self, dim: int, hidden_dim: int):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.left_proj = nn.Linear(dim, hidden_dim, bias=False)
self.right_proj = nn.Linear(dim, hidden_dim, bias=False)
self.left_gate = nn.Linear(dim, hidden_dim, bias=False)
self.right_gate = nn.Linear(dim, hidden_dim, bias=False)
self.out_gate = nn.Linear(dim, hidden_dim, bias=False)
self.to_out_norm = nn.LayerNorm(hidden_dim)
self.to_out = nn.Linear(hidden_dim, dim, bias=False)
def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
x = self.norm(x)
left = self.left_proj(x)
right = self.right_proj(x)
mask = mask.unsqueeze(-1)
left = left * mask
right = right * mask
left_gate = self.left_gate(x).sigmoid()
right_gate = self.right_gate(x).sigmoid()
out_gate = self.out_gate(x).sigmoid()
left = left * left_gate
right = right * right_gate
out = einsum("... i k d, ... j k d -> ... i j d", left, right)
out = self.to_out_norm(out)
out = out * out_gate
return self.to_out(out)
class TrimulGpumodeBenchmark(Benchmark):
seed: int = 42
def setup(self):
# Note: hidden_dim must be 128 (kernel constraint)
batch_size = 1
seq_len = 128
dim = 128
hidden_dim = 128
self.config = {"dim": dim, "hidden_dim": hidden_dim}
self.input_tensor = torch.randn(
batch_size, seq_len, seq_len, dim, device="cuda", dtype=torch.float32
)
self.mask = torch.ones(
batch_size, seq_len, seq_len, device="cuda", dtype=torch.float32
)
self.weights = {
"norm.weight": torch.ones(dim, device="cuda", dtype=torch.float32),
"norm.bias": torch.zeros(dim, device="cuda", dtype=torch.float32),
"left_proj.weight": torch.randn(
hidden_dim, dim, device="cuda", dtype=torch.float32
)
* 0.02,
"right_proj.weight": torch.randn(
hidden_dim, dim, device="cuda", dtype=torch.float32
)
* 0.02,
"left_gate.weight": torch.randn(
hidden_dim, dim, device="cuda", dtype=torch.float32
)
* 0.02,
"right_gate.weight": torch.randn(
hidden_dim, dim, device="cuda", dtype=torch.float32
)
* 0.02,
"out_gate.weight": torch.randn(
hidden_dim, dim, device="cuda", dtype=torch.float32
)
* 0.02,
"to_out_norm.weight": torch.ones(
hidden_dim, device="cuda", dtype=torch.float32
),
"to_out_norm.bias": torch.zeros(
hidden_dim, device="cuda", dtype=torch.float32
),
"to_out.weight": torch.randn(
dim, hidden_dim, device="cuda", dtype=torch.float32
)
* 0.02,
}
self.out = torch.empty(
batch_size, seq_len, seq_len, dim, device="cuda", dtype=torch.float32
)
def benchmark_base(self):
data = (self.input_tensor, self.mask, self.weights, self.config)
self.out = self.kernel.kernel_global(data)
def verify_base(self) -> torch.Tensor:
ref = TriMulReference(
dim=self.config["dim"], hidden_dim=self.config["hidden_dim"]
).cuda()
ref.norm.weight = nn.Parameter(self.weights["norm.weight"])
ref.norm.bias = nn.Parameter(self.weights["norm.bias"])
ref.left_proj.weight = nn.Parameter(self.weights["left_proj.weight"])
ref.right_proj.weight = nn.Parameter(self.weights["right_proj.weight"])
ref.left_gate.weight = nn.Parameter(self.weights["left_gate.weight"])
ref.right_gate.weight = nn.Parameter(self.weights["right_gate.weight"])
ref.out_gate.weight = nn.Parameter(self.weights["out_gate.weight"])
ref.to_out_norm.weight = nn.Parameter(self.weights["to_out_norm.weight"])
ref.to_out_norm.bias = nn.Parameter(self.weights["to_out_norm.bias"])
ref.to_out.weight = nn.Parameter(self.weights["to_out.weight"])
with torch.no_grad():
return ref(self.input_tensor, self.mask)