| import torch |
| from torch import nn, einsum |
|
|
| from kernels.benchmark import Benchmark |
|
|
|
|
| class TriMulReference(nn.Module): |
| """Reference implementation of Triangle Multiplicative Module.""" |
|
|
| def __init__(self, dim: int, hidden_dim: int): |
| super().__init__() |
| self.norm = nn.LayerNorm(dim) |
| self.left_proj = nn.Linear(dim, hidden_dim, bias=False) |
| self.right_proj = nn.Linear(dim, hidden_dim, bias=False) |
| self.left_gate = nn.Linear(dim, hidden_dim, bias=False) |
| self.right_gate = nn.Linear(dim, hidden_dim, bias=False) |
| self.out_gate = nn.Linear(dim, hidden_dim, bias=False) |
| self.to_out_norm = nn.LayerNorm(hidden_dim) |
| self.to_out = nn.Linear(hidden_dim, dim, bias=False) |
|
|
| def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: |
| x = self.norm(x) |
|
|
| left = self.left_proj(x) |
| right = self.right_proj(x) |
|
|
| mask = mask.unsqueeze(-1) |
| left = left * mask |
| right = right * mask |
|
|
| left_gate = self.left_gate(x).sigmoid() |
| right_gate = self.right_gate(x).sigmoid() |
| out_gate = self.out_gate(x).sigmoid() |
|
|
| left = left * left_gate |
| right = right * right_gate |
|
|
| out = einsum("... i k d, ... j k d -> ... i j d", left, right) |
|
|
| out = self.to_out_norm(out) |
| out = out * out_gate |
| return self.to_out(out) |
|
|
|
|
| class TrimulGpumodeBenchmark(Benchmark): |
| seed: int = 42 |
|
|
| def setup(self): |
| |
| batch_size = 1 |
| seq_len = 128 |
| dim = 128 |
| hidden_dim = 128 |
|
|
| self.config = {"dim": dim, "hidden_dim": hidden_dim} |
|
|
| self.input_tensor = torch.randn( |
| batch_size, seq_len, seq_len, dim, device="cuda", dtype=torch.float32 |
| ) |
| self.mask = torch.ones( |
| batch_size, seq_len, seq_len, device="cuda", dtype=torch.float32 |
| ) |
|
|
| self.weights = { |
| "norm.weight": torch.ones(dim, device="cuda", dtype=torch.float32), |
| "norm.bias": torch.zeros(dim, device="cuda", dtype=torch.float32), |
| "left_proj.weight": torch.randn( |
| hidden_dim, dim, device="cuda", dtype=torch.float32 |
| ) |
| * 0.02, |
| "right_proj.weight": torch.randn( |
| hidden_dim, dim, device="cuda", dtype=torch.float32 |
| ) |
| * 0.02, |
| "left_gate.weight": torch.randn( |
| hidden_dim, dim, device="cuda", dtype=torch.float32 |
| ) |
| * 0.02, |
| "right_gate.weight": torch.randn( |
| hidden_dim, dim, device="cuda", dtype=torch.float32 |
| ) |
| * 0.02, |
| "out_gate.weight": torch.randn( |
| hidden_dim, dim, device="cuda", dtype=torch.float32 |
| ) |
| * 0.02, |
| "to_out_norm.weight": torch.ones( |
| hidden_dim, device="cuda", dtype=torch.float32 |
| ), |
| "to_out_norm.bias": torch.zeros( |
| hidden_dim, device="cuda", dtype=torch.float32 |
| ), |
| "to_out.weight": torch.randn( |
| dim, hidden_dim, device="cuda", dtype=torch.float32 |
| ) |
| * 0.02, |
| } |
|
|
| self.out = torch.empty( |
| batch_size, seq_len, seq_len, dim, device="cuda", dtype=torch.float32 |
| ) |
|
|
| def benchmark_base(self): |
| data = (self.input_tensor, self.mask, self.weights, self.config) |
| self.out = self.kernel.kernel_global(data) |
|
|
| def verify_base(self) -> torch.Tensor: |
| ref = TriMulReference( |
| dim=self.config["dim"], hidden_dim=self.config["hidden_dim"] |
| ).cuda() |
|
|
| ref.norm.weight = nn.Parameter(self.weights["norm.weight"]) |
| ref.norm.bias = nn.Parameter(self.weights["norm.bias"]) |
| ref.left_proj.weight = nn.Parameter(self.weights["left_proj.weight"]) |
| ref.right_proj.weight = nn.Parameter(self.weights["right_proj.weight"]) |
| ref.left_gate.weight = nn.Parameter(self.weights["left_gate.weight"]) |
| ref.right_gate.weight = nn.Parameter(self.weights["right_gate.weight"]) |
| ref.out_gate.weight = nn.Parameter(self.weights["out_gate.weight"]) |
| ref.to_out_norm.weight = nn.Parameter(self.weights["to_out_norm.weight"]) |
| ref.to_out_norm.bias = nn.Parameter(self.weights["to_out_norm.bias"]) |
| ref.to_out.weight = nn.Parameter(self.weights["to_out.weight"]) |
|
|
| with torch.no_grad(): |
| return ref(self.input_tensor, self.mask) |
|
|