| |
|
|
| max_iterations: 100 |
| checkpoint_interval: 1 |
| log_level: "INFO" |
|
|
| llm: |
| models: |
| - name: "gpt-5" |
| weight: 1.0 |
| api_base: https://api.openai.com/v1 |
| temperature: 1.0 |
| |
| max_tokens: 32000 |
| timeout: 600 |
|
|
| prompt: |
| system_message: | |
| You are an expert Triton engineer tasked with translating PyTorch code into highly optimized Triton kernel code. |
| |
| Below is a pytorch implementation of the multi-head latent attention (MLA) module. You will want to implement a Triton kernel for the operations in the forward call: |
|
|
| ```python |
| import math |
| from dataclasses import dataclass |
| import torch |
| from torch import nn |
| import torch.nn.functional as F |
|
|
| class RoPE(nn.Module): |
| def __init__(self, d_model: int): |
| super().__init__() |
| self.d_model = d_model |
| theta = 10000 ** (-torch.arange(0, d_model//2,dtype=torch.bfloat16) / (d_model//2)) |
| self.register_buffer("theta", theta) |
|
|
| def rotate_half(self, x: torch.Tensor) -> torch.Tensor: |
| x1, x2 = x.chunk(2, dim=-1) |
| return torch.cat((-x2, x1), dim=-1) |
|
|
| def forward(self, x: torch.Tensor, start_pos: int = 0) -> torch.Tensor: |
| seq_len = x.size(-2) |
| d_model = x.size(-1) |
| assert d_model == self.d_model |
| seq_idx = torch.arange(start_pos, start_pos + seq_len, device=x.device) |
| idx_theta = torch.einsum('s,d->sd', seq_idx, self.theta) |
| idx_theta2 = torch.cat([idx_theta, idx_theta], dim=-1) |
| cos = idx_theta2.cos().to(torch.bfloat16) |
| sin = idx_theta2.sin().to(torch.bfloat16) |
| return x * cos + self.rotate_half(x) * sin |
|
|
| class KVCache(nn.Module): |
| def __init__(self, kv_cache_shape: tuple) -> None: |
| super().__init__() |
| self.register_buffer('data', torch.zeros(kv_cache_shape, dtype=torch.bfloat16, device='cuda')) |
| self.seq_len = 0 |
| self.zero() |
|
|
| def zero(self) -> None: |
| self.data.zero_() |
|
|
| def get_data(self) -> torch.Tensor: |
| return self.data |
|
|
| def forward(self, c_kv: torch.Tensor) -> torch.Tensor: |
| assert self.seq_len + c_kv.size(1) <= self.data.size(1), "KV Cache Exceeded" |
|
|
| self.data = self.data.to(c_kv.dtype) |
| self.data[ |
| :, self.seq_len : self.seq_len + c_kv.size(1), : |
| ] = c_kv |
| self.seq_len += c_kv.size(1) |
|
|
| return self.data[:, :self.seq_len], self.seq_len |
|
|
| @dataclass |
| class Config: |
| batch_size: int |
| dim: int |
| n_heads: int |
| q_lora_rank: int |
| kv_lora_rank: int |
| qk_nope_head_dim: int |
| qk_rope_head_dim: int |
| v_head_dim: int |
| seq_len: int |
| max_seq_len: int |
| kv_cache_shape: tuple |
| Q_proj_down_weight: torch.Tensor |
| Q_proj_up_weight: torch.Tensor |
| KV_proj_down_weight: torch.Tensor |
| KV_proj_up_weight: torch.Tensor |
| wo_weight: torch.Tensor |
|
|
| class MLA(nn.Module): |
| def __init__(self, config: Config): |
| super().__init__() |
| self.dim = config.dim |
| self.n_heads = config.n_heads |
| self.q_lora_rank = config.q_lora_rank |
| self.kv_lora_rank = config.kv_lora_rank |
| self.nope_head_dim = config.qk_nope_head_dim |
| self.rope_head_dim = config.qk_rope_head_dim |
| self.v_head_dim = config.v_head_dim |
| |
| self.Q_proj_down = nn.Linear(self.dim, self.q_lora_rank, bias=False, dtype=torch.bfloat16) |
| self.KV_proj_down = nn.Linear(self.dim, self.kv_lora_rank + self.rope_head_dim, bias=False, dtype=torch.bfloat16) |
|
|
| |
| self.Q_proj_up = nn.Linear(self.q_lora_rank, (self.nope_head_dim + self.rope_head_dim) * self.n_heads, bias=False, dtype=torch.bfloat16) |
| self.KV_proj_up = nn.Linear(self.kv_lora_rank, (self.nope_head_dim + self.v_head_dim) * self.n_heads, bias=False, dtype=torch.bfloat16) |
|
|
| |
| self.q_rope = RoPE(self.rope_head_dim) |
| self.k_rope = RoPE(self.rope_head_dim) |
|
|
| |
| self.wo = nn.Linear(self.v_head_dim * self.n_heads, self.dim, dtype=torch.bfloat16, bias=False) |
| self.eps = 1e-6 |
|
|
| def forward(self, x: torch.Tensor, kv_cache: KVCache) -> torch.Tensor: |
| |
| batch_size, seq_len, model_dim = x.size() |
|
|
| |
|
|
| q_lora = self.Q_proj_down(x) |
| kv_lora = self.KV_proj_down(x) |
| kv_lora, kv_len = kv_cache(kv_lora) |
| query_pos = kv_len - 1 |
|
|
| |
|
|
| |
| q_nope_and_rope = self.Q_proj_up(q_lora).view( |
| batch_size, seq_len, self.n_heads, self.nope_head_dim + self.rope_head_dim) |
| q_nope, q_rope = torch.split(q_nope_and_rope, [self.nope_head_dim, self.rope_head_dim], dim=-1) |
|
|
| |
| kv_nope, k_rope = torch.split(kv_lora, [self.kv_lora_rank, self.rope_head_dim], dim=-1) |
| kv_nope = self.KV_proj_up(kv_nope).view( |
| batch_size, kv_len, self.n_heads, self.nope_head_dim + self.v_head_dim) |
| k_nope, v = torch.split(kv_nope, [self.nope_head_dim, self.v_head_dim], dim=-1) |
|
|
| |
|
|
| |
| q_rope = q_rope.permute(0, 2, 1, 3) |
| q_rope = self.q_rope(q_rope, start_pos=query_pos) |
|
|
| q_nope = q_nope.permute(0, 2, 1, 3) |
| q = torch.concat([q_nope, q_rope], dim=-1) |
|
|
| |
| k_rope = k_rope[:, None, :, :] |
| k_rope = self.k_rope(k_rope).expand(-1,self.n_heads,-1,-1) |
| k_nope = k_nope.permute(0, 2, 1, 3) |
| k = torch.concat([k_nope, k_rope], dim=-1) |
|
|
| |
|
|
| v = v.permute(0, 2, 1, 3) |
| scores = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.rope_head_dim + self.nope_head_dim) |
| attn = F.softmax(scores, dim=-1).to(torch.bfloat16) |
| y = torch.matmul(attn, v).view(batch_size, 1, -1) |
| y = self.wo(y) |
|
|
| return y, kv_cache.get_data() |
| ``` |
|
|
| Your function should be defined as 'custom_kernel' (skeleton provided below) |
|
|
| ```python |
| |
| import os |
| import math |
| from typing import Tuple |
| import torch |
| import torch.nn.functional as F |
| import triton |
| from reference import KVCache, Config |
| |
|
|
| |
|
|
| def custom_kernel(data: Tuple[Config, torch.Tensor, KVCache]) -> Tuple[torch.Tensor, KVCache]: |
| config, x, kv_cache = data |
|
|
| bs = config.batch_size |
| sl = config.seq_len |
| pl = kv_cache.seq_len |
| msl = config.max_seq_len |
| nh = config.n_heads |
| d = config.dim |
| dq = config.q_lora_rank |
| dkv = config.kv_lora_rank |
| dnope = config.qk_nope_head_dim |
| drope = config.qk_rope_head_dim |
| dv = config.v_head_dim |
|
|
| wDQ = config.Q_proj_down_weight |
| wDKV = config.KV_proj_down_weight |
| wUQ = config.Q_proj_up_weight |
| wUKV = config.KV_proj_up_weight |
| wO = config.wo_weight |
|
|
| |
|
|
| return output, kv_cache.data |
| ``` |
|
|
| with the following signature: |
|
|
| Input: |
| - `data`: Tuple of (config: Config, x: torch.Tensor, kv_cache: KVCache) |
| - config: An instance of class `Config` containing model configurations and weights |
| - x: Input tensor of shape [batch_size, seq_len, dim] |
| - kv_cache: An instance of KVCache class for caching the keys and values |
|
|
| Output: |
| - output: Output tensor [batch_size, seq_len, dim] |
| - kv_cache.data: The data field of the updated `KVCache` instance with the new keys and values added |
|
|
| To warm you up in writing optimized triton code, here is an example code which is correct for your task but very unoptimized. Your code should be as optimized as possible but still correct. |
|
|
| ```python |
| import os |
| import math |
| from typing import Tuple |
| import torch |
| import torch.nn.functional as F |
| import triton |
| import triton.language as tl |
| from reference import KVCache, Config |
|
|
| @triton.jit |
| def rope_swap_halves_kernel( |
| x_ptr, |
| cos_ptr, sin_ptr, |
| B: tl.constexpr, |
| T: tl.constexpr, |
| D: tl.constexpr, |
| stride_xb, stride_xt, stride_xd, |
| stride_cos_t, stride_cos_d, |
| stride_sin_t, stride_sin_d, |
| BLOCK_HALF: tl.constexpr, |
| ): |
| pid = tl.program_id(0) |
| bt = pid |
| b = bt // T |
| t = bt - b * T |
| half = D // 2 |
| off = tl.arange(0, BLOCK_HALF) |
| mask = off < half |
| x_base = x_ptr + b * stride_xb + t * stride_xt |
| x0_ptr = x_base + off * stride_xd |
| x1_ptr = x_base + (half + off) * stride_xd |
| cos_base = cos_ptr + t * stride_cos_t |
| sin_base = sin_ptr + t * stride_sin_t |
| c_ptr = cos_base + off * stride_cos_d |
| s_ptr = sin_base + off * stride_sin_d |
| x0 = tl.load(x0_ptr, mask=mask, other=0.0).to(tl.float32) |
| x1 = tl.load(x1_ptr, mask=mask, other=0.0).to(tl.float32) |
| c = tl.load(c_ptr, mask=mask, other=0.0).to(tl.float32) |
| s = tl.load(s_ptr, mask=mask, other=0.0).to(tl.float32) |
| out0 = x0 * c - x1 * s |
| out1 = x1 * c + x0 * s |
| tl.store(x0_ptr, out0.to(tl.bfloat16), mask=mask) |
| tl.store(x1_ptr, out1.to(tl.bfloat16), mask=mask) |
|
|
| |
| ``` |
|
|
| Below are the different configs that your kernel will be tested on: |
|
|
| Common configs: |
| - {"batch_size": 128, "seq_len": 1, "kv_lora_rank": 512, "qk_rope_head_dim": 64, "v_head_dim": 128, "n_heads": 128, "dim": 7168, "q_lora_rank": 1536, "max_seq_len": 8192} |
|
|
| For correctness check: |
| - {"prefill": 128} |
| - {"prefill": 512} |
| - {"prefill": 1024} |
| - {"prefill": 2048} |
|
|
| For performance benchmark (optimize runtime for these): |
| - {"prefill": 6144} |
|
|
| Rules: |
| - The tensors arguments passed in will be already on your cuda device. |
| - The weights for all parameters in the MLA will be given as input. |
| - All weights and data will be in `torch.bfloat16` format. |
| - Define all of your code in one final ```python ``` block. |
| - The entrypoint to your code must be named 'custom_kernel'. |
| - You will be using trition 3.4.0 and your kernels will be run on an Nvidia H200 GPU. |
| - Consider optimizing multiple operations with triton, not just limited to softmax. E.g., rope, attention, etc. |
| - You are allowed to use torch.compile(). |
|
|
| Important rules in triton 3.4.0: |
| - `tl.load` does not have an argument called `dtype`. Never use it like `tl.load(..., dtype=...)`. |
| - Triton dtypes are not callable, so never use them like `tl.float16(1.0)`, `tl.float32(0.0)`. |
| - `tl.arange(start, end)`: |
| - range length (end - start) must be power-of-2 |
| - start, end must be of type `tl.constexpr` |
| - `tl.range(start, end, step, num_stages)`: |
| - keep loop index type stable, don't reassign it |
| - start, end, step do not have to be `tl.constexpr` but must stay scalar integer types |
| - num_stages must be `tl.constexpr` |
| - Do not something like x[0] or offs[0] inside a Triton kernel. Triton tensors are SIMD vectors; scalar indexing like [0] is not generally supported. |
|
|
| Here's an simple example correctly following these rules: |
|
|
| ```python |
| import torch |
| import triton |
| import triton.language as tl |
|
|
| @triton.jit |
| def kernel_right( |
| x_ptr, y_ptr, out_ptr, |
| n_elements: tl.constexpr, |
| BLOCK: tl.constexpr, |
| ROW_STEP: tl.constexpr, |
| NUM_STAGES: tl.constexpr, |
| ): |
| pid = tl.program_id(axis=0) |
| offs = pid * BLOCK + tl.arange(0, BLOCK) |
| mask = offs < n_elements |
| x = tl.load(x_ptr + offs, mask=mask, other=0.0) |
| y = tl.load(y_ptr + offs, mask=mask, other=0.0) |
| one_f32 = tl.full([], 1.0, tl.float32) |
| acc = tl.zeros((BLOCK,), dtype=tl.float32) |
| acc = tl.cast(x, tl.float32) + tl.cast(y, tl.float32) + one_f32 |
| base = tl.full([], pid * BLOCK, tl.int32) |
| x0 = tl.load(x_ptr + base, mask=(base < n_elements), other=0.0) |
| x0_vec = tl.full((BLOCK,), x0, tl.float32) |
| out_vec = acc + x0_vec |
| n_rows = tl.full([], 4, tl.int32) |
| extra = tl.zeros((BLOCK,), dtype=tl.float32) |
| for r in tl.range(0, n_rows, ROW_STEP, num_stages=NUM_STAGES): |
| shift = r * tl.full([], 1, tl.int32) |
| offs_r = offs + shift |
| xr = tl.load(x_ptr + offs_r, mask=(offs_r < n_elements), other=0.0) |
| extra += tl.cast(xr, tl.float32) |
| out_vec = out_vec + extra |
| tl.store(out_ptr + offs, tl.cast(out_vec, tl.float16), mask=mask) |
| ``` |
| evaluator: |
| timeout: 600 |
| max_retries: 3 |
| cascade_evaluation: true |
| cascade_thresholds: [0.4, 0.3] |
|
|
| diff_based_generation: true |
| max_solution_length: 60000 |
| random_seed: 42 |
|
|