| |
| from .task import input_t, output_t |
|
|
| import torch |
| from torch import nn, einsum |
| import math |
| import os |
| import requests |
|
|
| import triton |
| import triton.language as tl |
|
|
| |
| |
| torch.backends.cuda.matmul.allow_tf32 = True |
|
|
| |
| torch.backends.cudnn.allow_tf32 = True |
|
|
| |
| def alloc_fn(size: int, alignment: int, stream=None): |
| return torch.empty(size, device="cuda", dtype=torch.int8) |
|
|
| triton.set_allocator(alloc_fn) |
|
|
| |
| |
|
|
| |
| class TriMul(nn.Module): |
| |
| def __init__( |
| self, |
| dim: int, |
| hidden_dim: int, |
| ): |
| super().__init__() |
|
|
| self.norm = nn.LayerNorm(dim) |
|
|
| self.left_proj = nn.Linear(dim, hidden_dim, bias=False) |
| self.right_proj = nn.Linear(dim, hidden_dim, bias=False) |
|
|
| self.left_gate = nn.Linear(dim, hidden_dim, bias=False) |
| self.right_gate = nn.Linear(dim, hidden_dim, bias=False) |
| self.out_gate = nn.Linear(dim, hidden_dim, bias=False) |
|
|
| self.to_out_norm = nn.LayerNorm(hidden_dim) |
| self.to_out = nn.Linear(hidden_dim, dim, bias=False) |
|
|
| def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: |
| """ |
| x: [bs, seq_len, seq_len, dim] |
| mask: [bs, seq_len, seq_len] |
| |
| Returns: |
| output: [bs, seq_len, seq_len, dim] |
| """ |
| batch_size, seq_len, _, dim = x.shape |
|
|
| x = self.norm(x) |
|
|
| left = self.left_proj(x) |
| right = self.right_proj(x) |
|
|
| mask = mask.unsqueeze(-1) |
| left = left * mask |
| right = right * mask |
|
|
| left_gate = self.left_gate(x).sigmoid() |
| right_gate = self.right_gate(x).sigmoid() |
| out_gate = self.out_gate(x).sigmoid() |
|
|
| left = left * left_gate |
| right = right * right_gate |
|
|
| out = einsum('... i k d, ... j k d -> ... i j d', left, right) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| out = self.to_out_norm(out) |
| out = out * out_gate |
| return self.to_out(out) |
|
|
| @triton.jit |
| def triton_sigmoid(x): |
| """ |
| Compute sigmoid function: 1 / (1 + exp(-x)) |
| """ |
| return 1.0 / (1.0 + tl.exp(-x)) |
|
|
| def two_mm_kernel_configs_wrapper(): |
| if torch.cuda.get_device_capability() == (12, 0): |
| def two_mm_kernel_configs(): |
| configs = [] |
| for BLOCK_M in [16, 32]: |
| for BLOCK_N in [16, 32, 64]: |
| for BLOCK_K in [16, 32, 64]: |
| for num_stages in [2, 3]: |
| configs.append(triton.Config({ |
| 'BLOCK_M': BLOCK_M, |
| 'BLOCK_N': BLOCK_N, |
| 'BLOCK_K': BLOCK_K, |
| 'GROUP_SIZE_M': 8 |
| }, num_stages=num_stages, num_warps=8)) |
| return configs |
| |
| elif torch.cuda.get_device_capability()[0] == 9: |
| def get_optimal_two_mm_config_h100(B, seq_len, dim): |
| configs = { |
| (1, 128, 128): (128, 64, 128, 2, 8), |
| (1, 128, 256): (128, 64, 128, 2, 8), |
| (1, 128, 384): (128, 64, 64, 3, 8), |
| (1, 128, 512): (128, 64, 64, 3, 8), |
| (1, 128, 768): (128, 64, 64, 3, 8), |
| (1, 128, 1024): (128, 64, 64, 3, 8), |
| (1, 256, 128): (128, 64, 128, 2, 8), |
| (1, 256, 256): (128, 64, 128, 2, 8), |
| (1, 256, 384): (128, 64, 64, 3, 8), |
| (1, 256, 512): (128, 64, 64, 3, 8), |
| (1, 256, 768): (128, 64, 64, 3, 8), |
| (1, 256, 1024): (128, 64, 64, 3, 8), |
| (1, 512, 128): (128, 64, 128, 2, 8), |
| (1, 512, 256): (128, 64, 128, 2, 8), |
| (1, 512, 384): (128, 64, 128, 2, 8), |
| (1, 512, 512): (128, 64, 128, 2, 8), |
| (1, 512, 768): (128, 64, 64, 3, 8), |
| (1, 512, 1024): (128, 64, 64, 3, 8), |
| (1, 1024, 128): (128, 64, 128, 2, 8), |
| (1, 1024, 256): (128, 64, 64, 2, 8), |
| (1, 1024, 384): (128, 64, 128, 2, 8), |
| (1, 1024, 512): (128, 64, 128, 2, 8), |
| (1, 1024, 768): (128, 64, 128, 2, 8), |
| (1, 1024, 1024): (128, 64, 128, 2, 8), |
| (2, 128, 128): (128, 64, 128, 2, 8), |
| (2, 128, 256): (128, 64, 128, 2, 8), |
| (2, 128, 384): (128, 64, 64, 3, 8), |
| (2, 128, 512): (128, 64, 64, 3, 8), |
| (2, 128, 768): (128, 64, 64, 3, 8), |
| (2, 128, 1024): (128, 64, 64, 3, 8), |
| (2, 256, 128): (128, 64, 128, 2, 8), |
| (2, 256, 256): (128, 64, 128, 2, 8), |
| (2, 256, 384): (128, 64, 128, 2, 8), |
| (2, 256, 512): (128, 64, 128, 2, 8), |
| (2, 256, 768): (128, 64, 64, 3, 8), |
| (2, 256, 1024): (128, 64, 64, 3, 8), |
| (2, 512, 128): (128, 64, 128, 2, 8), |
| (2, 512, 256): (128, 64, 128, 2, 8), |
| (2, 512, 384): (128, 64, 128, 2, 8), |
| (2, 512, 512): (128, 64, 128, 2, 8), |
| (2, 512, 768): (128, 64, 128, 2, 8), |
| (2, 512, 1024): (128, 64, 128, 2, 8), |
| (2, 1024, 128): (128, 64, 128, 2, 8), |
| (2, 1024, 256): (128, 64, 128, 2, 8), |
| (2, 1024, 384): (128, 64, 128, 2, 8), |
| (2, 1024, 512): (128, 64, 128, 2, 8), |
| (2, 1024, 768): (128, 64, 128, 2, 8), |
| (2, 1024, 1024): (128, 64, 128, 2, 8), |
| } |
| return configs.get((B, seq_len, dim), (64, 64, 32, 2, 8)) |
|
|
| def two_mm_kernel_configs(): |
| |
| return [ |
| triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8), |
| triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), |
| triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8), |
| ] |
| |
| elif torch.cuda.get_device_capability()[0] == 10 and False: |
| def get_optimal_two_mm_config(B, seq_len, dim): |
| configs = { |
| (1, 128, 128): (64, 128, 64, 2, 8), |
| (1, 128, 256): (128, 64, 128, 2, 8), |
| (1, 128, 384): (128, 64, 128, 2, 8), |
| (1, 128, 512): (128, 64, 128, 2, 8), |
| (1, 128, 768): (128, 64, 64, 3, 8), |
| (1, 128, 1024): (128, 64, 64, 3, 8), |
| (1, 256, 128): (128, 64, 128, 2, 8), |
| (1, 256, 256): (128, 64, 128, 2, 8), |
| (1, 256, 384): (128, 64, 128, 2, 8), |
| (1, 256, 512): (128, 64, 64, 3, 8), |
| (1, 256, 768): (128, 64, 64, 3, 8), |
| (1, 256, 1024): (128, 64, 64, 3, 8), |
| (1, 512, 128): (128, 64, 128, 2, 8), |
| (1, 512, 256): (128, 64, 128, 2, 8), |
| (1, 512, 384): (128, 64, 128, 2, 8), |
| (1, 512, 512): (128, 64, 128, 2, 8), |
| (1, 512, 768): (128, 64, 64, 3, 8), |
| (1, 512, 1024): (128, 64, 64, 3, 8), |
| (1, 1024, 128): (128, 64, 128, 2, 8), |
| (1, 1024, 256): (128, 64, 128, 2, 8), |
| (1, 1024, 384): (128, 64, 128, 2, 8), |
| (1, 1024, 512): (128, 64, 128, 2, 8), |
| (1, 1024, 768): (128, 64, 64, 3, 8), |
| (1, 1024, 1024): (128, 64, 64, 3, 8), |
| (2, 128, 128): (128, 64, 128, 2, 8), |
| (2, 128, 256): (128, 64, 128, 2, 8), |
| (2, 128, 384): (128, 64, 128, 2, 8), |
| (2, 128, 512): (128, 64, 64, 3, 8), |
| (2, 128, 768): (128, 64, 64, 3, 8), |
| (2, 128, 1024): (128, 64, 64, 3, 8), |
| (2, 256, 128): (128, 64, 128, 2, 8), |
| (2, 256, 256): (128, 64, 128, 2, 8), |
| (2, 256, 384): (128, 64, 128, 2, 8), |
| (2, 256, 512): (128, 64, 64, 3, 8), |
| (2, 256, 768): (128, 64, 64, 3, 8), |
| (2, 256, 1024): (128, 64, 64, 3, 8), |
| (2, 512, 128): (128, 64, 128, 2, 8), |
| (2, 512, 256): (128, 64, 128, 2, 8), |
| (2, 512, 384): (128, 64, 128, 2, 8), |
| (2, 512, 512): (128, 64, 128, 2, 8), |
| (2, 512, 768): (128, 64, 64, 3, 8), |
| (2, 512, 1024): (128, 64, 64, 3, 8), |
| (2, 1024, 128): (128, 64, 128, 2, 8), |
| (2, 1024, 256): (128, 64, 128, 2, 8), |
| (2, 1024, 384): (128, 64, 128, 2, 8), |
| (2, 1024, 512): (128, 64, 128, 2, 8), |
| (2, 1024, 768): (128, 64, 64, 3, 8), |
| (2, 1024, 1024): (128, 64, 64, 3, 8), |
| } |
| return configs.get((B, seq_len, dim), (64, 64, 32, 2, 8)) |
|
|
| def two_mm_kernel_configs(): |
| |
| return [ |
| triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8), |
| triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8), |
| triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), |
| ] |
| elif torch.cuda.get_device_capability()[0] == 8: |
| |
| def two_mm_kernel_configs(): |
| configs = [] |
| for BLOCK_M in [64]: |
| for BLOCK_N in [64, 128]: |
| for BLOCK_K in [16]: |
| for num_stages in [3, 4]: |
| for num_warps in [4, 8]: |
| configs.append(triton.Config({ |
| 'BLOCK_M': BLOCK_M, |
| 'BLOCK_N': BLOCK_N, |
| 'BLOCK_K': BLOCK_K, |
| 'GROUP_SIZE_M': 8 |
| }, num_stages=num_stages, num_warps=num_warps)) |
| return configs |
| else: |
| def two_mm_kernel_configs(): |
| configs = [] |
| for BLOCK_M in [64, 128]: |
| for BLOCK_N in [64, 128]: |
| for BLOCK_K in [64, 128]: |
| for num_stages in [2, 3]: |
| configs.append(triton.Config({ |
| 'BLOCK_M': BLOCK_M, |
| 'BLOCK_N': BLOCK_N, |
| 'BLOCK_K': BLOCK_K, |
| 'GROUP_SIZE_M': 8 |
| }, num_stages=num_stages, num_warps=8)) |
| return configs |
|
|
| return two_mm_kernel_configs |
|
|
| def two_mm_kernel_wrapper(): |
| if torch.cuda.get_device_capability()[0] == 8: |
| @triton.jit |
| def two_mm_kernel(a_ptr, b1_ptr, b2_ptr, b3_ptr, b4_ptr, b5_ptr, c1_ptr, c2_ptr, d_ptr, mask_ptr, M, N, K, stride_a0, stride_a1, stride_a2, stride_a3, stride_bk, stride_bn, stride_c0, stride_c1, stride_c2, stride_c3, seq_len, stride_d0, stride_d1, stride_d2, stride_d3, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, NUM_SMS: tl.constexpr): |
| |
| start_pid = tl.program_id(axis=0) |
| num_pid_m = tl.cdiv(M, BLOCK_M) |
| num_pid_n = tl.cdiv(N, BLOCK_N) |
| k_tiles = tl.cdiv(K, BLOCK_K) |
| num_tiles = num_pid_m * num_pid_n |
|
|
| |
| |
| tile_id_c = start_pid - NUM_SMS |
| num_pid_in_group = GROUP_SIZE_M * num_pid_n |
|
|
| |
| for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=False): |
| |
| group_id = tile_id // num_pid_in_group |
| first_pid_m = group_id * GROUP_SIZE_M |
| group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) |
| pid_m = first_pid_m + (tile_id % group_size_m) |
| pid_n = (tile_id % num_pid_in_group) // group_size_m |
|
|
| |
| offs_am = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) |
| offs_bn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) |
| offs_k = tl.arange(0, BLOCK_K) |
|
|
| |
| accumulator1 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) |
| accumulator2 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) |
| accumulator3 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) |
| accumulator4 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) |
| accumulator_d = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) |
|
|
| |
| for ki in range(k_tiles): |
| k_start = ki * BLOCK_K |
| k_offsets = k_start + offs_k |
|
|
| |
| a_ptrs = a_ptr + offs_am[:, None] * stride_a2 + k_offsets[None, :] * stride_a3 |
| a_mask = (offs_am[:, None] < M) & (k_offsets[None, :] < K) |
|
|
| |
| b1_ptrs = b1_ptr + offs_bn[:, None] * stride_bn + k_offsets[None, :] * stride_bk |
| b2_ptrs = b2_ptr + offs_bn[:, None] * stride_bn + k_offsets[None, :] * stride_bk |
| b3_ptrs = b3_ptr + offs_bn[:, None] * stride_bn + k_offsets[None, :] * stride_bk |
| b4_ptrs = b4_ptr + offs_bn[:, None] * stride_bn + k_offsets[None, :] * stride_bk |
| b5_ptrs = b5_ptr + offs_bn[:, None] * stride_bn + k_offsets[None, :] * stride_bk |
| b_mask = (offs_bn[:, None] < N) & (k_offsets[None, :] < K) |
|
|
| |
| a = tl.load(a_ptrs, mask=a_mask, other=0.0) |
| b1 = tl.load(b1_ptrs, mask=b_mask, other=0.0) |
| b2 = tl.load(b2_ptrs, mask=b_mask, other=0.0) |
| b3 = tl.load(b3_ptrs, mask=b_mask, other=0.0) |
| b4 = tl.load(b4_ptrs, mask=b_mask, other=0.0) |
| b5 = tl.load(b5_ptrs, mask=b_mask, other=0.0) |
|
|
| |
| accumulator1 = tl.dot(a, b1.T, accumulator1, allow_tf32=True) |
| accumulator2 = tl.dot(a, b2.T, accumulator2, allow_tf32=True) |
| accumulator3 = tl.dot(a, b3.T, accumulator3, allow_tf32=True) |
| accumulator4 = tl.dot(a, b4.T, accumulator4, allow_tf32=True) |
| accumulator_d = tl.dot(a, b5.T, accumulator_d, allow_tf32=True) |
|
|
| |
| tile_id_c += NUM_SMS |
| group_id = tile_id_c // num_pid_in_group |
| first_pid_m = group_id * GROUP_SIZE_M |
| group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) |
| pid_m = first_pid_m + (tile_id_c % group_size_m) |
| pid_n = (tile_id_c % num_pid_in_group) // group_size_m |
|
|
| |
| offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) |
| offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) |
|
|
| |
| d_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) |
|
|
| |
| stride_cm = stride_c2 |
| stride_cn = stride_c3 |
|
|
| |
| stride_dm = stride_d2 |
| stride_dn = stride_d3 |
|
|
| off_c_batch = offs_cm // (seq_len * seq_len) |
| off_c_sl1 = (offs_cm // seq_len) % seq_len |
| off_c_sl2 = offs_cm % seq_len |
| off_c_dim = offs_cn |
|
|
| c_offsets = (off_c_batch * stride_c0 + off_c_sl1 * stride_c1 + off_c_sl2 * stride_c2)[:, None] + off_c_dim[None, :] * stride_c3 |
| c_mask = d_mask |
|
|
| c1_ptrs = c1_ptr + c_offsets |
| c2_ptrs = c2_ptr + c_offsets |
| d_ptrs = d_ptr + stride_dm * offs_cm[:, None] + stride_dn * offs_cn[None, :] |
|
|
| mask = tl.load(mask_ptr + offs_cm, mask=(offs_cm < M)) |
|
|
| |
| mask_2d = mask[:, None] |
| |
| accumulator1 = tl.where(mask_2d, accumulator1, 0) |
| accumulator2 = tl.where(mask_2d, accumulator2, 0) |
|
|
| |
| left_gate_sigmoid = triton_sigmoid(accumulator3) |
| right_gate_sigmoid = triton_sigmoid(accumulator4) |
| accumulator_d = triton_sigmoid(accumulator_d) |
|
|
| |
| |
| accumulator1 = accumulator1 * left_gate_sigmoid |
| accumulator2 = accumulator2 * right_gate_sigmoid |
|
|
| |
| c1 = accumulator1.to(c1_ptr.dtype.element_ty) |
| c2 = accumulator2.to(c2_ptr.dtype.element_ty) |
| d = accumulator_d.to(d_ptr.dtype.element_ty) |
|
|
| tl.store(c1_ptrs, c1, mask=c_mask) |
| tl.store(c2_ptrs, c2, mask=c_mask) |
| tl.store(d_ptrs, d, mask=d_mask) |
| else: |
| @triton.jit |
| def two_mm_kernel(a_ptr, b1_ptr, b2_ptr, b3_ptr, b4_ptr, b5_ptr, c1_ptr, c2_ptr, d_ptr, mask_ptr, M, N, K, stride_a0, stride_a1, stride_a2, stride_a3, stride_bk, stride_bn, stride_c0, stride_c1, stride_c2, stride_c3, seq_len, stride_d0, stride_d1, stride_d2, stride_d3, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, NUM_SMS: tl.constexpr): |
| |
| start_pid = tl.program_id(axis=0) |
| num_pid_m = tl.cdiv(M, BLOCK_M) |
| num_pid_n = tl.cdiv(N, BLOCK_N) |
| k_tiles = tl.cdiv(K, BLOCK_K) |
| num_tiles = num_pid_m * num_pid_n |
|
|
| |
| a_desc = tl._experimental_make_tensor_descriptor( |
| a_ptr, |
| shape=[M, K], |
| strides=[stride_a2, stride_a3], |
| block_shape=[BLOCK_M, BLOCK_K], |
| ) |
| b1_desc = tl._experimental_make_tensor_descriptor( |
| b1_ptr, |
| shape=[N, K], |
| strides=[stride_bn, stride_bk], |
| block_shape=[BLOCK_N, BLOCK_K], |
| ) |
| b2_desc = tl._experimental_make_tensor_descriptor( |
| b2_ptr, |
| shape=[N, K], |
| strides=[stride_bn, stride_bk], |
| block_shape=[BLOCK_N, BLOCK_K], |
| ) |
| b3_desc = tl._experimental_make_tensor_descriptor( |
| b3_ptr, |
| shape=[N, K], |
| strides=[stride_bn, stride_bk], |
| block_shape=[BLOCK_N, BLOCK_K], |
| ) |
| b4_desc = tl._experimental_make_tensor_descriptor( |
| b4_ptr, |
| shape=[N, K], |
| strides=[stride_bn, stride_bk], |
| block_shape=[BLOCK_N, BLOCK_K], |
| ) |
| b5_desc = tl._experimental_make_tensor_descriptor( |
| b5_ptr, |
| shape=[N, K], |
| strides=[stride_bn, stride_bk], |
| block_shape=[BLOCK_N, BLOCK_K], |
| ) |
|
|
| |
| |
| tile_id_c = start_pid - NUM_SMS |
| num_pid_in_group = GROUP_SIZE_M * num_pid_n |
|
|
| |
| for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=False): |
| |
| group_id = tile_id // num_pid_in_group |
| first_pid_m = group_id * GROUP_SIZE_M |
| group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) |
| pid_m = first_pid_m + (tile_id % group_size_m) |
| pid_n = (tile_id % num_pid_in_group) // group_size_m |
|
|
| |
| offs_am = pid_m * BLOCK_M |
| offs_bn = pid_n * BLOCK_N |
|
|
| |
| accumulator1 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) |
| accumulator2 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) |
| accumulator3 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) |
| accumulator4 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) |
| accumulator_d = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) |
|
|
| |
| for ki in range(k_tiles): |
| offs_k = ki * BLOCK_K |
| |
| a = a_desc.load([offs_am, offs_k]) |
| b1 = b1_desc.load([offs_bn, offs_k]) |
| b2 = b2_desc.load([offs_bn, offs_k]) |
| b3 = b3_desc.load([offs_bn, offs_k]) |
| b4 = b4_desc.load([offs_bn, offs_k]) |
| b5 = b5_desc.load([offs_bn, offs_k]) |
|
|
| |
| accumulator1 = tl.dot(a, b1.T, accumulator1, allow_tf32=True) |
| accumulator2 = tl.dot(a, b2.T, accumulator2, allow_tf32=True) |
| accumulator3 = tl.dot(a, b3.T, accumulator3, allow_tf32=True) |
| accumulator4 = tl.dot(a, b4.T, accumulator4, allow_tf32=True) |
| accumulator_d = tl.dot(a, b5.T, accumulator_d, allow_tf32=True) |
|
|
| |
| tile_id_c += NUM_SMS |
| group_id = tile_id_c // num_pid_in_group |
| first_pid_m = group_id * GROUP_SIZE_M |
| group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) |
| pid_m = first_pid_m + (tile_id_c % group_size_m) |
| pid_n = (tile_id_c % num_pid_in_group) // group_size_m |
|
|
| |
| offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) |
| offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) |
|
|
| |
| d_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) |
|
|
| |
| |
| |
| stride_cm = stride_c2 |
| stride_cn = stride_c3 |
|
|
| |
| stride_dm = stride_d2 |
| stride_dn = stride_d3 |
|
|
| off_c_batch = offs_cm // (seq_len * seq_len) |
| off_c_sl1 = (offs_cm // seq_len) % seq_len |
| off_c_sl2 = offs_cm % seq_len |
| off_c_dim = offs_cn |
|
|
| |
| c_offsets = (off_c_batch * stride_c0 + off_c_sl1 * stride_c1 + off_c_sl2 * stride_c2)[:, None] + off_c_dim[None, :] * stride_c3 |
| |
| c_mask = d_mask |
|
|
| c1_ptrs = c1_ptr + c_offsets |
| c2_ptrs = c2_ptr + c_offsets |
| |
| |
| d_ptrs = d_ptr + stride_dm * offs_cm[:, None] + stride_dn * offs_cn[None, :] |
|
|
| mask = tl.load(mask_ptr + offs_cm, mask=(offs_cm < M)) |
|
|
| |
| mask_2d = mask[:, None] |
| |
| accumulator1 = tl.where(mask_2d, accumulator1, 0) |
| accumulator2 = tl.where(mask_2d, accumulator2, 0) |
|
|
| |
| left_gate_sigmoid = triton_sigmoid(accumulator3) |
| right_gate_sigmoid = triton_sigmoid(accumulator4) |
| accumulator_d = triton_sigmoid(accumulator_d) |
|
|
| |
| |
| accumulator1 = accumulator1 * left_gate_sigmoid |
| accumulator2 = accumulator2 * right_gate_sigmoid |
|
|
| |
| c1 = accumulator1.to(c1_ptr.dtype.element_ty) |
| c2 = accumulator2.to(c2_ptr.dtype.element_ty) |
| d = accumulator_d.to(d_ptr.dtype.element_ty) |
|
|
| tl.store(c1_ptrs, c1, mask=c_mask) |
| tl.store(c2_ptrs, c2, mask=c_mask) |
| tl.store(d_ptrs, d, mask=d_mask) |
|
|
|
|
| if torch.cuda.get_device_capability()[0] not in [9, 10.2]: |
| two_mm_kernel = triton.autotune( |
| (two_mm_kernel_configs_wrapper())(), key=["M", "N", "K"] |
| )(two_mm_kernel) |
|
|
| return two_mm_kernel |
|
|
|
|
| def two_mm(A, left_proj, right_proj, left_gate, right_gate, out_gate, mask): |
| """ |
| Persistent matrix multiplication for all weight matrices using on-device TMA descriptors. |
| |
| Args: |
| A: [..., K] tensor (arbitrary leading dimensions) |
| left_proj: [N, K] matrix (will be transposed) |
| right_proj: [N, K] matrix (will be transposed) |
| left_gate: [N, K] left gate weight matrix |
| right_gate: [N, K] right gate weight matrix |
| out_gate: [N, K] output gate weight matrix |
| mask: mask tensor |
| |
| Returns: |
| (C1, C2, D): Tuple of result tensors [..., N] with same leading dims as A |
| C1 = (A @ left_proj.T) * sigmoid(A @ left_gate.T) (masked) |
| C2 = (A @ right_proj.T) * sigmoid(A @ right_gate.T) (masked) |
| D = sigmoid(A @ out_gate.T) (unmasked) |
| """ |
| |
| assert A.shape[-1] == left_proj.shape[1] == right_proj.shape[1], "Incompatible K dimensions" |
| assert A.dtype == left_proj.dtype == right_proj.dtype, "Incompatible dtypes" |
|
|
| |
| assert left_proj.stride() == right_proj.stride() == left_gate.stride() == right_gate.stride() == out_gate.stride(), \ |
| "All weight matrices must have identical strides" |
|
|
| |
| original_shape = A.shape[:-1] |
| K = A.shape[-1] |
| N = left_proj.shape[0] |
| B, seq_len, _, _ = A.shape |
| dtype = A.dtype |
|
|
| |
| A_2d = A.view(-1, K) |
| M = A_2d.shape[0] |
|
|
| |
| NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count |
|
|
| |
| grid = lambda META: (min(NUM_SMS, triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"])),) |
|
|
| |
| A_strides = A.stride() |
|
|
| |
| output_shape = original_shape + (N,) |
| |
| |
| C1 = torch.empty(B, N, seq_len, seq_len, device=A.device, dtype=torch.float16).permute(0, 2, 3, 1) |
| C2 = torch.empty(B, N, seq_len, seq_len, device=A.device, dtype=torch.float16).permute(0, 2, 3, 1) |
| D = torch.empty(output_shape, device=A.device, dtype=torch.float16) |
|
|
| C_strides = C1.stride() |
| D_strides = D.stride() |
|
|
| |
| if torch.cuda.get_device_capability()[0] == 10: |
| |
| BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps = (two_mm_kernel_configs_wrapper())(B, seq_len, K) |
| grid_size = min(NUM_SMS, triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N)) |
|
|
| two_mm_kernel_wrapper()[(grid_size,)]( |
| A_2d, left_proj, right_proj, left_gate, right_gate, out_gate, |
| C1, C2, D, mask, |
| M, N, K, |
| *A_strides, |
| left_proj.stride(1), left_proj.stride(0), |
| *C_strides, |
| seq_len, |
| *D_strides, |
| BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, NUM_SMS=NUM_SMS, |
| num_stages=num_stages, num_warps=num_warps |
| ) |
| elif torch.cuda.get_device_capability()[0] == 9: |
| |
| BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps = (two_mm_kernel_configs_wrapper())(B, seq_len, K) |
| grid_size = min(NUM_SMS, triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N)) |
|
|
| two_mm_kernel_wrapper()[(grid_size,)]( |
| A_2d, left_proj, right_proj, left_gate, right_gate, out_gate, |
| C1, C2, D, mask, |
| M, N, K, |
| *A_strides, |
| left_proj.stride(1), left_proj.stride(0), |
| *C_strides, |
| seq_len, |
| *D_strides, |
| BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, NUM_SMS=NUM_SMS, |
| num_stages=num_stages, num_warps=num_warps |
| ) |
| else: |
| |
| two_mm_kernel_wrapper()[grid]( |
| A_2d, left_proj, right_proj, left_gate, right_gate, out_gate, |
| C1, C2, D, mask, |
| M, N, K, |
| *A_strides, |
| left_proj.stride(1), left_proj.stride(0), |
| *C_strides, |
| seq_len, |
| *D_strides, |
| NUM_SMS=NUM_SMS |
| ) |
|
|
| return C1, C2, D |
|
|
|
|
| def second_layernorm_mul(inp, hidden_dim, weight, bias, mul_operand): |
| ln = torch.nn.functional.layer_norm(inp, (hidden_dim,), eps=1e-5, weight=weight.to(inp.dtype), bias=bias.to(inp.dtype)) |
| out = ln * mul_operand |
| return out |
|
|
| ''' |
| @triton.autotune( |
| [triton.Config({"ROW_BLOCK_SIZE": 16}, num_warps=4, num_stages=3)], |
| key=["R", "C"] |
| ) |
| ''' |
| @triton.jit |
| def layernorm_kernel_first( |
| X, |
| Y, |
| Weight, |
| Bias, |
| R, |
| C, |
| eps, |
| ROW_BLOCK_SIZE: tl.constexpr, |
| BLOCK_SIZE: tl.constexpr, |
| ): |
| row = tl.program_id(0) * ROW_BLOCK_SIZE + tl.arange(0, ROW_BLOCK_SIZE) |
| cols = tl.arange(0, BLOCK_SIZE) |
|
|
| mask_row = row < R |
| mask_col = cols < C |
|
|
| |
| x = tl.load( |
| X + row[:, None] * C + cols[None, :], |
| mask=mask_row[:, None] & mask_col[None, :], |
| other=0.0 |
| ).to(tl.float32) |
|
|
| weight = tl.load(Weight + cols, mask=mask_col, other=0.0).to(tl.float32) |
| bias = tl.load(Bias + cols, mask=mask_col, other=0.0).to(tl.float32) |
|
|
| mean = tl.sum(x, axis=1) / C |
| diff = tl.where(mask_row[:, None] & mask_col[None, :], x - mean[:, None], 0) |
| var = tl.sum(diff * diff, axis=1) / C |
| rstd = 1 / tl.sqrt(var + eps) |
|
|
| y_hat = (x - mean[:, None]) * rstd[:, None] |
| y = y_hat * weight[None, :] + bias[None, :] |
|
|
| tl.store( |
| Y + row[:, None] * C + cols[None, :], |
| y, |
| mask=mask_row[:, None] & mask_col[None, :] |
| ) |
|
|
|
|
| def get_optimal_config_ln(dim): |
| config = None |
| if torch.cuda.get_device_capability()[0] == 9: |
| if (dim <= 256): |
| config = (16, 1) |
| elif dim <= 512: |
| config = (16, 2) |
| elif dim <= 1024: |
| config = (16, 4) |
| |
| if not config: |
| config = (16, 4) |
| return config |
|
|
|
|
| def triton_layernorm_first(x, weight, bias, eps=1e-5, num_warps=None, ROW_BLOCK_SIZE=None): |
| B, seq_len, seq_len2, dim = x.shape |
| assert(seq_len == seq_len2) |
|
|
| R = B * seq_len * seq_len |
| C = dim |
|
|
| out = torch.empty_like(x, dtype=torch.float16) |
|
|
| if not num_warps or not ROW_BLOCK_SIZE: |
| ROW_BLOCK_SIZE, num_warps = get_optimal_config_ln(dim) |
|
|
| BLOCK_SIZE = triton.next_power_of_2(C) |
| assert(BLOCK_SIZE <= 1024) |
|
|
| def grid(meta): |
| return (triton.cdiv(R, meta["ROW_BLOCK_SIZE"]),) |
|
|
| layernorm_kernel_first[grid]( |
| x, out, weight, bias, |
| R, C, eps, |
| ROW_BLOCK_SIZE=ROW_BLOCK_SIZE, |
| BLOCK_SIZE=BLOCK_SIZE, |
| num_warps=num_warps, |
| num_stages=3 |
| ) |
|
|
| return out |
|
|
| ''' |
| def triton_layernorm_first(x, weight, bias, eps=1e-5): |
| B, seq_len, seq_len2, dim = x.shape |
| assert(seq_len == seq_len2) |
| |
| R = B * seq_len * seq_len |
| C = dim |
| |
| out = torch.empty_like(x) |
| |
| BLOCK_SIZE = triton.next_power_of_2(C) |
| assert(BLOCK_SIZE <= 1024) |
| |
| def grid(meta): |
| return (triton.cdiv(R, meta["ROW_BLOCK_SIZE"]),) |
| |
| layernorm_kernel_first[grid]( |
| x, out, weight, bias, |
| R, C, eps, |
| BLOCK_SIZE=BLOCK_SIZE |
| ) |
| |
| return out |
| ''' |
|
|
|
|
| @triton.autotune( |
| [triton.Config({"ROW_BLOCK_SIZE": 16}, num_warps=1, num_stages=3)], |
| key=[] |
| ) |
| @triton.jit |
| def layernorm_kernel_eltwise( |
| X, |
| Y, |
| Weight, |
| Bias, |
| OutGate, |
| seq_len, |
| stride_batch, |
| stride_dim, |
| R, |
| C, |
| eps, |
| ROW_BLOCK_SIZE: tl.constexpr, |
| BLOCK_SIZE: tl.constexpr, |
| ): |
| row = tl.program_id(0) * ROW_BLOCK_SIZE + tl.arange(0, ROW_BLOCK_SIZE) |
| cols = tl.arange(0, BLOCK_SIZE) |
|
|
| |
| tl.device_assert(seq_len*seq_len % ROW_BLOCK_SIZE == 0) |
| |
| batch = tl.program_id(0) * ROW_BLOCK_SIZE // (seq_len * seq_len) |
| seqs_off = row % (seq_len * seq_len) |
|
|
| off_r = batch * stride_batch + seqs_off |
| off_c = cols * stride_dim |
|
|
| mask_row = row < R |
| mask_col = cols < C |
|
|
| out_gate = tl.load( |
| OutGate + row[:, None] * C + cols[None, :], |
| mask = mask_row[:, None] & mask_col[None, :], |
| ) |
|
|
| x = tl.load( |
| X + off_r[:, None] + off_c[None, :], |
| mask=mask_row[:, None] & mask_col[None, :], |
| other=0.0 |
| ).to(tl.float32) |
|
|
| weight = tl.load(Weight + cols, mask=mask_col, other=0.0).to(tl.float32) |
| bias = tl.load(Bias + cols, mask=mask_col, other=0.0).to(tl.float32) |
|
|
| mean = tl.sum(x, axis=1) / C |
| diff = tl.where(mask_row[:, None] & mask_col[None, :], x - mean[:, None], 0) |
| var = tl.sum(diff * diff, axis=1) / C |
| rstd = 1 / tl.sqrt(var + eps) |
|
|
| y_hat = (x - mean[:, None]) * rstd[:, None] |
| y = y_hat * weight[None, :] + bias[None, :] |
|
|
| tl.store( |
| Y + row[:, None] * C + cols[None, :], |
| y * out_gate, |
| mask=mask_row[:, None] & mask_col[None, :] |
| ) |
|
|
|
|
| def triton_layernorm_eltwise(x, weight, bias, out_gate, eps=1e-5): |
| B, seq_len, seq_len2, dim = x.shape |
| assert(seq_len == seq_len2) |
| R = B * seq_len * seq_len |
| assert(x.stride(3) == seq_len*seq_len) |
| assert(out_gate.is_contiguous()) |
| C = dim |
|
|
| out = torch.empty_like(out_gate, dtype=torch.float32) |
|
|
| BLOCK_SIZE = triton.next_power_of_2(C) |
| assert(BLOCK_SIZE == 128) |
|
|
| def grid(meta): |
| return (triton.cdiv(R, meta["ROW_BLOCK_SIZE"]),) |
|
|
| layernorm_kernel_eltwise[grid]( |
| x, out, weight, bias, out_gate, |
| seq_len, |
| x.stride(0), x.stride(3), |
| R, C, eps, |
| BLOCK_SIZE=BLOCK_SIZE |
| ) |
|
|
| return out |
|
|
|
|
| def kernel_global(data: input_t) -> output_t: |
| """ |
| Reference implementation of TriMul using PyTorch. |
| |
| Args: |
| data: Tuple of (input: torch.Tensor, mask: torch.Tensor, weights: Dict[str, torch.Tensor], config: Dict) |
| - input: Input tensor of shape [batch_size, seq_len, seq_len, dim] |
| - mask: Mask tensor of shape [batch_size, seq_len, seq_len] |
| - weights: Dictionary containing model weights |
| - config: Dictionary containing model configuration parameters |
| """ |
| input_tensor, mask, weights, config = data |
|
|
| left_proj_weight = weights["left_proj.weight"].to(torch.float16) |
| right_proj_weight = weights["right_proj.weight"].to(torch.float16) |
| left_gate_weight = weights["left_gate.weight"].to(torch.float16) |
| right_gate_weight = weights["right_gate.weight"].to(torch.float16) |
| out_gate_weight = weights["out_gate.weight"].to(torch.float16) |
|
|
| hidden_dim = config["hidden_dim"] |
| |
|
|
| x = input_tensor |
|
|
| batch_size, seq_len, _, dim = x.shape |
|
|
| x = triton_layernorm_first(x, weights['norm.weight'], weights['norm.bias']) |
| |
|
|
| left, right, out_gate = two_mm(x, left_proj_weight, right_proj_weight, left_gate_weight, right_gate_weight, out_gate_weight, mask) |
| |
| |
|
|
| |
| |
|
|
| ''' |
| left = left.to(torch.float32) |
| right = right.to(torch.float32) |
| x = x.to(torch.float32) |
| |
| left_gate = left_gate.sigmoid() |
| right_gate = right_gate.sigmoid() |
| out_gate = out_gate.sigmoid() |
| ''' |
|
|
| |
| |
| |
|
|
| |
| out = torch.bmm(left.permute(0, 3, 1, 2).view(-1, left.shape[1], left.shape[2]), right.permute(0, 3, 2, 1).view(-1, right.shape[2], right.shape[1])) |
| out = out.view(batch_size, hidden_dim, seq_len, seq_len).permute(0, 2, 3, 1) |
|
|
| |
| out = triton_layernorm_eltwise(out, weights['to_out_norm.weight'], weights['to_out_norm.bias'], out_gate) |
| |
| |
| return torch.nn.functional.linear(out, weights['to_out.weight']) |
|
|
| ''' |
| # Fill in the given weights of the model |
| trimul.norm.weight = nn.Parameter(weights['norm.weight']) |
| trimul.norm.bias = nn.Parameter(weights['norm.bias']) |
| trimul.left_proj.weight = nn.Parameter(weights['left_proj.weight']) |
| trimul.right_proj.weight = nn.Parameter(weights['right_proj.weight']) |
| trimul.left_gate.weight = nn.Parameter(weights['left_gate.weight']) |
| trimul.right_gate.weight = nn.Parameter(weights['right_gate.weight']) |
| trimul.out_gate.weight = nn.Parameter(weights['out_gate.weight']) |
| trimul.to_out_norm.weight = nn.Parameter(weights['to_out_norm.weight']) |
| trimul.to_out_norm.bias = nn.Parameter(weights['to_out_norm.bias']) |
| trimul.to_out.weight = nn.Parameter(weights['to_out.weight']) |
| |
| output = trimul(input_tensor, mask) |
| |
| return output |
| ''' |