shinka-backup / ccevolve /baselines /thetaevolve /slime /backends /megatron_utils /kernels /fp8_kernel.py
| import torch | |
| import triton | |
| import triton.language as tl | |
| fp8_dtype = torch.float8_e4m3fn | |
| fp8_max = torch.finfo(fp8_dtype).max | |
| fp8_min = -fp8_max | |
| def ceil_div(x: int, y: int) -> int: | |
| """ | |
| Perform ceiling division of two integers. | |
| Args: | |
| x: the dividend. | |
| y: the divisor. | |
| Returns: | |
| The result of the ceiling division. | |
| """ | |
| return (x + y - 1) // y | |
| def _blockwise_cast_to_fp8_triton( | |
| X, | |
| Y, | |
| S, | |
| stride_xm, | |
| stride_xn, | |
| stride_ym, | |
| stride_yn, | |
| stride_sm, | |
| stride_sn, | |
| M, | |
| N, | |
| eps, | |
| fp8_min, | |
| fp8_max, | |
| BLOCK_M: tl.constexpr = 32, | |
| BLOCK_N: tl.constexpr = 128, | |
| ): | |
| pid_m = tl.cast(tl.program_id(axis=0), tl.int64) | |
| pid_n = tl.cast(tl.program_id(axis=1), tl.int64) | |
| off_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) | |
| off_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) | |
| mask_m = off_m < M | |
| mask_n = off_n < N | |
| mask = mask_m[:, None] & mask_n[None, :] | |
| x = tl.load(X + off_m[:, None] * stride_xm + off_n[None, :] * stride_xn, mask=mask, other=0.0).to(tl.float32) | |
| _absmax = tl.maximum(tl.max(tl.abs(x)), eps) | |
| x_s = _absmax / fp8_max | |
| s_inv = 1.0 / x_s | |
| y_q = tl.clamp(x * s_inv, fp8_min, fp8_max).to(Y.dtype.element_ty) | |
| tl.store(Y + off_m[:, None] * stride_ym + off_n[None, :] * stride_yn, y_q, mask=mask) | |
| tl.store(S + pid_m * stride_sm + pid_n * stride_sn, x_s) | |
| def blockwise_cast_to_fp8_triton(x: torch.Tensor, block_size=None) -> tuple[torch.Tensor, torch.Tensor]: | |
| BLOCK_M, BLOCK_N = 128, 128 | |
| if block_size: | |
| BLOCK_M, BLOCK_N = block_size[0], block_size[1] | |
| M, N = x.shape | |
| y = torch.empty(M, N, device=x.device, dtype=torch.float8_e4m3fn) | |
| s = torch.empty(ceil_div(M, BLOCK_M), ceil_div(N, BLOCK_N), dtype=torch.float32, device=x.device) | |
| def grid(meta): | |
| return (triton.cdiv(M, meta["BLOCK_M"]), triton.cdiv(N, meta["BLOCK_N"])) | |
| if x.is_contiguous(): | |
| kwargs = {"BLOCK_M": BLOCK_M, "BLOCK_N": BLOCK_N, "num_warps": 8, "num_stages": 2} | |
| else: | |
| kwargs = {"BLOCK_M": BLOCK_M, "BLOCK_N": BLOCK_N, "num_warps": 1, "num_stages": 4} | |
| _blockwise_cast_to_fp8_triton[grid]( | |
| x, y, s, *x.stride(), *y.stride(), *s.stride(), M, N, 1e-10, fp8_min, fp8_max, **kwargs | |
| ) | |
| return y, s | |