# EVOLVE-BLOCK-START """ Initial MLA Decode submission — optimised baseline with Triton softmax and RoPE kernels. """ import os import math from typing import Tuple import torch import torch.nn.functional as F import triton import triton.language as tl from reference import KVCache, Config @triton.jit def rope_swap_halves_kernel( x_ptr, cos_ptr, sin_ptr, B: tl.constexpr, T: tl.constexpr, D: tl.constexpr, stride_xb, stride_xt, stride_xd, stride_cos_t, stride_cos_d, stride_sin_t, stride_sin_d, BLOCK_HALF: tl.constexpr, ): pid = tl.program_id(0) bt = pid b = bt // T t = bt - b * T half = D // 2 off = tl.arange(0, BLOCK_HALF) mask = off < half x_base = x_ptr + b * stride_xb + t * stride_xt x0_ptr = x_base + off * stride_xd x1_ptr = x_base + (half + off) * stride_xd cos_base = cos_ptr + t * stride_cos_t sin_base = sin_ptr + t * stride_sin_t c_ptr = cos_base + off * stride_cos_d s_ptr = sin_base + off * stride_sin_d x0 = tl.load(x0_ptr, mask=mask, other=0.0).to(tl.float32) x1 = tl.load(x1_ptr, mask=mask, other=0.0).to(tl.float32) c = tl.load(c_ptr, mask=mask, other=0.0).to(tl.float32) s = tl.load(s_ptr, mask=mask, other=0.0).to(tl.float32) out0 = x0 * c - x1 * s out1 = x1 * c + x0 * s tl.store(x0_ptr, out0.to(tl.bfloat16), mask=mask) tl.store(x1_ptr, out1.to(tl.bfloat16), mask=mask) def rope_inplace_query(q_rope: torch.Tensor, cos_q: torch.Tensor, sin_q: torch.Tensor): assert q_rope.is_cuda assert q_rope.shape[-1] % 2 == 0 bs, nh, d_rope = q_rope.shape half = d_rope // 2 BLOCK_HALF = 1 << (half - 1).bit_length() grid = (bs * nh,) rope_swap_halves_kernel[grid]( q_rope, cos_q, sin_q, B=bs, T=nh, D=d_rope, stride_xb=q_rope.stride(0), stride_xt=q_rope.stride(1), stride_xd=q_rope.stride(2), stride_cos_t=0, stride_cos_d=cos_q.stride(0), stride_sin_t=0, stride_sin_d=sin_q.stride(0), BLOCK_HALF=BLOCK_HALF, num_warps=4, ) _rope_cache = {} def _rotate_half(x: torch.Tensor) -> torch.Tensor: half = x.shape[-1] // 2 return torch.cat((-x[..., half:], x[..., :half]), dim=-1) def _get_rope_tables(dim: int, max_seq_len: int, device: torch.device): key = (dim, max_seq_len, device) if key not in _rope_cache: half = dim // 2 theta = (10000.0 ** (-torch.arange(half, dtype=torch.float32, device=device) / half)).to( torch.bfloat16 ) pos = torch.arange(max_seq_len, dtype=torch.int64, device=device).unsqueeze_(1) idx = pos * theta[None, :] idx = torch.cat([idx, idx], dim=-1) _rope_cache[key] = (idx.cos().to(torch.bfloat16), idx.sin().to(torch.bfloat16)) return _rope_cache[key] @triton.jit def _softmax_kernel( out_ptr, in_ptr, stride_out, stride_in, n_cols, BLOCK_SIZE: tl.constexpr, NUM_STAGES: tl.constexpr, ): row = tl.program_id(0) row_off_in = row * stride_in row_off_out = row * stride_out max_val = tl.full([BLOCK_SIZE], -float("inf"), tl.float32) col = tl.arange(0, BLOCK_SIZE) for start in range(0, n_cols, BLOCK_SIZE): cur = start + col mask = cur < n_cols val = tl.load(in_ptr + row_off_in + cur, mask=mask, other=-float('inf')) max_val = tl.maximum(max_val, tl.cast(val, tl.float32)) row_max = tl.max(max_val) sum_val = tl.full([BLOCK_SIZE], 0.0, tl.float32) for start in range(0, n_cols, BLOCK_SIZE): cur = start + col mask = cur < n_cols val = tl.load(in_ptr + row_off_in + cur, mask=mask, other=-float('inf')) exp_val = tl.exp(tl.cast(val, tl.float32) - row_max) tl.store(out_ptr + row_off_out + cur, tl.cast(exp_val, tl.bfloat16), mask=mask) sum_val += exp_val row_sum = tl.sum(sum_val) for start in range(0, n_cols, BLOCK_SIZE): cur = start + col mask = cur < n_cols val = tl.load(out_ptr + row_off_out + cur, mask=mask, other=0.0) norm = tl.cast(val, tl.float32) / row_sum tl.store(out_ptr + row_off_out + cur, tl.cast(norm, tl.bfloat16), mask=mask) def _triton_softmax(x: torch.Tensor) -> torch.Tensor: assert x.is_cuda and x.dtype == torch.bfloat16 n_rows, n_cols = x.shape if n_cols <= 32: BLOCK_SIZE = 32 elif n_cols <= 64: BLOCK_SIZE = 64 elif n_cols <= 128: BLOCK_SIZE = 128 else: BLOCK_SIZE = 1 << (n_cols - 1).bit_length() BLOCK_SIZE = min(BLOCK_SIZE, 1024) out = torch.empty_like(x) grid = (n_rows,) _softmax_kernel[grid]( out, x, out.stride(0), x.stride(0), n_cols, BLOCK_SIZE=BLOCK_SIZE, NUM_STAGES=2, num_warps=4, ) return out def custom_kernel(data: Tuple[Config, torch.Tensor, KVCache]) -> Tuple[torch.Tensor, torch.Tensor]: """ Optimised forward step of the Multi-head Latent Attention (MLA) module. """ config, x, kv_cache = data bs = config.batch_size sl = config.seq_len nh = config.n_heads dq = config.q_lora_rank dkv = config.kv_lora_rank d_nope = config.qk_nope_head_dim d_rope = config.qk_rope_head_dim dv = config.v_head_dim msl = config.max_seq_len wDQ = config.Q_proj_down_weight wDKV = config.KV_proj_down_weight wUQ = config.Q_proj_up_weight wUKV = config.KV_proj_up_weight wO = config.wo_weight q_lora = F.linear(x, wDQ) kv_lora_input = F.linear(x, wDKV) kv_lora, kv_len = kv_cache(kv_lora_input) query_pos = kv_len - 1 q_up = F.linear(q_lora.squeeze(1), wUQ) q_up = q_up.view(bs, nh, d_nope + d_rope) q_nope = q_up[..., :d_nope] q_rope = q_up[..., d_nope:] kv_nope_input = kv_lora[..., :dkv] k_rope_input = kv_lora[..., dkv:] cos_table, sin_table = _get_rope_tables(d_rope, msl, x.device) cos_q = cos_table[query_pos].view(d_rope).contiguous() sin_q = sin_table[query_pos].view(d_rope).contiguous() rope_inplace_query(q_rope, cos_q, sin_q) cos_k = cos_table[:kv_len] sin_k = sin_table[:kv_len] k_rope = k_rope_input * cos_k + _rotate_half(k_rope_input) * sin_k wUKV_view = wUKV.view(nh, d_nope + dv, dkv) wK = wUKV_view[:, :d_nope, :] q_nope_latent = torch.einsum('bhd,hdk->bhk', q_nope, wK) kv_nope_T = kv_nope_input.transpose(1, 2) scores_nope = torch.matmul(q_nope_latent, kv_nope_T) scores_rope = torch.matmul(q_rope, k_rope.transpose(-2, -1)) scale = 1.0 / math.sqrt(d_nope + d_rope) scores = (scores_nope + scores_rope) * scale scores_flat = scores.reshape(bs * nh, kv_len) attn_flat = _triton_softmax(scores_flat) attn = attn_flat.view(bs, nh, kv_len) M = torch.matmul(attn, kv_nope_input) wV = wUKV_view[:, d_nope:, :] wV_T = wV.permute(0, 2, 1) y_head = torch.einsum('bhd,hdk->bhk', M, wV_T) y = y_head.reshape(bs, nh * dv) y = y.unsqueeze(1) output = F.linear(y, wO) return output, kv_cache.data # EVOLVE-BLOCK-END