| from typing import Tuple |
|
|
| import torch |
| import triton |
| import triton.language as tl |
|
|
| fp8_dtype = torch.float8_e4m3fn |
| fp8_max = torch.finfo(fp8_dtype).max |
| fp8_min = -fp8_max |
|
|
|
|
| def ceil_div(x: int, y: int) -> int: |
| """ |
| Perform ceiling division of two integers. |
| |
| Args: |
| x: the dividend. |
| y: the divisor. |
| |
| Returns: |
| The result of the ceiling division. |
| """ |
| return (x + y - 1) // y |
|
|
|
|
| @triton.jit |
| def _blockwise_cast_to_fp8_triton( |
| X, |
| Y, |
| S, |
| stride_xm, |
| stride_xn, |
| stride_ym, |
| stride_yn, |
| stride_sm, |
| stride_sn, |
| M, |
| N, |
| eps, |
| fp8_min, |
| fp8_max, |
| BLOCK_M: tl.constexpr = 32, |
| BLOCK_N: tl.constexpr = 128, |
| ): |
| pid_m = tl.cast(tl.program_id(axis=0), tl.int64) |
| pid_n = tl.cast(tl.program_id(axis=1), tl.int64) |
| off_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) |
| off_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) |
| mask_m = off_m < M |
| mask_n = off_n < N |
| mask = mask_m[:, None] & mask_n[None, :] |
|
|
| x = tl.load(X + off_m[:, None] * stride_xm + off_n[None, :] * stride_xn, mask=mask, other=0.0).to(tl.float32) |
| _absmax = tl.maximum(tl.max(tl.abs(x)), eps) |
| x_s = _absmax / fp8_max |
| s_inv = 1.0 / x_s |
| y_q = tl.clamp(x * s_inv, fp8_min, fp8_max).to(Y.dtype.element_ty) |
|
|
| tl.store(Y + off_m[:, None] * stride_ym + off_n[None, :] * stride_yn, y_q, mask=mask) |
| tl.store(S + pid_m * stride_sm + pid_n * stride_sn, x_s) |
|
|
|
|
| def blockwise_cast_to_fp8_triton(x: torch.Tensor, block_size=None) -> Tuple[torch.Tensor, torch.Tensor]: |
| BLOCK_M, BLOCK_N = 128, 128 |
| if block_size: |
| BLOCK_M, BLOCK_N = block_size[0], block_size[1] |
| M, N = x.shape |
| y = torch.empty(M, N, device=x.device, dtype=torch.float8_e4m3fn) |
| s = torch.empty(ceil_div(M, BLOCK_M), ceil_div(N, BLOCK_N), dtype=torch.float32, device=x.device) |
| grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]), triton.cdiv(N, meta["BLOCK_N"])) |
| if x.is_contiguous(): |
| kwargs = {"BLOCK_M": BLOCK_M, "BLOCK_N": BLOCK_N, "num_warps": 8, "num_stages": 2} |
| else: |
| kwargs = {"BLOCK_M": BLOCK_M, "BLOCK_N": BLOCK_N, "num_warps": 1, "num_stages": 4} |
| _blockwise_cast_to_fp8_triton[grid]( |
| x, y, s, *x.stride(), *y.stride(), *s.stride(), M, N, 1e-10, fp8_min, fp8_max, **kwargs |
| ) |
| return y, s |
|
|