JustinTX's picture
Add files using upload-large-folder tool
b0e88cf verified
# GPU Mode: MLA Decode (Multi-Head Latent Attention) Triton Kernel
max_iterations: 100
checkpoint_interval: 1
log_level: "INFO"
llm:
models:
- name: "gpt-5"
weight: 1.0
api_base: https://api.openai.com/v1
temperature: 1.0
# top_p: 0.95 # omitted by default; some providers (e.g. Anthropic) reject both temperature and top_p
max_tokens: 32000
timeout: 600
prompt:
system_message: |
You are an expert Triton engineer tasked with translating PyTorch code into highly optimized Triton kernel code.
Below is a pytorch implementation of the multi-head latent attention (MLA) module. You will want to implement a Triton kernel for the operations in the forward call:
```python
import math
from dataclasses import dataclass
import torch
from torch import nn
import torch.nn.functional as F
class RoPE(nn.Module):
def __init__(self, d_model: int):
super().__init__()
self.d_model = d_model
theta = 10000 ** (-torch.arange(0, d_model//2,dtype=torch.bfloat16) / (d_model//2))
self.register_buffer("theta", theta)
def rotate_half(self, x: torch.Tensor) -> torch.Tensor:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x: torch.Tensor, start_pos: int = 0) -> torch.Tensor:
seq_len = x.size(-2)
d_model = x.size(-1)
assert d_model == self.d_model
seq_idx = torch.arange(start_pos, start_pos + seq_len, device=x.device)
idx_theta = torch.einsum('s,d->sd', seq_idx, self.theta)
idx_theta2 = torch.cat([idx_theta, idx_theta], dim=-1)
cos = idx_theta2.cos().to(torch.bfloat16)
sin = idx_theta2.sin().to(torch.bfloat16)
return x * cos + self.rotate_half(x) * sin
class KVCache(nn.Module):
def __init__(self, kv_cache_shape: tuple) -> None:
super().__init__()
self.register_buffer('data', torch.zeros(kv_cache_shape, dtype=torch.bfloat16, device='cuda'))
self.seq_len = 0
self.zero()
def zero(self) -> None:
self.data.zero_()
def get_data(self) -> torch.Tensor:
return self.data
def forward(self, c_kv: torch.Tensor) -> torch.Tensor:
assert self.seq_len + c_kv.size(1) <= self.data.size(1), "KV Cache Exceeded"
self.data = self.data.to(c_kv.dtype)
self.data[
:, self.seq_len : self.seq_len + c_kv.size(1), :
] = c_kv
self.seq_len += c_kv.size(1)
return self.data[:, :self.seq_len], self.seq_len
@dataclass
class Config:
batch_size: int
dim: int
n_heads: int
q_lora_rank: int
kv_lora_rank: int
qk_nope_head_dim: int
qk_rope_head_dim: int
v_head_dim: int
seq_len: int
max_seq_len: int
kv_cache_shape: tuple
Q_proj_down_weight: torch.Tensor
Q_proj_up_weight: torch.Tensor
KV_proj_down_weight: torch.Tensor
KV_proj_up_weight: torch.Tensor
wo_weight: torch.Tensor
class MLA(nn.Module):
def __init__(self, config: Config):
super().__init__()
self.dim = config.dim
self.n_heads = config.n_heads
self.q_lora_rank = config.q_lora_rank
self.kv_lora_rank = config.kv_lora_rank
self.nope_head_dim = config.qk_nope_head_dim
self.rope_head_dim = config.qk_rope_head_dim
self.v_head_dim = config.v_head_dim
# Down-projection matrices
self.Q_proj_down = nn.Linear(self.dim, self.q_lora_rank, bias=False, dtype=torch.bfloat16)
self.KV_proj_down = nn.Linear(self.dim, self.kv_lora_rank + self.rope_head_dim, bias=False, dtype=torch.bfloat16)
# Up-projection and rope projection matrices
self.Q_proj_up = nn.Linear(self.q_lora_rank, (self.nope_head_dim + self.rope_head_dim) * self.n_heads, bias=False, dtype=torch.bfloat16)
self.KV_proj_up = nn.Linear(self.kv_lora_rank, (self.nope_head_dim + self.v_head_dim) * self.n_heads, bias=False, dtype=torch.bfloat16)
# RoPE on half embeddings
self.q_rope = RoPE(self.rope_head_dim)
self.k_rope = RoPE(self.rope_head_dim)
# Output projection
self.wo = nn.Linear(self.v_head_dim * self.n_heads, self.dim, dtype=torch.bfloat16, bias=False)
self.eps = 1e-6
def forward(self, x: torch.Tensor, kv_cache: KVCache) -> torch.Tensor:
# seq_len = 1 always here
batch_size, seq_len, model_dim = x.size()
## Step 1: Handle down-projection + KV cache ##
q_lora = self.Q_proj_down(x)
kv_lora = self.KV_proj_down(x)
kv_lora, kv_len = kv_cache(kv_lora)
query_pos = kv_len - 1
## Step 2: Up-project and prepare NoPE + RoPE ##
# Handle queries Q first
q_nope_and_rope = self.Q_proj_up(q_lora).view(
batch_size, seq_len, self.n_heads, self.nope_head_dim + self.rope_head_dim)
q_nope, q_rope = torch.split(q_nope_and_rope, [self.nope_head_dim, self.rope_head_dim], dim=-1)
# Handle keys and values K/V. V does not need RoPE
kv_nope, k_rope = torch.split(kv_lora, [self.kv_lora_rank, self.rope_head_dim], dim=-1)
kv_nope = self.KV_proj_up(kv_nope).view(
batch_size, kv_len, self.n_heads, self.nope_head_dim + self.v_head_dim)
k_nope, v = torch.split(kv_nope, [self.nope_head_dim, self.v_head_dim], dim=-1)
## Step 3: Handle RoPE Stream ##
# Compute RoPE for queries and combine with no-RoPE part
q_rope = q_rope.permute(0, 2, 1, 3) # bs x n_heads x seq_len x rope_head_dim
q_rope = self.q_rope(q_rope, start_pos=query_pos)
q_nope = q_nope.permute(0, 2, 1, 3) # bs x n_heads x seq_len x rope_head_dim
q = torch.concat([q_nope, q_rope], dim=-1)
# Compute RoPE for keys and combine with no-RoPE part
k_rope = k_rope[:, None, :, :]
k_rope = self.k_rope(k_rope).expand(-1,self.n_heads,-1,-1)
k_nope = k_nope.permute(0, 2, 1, 3) # bs x kv_len x n_heads x rope_head_dim
k = torch.concat([k_nope, k_rope], dim=-1)
## Step 4: Compute Multi-head Attention ##
v = v.permute(0, 2, 1, 3) # bs x n_heads x kv_len x v_head_dim
scores = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.rope_head_dim + self.nope_head_dim)
attn = F.softmax(scores, dim=-1).to(torch.bfloat16)
y = torch.matmul(attn, v).view(batch_size, 1, -1)
y = self.wo(y)
return y, kv_cache.get_data()
```
Your function should be defined as 'custom_kernel' (skeleton provided below)
```python
### DO NOT CHANGE THIS IMPORT STATEMENTS BLOCK ###
import os
import math
from typing import Tuple
import torch
import torch.nn.functional as F
import triton
from reference import KVCache, Config # Definition of KVCache and Config classes are shown above. Must import this way. Do not rewrite yourself.
### END OF IMPORT STATEMENTS BLOCK ###
### Import other packages here if needed
def custom_kernel(data: Tuple[Config, torch.Tensor, KVCache]) -> Tuple[torch.Tensor, KVCache]:
config, x, kv_cache = data
bs = config.batch_size
sl = config.seq_len
pl = kv_cache.seq_len
msl = config.max_seq_len
nh = config.n_heads
d = config.dim
dq = config.q_lora_rank
dkv = config.kv_lora_rank
dnope = config.qk_nope_head_dim
drope = config.qk_rope_head_dim
dv = config.v_head_dim
wDQ = config.Q_proj_down_weight
wDKV = config.KV_proj_down_weight
wUQ = config.Q_proj_up_weight
wUKV = config.KV_proj_up_weight
wO = config.wo_weight
# Perform MLA operations to process data into output and updated kv_cache
return output, kv_cache.data
```
with the following signature:
Input:
- `data`: Tuple of (config: Config, x: torch.Tensor, kv_cache: KVCache)
- config: An instance of class `Config` containing model configurations and weights
- x: Input tensor of shape [batch_size, seq_len, dim]
- kv_cache: An instance of KVCache class for caching the keys and values
Output:
- output: Output tensor [batch_size, seq_len, dim]
- kv_cache.data: The data field of the updated `KVCache` instance with the new keys and values added
To warm you up in writing optimized triton code, here is an example code which is correct for your task but very unoptimized. Your code should be as optimized as possible but still correct.
```python
import os
import math
from typing import Tuple
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
from reference import KVCache, Config
@triton.jit
def rope_swap_halves_kernel(
x_ptr,
cos_ptr, sin_ptr,
B: tl.constexpr,
T: tl.constexpr,
D: tl.constexpr,
stride_xb, stride_xt, stride_xd,
stride_cos_t, stride_cos_d,
stride_sin_t, stride_sin_d,
BLOCK_HALF: tl.constexpr,
):
pid = tl.program_id(0)
bt = pid
b = bt // T
t = bt - b * T
half = D // 2
off = tl.arange(0, BLOCK_HALF)
mask = off < half
x_base = x_ptr + b * stride_xb + t * stride_xt
x0_ptr = x_base + off * stride_xd
x1_ptr = x_base + (half + off) * stride_xd
cos_base = cos_ptr + t * stride_cos_t
sin_base = sin_ptr + t * stride_sin_t
c_ptr = cos_base + off * stride_cos_d
s_ptr = sin_base + off * stride_sin_d
x0 = tl.load(x0_ptr, mask=mask, other=0.0).to(tl.float32)
x1 = tl.load(x1_ptr, mask=mask, other=0.0).to(tl.float32)
c = tl.load(c_ptr, mask=mask, other=0.0).to(tl.float32)
s = tl.load(s_ptr, mask=mask, other=0.0).to(tl.float32)
out0 = x0 * c - x1 * s
out1 = x1 * c + x0 * s
tl.store(x0_ptr, out0.to(tl.bfloat16), mask=mask)
tl.store(x1_ptr, out1.to(tl.bfloat16), mask=mask)
# ... (see initial_program.py for full working baseline)
```
Below are the different configs that your kernel will be tested on:
Common configs:
- {"batch_size": 128, "seq_len": 1, "kv_lora_rank": 512, "qk_rope_head_dim": 64, "v_head_dim": 128, "n_heads": 128, "dim": 7168, "q_lora_rank": 1536, "max_seq_len": 8192}
For correctness check:
- {"prefill": 128}
- {"prefill": 512}
- {"prefill": 1024}
- {"prefill": 2048}
For performance benchmark (optimize runtime for these):
- {"prefill": 6144}
Rules:
- The tensors arguments passed in will be already on your cuda device.
- The weights for all parameters in the MLA will be given as input.
- All weights and data will be in `torch.bfloat16` format.
- Define all of your code in one final ```python ``` block.
- The entrypoint to your code must be named 'custom_kernel'.
- You will be using trition 3.4.0 and your kernels will be run on an Nvidia H200 GPU.
- Consider optimizing multiple operations with triton, not just limited to softmax. E.g., rope, attention, etc.
- You are allowed to use torch.compile().
Important rules in triton 3.4.0:
- `tl.load` does not have an argument called `dtype`. Never use it like `tl.load(..., dtype=...)`.
- Triton dtypes are not callable, so never use them like `tl.float16(1.0)`, `tl.float32(0.0)`.
- `tl.arange(start, end)`:
- range length (end - start) must be power-of-2
- start, end must be of type `tl.constexpr`
- `tl.range(start, end, step, num_stages)`:
- keep loop index type stable, don't reassign it
- start, end, step do not have to be `tl.constexpr` but must stay scalar integer types
- num_stages must be `tl.constexpr`
- Do not something like x[0] or offs[0] inside a Triton kernel. Triton tensors are SIMD vectors; scalar indexing like [0] is not generally supported.
Here's an simple example correctly following these rules:
```python
import torch
import triton
import triton.language as tl
@triton.jit
def kernel_right(
x_ptr, y_ptr, out_ptr,
n_elements: tl.constexpr,
BLOCK: tl.constexpr,
ROW_STEP: tl.constexpr,
NUM_STAGES: tl.constexpr,
):
pid = tl.program_id(axis=0)
offs = pid * BLOCK + tl.arange(0, BLOCK)
mask = offs < n_elements
x = tl.load(x_ptr + offs, mask=mask, other=0.0)
y = tl.load(y_ptr + offs, mask=mask, other=0.0)
one_f32 = tl.full([], 1.0, tl.float32)
acc = tl.zeros((BLOCK,), dtype=tl.float32)
acc = tl.cast(x, tl.float32) + tl.cast(y, tl.float32) + one_f32
base = tl.full([], pid * BLOCK, tl.int32)
x0 = tl.load(x_ptr + base, mask=(base < n_elements), other=0.0)
x0_vec = tl.full((BLOCK,), x0, tl.float32)
out_vec = acc + x0_vec
n_rows = tl.full([], 4, tl.int32)
extra = tl.zeros((BLOCK,), dtype=tl.float32)
for r in tl.range(0, n_rows, ROW_STEP, num_stages=NUM_STAGES):
shift = r * tl.full([], 1, tl.int32)
offs_r = offs + shift
xr = tl.load(x_ptr + offs_r, mask=(offs_r < n_elements), other=0.0)
extra += tl.cast(xr, tl.float32)
out_vec = out_vec + extra
tl.store(out_ptr + offs, tl.cast(out_vec, tl.float16), mask=mask)
```
evaluator:
timeout: 600
max_retries: 3
cascade_evaluation: true
cascade_thresholds: [0.4, 0.3]
diff_based_generation: true
max_solution_length: 60000
random_seed: 42