| |
| """ |
| Initial TriMul submission — PyTorch baseline with dummy Triton kernel. |
| """ |
|
|
| import torch |
| from torch import nn, einsum |
| import triton |
| import triton.language as tl |
|
|
|
|
| @triton.jit |
| def _dummy_kernel(x_ptr, BLOCK_SIZE: tl.constexpr): |
| pid = tl.program_id(0) |
| pass |
|
|
|
|
| class TriMul(nn.Module): |
| def __init__( |
| self, |
| dim: int, |
| hidden_dim: int, |
| ): |
| super().__init__() |
|
|
| self.norm = nn.LayerNorm(dim) |
|
|
| self.left_proj = nn.Linear(dim, hidden_dim, bias=False, dtype=torch.float32) |
| self.right_proj = nn.Linear(dim, hidden_dim, bias=False, dtype=torch.float32) |
|
|
| self.left_gate = nn.Linear(dim, hidden_dim, bias=False, dtype=torch.float32) |
| self.right_gate = nn.Linear(dim, hidden_dim, bias=False, dtype=torch.float32) |
| self.out_gate = nn.Linear(dim, hidden_dim, bias=False, dtype=torch.float32) |
|
|
| self.to_out_norm = nn.LayerNorm(hidden_dim) |
| self.to_out = nn.Linear(hidden_dim, dim, bias=False, dtype=torch.float32) |
|
|
| def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: |
| batch_size, seq_len, _, dim = x.shape |
|
|
| x = self.norm(x) |
| x = x.to(torch.float32) |
|
|
| left = self.left_proj(x.to(torch.float32)) |
| right = self.right_proj(x.to(torch.float32)) |
|
|
| mask = mask.unsqueeze(-1) |
| left = left * mask |
| right = right * mask |
|
|
| left_gate = self.left_gate(x.to(torch.float32)).sigmoid() |
| right_gate = self.right_gate(x.to(torch.float32)).sigmoid() |
| out_gate = self.out_gate(x.to(torch.float32)).sigmoid() |
|
|
| left = left * left_gate |
| right = right * right_gate |
|
|
| out = einsum('... i k d, ... j k d -> ... i j d', left.to(torch.bfloat16), right.to(torch.bfloat16)) |
|
|
| out = out.to(torch.float32) |
| out = self.to_out_norm(out) |
| out = out * out_gate |
| return self.to_out(out) |
|
|
|
|
| def custom_kernel(data): |
| input_tensor, mask, weights, config = data |
| trimul = TriMul(config["dim"], config["hidden_dim"]).to(input_tensor.device) |
|
|
| trimul.norm.weight = nn.Parameter(weights['norm.weight'].to(torch.float32)) |
| trimul.left_proj.weight = nn.Parameter(weights['left_proj.weight'].to(torch.float32)) |
| trimul.right_proj.weight = nn.Parameter(weights['right_proj.weight'].to(torch.float32)) |
| trimul.left_gate.weight = nn.Parameter(weights['left_gate.weight'].to(torch.float32)) |
| trimul.right_gate.weight = nn.Parameter(weights['right_gate.weight'].to(torch.float32)) |
| trimul.out_gate.weight = nn.Parameter(weights['out_gate.weight'].to(torch.float32)) |
| trimul.to_out_norm.weight = nn.Parameter(weights['to_out_norm.weight'].to(torch.float32)) |
| trimul.to_out.weight = nn.Parameter(weights['to_out.weight'].to(torch.float32)) |
| trimul.norm.bias = nn.Parameter(weights['norm.bias'].to(torch.float32)) |
| trimul.to_out_norm.bias = nn.Parameter(weights['to_out_norm.bias'].to(torch.float32)) |
|
|
| output = trimul(input_tensor, mask).to(torch.float32) |
|
|
| return output |
| |
|
|