| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import torch |
| import triton |
| import triton.language as tl |
| from .act_quant import fp8_act_quant |
| from torch.library import triton_op, wrap_triton |
|
|
| from .utils import device_context |
|
|
|
|
| @triton.autotune( |
| configs=[ |
| triton.Config({}, num_warps=w, num_stages=s) |
| for w in [2, 4, 8, 16] |
| for s in [2, 3, 4, 5] |
| ], |
| key=["N", "K", "BLOCK_SIZE_M"], |
| ) |
| @triton.jit |
| def w8a8_block_fp8_matmul_grouped_kernel( |
| A, |
| B, |
| C, |
| Bs, |
| Offsets, |
| TileOffsets, |
| |
| S, |
| N, |
| K, |
| |
| stride_am, |
| stride_ak, |
| stride_be, |
| stride_bk, |
| stride_bn, |
| stride_cm, |
| stride_cn, |
| stride_bs_e, |
| stride_bs_k, |
| stride_bs_n, |
| |
| NUM_EXPERTS: tl.constexpr, |
| BLOCK_SIZE_N: tl.constexpr, |
| BLOCK_SIZE_K: tl.constexpr, |
| BLOCK_SIZE_M: tl.constexpr, |
| NUM_EXPERTS_BIT_LENGTH: tl.constexpr, |
| ): |
| """Block-scale grouped FP8 expert matmul kernel. |
| |
| Tokens are assumed sorted by expert. The kernel maps each M-tile to its |
| owning expert via ``TileOffsets`` and applies fused activation quantization. |
| """ |
| pid_m = tl.program_id(axis=0) |
| pid_n = tl.program_id(axis=1) |
|
|
| |
| total_tiles = tl.load(TileOffsets + NUM_EXPERTS - 1) |
| if pid_m >= total_tiles: |
| return |
|
|
| |
| |
| |
| |
| |
| |
| |
| lo = 0 |
| hi = NUM_EXPERTS |
| for _ in tl.static_range(NUM_EXPERTS_BIT_LENGTH): |
| mid = (lo + hi) >> 1 |
| mid_val = tl.load(TileOffsets + mid) |
| is_left = mid_val <= pid_m |
| lo = tl.where(is_left, mid + 1, lo) |
| hi = tl.where(is_left, hi, mid) |
|
|
| |
| |
| |
| expert_id = lo.to(tl.int64) |
|
|
| prev_eid = tl.maximum(expert_id - 1, 0) |
| expert_start = tl.where(expert_id == 0, 0, tl.load(Offsets + prev_eid)) |
| expert_end = tl.load(Offsets + expert_id) |
| M_expert = expert_end - expert_start |
|
|
| expert_tile_start = tl.where(expert_id == 0, 0, tl.load(TileOffsets + prev_eid)) |
| local_tile = pid_m - expert_tile_start |
| m_off = local_tile * BLOCK_SIZE_M |
|
|
| offs_am = m_off + tl.arange(0, BLOCK_SIZE_M) |
| row_mask = offs_am < M_expert |
| offs_global_m = expert_start + offs_am |
|
|
| offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) |
| offs_k = tl.arange(0, BLOCK_SIZE_K) |
|
|
| a_ptrs = A + offs_global_m[:, None] * stride_am + offs_k[None, :] * stride_ak |
| b_ptrs = ( |
| B |
| + expert_id * stride_be |
| + offs_k[:, None] * stride_bk |
| + offs_bn[None, :] * stride_bn |
| ) |
| bs_ptrs = Bs + expert_id * stride_bs_e + pid_n * stride_bs_n |
|
|
| accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) |
| for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): |
| |
| a_raw = tl.load(a_ptrs, mask=row_mask[:, None], other=0.0).to(tl.float32) |
| a_s = tl.max(tl.abs(a_raw), axis=1) / 448.0 |
| a = (a_raw / tl.maximum(a_s[:, None], 1e-12)).to(tl.float8e4nv) |
| |
| b = tl.load(b_ptrs) |
| b_s = tl.load(bs_ptrs + k * stride_bs_k) |
| accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :] |
| a_ptrs += BLOCK_SIZE_K * stride_ak |
| b_ptrs += BLOCK_SIZE_K * stride_bk |
|
|
| if C.dtype.element_ty == tl.bfloat16: |
| c = accumulator.to(tl.bfloat16) |
| elif C.dtype.element_ty == tl.float16: |
| c = accumulator.to(tl.float16) |
| else: |
| c = accumulator.to(tl.float32) |
|
|
| c_ptrs = C + stride_cm * offs_global_m[:, None] + stride_cn * offs_bn[None, :] |
| c_mask = row_mask[:, None] |
| tl.store(c_ptrs, c, mask=c_mask) |
|
|
|
|
| @triton.autotune( |
| configs=[ |
| triton.Config({}, num_warps=w, num_stages=s) |
| for w in [2, 4, 8, 16] |
| for s in [2, 3, 4, 5] |
| ], |
| key=["N", "K", "BLOCK_SIZE_M"], |
| ) |
| @triton.jit |
| def w8a8_tensor_fp8_matmul_grouped_kernel( |
| A, |
| B, |
| C, |
| As, |
| Bs, |
| Offsets, |
| TileOffsets, |
| S, |
| N, |
| K, |
| stride_am, |
| stride_ak, |
| stride_be, |
| stride_bk, |
| stride_bn, |
| stride_cm, |
| stride_cn, |
| stride_as_m, |
| stride_bs_e, |
| NUM_EXPERTS: tl.constexpr, |
| BLOCK_SIZE_N: tl.constexpr, |
| BLOCK_SIZE_K: tl.constexpr, |
| BLOCK_SIZE_M: tl.constexpr, |
| NUM_EXPERTS_BIT_LENGTH: tl.constexpr, |
| ): |
| """Tensor-scale grouped FP8 expert matmul kernel. |
| |
| Uses grouped expert scheduling with pre-quantized activations plus |
| per-token activation scales and per-expert tensor weight scales. |
| """ |
| pid_m = tl.program_id(axis=0) |
| pid_n = tl.program_id(axis=1) |
|
|
| total_tiles = tl.load(TileOffsets + NUM_EXPERTS - 1) |
| if pid_m >= total_tiles: |
| return |
|
|
| lo = 0 |
| hi = NUM_EXPERTS |
| for _ in tl.static_range(NUM_EXPERTS_BIT_LENGTH): |
| mid = (lo + hi) >> 1 |
| mid_val = tl.load(TileOffsets + mid) |
| is_left = mid_val <= pid_m |
| lo = tl.where(is_left, mid + 1, lo) |
| hi = tl.where(is_left, hi, mid) |
| expert_id = lo.to(tl.int64) |
|
|
| prev_eid = tl.maximum(expert_id - 1, 0) |
| expert_start = tl.where(expert_id == 0, 0, tl.load(Offsets + prev_eid)) |
| expert_end = tl.load(Offsets + expert_id) |
| M_expert = expert_end - expert_start |
|
|
| expert_tile_start = tl.where(expert_id == 0, 0, tl.load(TileOffsets + prev_eid)) |
| local_tile = pid_m - expert_tile_start |
| m_off = local_tile * BLOCK_SIZE_M |
|
|
| offs_am = m_off + tl.arange(0, BLOCK_SIZE_M) |
| row_mask = offs_am < M_expert |
| offs_global_m = expert_start + offs_am |
|
|
| offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) |
| offs_k = tl.arange(0, BLOCK_SIZE_K) |
|
|
| a_ptrs = A + offs_global_m[:, None] * stride_am + offs_k[None, :] * stride_ak |
| b_ptrs = ( |
| B |
| + expert_id * stride_be |
| + offs_k[:, None] * stride_bk |
| + offs_bn[None, :] * stride_bn |
| ) |
|
|
| a_s = tl.load(As + offs_global_m * stride_as_m, mask=row_mask, other=0.0) |
| b_s = tl.load(Bs + expert_id * stride_bs_e) |
|
|
| accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) |
| for _ in range(0, tl.cdiv(K, BLOCK_SIZE_K)): |
| a = tl.load(a_ptrs, mask=row_mask[:, None], other=0.0) |
| b = tl.load(b_ptrs) |
|
|
| accumulator += tl.dot(a, b) |
| a_ptrs += BLOCK_SIZE_K * stride_ak |
| b_ptrs += BLOCK_SIZE_K * stride_bk |
|
|
| accumulator = accumulator * a_s[:, None] * b_s |
|
|
| if C.dtype.element_ty == tl.bfloat16: |
| c = accumulator.to(tl.bfloat16) |
| elif C.dtype.element_ty == tl.float16: |
| c = accumulator.to(tl.float16) |
| else: |
| c = accumulator.to(tl.float32) |
|
|
| c_ptrs = C + stride_cm * offs_global_m[:, None] + stride_cn * offs_bn[None, :] |
| c_mask = row_mask[:, None] |
| tl.store(c_ptrs, c, mask=c_mask) |
|
|
|
|
| @triton_op("finegrained_fp8::w8a8_block_fp8_matmul_grouped", mutates_args=()) |
| def _w8a8_block_fp8_matmul_grouped( |
| A: torch.Tensor, |
| B: torch.Tensor, |
| Bs: torch.Tensor, |
| offsets: torch.Tensor, |
| tokens_per_expert: torch.Tensor, |
| block_size: list[int], |
| ) -> torch.Tensor: |
| """Block-scale grouped FP8 matmul: C = A @ B.T per expert, with fused act quant. |
| |
| A: (S, K) raw bf16/fp16 activations, sorted by expert |
| B: (E, N, K) FP8 expert weights |
| Bs: (E, N // block_n, K // block_k) per-block weight scales |
| """ |
| assert A.ndim == 2, f"A must be 2D (S, K), got ndim={A.ndim}" |
| assert A.is_contiguous(), "A must be contiguous" |
| assert B.ndim == 3, f"B must be 3D (E, N, K), got ndim={B.ndim}" |
| assert B.is_contiguous(), "B must be contiguous" |
| assert A.shape[1] == B.shape[2], ( |
| f"K mismatch: A has K={A.shape[1]}, B has K={B.shape[2]}" |
| ) |
|
|
| S, K = A.shape |
| E, N, _ = B.shape |
|
|
| assert len(block_size) == 2, ( |
| f"block_size must be [block_n, block_k], got {block_size}" |
| ) |
| block_n, block_k = block_size[0], block_size[1] |
| |
| assert N % block_n == 0, f"N ({N}) must be divisible by block_n ({block_n})" |
| assert K % block_k == 0, f"K ({K}) must be divisible by block_k ({block_k})" |
| assert Bs.ndim == 3, ( |
| f"Bs must be 3D (E, N//block_n, K//block_k), got ndim={Bs.ndim}" |
| ) |
| assert Bs.shape == (E, N // block_n, K // block_k), ( |
| f"Bs shape {tuple(Bs.shape)} != expected ({E}, {N // block_n}, {K // block_k})" |
| ) |
|
|
| C = A.new_empty(S, N) |
| |
| BLOCK_SIZE_M = min(max(triton.next_power_of_2((S + E - 1) // E), 16), 128) |
| tiles_per_expert = (tokens_per_expert + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M |
| tile_offsets = torch.cumsum(tiles_per_expert, dim=0).to(torch.int32) |
| |
| |
| |
| |
| max_M_tiles = triton.cdiv(S, BLOCK_SIZE_M) + E |
| grid = (max_M_tiles, triton.cdiv(N, block_n)) |
| with device_context(A.device): |
| wrap_triton(w8a8_block_fp8_matmul_grouped_kernel)[grid]( |
| A, |
| B, |
| C, |
| Bs, |
| offsets, |
| tile_offsets, |
| S, |
| N, |
| K, |
| A.stride(0), |
| A.stride(1), |
| B.stride(0), |
| B.stride(2), |
| B.stride(1), |
| C.stride(0), |
| C.stride(1), |
| Bs.stride(0), |
| Bs.stride(2), |
| Bs.stride(1), |
| |
| NUM_EXPERTS=E, |
| BLOCK_SIZE_N=block_n, |
| BLOCK_SIZE_K=block_k, |
| BLOCK_SIZE_M=BLOCK_SIZE_M, |
| NUM_EXPERTS_BIT_LENGTH=E.bit_length(), |
| ) |
|
|
| return C |
|
|
|
|
| @triton_op("finegrained_fp8::w8a8_tensor_fp8_matmul_grouped", mutates_args=()) |
| def _w8a8_tensor_fp8_matmul_grouped( |
| A: torch.Tensor, |
| B: torch.Tensor, |
| Bs: torch.Tensor, |
| offsets: torch.Tensor, |
| tokens_per_expert: torch.Tensor, |
| ) -> torch.Tensor: |
| """Tensor-scale grouped FP8 matmul: C = A @ B.T per expert, with fused act quant. |
| |
| A: (S, K) raw bf16/fp16 activations, sorted by expert |
| B: (E, N, K) FP8 expert weights |
| Bs: (E,) or (E, 1, 1) per-expert weight scales |
| """ |
| assert A.ndim == 2, f"A must be 2D (S, K), got ndim={A.ndim}" |
| assert A.is_contiguous(), "A must be contiguous" |
| assert B.ndim == 3, f"B must be 3D (E, N, K), got ndim={B.ndim}" |
| assert B.is_contiguous(), "B must be contiguous" |
| assert A.shape[1] == B.shape[2], ( |
| f"K mismatch: A has K={A.shape[1]}, B has K={B.shape[2]}" |
| ) |
|
|
| S, K = A.shape |
| E, N, _ = B.shape |
|
|
| |
| if Bs.ndim == 1: |
| assert Bs.shape[0] == E, f"Bs shape {tuple(Bs.shape)} != expected ({E},)" |
| Bs = Bs.reshape(E, 1, 1) |
| else: |
| assert Bs.shape == (E, 1, 1), ( |
| f"Bs shape {tuple(Bs.shape)} != expected ({E}, 1, 1)" |
| ) |
|
|
| BLOCK_SIZE_N = 128 |
| BLOCK_SIZE_K = 128 |
| C = A.new_empty(S, N) |
| qA, As = fp8_act_quant(A, K) |
| BLOCK_SIZE_M = min(max(triton.next_power_of_2((S + E - 1) // E), 16), 128) |
| tiles_per_expert = (tokens_per_expert + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M |
| tile_offsets = torch.cumsum(tiles_per_expert, dim=0).to(torch.int32) |
| |
| |
| |
| |
| max_M_tiles = triton.cdiv(S, BLOCK_SIZE_M) + E |
| grid = (max_M_tiles, triton.cdiv(N, BLOCK_SIZE_N)) |
| with device_context(A.device): |
| wrap_triton(w8a8_tensor_fp8_matmul_grouped_kernel)[grid]( |
| qA, |
| B, |
| C, |
| As, |
| Bs, |
| offsets, |
| tile_offsets, |
| S, |
| N, |
| K, |
| qA.stride(0), |
| qA.stride(1), |
| B.stride(0), |
| B.stride(2), |
| B.stride(1), |
| C.stride(0), |
| C.stride(1), |
| As.stride(0), |
| Bs.stride(0), |
| |
| NUM_EXPERTS=E, |
| BLOCK_SIZE_N=BLOCK_SIZE_N, |
| BLOCK_SIZE_K=BLOCK_SIZE_K, |
| BLOCK_SIZE_M=BLOCK_SIZE_M, |
| NUM_EXPERTS_BIT_LENGTH=E.bit_length(), |
| ) |
|
|
| return C |
|
|
|
|
| def w8a8_block_fp8_matmul_grouped( |
| A: torch.Tensor, |
| B: torch.Tensor, |
| Bs: torch.Tensor, |
| offsets: torch.Tensor, |
| tokens_per_expert: torch.Tensor, |
| block_size: list[int], |
| ) -> torch.Tensor: |
| """Block-scale grouped FP8 matmul with fused activation quantization. |
| |
| A: (S, K) raw activations sorted by expert, bf16/fp16/fp32 |
| B: (E, N, K) FP8 expert weights |
| Bs: (E, N // block_n, K // block_k) per-block weight scales |
| """ |
| return torch.ops.finegrained_fp8.w8a8_block_fp8_matmul_grouped( |
| A, B, Bs, offsets, tokens_per_expert, block_size |
| ) |
|
|
|
|
| def w8a8_tensor_fp8_matmul_grouped( |
| A: torch.Tensor, |
| B: torch.Tensor, |
| Bs: torch.Tensor, |
| offsets: torch.Tensor, |
| tokens_per_expert: torch.Tensor, |
| ) -> torch.Tensor: |
| """Tensor-scale grouped FP8 matmul with fused activation quantization. |
| |
| A: (S, K) raw activations sorted by expert, bf16/fp16/fp32 |
| B: (E, N, K) FP8 expert weights |
| Bs: (E,) or (E, 1, 1) per-expert weight scales |
| """ |
| return torch.ops.finegrained_fp8.w8a8_tensor_fp8_matmul_grouped( |
| A, B, Bs, offsets, tokens_per_expert |
| ) |
|
|
|
|
| def w8a8_fp8_matmul_grouped( |
| A: torch.Tensor, |
| B: torch.Tensor, |
| Bs: torch.Tensor, |
| offsets: torch.Tensor, |
| tokens_per_expert: torch.Tensor, |
| block_size: list[int] | None, |
| ) -> torch.Tensor: |
| """Unified grouped W8A8 FP8 matmul dispatcher. |
| |
| Dispatch rules: |
| - tensor mode when ``block_size is None`` |
| - tensor mode when ``block_size == [N, K]`` |
| - otherwise block mode |
| |
| Returns: |
| Output tensor ``[S, N]`` in the same dtype as ``A``, in expert-sorted order. |
| """ |
| if block_size is None or ( |
| block_size[0] == B.size(1) and block_size[1] == B.size(2) |
| ): |
| return w8a8_tensor_fp8_matmul_grouped(A, B, Bs, offsets, tokens_per_expert) |
|
|
| return w8a8_block_fp8_matmul_grouped( |
| A, B, Bs, offsets, tokens_per_expert, block_size |
| ) |
|
|