File size: 14,009 Bytes
b0e88cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
# GPU Mode: MLA Decode (Multi-Head Latent Attention) Triton Kernel

max_iterations: 100
checkpoint_interval: 1
log_level: "INFO"

llm:
  models:
    - name: "gpt-5"
      weight: 1.0
  api_base: https://api.openai.com/v1
  temperature: 1.0
  # top_p: 0.95  # omitted by default; some providers (e.g. Anthropic) reject both temperature and top_p
  max_tokens: 32000
  timeout: 600

prompt:
  system_message: |
    You are an expert Triton engineer tasked with translating PyTorch code into highly optimized Triton kernel code.

    Below is a pytorch implementation of the multi-head latent attention (MLA) module. You will want to implement a Triton kernel for the operations in the forward call:

    ```python
    import math
    from dataclasses import dataclass
    import torch
    from torch import nn
    import torch.nn.functional as F

    class RoPE(nn.Module):
        def __init__(self, d_model: int):
            super().__init__()
            self.d_model = d_model
            theta = 10000 ** (-torch.arange(0, d_model//2,dtype=torch.bfloat16) / (d_model//2))
            self.register_buffer("theta", theta)

        def rotate_half(self, x: torch.Tensor) -> torch.Tensor:
            x1, x2 = x.chunk(2, dim=-1)
            return torch.cat((-x2, x1), dim=-1)

        def forward(self, x: torch.Tensor, start_pos: int = 0) -> torch.Tensor:
            seq_len = x.size(-2)
            d_model = x.size(-1)
            assert d_model == self.d_model
            seq_idx = torch.arange(start_pos, start_pos + seq_len, device=x.device)
            idx_theta = torch.einsum('s,d->sd', seq_idx, self.theta)
            idx_theta2 = torch.cat([idx_theta, idx_theta], dim=-1)
            cos = idx_theta2.cos().to(torch.bfloat16)
            sin = idx_theta2.sin().to(torch.bfloat16)
            return x * cos + self.rotate_half(x) * sin

    class KVCache(nn.Module):
        def __init__(self, kv_cache_shape: tuple) -> None:
            super().__init__()
            self.register_buffer('data', torch.zeros(kv_cache_shape, dtype=torch.bfloat16, device='cuda'))
            self.seq_len = 0
            self.zero()

        def zero(self) -> None:
            self.data.zero_()

        def get_data(self) -> torch.Tensor:
            return self.data

        def forward(self, c_kv: torch.Tensor) -> torch.Tensor:
            assert self.seq_len + c_kv.size(1) <= self.data.size(1), "KV Cache Exceeded"

            self.data = self.data.to(c_kv.dtype)
            self.data[
                :, self.seq_len : self.seq_len + c_kv.size(1), :
            ] = c_kv
            self.seq_len += c_kv.size(1)

            return self.data[:, :self.seq_len], self.seq_len

    @dataclass
    class Config:
        batch_size: int
        dim: int
        n_heads: int
        q_lora_rank: int
        kv_lora_rank: int
        qk_nope_head_dim: int
        qk_rope_head_dim: int
        v_head_dim: int
        seq_len: int
        max_seq_len: int
        kv_cache_shape: tuple
        Q_proj_down_weight: torch.Tensor
        Q_proj_up_weight: torch.Tensor
        KV_proj_down_weight: torch.Tensor
        KV_proj_up_weight: torch.Tensor
        wo_weight: torch.Tensor

    class MLA(nn.Module):
        def __init__(self, config: Config):
            super().__init__()
            self.dim = config.dim
            self.n_heads = config.n_heads
            self.q_lora_rank = config.q_lora_rank
            self.kv_lora_rank = config.kv_lora_rank
            self.nope_head_dim = config.qk_nope_head_dim
            self.rope_head_dim = config.qk_rope_head_dim
            self.v_head_dim = config.v_head_dim
            # Down-projection matrices
            self.Q_proj_down = nn.Linear(self.dim, self.q_lora_rank, bias=False, dtype=torch.bfloat16)
            self.KV_proj_down = nn.Linear(self.dim, self.kv_lora_rank + self.rope_head_dim, bias=False, dtype=torch.bfloat16)

            # Up-projection and rope projection matrices
            self.Q_proj_up = nn.Linear(self.q_lora_rank, (self.nope_head_dim + self.rope_head_dim) * self.n_heads, bias=False, dtype=torch.bfloat16)
            self.KV_proj_up = nn.Linear(self.kv_lora_rank, (self.nope_head_dim + self.v_head_dim) * self.n_heads, bias=False, dtype=torch.bfloat16)

            # RoPE on half embeddings
            self.q_rope = RoPE(self.rope_head_dim)
            self.k_rope = RoPE(self.rope_head_dim)

            # Output projection
            self.wo = nn.Linear(self.v_head_dim * self.n_heads, self.dim, dtype=torch.bfloat16, bias=False)
            self.eps = 1e-6

        def forward(self, x: torch.Tensor, kv_cache: KVCache) -> torch.Tensor:
            # seq_len = 1 always here
            batch_size, seq_len, model_dim = x.size()

            ## Step 1: Handle down-projection + KV cache ##

            q_lora = self.Q_proj_down(x)
            kv_lora = self.KV_proj_down(x)
            kv_lora, kv_len = kv_cache(kv_lora)
            query_pos = kv_len - 1

            ## Step 2: Up-project and prepare NoPE + RoPE ##

            # Handle queries Q first
            q_nope_and_rope = self.Q_proj_up(q_lora).view(
                batch_size, seq_len, self.n_heads, self.nope_head_dim + self.rope_head_dim)
            q_nope, q_rope = torch.split(q_nope_and_rope, [self.nope_head_dim, self.rope_head_dim], dim=-1)

            # Handle keys and values K/V. V does not need RoPE
            kv_nope, k_rope = torch.split(kv_lora, [self.kv_lora_rank, self.rope_head_dim], dim=-1)
            kv_nope = self.KV_proj_up(kv_nope).view(
                batch_size, kv_len, self.n_heads, self.nope_head_dim + self.v_head_dim)
            k_nope, v = torch.split(kv_nope, [self.nope_head_dim, self.v_head_dim], dim=-1)

            ## Step 3: Handle RoPE Stream ##

            # Compute RoPE for queries and combine with no-RoPE part
            q_rope = q_rope.permute(0, 2, 1, 3) # bs x n_heads x seq_len x rope_head_dim
            q_rope = self.q_rope(q_rope, start_pos=query_pos)

            q_nope = q_nope.permute(0, 2, 1, 3) # bs x n_heads x seq_len x rope_head_dim
            q = torch.concat([q_nope, q_rope], dim=-1)

            # Compute RoPE for keys and combine with no-RoPE part
            k_rope = k_rope[:, None, :, :]
            k_rope = self.k_rope(k_rope).expand(-1,self.n_heads,-1,-1)
            k_nope = k_nope.permute(0, 2, 1, 3) # bs x kv_len x n_heads x rope_head_dim
            k = torch.concat([k_nope, k_rope], dim=-1)

            ## Step 4: Compute Multi-head Attention ##

            v = v.permute(0, 2, 1, 3) # bs x n_heads x kv_len x v_head_dim
            scores = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.rope_head_dim + self.nope_head_dim)
            attn = F.softmax(scores, dim=-1).to(torch.bfloat16)
            y = torch.matmul(attn, v).view(batch_size, 1, -1)
            y = self.wo(y)

            return y, kv_cache.get_data()
    ```

    Your function should be defined as 'custom_kernel' (skeleton provided below)

    ```python
    ### DO NOT CHANGE THIS IMPORT STATEMENTS BLOCK ###
    import os
    import math
    from typing import Tuple
    import torch
    import torch.nn.functional as F
    import triton
    from reference import KVCache, Config  # Definition of KVCache and Config classes are shown above. Must import this way. Do not rewrite yourself.
    ### END OF IMPORT STATEMENTS BLOCK ###

    ### Import other packages here if needed

    def custom_kernel(data: Tuple[Config, torch.Tensor, KVCache]) -> Tuple[torch.Tensor, KVCache]:
        config, x, kv_cache = data

        bs = config.batch_size
        sl = config.seq_len
        pl = kv_cache.seq_len
        msl = config.max_seq_len
        nh = config.n_heads
        d =  config.dim
        dq = config.q_lora_rank
        dkv = config.kv_lora_rank
        dnope = config.qk_nope_head_dim
        drope = config.qk_rope_head_dim
        dv = config.v_head_dim

        wDQ  = config.Q_proj_down_weight
        wDKV = config.KV_proj_down_weight
        wUQ  = config.Q_proj_up_weight
        wUKV = config.KV_proj_up_weight
        wO   = config.wo_weight

        # Perform MLA operations to process data into output and updated kv_cache

        return output, kv_cache.data
    ```

    with the following signature:

    Input:
    - `data`: Tuple of (config: Config, x: torch.Tensor, kv_cache: KVCache)
        - config: An instance of class `Config` containing model configurations and weights
        - x: Input tensor of shape [batch_size, seq_len, dim]
        - kv_cache: An instance of KVCache class for caching the keys and values

    Output:
    - output: Output tensor [batch_size, seq_len, dim]
    - kv_cache.data: The data field of the updated `KVCache` instance with the new keys and values added

    To warm you up in writing optimized triton code, here is an example code which is correct for your task but very unoptimized. Your code should be as optimized as possible but still correct.

    ```python
    import os
    import math
    from typing import Tuple
    import torch
    import torch.nn.functional as F
    import triton
    import triton.language as tl
    from reference import KVCache, Config

    @triton.jit
    def rope_swap_halves_kernel(
        x_ptr,
        cos_ptr, sin_ptr,
        B: tl.constexpr,
        T: tl.constexpr,
        D: tl.constexpr,
        stride_xb, stride_xt, stride_xd,
        stride_cos_t, stride_cos_d,
        stride_sin_t, stride_sin_d,
        BLOCK_HALF: tl.constexpr,
    ):
        pid = tl.program_id(0)
        bt = pid
        b = bt // T
        t = bt - b * T
        half = D // 2
        off = tl.arange(0, BLOCK_HALF)
        mask = off < half
        x_base = x_ptr + b * stride_xb + t * stride_xt
        x0_ptr = x_base + off * stride_xd
        x1_ptr = x_base + (half + off) * stride_xd
        cos_base = cos_ptr + t * stride_cos_t
        sin_base = sin_ptr + t * stride_sin_t
        c_ptr = cos_base + off * stride_cos_d
        s_ptr = sin_base + off * stride_sin_d
        x0 = tl.load(x0_ptr, mask=mask, other=0.0).to(tl.float32)
        x1 = tl.load(x1_ptr, mask=mask, other=0.0).to(tl.float32)
        c  = tl.load(c_ptr,  mask=mask, other=0.0).to(tl.float32)
        s  = tl.load(s_ptr,  mask=mask, other=0.0).to(tl.float32)
        out0 = x0 * c - x1 * s
        out1 = x1 * c + x0 * s
        tl.store(x0_ptr, out0.to(tl.bfloat16), mask=mask)
        tl.store(x1_ptr, out1.to(tl.bfloat16), mask=mask)

    # ... (see initial_program.py for full working baseline)
    ```

    Below are the different configs that your kernel will be tested on:

    Common configs:
      - {"batch_size": 128, "seq_len": 1, "kv_lora_rank": 512, "qk_rope_head_dim": 64, "v_head_dim": 128, "n_heads": 128, "dim": 7168, "q_lora_rank": 1536, "max_seq_len": 8192}

    For correctness check:
      - {"prefill": 128}
      - {"prefill": 512}
      - {"prefill": 1024}
      - {"prefill": 2048}

    For performance benchmark (optimize runtime for these):
      - {"prefill": 6144}

    Rules:
    - The tensors arguments passed in will be already on your cuda device.
    - The weights for all parameters in the MLA will be given as input.
    - All weights and data will be in `torch.bfloat16` format.
    - Define all of your code in one final ```python ``` block.
    - The entrypoint to your code must be named 'custom_kernel'.
    - You will be using trition 3.4.0 and your kernels will be run on an Nvidia H200 GPU.
    - Consider optimizing multiple operations with triton, not just limited to softmax. E.g., rope, attention, etc.
    - You are allowed to use torch.compile().

    Important rules in triton 3.4.0:
    - `tl.load` does not have an argument called `dtype`. Never use it like `tl.load(..., dtype=...)`.
    - Triton dtypes are not callable, so never use them like `tl.float16(1.0)`, `tl.float32(0.0)`.
    - `tl.arange(start, end)`:
        - range length (end - start) must be power-of-2
        - start, end must be of type `tl.constexpr`
    - `tl.range(start, end, step, num_stages)`:
        - keep loop index type stable, don't reassign it
        - start, end, step do not have to be `tl.constexpr` but must stay scalar integer types
        - num_stages must be `tl.constexpr`
    - Do not something like x[0] or offs[0] inside a Triton kernel. Triton tensors are SIMD vectors; scalar indexing like [0] is not generally supported.

    Here's an simple example correctly following these rules:

    ```python
    import torch
    import triton
    import triton.language as tl

    @triton.jit
    def kernel_right(
        x_ptr, y_ptr, out_ptr,
        n_elements: tl.constexpr,
        BLOCK: tl.constexpr,
        ROW_STEP: tl.constexpr,
        NUM_STAGES: tl.constexpr,
    ):
        pid = tl.program_id(axis=0)
        offs = pid * BLOCK + tl.arange(0, BLOCK)
        mask = offs < n_elements
        x = tl.load(x_ptr + offs, mask=mask, other=0.0)
        y = tl.load(y_ptr + offs, mask=mask, other=0.0)
        one_f32 = tl.full([], 1.0, tl.float32)
        acc = tl.zeros((BLOCK,), dtype=tl.float32)
        acc = tl.cast(x, tl.float32) + tl.cast(y, tl.float32) + one_f32
        base = tl.full([], pid * BLOCK, tl.int32)
        x0 = tl.load(x_ptr + base, mask=(base < n_elements), other=0.0)
        x0_vec = tl.full((BLOCK,), x0, tl.float32)
        out_vec = acc + x0_vec
        n_rows = tl.full([], 4, tl.int32)
        extra = tl.zeros((BLOCK,), dtype=tl.float32)
        for r in tl.range(0, n_rows, ROW_STEP, num_stages=NUM_STAGES):
            shift = r * tl.full([], 1, tl.int32)
            offs_r = offs + shift
            xr = tl.load(x_ptr + offs_r, mask=(offs_r < n_elements), other=0.0)
            extra += tl.cast(xr, tl.float32)
        out_vec = out_vec + extra
        tl.store(out_ptr + offs, tl.cast(out_vec, tl.float16), mask=mask)
    ```
evaluator:
  timeout: 600
  max_retries: 3
  cascade_evaluation: true
  cascade_thresholds: [0.4, 0.3]

diff_based_generation: true
max_solution_length: 60000
random_seed: 42