| from typing import Any, Callable, Dict, Hashable, Tuple |
|
|
| import torch |
| import triton |
| import triton.language as tl |
| from triton.compiler import CompiledKernel |
| from triton.runtime import JITFunction |
|
|
| try: |
| import triton.language.math as tlmath |
| except ImportError: |
| import triton.language.libdevice as tlmath |
|
|
|
|
| class TritonKernel: |
| def __init__( |
| self, |
| kernel_fn: JITFunction, |
| grid_fn: Callable[[Tuple[Any, ...]], Tuple[int, int, int]], |
| ) -> None: |
| self.kernel_fn_ = kernel_fn |
| self.grid_fn_ = grid_fn |
| self.kernel_cache_: Dict[Hashable, CompiledKernel] = {} |
|
|
| def run(self, *args, **kwargs): |
| |
| input_device = args[0].device |
| prev_dev_idx, cur_dev_idx = -1, torch.cuda.current_device() |
| if input_device.index != cur_dev_idx: |
| prev_dev_idx = cur_dev_idx |
| torch.cuda.set_device(input_device.index) |
|
|
| |
| grid = self.grid_fn_(args) |
|
|
| |
| kernel_key = (input_device,) + tuple(kwargs.items()) |
| if kernel_key in self.kernel_cache_: |
| kernel = self.kernel_cache_[kernel_key] |
| kernel[grid](*args) |
| else: |
| |
| kernel = self.kernel_fn_[grid](*args, **kwargs) |
| self.kernel_cache_[kernel_key] = kernel |
|
|
| |
| torch.cuda.set_device(prev_dev_idx) |
|
|
|
|
| @triton.jit |
| def _apply_rope_fwd_kernel(X, Cos, Sin, Y, HEAD_DIM: tl.constexpr): |
| batch_idx, tok_idx, head_idx = tl.program_id(0), tl.program_id(1), tl.program_id(2) |
| seq_len, num_heads = tl.num_programs(1), tl.num_programs(2) |
| block_idx = tl.arange(0, HEAD_DIM) |
| x_base_idx = ((batch_idx * seq_len + tok_idx) * num_heads * 3 + head_idx) * HEAD_DIM |
| x = tl.load(X + x_base_idx + block_idx) |
| freq_idx = tok_idx * HEAD_DIM + block_idx |
| cos = tl.load(Cos + freq_idx) |
| rot_idx = (HEAD_DIM // 2 + block_idx) % HEAD_DIM |
| x_rot = tl.load(X + x_base_idx + rot_idx) |
| x_rot = tl.where(block_idx >= HEAD_DIM // 2, x_rot, -x_rot) |
| sin = tl.load(Sin + freq_idx) |
| y_idx = ( |
| (batch_idx * seq_len + tok_idx) * num_heads + head_idx |
| ) * HEAD_DIM + block_idx |
| y = x * cos + x_rot * sin |
| tl.store(Y + y_idx, y.to(Y.dtype.element_ty)) |
|
|
|
|
| apply_rope_fwd_kernel = TritonKernel( |
| _apply_rope_fwd_kernel, lambda args: tuple(args[0].shape[:3]) |
| ) |
|
|
|
|
| def apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): |
| y = torch.empty(x.shape, dtype=x.dtype, device=x.device) |
| apply_rope_fwd_kernel.run(x, cos, sin, y, HEAD_DIM=x.size(-1)) |
| return y |
|
|
|
|
| @triton.jit |
| def _rms_norm_fwd_kernel(X, W, Y, eps, hidden_dim, BLOCK_SIZE: tl.constexpr): |
| tok_idx = tl.program_id(0) |
|
|
| mean_sq = tl.zeros([BLOCK_SIZE], tl.float32) |
| for offset in range(0, hidden_dim, BLOCK_SIZE): |
| dim_idx = offset + tl.arange(0, BLOCK_SIZE) |
| x = tl.load( |
| X + tok_idx * hidden_dim + dim_idx, mask=dim_idx < hidden_dim, other=0 |
| ).to(tl.float32) |
| mean_sq += x * x / hidden_dim |
| rrms = tlmath.rsqrt(tl.sum(mean_sq, 0) + eps) |
|
|
| for offset in range(0, hidden_dim, BLOCK_SIZE): |
| dim_idx = offset + tl.arange(0, BLOCK_SIZE) |
| dim_mask = dim_idx < hidden_dim |
| hidden_idx = tok_idx * hidden_dim + dim_idx |
| x = tl.load(X + hidden_idx, mask=dim_mask, other=0) |
| w = tl.load(W + dim_idx, mask=dim_mask, other=0) |
| y = x * rrms * w |
| tl.store(Y + hidden_idx, y.to(Y.dtype.element_ty), mask=dim_mask) |
|
|
|
|
| rms_norm_fwd_kernel = TritonKernel( |
| _rms_norm_fwd_kernel, lambda args: (args[0].shape[:-1].numel(), 1, 1) |
| ) |
|
|
|
|
| def rms_norm(x: torch.Tensor, weight: torch.Tensor, eps: float): |
| y = torch.empty_like(x) |
| hidden_dim = x.size(-1) |
| rms_norm_fwd_kernel.run( |
| x, weight, y, eps, hidden_dim, BLOCK_SIZE=triton.next_power_of_2(hidden_dim) |
| ) |
| return y |
|
|