JustinTX's picture
Add files using upload-large-folder tool
d7b3a74 verified
from typing import Tuple
import torch
import triton
import triton.language as tl
fp8_dtype = torch.float8_e4m3fn
fp8_max = torch.finfo(fp8_dtype).max
fp8_min = -fp8_max
def ceil_div(x: int, y: int) -> int:
"""
Perform ceiling division of two integers.
Args:
x: the dividend.
y: the divisor.
Returns:
The result of the ceiling division.
"""
return (x + y - 1) // y
@triton.jit
def _blockwise_cast_to_fp8_triton(
X,
Y,
S,
stride_xm,
stride_xn,
stride_ym,
stride_yn,
stride_sm,
stride_sn,
M,
N,
eps,
fp8_min,
fp8_max,
BLOCK_M: tl.constexpr = 32,
BLOCK_N: tl.constexpr = 128,
):
pid_m = tl.cast(tl.program_id(axis=0), tl.int64)
pid_n = tl.cast(tl.program_id(axis=1), tl.int64)
off_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
off_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
mask_m = off_m < M
mask_n = off_n < N
mask = mask_m[:, None] & mask_n[None, :]
x = tl.load(X + off_m[:, None] * stride_xm + off_n[None, :] * stride_xn, mask=mask, other=0.0).to(tl.float32)
_absmax = tl.maximum(tl.max(tl.abs(x)), eps)
x_s = _absmax / fp8_max
s_inv = 1.0 / x_s
y_q = tl.clamp(x * s_inv, fp8_min, fp8_max).to(Y.dtype.element_ty)
tl.store(Y + off_m[:, None] * stride_ym + off_n[None, :] * stride_yn, y_q, mask=mask)
tl.store(S + pid_m * stride_sm + pid_n * stride_sn, x_s)
def blockwise_cast_to_fp8_triton(x: torch.Tensor, block_size=None) -> Tuple[torch.Tensor, torch.Tensor]:
BLOCK_M, BLOCK_N = 128, 128
if block_size:
BLOCK_M, BLOCK_N = block_size[0], block_size[1]
M, N = x.shape
y = torch.empty(M, N, device=x.device, dtype=torch.float8_e4m3fn)
s = torch.empty(ceil_div(M, BLOCK_M), ceil_div(N, BLOCK_N), dtype=torch.float32, device=x.device)
grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]), triton.cdiv(N, meta["BLOCK_N"]))
if x.is_contiguous():
kwargs = {"BLOCK_M": BLOCK_M, "BLOCK_N": BLOCK_N, "num_warps": 8, "num_stages": 2}
else:
kwargs = {"BLOCK_M": BLOCK_M, "BLOCK_N": BLOCK_N, "num_warps": 1, "num_stages": 4}
_blockwise_cast_to_fp8_triton[grid](
x, y, s, *x.stride(), *y.stride(), *s.stride(), M, N, 1e-10, fp8_min, fp8_max, **kwargs
)
return y, s