from typing import Tuple import torch import triton import triton.language as tl fp8_dtype = torch.float8_e4m3fn fp8_max = torch.finfo(fp8_dtype).max fp8_min = -fp8_max def ceil_div(x: int, y: int) -> int: """ Perform ceiling division of two integers. Args: x: the dividend. y: the divisor. Returns: The result of the ceiling division. """ return (x + y - 1) // y @triton.jit def _blockwise_cast_to_fp8_triton( X, Y, S, stride_xm, stride_xn, stride_ym, stride_yn, stride_sm, stride_sn, M, N, eps, fp8_min, fp8_max, BLOCK_M: tl.constexpr = 32, BLOCK_N: tl.constexpr = 128, ): pid_m = tl.cast(tl.program_id(axis=0), tl.int64) pid_n = tl.cast(tl.program_id(axis=1), tl.int64) off_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) off_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) mask_m = off_m < M mask_n = off_n < N mask = mask_m[:, None] & mask_n[None, :] x = tl.load(X + off_m[:, None] * stride_xm + off_n[None, :] * stride_xn, mask=mask, other=0.0).to(tl.float32) _absmax = tl.maximum(tl.max(tl.abs(x)), eps) x_s = _absmax / fp8_max s_inv = 1.0 / x_s y_q = tl.clamp(x * s_inv, fp8_min, fp8_max).to(Y.dtype.element_ty) tl.store(Y + off_m[:, None] * stride_ym + off_n[None, :] * stride_yn, y_q, mask=mask) tl.store(S + pid_m * stride_sm + pid_n * stride_sn, x_s) def blockwise_cast_to_fp8_triton(x: torch.Tensor, block_size=None) -> Tuple[torch.Tensor, torch.Tensor]: BLOCK_M, BLOCK_N = 128, 128 if block_size: BLOCK_M, BLOCK_N = block_size[0], block_size[1] M, N = x.shape y = torch.empty(M, N, device=x.device, dtype=torch.float8_e4m3fn) s = torch.empty(ceil_div(M, BLOCK_M), ceil_div(N, BLOCK_N), dtype=torch.float32, device=x.device) grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]), triton.cdiv(N, meta["BLOCK_N"])) if x.is_contiguous(): kwargs = {"BLOCK_M": BLOCK_M, "BLOCK_N": BLOCK_N, "num_warps": 8, "num_stages": 2} else: kwargs = {"BLOCK_M": BLOCK_M, "BLOCK_N": BLOCK_N, "num_warps": 1, "num_stages": 4} _blockwise_cast_to_fp8_triton[grid]( x, y, s, *x.stride(), *y.stride(), *s.stride(), M, N, 1e-10, fp8_min, fp8_max, **kwargs ) return y, s