| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from contextlib import contextmanager |
| from typing import Tuple |
|
|
| import torch |
| import triton |
| import triton.language as tl |
| from torch.library import triton_op, wrap_triton |
|
|
| from ._ops import add_op_namespace_prefix, ops |
|
|
| |
|
|
| |
| FP8_DTYPE = torch.float8_e4m3fn |
| |
| |
| FP4_VALUES_PER_BYTE = 2 |
| FP4_SCALE_GROUP_K = 32 |
|
|
|
|
| |
|
|
|
|
| @contextmanager |
| def device_context(device: torch.device): |
| """Context manager that sets the active device for any backend (cuda, xpu, etc.).""" |
| backend = getattr(torch, device.type, None) |
| if backend is not None and hasattr(backend, "device"): |
| with backend.device(device): |
| yield |
| else: |
| yield |
|
|
|
|
| def adaptive_block_size_m(target_m: int) -> int: |
| """Smallest power-of-2 >= ``target_m``, floored at 16 and capped at 128. |
| |
| Used by all matmul wrappers (single / batched / grouped) to size the M tile |
| to the workload — small per-expert M wants smaller tiles, large M caps out |
| at 128 to keep register pressure bounded. Pass ``M`` for single matmul, or |
| ``(S + E - 1) // E`` (avg tokens per expert) for batched/grouped. |
| """ |
| return min(max(triton.next_power_of_2(target_m), 16), 128) |
|
|
|
|
| def grouped_tile_layout( |
| tokens_per_expert: torch.Tensor, |
| block_size_m: int, |
| S: int, |
| E: int, |
| ) -> Tuple[torch.Tensor, int]: |
| """Compute the M-tile layout for grouped kernels. |
| |
| Returns ``(tile_offsets, max_m_tiles)``: |
| - ``tile_offsets``: int32 (E,) cumulative tile-end per expert, used by |
| ``grouped_expert_lookup`` to locate an M-tile's owning expert. |
| - ``max_m_tiles``: upper bound on total M-tiles, used as the grid axis-0 |
| size. Real tile count <= this; surplus programs early-return inside the |
| kernel. Keeps the grid data-independent (cuda-graph / torch.compile safe). |
| """ |
| tiles_per_expert = (tokens_per_expert + block_size_m - 1) // block_size_m |
| tile_offsets = torch.cumsum(tiles_per_expert, dim=0).to(torch.int32) |
| max_m_tiles = triton.cdiv(S, block_size_m) + E |
| return tile_offsets, max_m_tiles |
|
|
|
|
| |
|
|
|
|
| @triton.jit |
| def fp8_act_quant_inline(a_raw): |
| """Inline FP8 (E4M3) activation quant for the W8A8 block-scale path. |
| |
| Per-row amax → fp32 scale ``amax/448`` (floored at 1e-12 against zero rows) |
| → cast values to FP8. Returns ``(a_fp8, a_s)`` with shapes ``(M, K)`` and |
| ``(M,)``. |
| """ |
| a_s = tl.max(tl.abs(a_raw), axis=1) / 448.0 |
| a_fp8 = (a_raw / tl.maximum(a_s[:, None], 1e-12)).to(tl.float8e4nv) |
| return a_fp8, a_s |
|
|
|
|
| @triton.jit |
| def fp4_act_quant_inline( |
| a_raw, |
| BLOCK_SIZE_M: tl.constexpr, |
| BLOCK_SIZE_K: tl.constexpr, |
| SCALE_GROUP_K: tl.constexpr, |
| ): |
| """Inline FP8 (E4M3) activation quant for the W4A8 path. |
| |
| Per-row, per-K-group amax → UE8M0 scale (ceil to next power-of-2 via the |
| mantissa-nonzero bump trick) → cast values to FP8. Returns ``(a_fp8, |
| a_scale_u8)`` with shapes ``(M, K)`` and ``(M, K // SCALE_GROUP_K)``. |
| """ |
| a_groups = tl.reshape( |
| a_raw, (BLOCK_SIZE_M, BLOCK_SIZE_K // SCALE_GROUP_K, SCALE_GROUP_K) |
| ) |
| a_s_fp32 = tl.max(tl.abs(a_groups), axis=2) / 448.0 |
| bits = a_s_fp32.to(tl.int32, bitcast=True) |
| |
| exp_ceil = ((bits >> 23) & 0xFF) + ((bits & 0x7FFFFF) != 0).to(tl.int32) |
| exp_ceil = tl.minimum(tl.maximum(exp_ceil, 1), 254) |
| a_scale_u8 = exp_ceil.to(tl.uint8) |
| a_s_pow2 = (exp_ceil << 23).to(tl.float32, bitcast=True) |
| a_fp8 = tl.reshape( |
| a_groups / tl.maximum(a_s_pow2[:, :, None], 1e-12), |
| (BLOCK_SIZE_M, BLOCK_SIZE_K), |
| ).to(tl.float8e4nv) |
| return a_fp8, a_scale_u8 |
|
|
|
|
| @triton.jit |
| def grouped_expert_lookup( |
| pid_m, |
| Offsets, |
| TileOffsets, |
| stride_offs, |
| stride_tile, |
| NUM_EXPERTS: tl.constexpr, |
| NUM_EXPERTS_BIT_LENGTH: tl.constexpr, |
| BLOCK_SIZE_M: tl.constexpr, |
| ): |
| """Locate the expert owning a grouped-kernel M-tile and compute row offsets. |
| |
| Returns ``(expert_id, offs_global_m, row_mask)``: |
| - ``expert_id``: int64 |
| - ``offs_global_m``: ``(BLOCK_SIZE_M,)`` global row indices into A |
| - ``row_mask``: ``(BLOCK_SIZE_M,)`` validity mask within the expert's M |
| |
| Caller is expected to have already early-returned if ``pid_m`` exceeds |
| ``total_tiles`` (``TileOffsets[(NUM_EXPERTS - 1) * stride_tile]``). |
| """ |
| |
| |
| |
| lo = 0 |
| hi = NUM_EXPERTS |
| for _ in tl.static_range(NUM_EXPERTS_BIT_LENGTH): |
| mid = (lo + hi) >> 1 |
| mid_val = tl.load(TileOffsets + mid * stride_tile) |
| is_left = mid_val <= pid_m |
| lo = tl.where(is_left, mid + 1, lo) |
| hi = tl.where(is_left, hi, mid) |
| |
| |
| expert_id = lo.to(tl.int64) |
|
|
| prev_eid = tl.maximum(expert_id - 1, 0) |
| expert_start = tl.where( |
| expert_id == 0, 0, tl.load(Offsets + prev_eid * stride_offs) |
| ) |
| expert_end = tl.load(Offsets + expert_id * stride_offs) |
| M_expert = expert_end - expert_start |
|
|
| expert_tile_start = tl.where( |
| expert_id == 0, 0, tl.load(TileOffsets + prev_eid * stride_tile) |
| ) |
| local_tile = pid_m - expert_tile_start |
| m_off = local_tile * BLOCK_SIZE_M |
|
|
| offs_am = m_off + tl.arange(0, BLOCK_SIZE_M) |
| row_mask = offs_am < M_expert |
| offs_global_m = expert_start + offs_am |
| return expert_id, offs_global_m, row_mask |
|
|
|
|
| |
|
|
|
|
| @triton.jit |
| def _fp8_act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr): |
| pid = tl.program_id(axis=0) |
| offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) |
| x = tl.load(x_ptr + offs).to(tl.float32) |
| s = tl.max(tl.abs(x)) / 448.0 |
| y = (x / tl.maximum(s, 1e-12)).to(y_ptr.dtype.element_ty) |
| tl.store(y_ptr + offs, y) |
| tl.store(s_ptr + pid, s) |
|
|
|
|
| @triton_op(add_op_namespace_prefix("fp8_act_quant"), mutates_args=()) |
| def _fp8_act_quant( |
| x: torch.Tensor, block_size: int = 128 |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| assert x.is_contiguous() |
| assert x.shape[-1] % block_size == 0 |
| y = torch.empty_like(x, dtype=FP8_DTYPE) |
| grid = (triton.cdiv(x.numel(), block_size),) |
| s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32) |
|
|
| with device_context(x.device): |
| wrap_triton(_fp8_act_quant_kernel)[grid](x, y, s, BLOCK_SIZE=block_size) |
|
|
| return y, s |
|
|
|
|
| def fp8_act_quant( |
| x: torch.Tensor, block_size: int = 128 |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """Quantize activations to FP8 with per-block dynamic scaling. |
| |
| Splits the last dimension of ``x`` into blocks of ``block_size`` elements, |
| computes ``scale = max(|x_block|) / 448`` per block, and quantizes to |
| ``float8_e4m3fn``. |
| |
| Args: |
| x: Input tensor in bf16/fp16/fp32. Last dimension must be divisible by |
| ``block_size`` and the tensor must be contiguous. |
| block_size: Number of elements per quantization block (default: 128). |
| |
| Returns: |
| A tuple ``(quantized, scales)`` where ``quantized`` has dtype |
| ``float8_e4m3fn`` with the same shape as ``x``, and ``scales`` has |
| shape ``(*x.shape[:-1], x.shape[-1] // block_size)`` in float32. |
| """ |
| return ops.fp8_act_quant(x, block_size) |
|
|