File size: 14,009 Bytes
b0e88cf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 | # GPU Mode: MLA Decode (Multi-Head Latent Attention) Triton Kernel
max_iterations: 100
checkpoint_interval: 1
log_level: "INFO"
llm:
models:
- name: "gpt-5"
weight: 1.0
api_base: https://api.openai.com/v1
temperature: 1.0
# top_p: 0.95 # omitted by default; some providers (e.g. Anthropic) reject both temperature and top_p
max_tokens: 32000
timeout: 600
prompt:
system_message: |
You are an expert Triton engineer tasked with translating PyTorch code into highly optimized Triton kernel code.
Below is a pytorch implementation of the multi-head latent attention (MLA) module. You will want to implement a Triton kernel for the operations in the forward call:
```python
import math
from dataclasses import dataclass
import torch
from torch import nn
import torch.nn.functional as F
class RoPE(nn.Module):
def __init__(self, d_model: int):
super().__init__()
self.d_model = d_model
theta = 10000 ** (-torch.arange(0, d_model//2,dtype=torch.bfloat16) / (d_model//2))
self.register_buffer("theta", theta)
def rotate_half(self, x: torch.Tensor) -> torch.Tensor:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def forward(self, x: torch.Tensor, start_pos: int = 0) -> torch.Tensor:
seq_len = x.size(-2)
d_model = x.size(-1)
assert d_model == self.d_model
seq_idx = torch.arange(start_pos, start_pos + seq_len, device=x.device)
idx_theta = torch.einsum('s,d->sd', seq_idx, self.theta)
idx_theta2 = torch.cat([idx_theta, idx_theta], dim=-1)
cos = idx_theta2.cos().to(torch.bfloat16)
sin = idx_theta2.sin().to(torch.bfloat16)
return x * cos + self.rotate_half(x) * sin
class KVCache(nn.Module):
def __init__(self, kv_cache_shape: tuple) -> None:
super().__init__()
self.register_buffer('data', torch.zeros(kv_cache_shape, dtype=torch.bfloat16, device='cuda'))
self.seq_len = 0
self.zero()
def zero(self) -> None:
self.data.zero_()
def get_data(self) -> torch.Tensor:
return self.data
def forward(self, c_kv: torch.Tensor) -> torch.Tensor:
assert self.seq_len + c_kv.size(1) <= self.data.size(1), "KV Cache Exceeded"
self.data = self.data.to(c_kv.dtype)
self.data[
:, self.seq_len : self.seq_len + c_kv.size(1), :
] = c_kv
self.seq_len += c_kv.size(1)
return self.data[:, :self.seq_len], self.seq_len
@dataclass
class Config:
batch_size: int
dim: int
n_heads: int
q_lora_rank: int
kv_lora_rank: int
qk_nope_head_dim: int
qk_rope_head_dim: int
v_head_dim: int
seq_len: int
max_seq_len: int
kv_cache_shape: tuple
Q_proj_down_weight: torch.Tensor
Q_proj_up_weight: torch.Tensor
KV_proj_down_weight: torch.Tensor
KV_proj_up_weight: torch.Tensor
wo_weight: torch.Tensor
class MLA(nn.Module):
def __init__(self, config: Config):
super().__init__()
self.dim = config.dim
self.n_heads = config.n_heads
self.q_lora_rank = config.q_lora_rank
self.kv_lora_rank = config.kv_lora_rank
self.nope_head_dim = config.qk_nope_head_dim
self.rope_head_dim = config.qk_rope_head_dim
self.v_head_dim = config.v_head_dim
# Down-projection matrices
self.Q_proj_down = nn.Linear(self.dim, self.q_lora_rank, bias=False, dtype=torch.bfloat16)
self.KV_proj_down = nn.Linear(self.dim, self.kv_lora_rank + self.rope_head_dim, bias=False, dtype=torch.bfloat16)
# Up-projection and rope projection matrices
self.Q_proj_up = nn.Linear(self.q_lora_rank, (self.nope_head_dim + self.rope_head_dim) * self.n_heads, bias=False, dtype=torch.bfloat16)
self.KV_proj_up = nn.Linear(self.kv_lora_rank, (self.nope_head_dim + self.v_head_dim) * self.n_heads, bias=False, dtype=torch.bfloat16)
# RoPE on half embeddings
self.q_rope = RoPE(self.rope_head_dim)
self.k_rope = RoPE(self.rope_head_dim)
# Output projection
self.wo = nn.Linear(self.v_head_dim * self.n_heads, self.dim, dtype=torch.bfloat16, bias=False)
self.eps = 1e-6
def forward(self, x: torch.Tensor, kv_cache: KVCache) -> torch.Tensor:
# seq_len = 1 always here
batch_size, seq_len, model_dim = x.size()
## Step 1: Handle down-projection + KV cache ##
q_lora = self.Q_proj_down(x)
kv_lora = self.KV_proj_down(x)
kv_lora, kv_len = kv_cache(kv_lora)
query_pos = kv_len - 1
## Step 2: Up-project and prepare NoPE + RoPE ##
# Handle queries Q first
q_nope_and_rope = self.Q_proj_up(q_lora).view(
batch_size, seq_len, self.n_heads, self.nope_head_dim + self.rope_head_dim)
q_nope, q_rope = torch.split(q_nope_and_rope, [self.nope_head_dim, self.rope_head_dim], dim=-1)
# Handle keys and values K/V. V does not need RoPE
kv_nope, k_rope = torch.split(kv_lora, [self.kv_lora_rank, self.rope_head_dim], dim=-1)
kv_nope = self.KV_proj_up(kv_nope).view(
batch_size, kv_len, self.n_heads, self.nope_head_dim + self.v_head_dim)
k_nope, v = torch.split(kv_nope, [self.nope_head_dim, self.v_head_dim], dim=-1)
## Step 3: Handle RoPE Stream ##
# Compute RoPE for queries and combine with no-RoPE part
q_rope = q_rope.permute(0, 2, 1, 3) # bs x n_heads x seq_len x rope_head_dim
q_rope = self.q_rope(q_rope, start_pos=query_pos)
q_nope = q_nope.permute(0, 2, 1, 3) # bs x n_heads x seq_len x rope_head_dim
q = torch.concat([q_nope, q_rope], dim=-1)
# Compute RoPE for keys and combine with no-RoPE part
k_rope = k_rope[:, None, :, :]
k_rope = self.k_rope(k_rope).expand(-1,self.n_heads,-1,-1)
k_nope = k_nope.permute(0, 2, 1, 3) # bs x kv_len x n_heads x rope_head_dim
k = torch.concat([k_nope, k_rope], dim=-1)
## Step 4: Compute Multi-head Attention ##
v = v.permute(0, 2, 1, 3) # bs x n_heads x kv_len x v_head_dim
scores = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.rope_head_dim + self.nope_head_dim)
attn = F.softmax(scores, dim=-1).to(torch.bfloat16)
y = torch.matmul(attn, v).view(batch_size, 1, -1)
y = self.wo(y)
return y, kv_cache.get_data()
```
Your function should be defined as 'custom_kernel' (skeleton provided below)
```python
### DO NOT CHANGE THIS IMPORT STATEMENTS BLOCK ###
import os
import math
from typing import Tuple
import torch
import torch.nn.functional as F
import triton
from reference import KVCache, Config # Definition of KVCache and Config classes are shown above. Must import this way. Do not rewrite yourself.
### END OF IMPORT STATEMENTS BLOCK ###
### Import other packages here if needed
def custom_kernel(data: Tuple[Config, torch.Tensor, KVCache]) -> Tuple[torch.Tensor, KVCache]:
config, x, kv_cache = data
bs = config.batch_size
sl = config.seq_len
pl = kv_cache.seq_len
msl = config.max_seq_len
nh = config.n_heads
d = config.dim
dq = config.q_lora_rank
dkv = config.kv_lora_rank
dnope = config.qk_nope_head_dim
drope = config.qk_rope_head_dim
dv = config.v_head_dim
wDQ = config.Q_proj_down_weight
wDKV = config.KV_proj_down_weight
wUQ = config.Q_proj_up_weight
wUKV = config.KV_proj_up_weight
wO = config.wo_weight
# Perform MLA operations to process data into output and updated kv_cache
return output, kv_cache.data
```
with the following signature:
Input:
- `data`: Tuple of (config: Config, x: torch.Tensor, kv_cache: KVCache)
- config: An instance of class `Config` containing model configurations and weights
- x: Input tensor of shape [batch_size, seq_len, dim]
- kv_cache: An instance of KVCache class for caching the keys and values
Output:
- output: Output tensor [batch_size, seq_len, dim]
- kv_cache.data: The data field of the updated `KVCache` instance with the new keys and values added
To warm you up in writing optimized triton code, here is an example code which is correct for your task but very unoptimized. Your code should be as optimized as possible but still correct.
```python
import os
import math
from typing import Tuple
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
from reference import KVCache, Config
@triton.jit
def rope_swap_halves_kernel(
x_ptr,
cos_ptr, sin_ptr,
B: tl.constexpr,
T: tl.constexpr,
D: tl.constexpr,
stride_xb, stride_xt, stride_xd,
stride_cos_t, stride_cos_d,
stride_sin_t, stride_sin_d,
BLOCK_HALF: tl.constexpr,
):
pid = tl.program_id(0)
bt = pid
b = bt // T
t = bt - b * T
half = D // 2
off = tl.arange(0, BLOCK_HALF)
mask = off < half
x_base = x_ptr + b * stride_xb + t * stride_xt
x0_ptr = x_base + off * stride_xd
x1_ptr = x_base + (half + off) * stride_xd
cos_base = cos_ptr + t * stride_cos_t
sin_base = sin_ptr + t * stride_sin_t
c_ptr = cos_base + off * stride_cos_d
s_ptr = sin_base + off * stride_sin_d
x0 = tl.load(x0_ptr, mask=mask, other=0.0).to(tl.float32)
x1 = tl.load(x1_ptr, mask=mask, other=0.0).to(tl.float32)
c = tl.load(c_ptr, mask=mask, other=0.0).to(tl.float32)
s = tl.load(s_ptr, mask=mask, other=0.0).to(tl.float32)
out0 = x0 * c - x1 * s
out1 = x1 * c + x0 * s
tl.store(x0_ptr, out0.to(tl.bfloat16), mask=mask)
tl.store(x1_ptr, out1.to(tl.bfloat16), mask=mask)
# ... (see initial_program.py for full working baseline)
```
Below are the different configs that your kernel will be tested on:
Common configs:
- {"batch_size": 128, "seq_len": 1, "kv_lora_rank": 512, "qk_rope_head_dim": 64, "v_head_dim": 128, "n_heads": 128, "dim": 7168, "q_lora_rank": 1536, "max_seq_len": 8192}
For correctness check:
- {"prefill": 128}
- {"prefill": 512}
- {"prefill": 1024}
- {"prefill": 2048}
For performance benchmark (optimize runtime for these):
- {"prefill": 6144}
Rules:
- The tensors arguments passed in will be already on your cuda device.
- The weights for all parameters in the MLA will be given as input.
- All weights and data will be in `torch.bfloat16` format.
- Define all of your code in one final ```python ``` block.
- The entrypoint to your code must be named 'custom_kernel'.
- You will be using trition 3.4.0 and your kernels will be run on an Nvidia H200 GPU.
- Consider optimizing multiple operations with triton, not just limited to softmax. E.g., rope, attention, etc.
- You are allowed to use torch.compile().
Important rules in triton 3.4.0:
- `tl.load` does not have an argument called `dtype`. Never use it like `tl.load(..., dtype=...)`.
- Triton dtypes are not callable, so never use them like `tl.float16(1.0)`, `tl.float32(0.0)`.
- `tl.arange(start, end)`:
- range length (end - start) must be power-of-2
- start, end must be of type `tl.constexpr`
- `tl.range(start, end, step, num_stages)`:
- keep loop index type stable, don't reassign it
- start, end, step do not have to be `tl.constexpr` but must stay scalar integer types
- num_stages must be `tl.constexpr`
- Do not something like x[0] or offs[0] inside a Triton kernel. Triton tensors are SIMD vectors; scalar indexing like [0] is not generally supported.
Here's an simple example correctly following these rules:
```python
import torch
import triton
import triton.language as tl
@triton.jit
def kernel_right(
x_ptr, y_ptr, out_ptr,
n_elements: tl.constexpr,
BLOCK: tl.constexpr,
ROW_STEP: tl.constexpr,
NUM_STAGES: tl.constexpr,
):
pid = tl.program_id(axis=0)
offs = pid * BLOCK + tl.arange(0, BLOCK)
mask = offs < n_elements
x = tl.load(x_ptr + offs, mask=mask, other=0.0)
y = tl.load(y_ptr + offs, mask=mask, other=0.0)
one_f32 = tl.full([], 1.0, tl.float32)
acc = tl.zeros((BLOCK,), dtype=tl.float32)
acc = tl.cast(x, tl.float32) + tl.cast(y, tl.float32) + one_f32
base = tl.full([], pid * BLOCK, tl.int32)
x0 = tl.load(x_ptr + base, mask=(base < n_elements), other=0.0)
x0_vec = tl.full((BLOCK,), x0, tl.float32)
out_vec = acc + x0_vec
n_rows = tl.full([], 4, tl.int32)
extra = tl.zeros((BLOCK,), dtype=tl.float32)
for r in tl.range(0, n_rows, ROW_STEP, num_stages=NUM_STAGES):
shift = r * tl.full([], 1, tl.int32)
offs_r = offs + shift
xr = tl.load(x_ptr + offs_r, mask=(offs_r < n_elements), other=0.0)
extra += tl.cast(xr, tl.float32)
out_vec = out_vec + extra
tl.store(out_ptr + offs, tl.cast(out_vec, tl.float16), mask=mask)
```
evaluator:
timeout: 600
max_retries: 3
cascade_evaluation: true
cascade_thresholds: [0.4, 0.3]
diff_based_generation: true
max_solution_length: 60000
random_seed: 42
|