danieldk's picture
danieldk HF Staff
Build uploaded using `kernels`.
4298e26 verified
# Copyright (c) 2025, Tri Dao.
import math
import hashlib
import inspect
from typing import Type, Callable, Optional, Tuple, overload
import cutlass
import cutlass.cute as cute
from cutlass import Float32, const_expr
from cutlass.cutlass_dsl import T, dsl_user_op
from cutlass._mlir.dialects import nvvm, llvm
from cutlass.cute.runtime import from_dlpack
from .quack import activation
_MIXER_ATTRS = ("__vec_size__",)
# Obtained from sollya:
# fpminimax(exp(x * log(2.0)), 1, [|1,24...|],[0;1],relative);
POLY_EX2 = {
0: (1.0),
1: (
1.0,
0.922497093677520751953125,
),
2: (
1.0,
0.6657850742340087890625,
0.330107033252716064453125,
),
3: (
1.0,
0.695146143436431884765625,
0.227564394474029541015625,
0.077119089663028717041015625,
),
4: (
1.0,
0.693042695522308349609375,
0.2412912547588348388671875,
5.2225358784198760986328125e-2,
1.3434938155114650726318359375e-2,
),
5: (
1.0,
0.693151414394378662109375,
0.24016360938549041748046875,
5.5802188813686370849609375e-2,
9.01452265679836273193359375e-3,
1.86810153536498546600341796875e-3,
),
}
def _compute_base_hash(func: Callable) -> str:
"""Compute hash from source code or bytecode and closure values."""
try:
data = inspect.getsource(func).encode()
except (OSError, TypeError):
if hasattr(func, "__code__") and func.__code__ is not None:
data = func.__code__.co_code
else:
data = repr(func).encode()
hasher = hashlib.sha256(data)
if hasattr(func, "__closure__") and func.__closure__ is not None:
for cell in func.__closure__:
hasher.update(repr(cell.cell_contents).encode())
return hasher.hexdigest()
def hash_callable(
func: Callable, mixer_attrs: Tuple[str] = _MIXER_ATTRS, set_cute_hash: bool = True
) -> str:
"""Hash a callable based on the source code or bytecode and closure values.
Fast-path: if the callable (or its __wrapped__ base) has a ``__cute_hash__``
attribute, that value is returned immediately as the base hash, then
metadata dunders are mixed in to produce the final dict-key hash.
set_cute_hash: whether or not to set func.__cute_hash__
"""
# Resolve base hash
if hasattr(func, "__cute_hash__"):
base_hash = func.__cute_hash__
else:
# Unwrap decorated functions (e.g., cute.jit wrappers).
base_func = getattr(func, "__wrapped__", func)
if hasattr(base_func, "__cute_hash__"):
base_hash = base_func.__cute_hash__
else:
base_hash = _compute_base_hash(base_func)
if set_cute_hash:
base_func.__cute_hash__ = base_hash
# Mix in mutable metadata dunders
mixer_values = tuple(getattr(func, attr, None) for attr in mixer_attrs)
if all(v is None for v in mixer_values):
return base_hash
hasher = hashlib.sha256(base_hash.encode())
for attr, val in zip(_MIXER_ATTRS, mixer_values):
hasher.update(f"{attr}={val!r}".encode())
return hasher.hexdigest()
def create_softcap_scoremod(softcap_val):
inv_softcap = 1.0 / softcap_val
@cute.jit
def scoremod_premask_fn(acc_S_SSA, batch_idx, head_idx, q_idx, kv_idx, aux_tensors):
scores = acc_S_SSA * inv_softcap
return scores * cute.math.tanh(scores, fastmath=True)
return scoremod_premask_fn
def convert_from_dlpack(x, leading_dim, alignment=16, divisibility=1) -> cute.Tensor:
return (
from_dlpack(x, assumed_align=alignment)
.mark_layout_dynamic(leading_dim=leading_dim)
.mark_compact_shape_dynamic(
mode=leading_dim, stride_order=x.dim_order(), divisibility=divisibility
)
)
def convert_from_dlpack_leading_static(
x, leading_dim, alignment=16, static_modes=None, stride_order=None
) -> cute.Tensor:
if stride_order is None:
stride_order = x.dim_order()
x_ = from_dlpack(x, assumed_align=alignment)
for i in range(x.ndim):
if i != leading_dim and (static_modes is None or i not in static_modes):
x_ = x_.mark_compact_shape_dynamic(mode=i, stride_order=stride_order)
return x_
def make_tiled_copy_A(
copy_atom: cute.CopyAtom, tiled_mma: cute.TiledMma, swapAB: cutlass.Constexpr[bool] = False
) -> cute.TiledCopy:
if const_expr(swapAB):
return cute.make_tiled_copy_B(copy_atom, tiled_mma)
else:
return cute.make_tiled_copy_A(copy_atom, tiled_mma)
def make_tiled_copy_B(
copy_atom: cute.CopyAtom, tiled_mma: cute.TiledMma, swapAB: cutlass.Constexpr[bool] = False
) -> cute.TiledCopy:
if const_expr(swapAB):
return cute.make_tiled_copy_A(copy_atom, tiled_mma)
else:
return cute.make_tiled_copy_B(copy_atom, tiled_mma)
def mma_make_fragment_A(
smem: cute.Tensor, thr_mma: cute.core.ThrMma, swapAB: cutlass.Constexpr[bool] = False
) -> cute.Tensor:
if const_expr(swapAB):
return mma_make_fragment_B(smem, thr_mma)
else:
return thr_mma.make_fragment_A(thr_mma.partition_A(smem))
def mma_make_fragment_B(
smem: cute.Tensor, thr_mma: cute.core.ThrMma, swapAB: cutlass.Constexpr[bool] = False
) -> cute.Tensor:
if const_expr(swapAB):
return mma_make_fragment_A(smem, thr_mma)
else:
return thr_mma.make_fragment_B(thr_mma.partition_B(smem))
def get_smem_store_atom(
arch: cutlass.Constexpr[int], element_type: Type[cute.Numeric], transpose: bool = False
) -> cute.CopyAtom:
if const_expr(arch < 90 or element_type.width != 16):
return cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(),
element_type,
num_bits_per_copy=2 * element_type.width,
)
else:
return cute.make_copy_atom(
cute.nvgpu.warp.StMatrix8x8x16bOp(transpose=transpose, num_matrices=4),
element_type,
)
@cute.jit
def warp_reduce(
val: cute.TensorSSA | cute.Numeric,
op: Callable,
width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE,
) -> cute.TensorSSA | cute.Numeric:
if const_expr(isinstance(val, cute.TensorSSA)):
res = cute.make_fragment(val.shape, val.dtype)
res.store(val)
for i in cutlass.range_constexpr(cute.size(val.shape)):
res[i] = warp_reduce(res[i], op, width)
return res.load()
else:
for i in cutlass.range_constexpr(int(math.log2(width))):
val = op(val, cute.arch.shuffle_sync_bfly(val, offset=1 << i))
return val
@dsl_user_op
def fmax(
a: float | Float32, b: float | Float32, c: float | Float32 | None = None, *, loc=None, ip=None
) -> Float32:
from cutlass import CUDA_VERSION
# * NVVM call based on nvvm version
if CUDA_VERSION.major == 12 and CUDA_VERSION.minor == 9:
# Old API: requires explicit result type as first positional argument
return Float32(
nvvm.fmax(
T.f32(),
Float32(a).ir_value(loc=loc, ip=ip),
Float32(b).ir_value(loc=loc, ip=ip),
c=Float32(c).ir_value(loc=loc, ip=ip) if c is not None else None,
loc=loc,
ip=ip,
)
)
else:
# New API: infers result type automatically
return Float32(
nvvm.fmax(
Float32(a).ir_value(loc=loc, ip=ip),
Float32(b).ir_value(loc=loc, ip=ip),
c=Float32(c).ir_value(loc=loc, ip=ip) if c is not None else None,
loc=loc,
ip=ip,
)
)
@cute.jit
def fmax_reduce(
x: cute.TensorSSA, init_val: float | Float32 | None = None, arch: cutlass.Constexpr[int] = 80
) -> Float32:
if const_expr(arch < 100 or cute.size(x.shape) % 8 != 0):
# if const_expr(init_val is None):
# init_val = -cutlass.Float32.if
# return x.reduce(cute.ReductionOp.MAX, init_val, 0)
res = cute.make_fragment(x.shape, Float32)
res.store(x)
# local_max = [res[0], res[1]]
# for i in cutlass.range_constexpr(2, cute.size(x.shape), 2):
# local_max[0] = fmax(local_max[0], res[i + 0])
# local_max[1] = fmax(local_max[1], res[i + 1])
# local_max[0] = fmax(local_max[0], local_max[1])
# return local_max[0] if const_expr(init_val is None) else fmax(local_max[0], init_val)
local_max = [res[0], res[1], res[2], res[3]]
for i in cutlass.range_constexpr(4, cute.size(x.shape), 4):
local_max[0] = fmax(local_max[0], res[i + 0])
local_max[1] = fmax(local_max[1], res[i + 1])
local_max[2] = fmax(local_max[2], res[i + 2])
local_max[3] = fmax(local_max[3], res[i + 3])
local_max[0] = fmax(local_max[0], local_max[1])
local_max[2] = fmax(local_max[2], local_max[3])
local_max[0] = fmax(local_max[0], local_max[2])
return local_max[0] if const_expr(init_val is None) else fmax(local_max[0], init_val)
else:
# [2025-06-15] x.reduce only seems to use 50% 3-input max and 50% 2-input max
# We instead force the 3-input max.
res = cute.make_fragment(x.shape, Float32)
res.store(x)
local_max_0 = (
fmax(init_val, res[0], res[1])
if const_expr(init_val is not None)
else fmax(res[0], res[1])
)
local_max = [
local_max_0,
fmax(res[2], res[3]),
fmax(res[4], res[5]),
fmax(res[6], res[7]),
]
for i in cutlass.range_constexpr(8, cute.size(x.shape), 8):
local_max[0] = fmax(local_max[0], res[i], res[i + 1])
local_max[1] = fmax(local_max[1], res[i + 2], res[i + 3])
local_max[2] = fmax(local_max[2], res[i + 4], res[i + 5])
local_max[3] = fmax(local_max[3], res[i + 6], res[i + 7])
local_max[0] = fmax(local_max[0], local_max[1])
return fmax(local_max[0], local_max[2], local_max[3])
@cute.jit
def fadd_reduce(
x: cute.TensorSSA, init_val: float | Float32 | None = None, arch: cutlass.Constexpr[int] = 80
) -> Float32:
if const_expr(arch < 100 or cute.size(x.shape) % 8 != 0):
if const_expr(init_val is None):
init_val = Float32.zero
return x.reduce(cute.ReductionOp.ADD, init_val, 0)
# res = cute.make_fragment(x.shape, Float32)
# res.store(x)
# local_sum = [res[0], res[1], res[2], res[3]]
# for i in cutlass.range_constexpr(4, cute.size(x.shape), 4):
# local_sum[0] += res[i + 0]
# local_sum[1] += res[i + 1]
# local_sum[2] += res[i + 2]
# local_sum[3] += res[i + 3]
# local_sum[0] += local_sum[1]
# local_sum[2] += local_sum[3]
# local_sum[0] += local_sum[2]
# return local_sum[0] if const_expr(init_val is None) else local_sum[0] + init_val
else:
res = cute.make_fragment(x.shape, Float32)
res.store(x)
local_sum_0 = (
cute.arch.add_packed_f32x2((init_val, 0.0), (res[0], res[1]))
# cute.arch.add_packed_f32x2((init_val / 2, init_val / 2), (res[0], res[1]))
if const_expr(init_val is not None)
else (res[0], res[1])
)
local_sum = [local_sum_0, (res[2], res[3]), (res[4], res[5]), (res[6], res[7])]
for i in cutlass.range_constexpr(8, cute.size(x.shape), 8):
local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], (res[i + 0], res[i + 1]))
local_sum[1] = cute.arch.add_packed_f32x2(local_sum[1], (res[i + 2], res[i + 3]))
local_sum[2] = cute.arch.add_packed_f32x2(local_sum[2], (res[i + 4], res[i + 5]))
local_sum[3] = cute.arch.add_packed_f32x2(local_sum[3], (res[i + 6], res[i + 7]))
local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], local_sum[1])
local_sum[2] = cute.arch.add_packed_f32x2(local_sum[2], local_sum[3])
local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], local_sum[2])
return local_sum[0][0] + local_sum[0][1]
@dsl_user_op
def atomic_add_fp32(a: float | Float32, gmem_ptr: cute.Pointer, *, loc=None, ip=None) -> None:
# gmem_ptr_i64 = gmem_ptr.toint(loc=loc, ip=ip).ir_value()
# # cache_hint = cutlass.Int64(0x12F0000000000000)
# llvm.inline_asm(
# None,
# [gmem_ptr_i64, Float32(a).ir_value(loc=loc, ip=ip)],
# # [gmem_ptr_i64, Float32(a).ir_value(loc=loc, ip=ip), cache_hint.ir_value()],
# "red.global.add.f32 [$0], $1;",
# # "red.global.add.L2::cache_hint.f32 [$0], $1, 0x12F0000000000000;",
# # "red.global.add.L2::cache_hint.f32 [$0], $1, $2;",
# "l,f",
# # "l,f,l",
# has_side_effects=True,
# is_align_stack=False,
# asm_dialect=llvm.AsmDialect.AD_ATT,
# )
nvvm.atomicrmw(
res=T.f32(), op=nvvm.AtomicOpKind.FADD, ptr=gmem_ptr.llvm_ptr, a=Float32(a).ir_value()
)
@dsl_user_op
def elem_pointer(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cute.Pointer:
return x.iterator + cute.crd2idx(coord, x.layout, loc=loc, ip=ip)
@cute.jit
def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor:
# Only compute predicates for the "k" dimension. For the mn dimension, we will use "if"
tApA = cute.make_fragment(
cute.make_layout(
(cute.size(tAcA, mode=[0, 1]), cute.size(tAcA, mode=[1]), cute.size(tAcA, mode=[2])),
stride=(cute.size(tAcA, mode=[2]), 0, 1),
),
cutlass.Boolean,
)
for rest_v in cutlass.range_constexpr(tApA.shape[0]):
for rest_k in cutlass.range_constexpr(tApA.shape[2]):
tApA[rest_v, 0, rest_k] = cute.elem_less(tAcA[(0, rest_v), 0, rest_k][1], limit)
return tApA
def canonical_warp_group_idx(sync: bool = True) -> cutlass.Int32:
warp_group_idx = cute.arch.thread_idx()[0] // 128
if const_expr(sync):
warp_group_idx = cute.arch.make_warp_uniform(warp_group_idx)
return warp_group_idx
# @dsl_user_op
# def warp_vote_any_lt(a: float | Float32, b: float | Float32, *, loc=None, ip=None) -> cutlass.Boolean:
# mask = cutlass.Int32(-1)
# return cutlass.Boolean(
# llvm.inline_asm(
# T.i32(),
# [Float32(a).ir_value(loc=loc, ip=ip), Float32(b).ir_value(loc=loc, ip=ip), mask.ir_value(loc=loc, ip=ip)],
# ".pred p1, p2;\n"
# "setp.lt.f32 p1, $1, $2;\n"
# "vote.sync.any.pred p2, p1, $3;\n"
# "selp.u32 $0, 1, 0, p2;",
# # "selp.u32 $0, 1, 0, p1;",
# "=r,f,f,r",
# has_side_effects=False,
# is_align_stack=False,
# asm_dialect=llvm.AsmDialect.AD_ATT,
# )
# )
@cute.jit
def shuffle_sync(
value: cute.Numeric,
offset: cute.typing.Int,
width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE,
) -> cute.Numeric:
assert value.width % 32 == 0, "value type must be a multiple of 32 bits"
# 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000
mask = cute.arch.WARP_SIZE - width
clamp = cute.arch.WARP_SIZE - 1
mask_and_clamp = mask << 8 | clamp
# important: need stride 1 and not 0 for recast_tensor to work
val = cute.make_rmem_tensor(cute.make_layout((1,), stride=(1,)), type(value))
val[0] = value
val_i32 = cute.recast_tensor(val, cutlass.Int32)
for i in cutlass.range_constexpr(cute.size(val_i32)):
val_i32[i] = cute.arch.shuffle_sync(val_i32[i], offset, mask_and_clamp=mask_and_clamp)
return val[0]
@dsl_user_op
def shr_u32(val: cutlass.Uint32, shift: cutlass.Uint32, *, loc=None, ip=None) -> cutlass.Uint32:
return cutlass.Uint32(
llvm.inline_asm(
T.i32(),
[
cutlass.Uint32(val).ir_value(loc=loc, ip=ip),
cutlass.Uint32(shift).ir_value(loc=loc, ip=ip),
],
"shr.s32 $0, $1, $2;",
"=r,r,r",
has_side_effects=False,
is_align_stack=False,
asm_dialect=llvm.AsmDialect.AD_ATT,
)
)
@cute.jit
def warp_prefix_sum(val: cutlass.Int32, lane: Optional[cutlass.Int32] = None) -> cutlass.Int32:
if const_expr(lane is None):
lane = cute.arch.lane_idx()
# if cute.arch.thread_idx()[0] >= 128 and cute.arch.thread_idx()[0] < 128 + 32 and cute.arch.block_idx()[0] == 0: cute.printf("tidx = %d, val = %d", cute.arch.thread_idx()[0] % 32, val)
for i in cutlass.range_constexpr(int(math.log2(cute.arch.WARP_SIZE))):
offset = 1 << i
# Very important that we set mask_and_clamp to 0
partial_sum = cute.arch.shuffle_sync_up(val, offset=offset, mask_and_clamp=0)
if lane >= offset:
val += partial_sum
# if cute.arch.thread_idx()[0] >= 128 and cute.arch.thread_idx()[0] < 128 + 32 and cute.arch.block_idx()[0] == 0: cute.printf("tidx = %d, partial_sum = %d, val = %d", cute.arch.thread_idx()[0] % 32, partial_sum, val)
return val
@dsl_user_op
def cvt_f16x2_f32(
a: float | Float32, b: float | Float32, to_dtype: Type, *, loc=None, ip=None
) -> cutlass.Int32:
assert to_dtype in [cutlass.BFloat16, cutlass.Float16], "to_dtype must be BFloat16 or Float16"
return cutlass.Int32(
llvm.inline_asm(
T.i32(),
[Float32(a).ir_value(loc=loc, ip=ip), Float32(b).ir_value(loc=loc, ip=ip)],
f"cvt.rn.{'bf16x2' if to_dtype is cutlass.BFloat16 else 'f16x2'}.f32 $0, $2, $1;",
"=r,f,f",
has_side_effects=False,
is_align_stack=False,
asm_dialect=llvm.AsmDialect.AD_ATT,
)
)
@overload
def cvt_f16(src: cute.Tensor, dst: cute.Tensor) -> None: ...
@overload
def cvt_f16(src: cute.Tensor, dtype: Type[cute.Numeric]) -> cute.Tensor: ...
@cute.jit
def cvt_f16(src: cute.Tensor, dst_or_dtype):
"""Convert Float32 tensor to Float16/BFloat16.
Args:
src: Source tensor with Float32 element type
dst_or_dtype: Either a destination tensor or a dtype (Float16/BFloat16)
Returns:
None if dst is a tensor, or a new tensor if dtype is provided
"""
if const_expr(isinstance(dst_or_dtype, type)):
# dtype variant: create new tensor and call the tensor variant
dtype = dst_or_dtype
dst = cute.make_fragment(src.shape, dtype)
cvt_f16(src, dst)
return dst
else:
# tensor variant: write to dst
dst = dst_or_dtype
assert cute.size(dst.shape) == cute.size(src.shape), "dst and src must have the same size"
assert cute.size(src.shape) % 2 == 0, "src must have an even number of elements"
assert dst.element_type in [cutlass.BFloat16, cutlass.Float16], (
"dst must be BFloat16 or Float16"
)
assert src.element_type is Float32, "src must be Float32"
dst_i32 = cute.recast_tensor(dst, cutlass.Int32)
assert cute.size(dst_i32.shape) * 2 == cute.size(src.shape)
for i in cutlass.range_constexpr(cute.size(dst_i32)):
dst_i32[i] = cvt_f16x2_f32(src[2 * i], src[2 * i + 1], dst.element_type)
@dsl_user_op
@cute.jit
def evaluate_polynomial(x: Float32, poly: Tuple[Float32, ...], *, loc=None, ip=None) -> Float32:
deg = len(poly) - 1
out = poly[deg]
for i in cutlass.range_constexpr(deg - 1, -1, -1):
out = out * x + poly[i]
return out
@dsl_user_op
@cute.jit
def evaluate_polynomial_2(
x: Float32, y: Float32, poly: Tuple[Float32, ...], *, loc=None, ip=None
) -> Tuple[Float32, Float32]:
deg = len(poly) - 1
out = (poly[deg], poly[deg])
for i in cutlass.range_constexpr(deg - 1, -1, -1):
out = cute.arch.fma_packed_f32x2(out, (x, y), (poly[i], poly[i]))
return out
@dsl_user_op
def add_round_down(x: float | Float32, y: float | Float32, *, loc=None, ip=None) -> Float32:
# There's probably a way to call llvm or nvvm to do this instead of ptx
return cutlass.Float32(
llvm.inline_asm(
T.f32(),
[Float32(x).ir_value(loc=loc, ip=ip), Float32(y).ir_value(loc=loc, ip=ip)],
"add.rm.ftz.f32 $0, $1, $2;",
"=f,f,f",
has_side_effects=False,
is_align_stack=False,
asm_dialect=llvm.AsmDialect.AD_ATT,
)
)
@dsl_user_op
def combine_int_frac_ex2(x_rounded: Float32, frac_ex2: Float32, *, loc=None, ip=None) -> Float32:
return cutlass.Float32(
llvm.inline_asm(
T.f32(),
[
Float32(x_rounded).ir_value(loc=loc, ip=ip),
Float32(frac_ex2).ir_value(loc=loc, ip=ip),
],
"{\n\t"
".reg .s32 x_rounded_i, frac_ex_i, x_rounded_e, out_i;\n\t"
"mov.b32 x_rounded_i, $1;\n\t"
"mov.b32 frac_ex_i, $2;\n\t"
"shl.b32 x_rounded_e, x_rounded_i, 23;\n\t"
# add.u32 generates IMAD instruction and add.s32 generates LEA instruction
# IMAD uses the FMA pipeline and LEA uses the ALU pipeline, afaik
"add.s32 out_i, x_rounded_e, frac_ex_i;\n\t"
"mov.b32 $0, out_i;\n\t"
"}\n",
"=f,f,f",
has_side_effects=False,
is_align_stack=False,
asm_dialect=llvm.AsmDialect.AD_ATT,
)
)
@dsl_user_op
def ex2_emulation(x: Float32, *, poly_degree: int = 3, loc=None, ip=None) -> Float32:
assert poly_degree in POLY_EX2, f"Polynomial degree {poly_degree} not supported"
# We assume x <= 127.0
fp32_round_int = float(2**23 + 2**22)
x_clamped = cute.arch.fmax(x, -127.0)
# We want to round down here, so that the fractional part is in [0, 1)
x_rounded = add_round_down(x_clamped, fp32_round_int, loc=loc, ip=ip)
# The integer floor of x is now in the last 8 bits of x_rounded
# We assume the next 2 ops round to nearest even. The rounding mode is important.
x_rounded_back = x_rounded - fp32_round_int
x_frac = x_clamped - x_rounded_back
x_frac_ex2 = evaluate_polynomial(x_frac, POLY_EX2[poly_degree], loc=loc, ip=ip)
return combine_int_frac_ex2(x_rounded, x_frac_ex2, loc=loc, ip=ip)
# TODO: check that the ex2_emulation_2 produces the same SASS as the ptx version
@dsl_user_op
def ex2_emulation_2(
x: Float32, y: Float32, *, poly_degree: int = 3, loc=None, ip=None
) -> Tuple[Float32, Float32]:
# We assume x <= 127.0 and y <= 127.0
fp32_round_int = float(2**23 + 2**22)
xy_clamped = (cute.arch.fmax(x, -127.0), cute.arch.fmax(y, -127.0))
# We want to round down here, so that the fractional part is in [0, 1)
xy_rounded = cute.arch.add_packed_f32x2(xy_clamped, (fp32_round_int, fp32_round_int), rnd="rm")
# The integer floor of x & y are now in the last 8 bits of xy_rounded
# We want the next 2 ops to round to nearest even. The rounding mode is important.
xy_rounded_back = activation.sub_packed_f32x2(
xy_rounded, (fp32_round_int, fp32_round_int)
)
xy_frac = activation.sub_packed_f32x2(xy_clamped, xy_rounded_back)
xy_frac_ex2 = evaluate_polynomial_2(*xy_frac, POLY_EX2[poly_degree], loc=loc, ip=ip)
x_out = combine_int_frac_ex2(xy_rounded[0], xy_frac_ex2[0], loc=loc, ip=ip)
y_out = combine_int_frac_ex2(xy_rounded[1], xy_frac_ex2[1], loc=loc, ip=ip)
return x_out, y_out
@dsl_user_op
def e2e_asm2(x: Float32, y: Float32, *, loc=None, ip=None) -> Tuple[Float32, Float32]:
out_f32x2 = llvm.inline_asm(
llvm.StructType.get_literal([T.f32(), T.f32()]),
[Float32(x).ir_value(loc=loc, ip=ip), Float32(y, loc=loc, ip=ip).ir_value()],
"{\n\t"
".reg .f32 f1, f2, f3, f4, f5, f6, f7;\n\t"
".reg .b64 l1, l2, l3, l4, l5, l6, l7, l8, l9, l10;\n\t"
".reg .s32 r1, r2, r3, r4, r5, r6, r7, r8;\n\t"
"max.ftz.f32 f1, $2, 0fC2FE0000;\n\t"
"max.ftz.f32 f2, $3, 0fC2FE0000;\n\t"
"mov.b64 l1, {f1, f2};\n\t"
"mov.f32 f3, 0f4B400000;\n\t"
"mov.b64 l2, {f3, f3};\n\t"
"add.rm.ftz.f32x2 l7, l1, l2;\n\t"
"sub.rn.ftz.f32x2 l8, l7, l2;\n\t"
"sub.rn.ftz.f32x2 l9, l1, l8;\n\t"
"mov.f32 f7, 0f3D9DF09D;\n\t"
"mov.b64 l6, {f7, f7};\n\t"
"mov.f32 f6, 0f3E6906A4;\n\t"
"mov.b64 l5, {f6, f6};\n\t"
"mov.f32 f5, 0f3F31F519;\n\t"
"mov.b64 l4, {f5, f5};\n\t"
"mov.f32 f4, 0f3F800000;\n\t"
"mov.b64 l3, {f4, f4};\n\t"
"fma.rn.ftz.f32x2 l10, l9, l6, l5;\n\t"
"fma.rn.ftz.f32x2 l10, l10, l9, l4;\n\t"
"fma.rn.ftz.f32x2 l10, l10, l9, l3;\n\t"
"mov.b64 {r1, r2}, l7;\n\t"
"mov.b64 {r3, r4}, l10;\n\t"
"shl.b32 r5, r1, 23;\n\t"
"add.s32 r7, r5, r3;\n\t"
"shl.b32 r6, r2, 23;\n\t"
"add.s32 r8, r6, r4;\n\t"
"mov.b32 $0, r7;\n\t"
"mov.b32 $1, r8;\n\t"
"}\n",
"=r,=r,f,f",
has_side_effects=False,
is_align_stack=False,
asm_dialect=llvm.AsmDialect.AD_ATT,
)
out0 = Float32(llvm.extractvalue(T.f32(), out_f32x2, [0], loc=loc, ip=ip))
out1 = Float32(llvm.extractvalue(T.f32(), out_f32x2, [1], loc=loc, ip=ip))
return out0, out1
@dsl_user_op
def domain_offset_aligned(
coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None
) -> cute.Tensor:
assert isinstance(tensor.iterator, cute.Pointer)
# We assume that applying the offset does not change the pointer alignment
new_ptr = cute.make_ptr(
tensor.element_type,
elem_pointer(tensor, coord).toint(),
tensor.memspace,
assumed_align=tensor.iterator.alignment,
)
return cute.make_tensor(new_ptr, tensor.layout)
@cute.jit
def scalar_to_ssa(a: cute.Numeric, dtype) -> cute.TensorSSA:
"""Convert a scalar to a cute TensorSSA of shape (1,) and given dtype"""
vec = cute.make_fragment(1, dtype)
vec[0] = a
return vec.load()
def ssa_to_scalar(val):
"""Could inline but nice for reflecting the above api"""
return val[0]