kernels-bot's picture
Uploaded using `kernel-builder`.
2b537bb verified
raw
history blame
29.6 kB
# Copyright (c) 2025, Tri Dao.
import math
import hashlib
import inspect
import os
from typing import Type, Callable, Optional, Tuple, overload
import cutlass
import cutlass.cute as cute
from cutlass import Float32, Int32, const_expr
from cutlass.cute import FastDivmodDivisor
from cutlass.cutlass_dsl import T, dsl_user_op
from cutlass._mlir.dialects import nvvm, llvm
from cutlass.cute.runtime import from_dlpack
from .quack import activation
_MIXER_ATTRS = ("__vec_size__",)
# Obtained from sollya:
# fpminimax(exp(x * log(2.0)), 1, [|1,24...|],[0;1],relative);
POLY_EX2 = {
0: (1.0),
1: (
1.0,
0.922497093677520751953125,
),
2: (
1.0,
0.6657850742340087890625,
0.330107033252716064453125,
),
3: (
1.0,
0.695146143436431884765625,
0.227564394474029541015625,
0.077119089663028717041015625,
),
4: (
1.0,
0.693042695522308349609375,
0.2412912547588348388671875,
5.2225358784198760986328125e-2,
1.3434938155114650726318359375e-2,
),
5: (
1.0,
0.693151414394378662109375,
0.24016360938549041748046875,
5.5802188813686370849609375e-2,
9.01452265679836273193359375e-3,
1.86810153536498546600341796875e-3,
),
}
_fa_clc_enabled: bool = os.environ.get("FA_CLC", "0") == "1"
_fa_disable_2cta_enabled: bool = os.environ.get("FA_DISABLE_2CTA", "0") == "1"
def _get_use_clc_scheduler_default() -> bool:
return _fa_clc_enabled
def _get_disable_2cta_default() -> bool:
return _fa_disable_2cta_enabled
def _compute_base_hash(func: Callable) -> str:
"""Compute hash from source code or bytecode and closure values."""
try:
data = inspect.getsource(func).encode()
except (OSError, TypeError):
if hasattr(func, "__code__") and func.__code__ is not None:
data = func.__code__.co_code
else:
data = repr(func).encode()
hasher = hashlib.sha256(data)
if hasattr(func, "__closure__") and func.__closure__ is not None:
for cell in func.__closure__:
hasher.update(repr(cell.cell_contents).encode())
return hasher.hexdigest()
def hash_callable(
func: Callable, mixer_attrs: Tuple[str] = _MIXER_ATTRS, set_cute_hash: bool = True
) -> str:
"""Hash a callable based on the source code or bytecode and closure values.
Fast-path: if the callable (or its __wrapped__ base) has a ``__cute_hash__``
attribute, that value is returned immediately as the base hash, then
metadata dunders are mixed in to produce the final dict-key hash.
set_cute_hash: whether or not to set func.__cute_hash__
"""
# Resolve base hash
if hasattr(func, "__cute_hash__"):
base_hash = func.__cute_hash__
else:
# Unwrap decorated functions (e.g., cute.jit wrappers).
base_func = getattr(func, "__wrapped__", func)
if hasattr(base_func, "__cute_hash__"):
base_hash = base_func.__cute_hash__
else:
base_hash = _compute_base_hash(base_func)
if set_cute_hash:
base_func.__cute_hash__ = base_hash
# Mix in mutable metadata dunders
mixer_values = tuple(getattr(func, attr, None) for attr in mixer_attrs)
if all(v is None for v in mixer_values):
return base_hash
hasher = hashlib.sha256(base_hash.encode())
for attr, val in zip(_MIXER_ATTRS, mixer_values):
hasher.update(f"{attr}={val!r}".encode())
return hasher.hexdigest()
def create_softcap_scoremod(softcap_val):
inv_softcap = 1.0 / softcap_val
@cute.jit
def scoremod_premask_fn(acc_S_SSA, batch_idx, head_idx, q_idx, kv_idx, aux_tensors):
scores = acc_S_SSA * inv_softcap
return scores * cute.math.tanh(scores, fastmath=True)
return scoremod_premask_fn
LOG2_E = math.log2(math.e)
def compute_softmax_scale_log2(softmax_scale, score_mod):
"""Compute softmax_scale_log2 and adjusted softmax_scale based on whether score_mod is used.
When score_mod is None, fold the log2(e) factor into softmax_scale_log2 and set softmax_scale
to None. When score_mod is present, keep softmax_scale separate so it can be applied before
the score_mod, and set softmax_scale_log2 to just the change-of-base constant.
Returns (softmax_scale_log2, softmax_scale).
"""
if const_expr(score_mod is None):
return softmax_scale * LOG2_E, None
else:
return LOG2_E, softmax_scale
def compute_fastdiv_mods(mQ, mK, qhead_per_kvhead, pack_gqa, aux_tensors, mPageTable=None):
"""Compute FastDivmodDivisor pairs for aux_tensors index computation.
Returns a (seqlen_q_divmod, seqlen_k_divmod) tuple, or None if aux_tensors is None.
"""
if const_expr(aux_tensors is None):
return None
seqlen_q = cute.size(mQ.shape[0]) // (qhead_per_kvhead if const_expr(pack_gqa) else 1)
seqlen_k = (
cute.size(mK.shape[0])
if const_expr(mPageTable is None)
else mK.shape[0] * mPageTable.shape[1]
)
return (FastDivmodDivisor(seqlen_q), FastDivmodDivisor(seqlen_k))
def convert_from_dlpack(x, leading_dim, alignment=16, divisibility=1) -> cute.Tensor:
return (
from_dlpack(x, assumed_align=alignment)
.mark_layout_dynamic(leading_dim=leading_dim)
.mark_compact_shape_dynamic(
mode=leading_dim, stride_order=x.dim_order(), divisibility=divisibility
)
)
def convert_from_dlpack_leading_static(
x, leading_dim, alignment=16, static_modes=None, stride_order=None
) -> cute.Tensor:
if stride_order is None:
stride_order = x.dim_order()
x_ = from_dlpack(x, assumed_align=alignment)
for i in range(x.ndim):
if i != leading_dim and (static_modes is None or i not in static_modes):
x_ = x_.mark_compact_shape_dynamic(mode=i, stride_order=stride_order)
return x_
def make_tiled_copy_A(
copy_atom: cute.CopyAtom, tiled_mma: cute.TiledMma, swapAB: cutlass.Constexpr[bool] = False
) -> cute.TiledCopy:
if const_expr(swapAB):
return cute.make_tiled_copy_B(copy_atom, tiled_mma)
else:
return cute.make_tiled_copy_A(copy_atom, tiled_mma)
def make_tiled_copy_B(
copy_atom: cute.CopyAtom, tiled_mma: cute.TiledMma, swapAB: cutlass.Constexpr[bool] = False
) -> cute.TiledCopy:
if const_expr(swapAB):
return cute.make_tiled_copy_A(copy_atom, tiled_mma)
else:
return cute.make_tiled_copy_B(copy_atom, tiled_mma)
def mma_make_fragment_A(
smem: cute.Tensor, thr_mma: cute.core.ThrMma, swapAB: cutlass.Constexpr[bool] = False
) -> cute.Tensor:
if const_expr(swapAB):
return mma_make_fragment_B(smem, thr_mma)
else:
return thr_mma.make_fragment_A(thr_mma.partition_A(smem))
def mma_make_fragment_B(
smem: cute.Tensor, thr_mma: cute.core.ThrMma, swapAB: cutlass.Constexpr[bool] = False
) -> cute.Tensor:
if const_expr(swapAB):
return mma_make_fragment_A(smem, thr_mma)
else:
return thr_mma.make_fragment_B(thr_mma.partition_B(smem))
def get_smem_store_atom(
arch: cutlass.Constexpr[int], element_type: Type[cute.Numeric], transpose: bool = False
) -> cute.CopyAtom:
if const_expr(arch < 90 or element_type.width != 16):
return cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(),
element_type,
num_bits_per_copy=2 * element_type.width,
)
else:
return cute.make_copy_atom(
cute.nvgpu.warp.StMatrix8x8x16bOp(transpose=transpose, num_matrices=4),
element_type,
)
@cute.jit
def warp_reduce(
val: cute.TensorSSA | cute.Numeric,
op: Callable,
width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE,
) -> cute.TensorSSA | cute.Numeric:
if const_expr(isinstance(val, cute.TensorSSA)):
res = cute.make_fragment(val.shape, val.dtype)
res.store(val)
for i in cutlass.range_constexpr(cute.size(val.shape)):
res[i] = warp_reduce(res[i], op, width)
return res.load()
else:
for i in cutlass.range_constexpr(int(math.log2(width))):
val = op(val, cute.arch.shuffle_sync_bfly(val, offset=1 << i))
return val
@dsl_user_op
def smid(*, loc=None, ip=None) -> Int32:
return Int32(
llvm.inline_asm(
T.i32(),
[],
"mov.u32 $0, %smid;",
"=r",
has_side_effects=False,
is_align_stack=False,
asm_dialect=llvm.AsmDialect.AD_ATT,
)
)
@dsl_user_op
def fmax(
a: float | Float32, b: float | Float32, c: float | Float32 | None = None, *, loc=None, ip=None
) -> Float32:
from cutlass import CUDA_VERSION
# * NVVM call based on nvvm version
if CUDA_VERSION.major == 12 and CUDA_VERSION.minor == 9:
# Old API: requires explicit result type as first positional argument
return Float32(
nvvm.fmax(
T.f32(),
Float32(a).ir_value(loc=loc, ip=ip),
Float32(b).ir_value(loc=loc, ip=ip),
c=Float32(c).ir_value(loc=loc, ip=ip) if c is not None else None,
loc=loc,
ip=ip,
)
)
else:
# New API: infers result type automatically
return Float32(
nvvm.fmax(
Float32(a).ir_value(loc=loc, ip=ip),
Float32(b).ir_value(loc=loc, ip=ip),
c=Float32(c).ir_value(loc=loc, ip=ip) if c is not None else None,
loc=loc,
ip=ip,
)
)
@cute.jit
def fmax_reduce(
x: cute.TensorSSA, init_val: float | Float32 | None = None, arch: cutlass.Constexpr[int] = 80
) -> Float32:
if const_expr(arch < 100 or cute.size(x.shape) % 8 != 0):
# if const_expr(init_val is None):
# init_val = -cutlass.Float32.if
# return x.reduce(cute.ReductionOp.MAX, init_val, 0)
res = cute.make_fragment(x.shape, Float32)
res.store(x)
# local_max = [res[0], res[1]]
# for i in cutlass.range_constexpr(2, cute.size(x.shape), 2):
# local_max[0] = fmax(local_max[0], res[i + 0])
# local_max[1] = fmax(local_max[1], res[i + 1])
# local_max[0] = fmax(local_max[0], local_max[1])
# return local_max[0] if const_expr(init_val is None) else fmax(local_max[0], init_val)
local_max = [res[0], res[1], res[2], res[3]]
for i in cutlass.range_constexpr(4, cute.size(x.shape), 4):
local_max[0] = fmax(local_max[0], res[i + 0])
local_max[1] = fmax(local_max[1], res[i + 1])
local_max[2] = fmax(local_max[2], res[i + 2])
local_max[3] = fmax(local_max[3], res[i + 3])
local_max[0] = fmax(local_max[0], local_max[1])
local_max[2] = fmax(local_max[2], local_max[3])
local_max[0] = fmax(local_max[0], local_max[2])
return local_max[0] if const_expr(init_val is None) else fmax(local_max[0], init_val)
else:
# [2025-06-15] x.reduce only seems to use 50% 3-input max and 50% 2-input max
# We instead force the 3-input max.
res = cute.make_fragment(x.shape, Float32)
res.store(x)
local_max_0 = (
fmax(init_val, res[0], res[1])
if const_expr(init_val is not None)
else fmax(res[0], res[1])
)
local_max = [
local_max_0,
fmax(res[2], res[3]),
fmax(res[4], res[5]),
fmax(res[6], res[7]),
]
for i in cutlass.range_constexpr(8, cute.size(x.shape), 8):
local_max[0] = fmax(local_max[0], res[i], res[i + 1])
local_max[1] = fmax(local_max[1], res[i + 2], res[i + 3])
local_max[2] = fmax(local_max[2], res[i + 4], res[i + 5])
local_max[3] = fmax(local_max[3], res[i + 6], res[i + 7])
local_max[0] = fmax(local_max[0], local_max[1])
return fmax(local_max[0], local_max[2], local_max[3])
@cute.jit
def fadd_reduce(
x: cute.TensorSSA, init_val: float | Float32 | None = None, arch: cutlass.Constexpr[int] = 80
) -> Float32:
if const_expr(arch < 100 or cute.size(x.shape) % 8 != 0):
if const_expr(init_val is None):
init_val = Float32.zero
return x.reduce(cute.ReductionOp.ADD, init_val, 0)
# res = cute.make_fragment(x.shape, Float32)
# res.store(x)
# local_sum = [res[0], res[1], res[2], res[3]]
# for i in cutlass.range_constexpr(4, cute.size(x.shape), 4):
# local_sum[0] += res[i + 0]
# local_sum[1] += res[i + 1]
# local_sum[2] += res[i + 2]
# local_sum[3] += res[i + 3]
# local_sum[0] += local_sum[1]
# local_sum[2] += local_sum[3]
# local_sum[0] += local_sum[2]
# return local_sum[0] if const_expr(init_val is None) else local_sum[0] + init_val
else:
res = cute.make_fragment(x.shape, Float32)
res.store(x)
local_sum_0 = (
cute.arch.add_packed_f32x2((init_val, 0.0), (res[0], res[1]))
# cute.arch.add_packed_f32x2((init_val / 2, init_val / 2), (res[0], res[1]))
if const_expr(init_val is not None)
else (res[0], res[1])
)
local_sum = [local_sum_0, (res[2], res[3]), (res[4], res[5]), (res[6], res[7])]
for i in cutlass.range_constexpr(8, cute.size(x.shape), 8):
local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], (res[i + 0], res[i + 1]))
local_sum[1] = cute.arch.add_packed_f32x2(local_sum[1], (res[i + 2], res[i + 3]))
local_sum[2] = cute.arch.add_packed_f32x2(local_sum[2], (res[i + 4], res[i + 5]))
local_sum[3] = cute.arch.add_packed_f32x2(local_sum[3], (res[i + 6], res[i + 7]))
local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], local_sum[1])
local_sum[2] = cute.arch.add_packed_f32x2(local_sum[2], local_sum[3])
local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], local_sum[2])
return local_sum[0][0] + local_sum[0][1]
@dsl_user_op
def atomic_add_fp32(a: float | Float32, gmem_ptr: cute.Pointer, *, loc=None, ip=None) -> None:
# gmem_ptr_i64 = gmem_ptr.toint(loc=loc, ip=ip).ir_value()
# # cache_hint = cutlass.Int64(0x12F0000000000000)
# llvm.inline_asm(
# None,
# [gmem_ptr_i64, Float32(a).ir_value(loc=loc, ip=ip)],
# # [gmem_ptr_i64, Float32(a).ir_value(loc=loc, ip=ip), cache_hint.ir_value()],
# "red.global.add.f32 [$0], $1;",
# # "red.global.add.L2::cache_hint.f32 [$0], $1, 0x12F0000000000000;",
# # "red.global.add.L2::cache_hint.f32 [$0], $1, $2;",
# "l,f",
# # "l,f,l",
# has_side_effects=True,
# is_align_stack=False,
# asm_dialect=llvm.AsmDialect.AD_ATT,
# )
nvvm.atomicrmw(
res=T.f32(), op=nvvm.AtomicOpKind.FADD, ptr=gmem_ptr.llvm_ptr, a=Float32(a).ir_value()
)
@dsl_user_op
def elem_pointer(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cute.Pointer:
return x.iterator + cute.crd2idx(coord, x.layout, loc=loc, ip=ip)
@cute.jit
def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor:
# Only compute predicates for the "k" dimension. For the mn dimension, we will use "if"
tApA = cute.make_fragment(
cute.make_layout(
(cute.size(tAcA, mode=[0, 1]), cute.size(tAcA, mode=[1]), cute.size(tAcA, mode=[2])),
stride=(cute.size(tAcA, mode=[2]), 0, 1),
),
cutlass.Boolean,
)
for rest_v in cutlass.range_constexpr(tApA.shape[0]):
for rest_k in cutlass.range_constexpr(tApA.shape[2]):
tApA[rest_v, 0, rest_k] = cute.elem_less(tAcA[(0, rest_v), 0, rest_k][1], limit)
return tApA
def canonical_warp_group_idx(sync: bool = True) -> cutlass.Int32:
warp_group_idx = cute.arch.thread_idx()[0] // 128
if const_expr(sync):
warp_group_idx = cute.arch.make_warp_uniform(warp_group_idx)
return warp_group_idx
# @dsl_user_op
# def warp_vote_any_lt(a: float | Float32, b: float | Float32, *, loc=None, ip=None) -> cutlass.Boolean:
# mask = cutlass.Int32(-1)
# return cutlass.Boolean(
# llvm.inline_asm(
# T.i32(),
# [Float32(a).ir_value(loc=loc, ip=ip), Float32(b).ir_value(loc=loc, ip=ip), mask.ir_value(loc=loc, ip=ip)],
# ".pred p1, p2;\n"
# "setp.lt.f32 p1, $1, $2;\n"
# "vote.sync.any.pred p2, p1, $3;\n"
# "selp.u32 $0, 1, 0, p2;",
# # "selp.u32 $0, 1, 0, p1;",
# "=r,f,f,r",
# has_side_effects=False,
# is_align_stack=False,
# asm_dialect=llvm.AsmDialect.AD_ATT,
# )
# )
@cute.jit
def shuffle_sync(
value: cute.Numeric,
offset: cute.typing.Int,
width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE,
) -> cute.Numeric:
assert value.width % 32 == 0, "value type must be a multiple of 32 bits"
# 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000
mask = cute.arch.WARP_SIZE - width
clamp = cute.arch.WARP_SIZE - 1
mask_and_clamp = mask << 8 | clamp
# important: need stride 1 and not 0 for recast_tensor to work
val = cute.make_rmem_tensor(cute.make_layout((1,), stride=(1,)), type(value))
val[0] = value
val_i32 = cute.recast_tensor(val, cutlass.Int32)
for i in cutlass.range_constexpr(cute.size(val_i32)):
val_i32[i] = cute.arch.shuffle_sync(val_i32[i], offset, mask_and_clamp=mask_and_clamp)
return val[0]
@dsl_user_op
def shl_u32(val: cutlass.Uint32, shift: cutlass.Uint32, *, loc=None, ip=None) -> cutlass.Uint32:
"""
Left-shift val by shift bits using PTX shl.b32 (sign-agnostic).
Named ``shl_u32`` (not ``shl_b32``) because python type annotations
distinguish signed/unsigned.
PTX semantics (§9.7.8.8): "Shift amounts greater than the register width N
are clamped to N." So ``shl.b32 d, a, 32`` is well-defined and yields 0.
This differs from C/C++ and LLVM IR, where shifting by >= the type width is
undefined behavior. CuTeDSL compiles through MLIR -> LLVM IR, so a plain
Python-level ``Uint32(x) << Uint32(n)`` inherits LLVM's UB: the optimizer
may treat the result as poison and eliminate dependent code. Inline PTX
bypasses the LLVM IR shift entirely — the instruction is emitted verbatim
into PTX where clamping makes it safe for all shift amounts.
"""
return cutlass.Uint32(
llvm.inline_asm(
T.i32(),
[
cutlass.Uint32(val).ir_value(loc=loc, ip=ip),
cutlass.Uint32(shift).ir_value(loc=loc, ip=ip),
],
"shl.b32 $0, $1, $2;",
"=r,r,r",
has_side_effects=False,
is_align_stack=False,
asm_dialect=llvm.AsmDialect.AD_ATT,
)
)
@dsl_user_op
def shr_u32(val: cutlass.Uint32, shift: cutlass.Uint32, *, loc=None, ip=None) -> cutlass.Uint32:
"""
Unsigned right-shift val by shift bits using PTX shr.u32 (zero-fills).
See ``shl_u32`` docstring for why inline PTX is used instead of plain
CuTeDSL shift operators (LLVM shift-by-type-width UB).
"""
return cutlass.Uint32(
llvm.inline_asm(
T.i32(),
[
cutlass.Uint32(val).ir_value(loc=loc, ip=ip),
cutlass.Uint32(shift).ir_value(loc=loc, ip=ip),
],
"shr.u32 $0, $1, $2;",
"=r,r,r",
has_side_effects=False,
is_align_stack=False,
asm_dialect=llvm.AsmDialect.AD_ATT,
)
)
@cute.jit
def warp_prefix_sum(val: cutlass.Int32, lane: Optional[cutlass.Int32] = None) -> cutlass.Int32:
if const_expr(lane is None):
lane = cute.arch.lane_idx()
# if cute.arch.thread_idx()[0] >= 128 and cute.arch.thread_idx()[0] < 128 + 32 and cute.arch.block_idx()[0] == 0: cute.printf("tidx = %d, val = %d", cute.arch.thread_idx()[0] % 32, val)
for i in cutlass.range_constexpr(int(math.log2(cute.arch.WARP_SIZE))):
offset = 1 << i
# Very important that we set mask_and_clamp to 0
partial_sum = cute.arch.shuffle_sync_up(val, offset=offset, mask_and_clamp=0)
if lane >= offset:
val += partial_sum
# if cute.arch.thread_idx()[0] >= 128 and cute.arch.thread_idx()[0] < 128 + 32 and cute.arch.block_idx()[0] == 0: cute.printf("tidx = %d, partial_sum = %d, val = %d", cute.arch.thread_idx()[0] % 32, partial_sum, val)
return val
@dsl_user_op
def cvt_f16x2_f32(
a: float | Float32, b: float | Float32, to_dtype: Type, *, loc=None, ip=None
) -> cutlass.Int32:
assert to_dtype in [cutlass.BFloat16, cutlass.Float16], "to_dtype must be BFloat16 or Float16"
return cutlass.Int32(
llvm.inline_asm(
T.i32(),
[Float32(a).ir_value(loc=loc, ip=ip), Float32(b).ir_value(loc=loc, ip=ip)],
f"cvt.rn.{'bf16x2' if to_dtype is cutlass.BFloat16 else 'f16x2'}.f32 $0, $2, $1;",
"=r,f,f",
has_side_effects=False,
is_align_stack=False,
asm_dialect=llvm.AsmDialect.AD_ATT,
)
)
@overload
def cvt_f16(src: cute.Tensor, dst: cute.Tensor) -> None: ...
@overload
def cvt_f16(src: cute.Tensor, dtype: Type[cute.Numeric]) -> cute.Tensor: ...
@cute.jit
def cvt_f16(src: cute.Tensor, dst_or_dtype):
"""Convert Float32 tensor to Float16/BFloat16.
Args:
src: Source tensor with Float32 element type
dst_or_dtype: Either a destination tensor or a dtype (Float16/BFloat16)
Returns:
None if dst is a tensor, or a new tensor if dtype is provided
"""
if const_expr(isinstance(dst_or_dtype, type)):
# dtype variant: create new tensor and call the tensor variant
dtype = dst_or_dtype
dst = cute.make_fragment(src.shape, dtype)
cvt_f16(src, dst)
return dst
else:
# tensor variant: write to dst
dst = dst_or_dtype
assert cute.size(dst.shape) == cute.size(src.shape), "dst and src must have the same size"
assert cute.size(src.shape) % 2 == 0, "src must have an even number of elements"
assert dst.element_type in [cutlass.BFloat16, cutlass.Float16], (
"dst must be BFloat16 or Float16"
)
assert src.element_type is Float32, "src must be Float32"
dst_i32 = cute.recast_tensor(dst, cutlass.Int32)
assert cute.size(dst_i32.shape) * 2 == cute.size(src.shape)
for i in cutlass.range_constexpr(cute.size(dst_i32)):
dst_i32[i] = cvt_f16x2_f32(src[2 * i], src[2 * i + 1], dst.element_type)
@dsl_user_op
@cute.jit
def evaluate_polynomial(x: Float32, poly: Tuple[Float32, ...], *, loc=None, ip=None) -> Float32:
deg = len(poly) - 1
out = poly[deg]
for i in cutlass.range_constexpr(deg - 1, -1, -1):
out = out * x + poly[i]
return out
@dsl_user_op
@cute.jit
def evaluate_polynomial_2(
x: Float32, y: Float32, poly: Tuple[Float32, ...], *, loc=None, ip=None
) -> Tuple[Float32, Float32]:
deg = len(poly) - 1
out = (poly[deg], poly[deg])
for i in cutlass.range_constexpr(deg - 1, -1, -1):
out = cute.arch.fma_packed_f32x2(out, (x, y), (poly[i], poly[i]))
return out
@dsl_user_op
def add_round_down(x: float | Float32, y: float | Float32, *, loc=None, ip=None) -> Float32:
# There's probably a way to call llvm or nvvm to do this instead of ptx
return cutlass.Float32(
llvm.inline_asm(
T.f32(),
[Float32(x).ir_value(loc=loc, ip=ip), Float32(y).ir_value(loc=loc, ip=ip)],
"add.rm.ftz.f32 $0, $1, $2;",
"=f,f,f",
has_side_effects=False,
is_align_stack=False,
asm_dialect=llvm.AsmDialect.AD_ATT,
)
)
@dsl_user_op
def combine_int_frac_ex2(x_rounded: Float32, frac_ex2: Float32, *, loc=None, ip=None) -> Float32:
return cutlass.Float32(
llvm.inline_asm(
T.f32(),
[
Float32(x_rounded).ir_value(loc=loc, ip=ip),
Float32(frac_ex2).ir_value(loc=loc, ip=ip),
],
"{\n\t"
".reg .s32 x_rounded_i, frac_ex_i, x_rounded_e, out_i;\n\t"
"mov.b32 x_rounded_i, $1;\n\t"
"mov.b32 frac_ex_i, $2;\n\t"
"shl.b32 x_rounded_e, x_rounded_i, 23;\n\t"
# add.u32 generates IMAD instruction and add.s32 generates LEA instruction
# IMAD uses the FMA pipeline and LEA uses the ALU pipeline, afaik
"add.s32 out_i, x_rounded_e, frac_ex_i;\n\t"
"mov.b32 $0, out_i;\n\t"
"}\n",
"=f,f,f",
has_side_effects=False,
is_align_stack=False,
asm_dialect=llvm.AsmDialect.AD_ATT,
)
)
@dsl_user_op
def ex2_emulation(x: Float32, *, poly_degree: int = 3, loc=None, ip=None) -> Float32:
assert poly_degree in POLY_EX2, f"Polynomial degree {poly_degree} not supported"
# We assume x <= 127.0
fp32_round_int = float(2**23 + 2**22)
x_clamped = cute.arch.fmax(x, -127.0)
# We want to round down here, so that the fractional part is in [0, 1)
x_rounded = add_round_down(x_clamped, fp32_round_int, loc=loc, ip=ip)
# The integer floor of x is now in the last 8 bits of x_rounded
# We assume the next 2 ops round to nearest even. The rounding mode is important.
x_rounded_back = x_rounded - fp32_round_int
x_frac = x_clamped - x_rounded_back
x_frac_ex2 = evaluate_polynomial(x_frac, POLY_EX2[poly_degree], loc=loc, ip=ip)
return combine_int_frac_ex2(x_rounded, x_frac_ex2, loc=loc, ip=ip)
# TODO: check that the ex2_emulation_2 produces the same SASS as the ptx version
@dsl_user_op
def ex2_emulation_2(
x: Float32, y: Float32, *, poly_degree: int = 3, loc=None, ip=None
) -> Tuple[Float32, Float32]:
# We assume x <= 127.0 and y <= 127.0
fp32_round_int = float(2**23 + 2**22)
xy_clamped = (cute.arch.fmax(x, -127.0), cute.arch.fmax(y, -127.0))
# We want to round down here, so that the fractional part is in [0, 1)
xy_rounded = cute.arch.add_packed_f32x2(xy_clamped, (fp32_round_int, fp32_round_int), rnd="rm")
# The integer floor of x & y are now in the last 8 bits of xy_rounded
# We want the next 2 ops to round to nearest even. The rounding mode is important.
xy_rounded_back = activation.sub_packed_f32x2(
xy_rounded, (fp32_round_int, fp32_round_int)
)
xy_frac = activation.sub_packed_f32x2(xy_clamped, xy_rounded_back)
xy_frac_ex2 = evaluate_polynomial_2(*xy_frac, POLY_EX2[poly_degree], loc=loc, ip=ip)
x_out = combine_int_frac_ex2(xy_rounded[0], xy_frac_ex2[0], loc=loc, ip=ip)
y_out = combine_int_frac_ex2(xy_rounded[1], xy_frac_ex2[1], loc=loc, ip=ip)
return x_out, y_out
@dsl_user_op
def e2e_asm2(x: Float32, y: Float32, *, loc=None, ip=None) -> Tuple[Float32, Float32]:
out_f32x2 = llvm.inline_asm(
llvm.StructType.get_literal([T.f32(), T.f32()]),
[Float32(x).ir_value(loc=loc, ip=ip), Float32(y, loc=loc, ip=ip).ir_value()],
"{\n\t"
".reg .f32 f1, f2, f3, f4, f5, f6, f7;\n\t"
".reg .b64 l1, l2, l3, l4, l5, l6, l7, l8, l9, l10;\n\t"
".reg .s32 r1, r2, r3, r4, r5, r6, r7, r8;\n\t"
"max.ftz.f32 f1, $2, 0fC2FE0000;\n\t"
"max.ftz.f32 f2, $3, 0fC2FE0000;\n\t"
"mov.b64 l1, {f1, f2};\n\t"
"mov.f32 f3, 0f4B400000;\n\t"
"mov.b64 l2, {f3, f3};\n\t"
"add.rm.ftz.f32x2 l7, l1, l2;\n\t"
"sub.rn.ftz.f32x2 l8, l7, l2;\n\t"
"sub.rn.ftz.f32x2 l9, l1, l8;\n\t"
"mov.f32 f7, 0f3D9DF09D;\n\t"
"mov.b64 l6, {f7, f7};\n\t"
"mov.f32 f6, 0f3E6906A4;\n\t"
"mov.b64 l5, {f6, f6};\n\t"
"mov.f32 f5, 0f3F31F519;\n\t"
"mov.b64 l4, {f5, f5};\n\t"
"mov.f32 f4, 0f3F800000;\n\t"
"mov.b64 l3, {f4, f4};\n\t"
"fma.rn.ftz.f32x2 l10, l9, l6, l5;\n\t"
"fma.rn.ftz.f32x2 l10, l10, l9, l4;\n\t"
"fma.rn.ftz.f32x2 l10, l10, l9, l3;\n\t"
"mov.b64 {r1, r2}, l7;\n\t"
"mov.b64 {r3, r4}, l10;\n\t"
"shl.b32 r5, r1, 23;\n\t"
"add.s32 r7, r5, r3;\n\t"
"shl.b32 r6, r2, 23;\n\t"
"add.s32 r8, r6, r4;\n\t"
"mov.b32 $0, r7;\n\t"
"mov.b32 $1, r8;\n\t"
"}\n",
"=r,=r,f,f",
has_side_effects=False,
is_align_stack=False,
asm_dialect=llvm.AsmDialect.AD_ATT,
)
out0 = Float32(llvm.extractvalue(T.f32(), out_f32x2, [0], loc=loc, ip=ip))
out1 = Float32(llvm.extractvalue(T.f32(), out_f32x2, [1], loc=loc, ip=ip))
return out0, out1
@dsl_user_op
def domain_offset_aligned(
coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None
) -> cute.Tensor:
assert isinstance(tensor.iterator, cute.Pointer)
# We assume that applying the offset does not change the pointer alignment
new_ptr = cute.make_ptr(
tensor.element_type,
elem_pointer(tensor, coord).toint(),
tensor.memspace,
assumed_align=tensor.iterator.alignment,
)
return cute.make_tensor(new_ptr, tensor.layout)
@cute.jit
def scalar_to_ssa(a: cute.Numeric, dtype) -> cute.TensorSSA:
"""Convert a scalar to a cute TensorSSA of shape (1,) and given dtype"""
vec = cute.make_fragment(1, dtype)
vec[0] = a
return vec.load()
def ssa_to_scalar(val):
"""Could inline but nice for reflecting the above api"""
return val[0]