# Copyright (c) 2025, Tri Dao. import math import hashlib import inspect from typing import Type, Callable, Optional, Tuple, overload import cutlass import cutlass.cute as cute from cutlass import Float32, const_expr from cutlass.cutlass_dsl import T, dsl_user_op from cutlass._mlir.dialects import nvvm, llvm from cutlass.cute.runtime import from_dlpack from .quack import activation _MIXER_ATTRS = ("__vec_size__",) # Obtained from sollya: # fpminimax(exp(x * log(2.0)), 1, [|1,24...|],[0;1],relative); POLY_EX2 = { 0: (1.0), 1: ( 1.0, 0.922497093677520751953125, ), 2: ( 1.0, 0.6657850742340087890625, 0.330107033252716064453125, ), 3: ( 1.0, 0.695146143436431884765625, 0.227564394474029541015625, 0.077119089663028717041015625, ), 4: ( 1.0, 0.693042695522308349609375, 0.2412912547588348388671875, 5.2225358784198760986328125e-2, 1.3434938155114650726318359375e-2, ), 5: ( 1.0, 0.693151414394378662109375, 0.24016360938549041748046875, 5.5802188813686370849609375e-2, 9.01452265679836273193359375e-3, 1.86810153536498546600341796875e-3, ), } def _compute_base_hash(func: Callable) -> str: """Compute hash from source code or bytecode and closure values.""" try: data = inspect.getsource(func).encode() except (OSError, TypeError): if hasattr(func, "__code__") and func.__code__ is not None: data = func.__code__.co_code else: data = repr(func).encode() hasher = hashlib.sha256(data) if hasattr(func, "__closure__") and func.__closure__ is not None: for cell in func.__closure__: hasher.update(repr(cell.cell_contents).encode()) return hasher.hexdigest() def hash_callable( func: Callable, mixer_attrs: Tuple[str] = _MIXER_ATTRS, set_cute_hash: bool = True ) -> str: """Hash a callable based on the source code or bytecode and closure values. Fast-path: if the callable (or its __wrapped__ base) has a ``__cute_hash__`` attribute, that value is returned immediately as the base hash, then metadata dunders are mixed in to produce the final dict-key hash. set_cute_hash: whether or not to set func.__cute_hash__ """ # Resolve base hash if hasattr(func, "__cute_hash__"): base_hash = func.__cute_hash__ else: # Unwrap decorated functions (e.g., cute.jit wrappers). base_func = getattr(func, "__wrapped__", func) if hasattr(base_func, "__cute_hash__"): base_hash = base_func.__cute_hash__ else: base_hash = _compute_base_hash(base_func) if set_cute_hash: base_func.__cute_hash__ = base_hash # Mix in mutable metadata dunders mixer_values = tuple(getattr(func, attr, None) for attr in mixer_attrs) if all(v is None for v in mixer_values): return base_hash hasher = hashlib.sha256(base_hash.encode()) for attr, val in zip(_MIXER_ATTRS, mixer_values): hasher.update(f"{attr}={val!r}".encode()) return hasher.hexdigest() def create_softcap_scoremod(softcap_val): inv_softcap = 1.0 / softcap_val @cute.jit def scoremod_premask_fn(acc_S_SSA, batch_idx, head_idx, q_idx, kv_idx, aux_tensors): scores = acc_S_SSA * inv_softcap return scores * cute.math.tanh(scores, fastmath=True) return scoremod_premask_fn def convert_from_dlpack(x, leading_dim, alignment=16, divisibility=1) -> cute.Tensor: return ( from_dlpack(x, assumed_align=alignment) .mark_layout_dynamic(leading_dim=leading_dim) .mark_compact_shape_dynamic( mode=leading_dim, stride_order=x.dim_order(), divisibility=divisibility ) ) def convert_from_dlpack_leading_static( x, leading_dim, alignment=16, static_modes=None, stride_order=None ) -> cute.Tensor: if stride_order is None: stride_order = x.dim_order() x_ = from_dlpack(x, assumed_align=alignment) for i in range(x.ndim): if i != leading_dim and (static_modes is None or i not in static_modes): x_ = x_.mark_compact_shape_dynamic(mode=i, stride_order=stride_order) return x_ def make_tiled_copy_A( copy_atom: cute.CopyAtom, tiled_mma: cute.TiledMma, swapAB: cutlass.Constexpr[bool] = False ) -> cute.TiledCopy: if const_expr(swapAB): return cute.make_tiled_copy_B(copy_atom, tiled_mma) else: return cute.make_tiled_copy_A(copy_atom, tiled_mma) def make_tiled_copy_B( copy_atom: cute.CopyAtom, tiled_mma: cute.TiledMma, swapAB: cutlass.Constexpr[bool] = False ) -> cute.TiledCopy: if const_expr(swapAB): return cute.make_tiled_copy_A(copy_atom, tiled_mma) else: return cute.make_tiled_copy_B(copy_atom, tiled_mma) def mma_make_fragment_A( smem: cute.Tensor, thr_mma: cute.core.ThrMma, swapAB: cutlass.Constexpr[bool] = False ) -> cute.Tensor: if const_expr(swapAB): return mma_make_fragment_B(smem, thr_mma) else: return thr_mma.make_fragment_A(thr_mma.partition_A(smem)) def mma_make_fragment_B( smem: cute.Tensor, thr_mma: cute.core.ThrMma, swapAB: cutlass.Constexpr[bool] = False ) -> cute.Tensor: if const_expr(swapAB): return mma_make_fragment_A(smem, thr_mma) else: return thr_mma.make_fragment_B(thr_mma.partition_B(smem)) def get_smem_store_atom( arch: cutlass.Constexpr[int], element_type: Type[cute.Numeric], transpose: bool = False ) -> cute.CopyAtom: if const_expr(arch < 90 or element_type.width != 16): return cute.make_copy_atom( cute.nvgpu.CopyUniversalOp(), element_type, num_bits_per_copy=2 * element_type.width, ) else: return cute.make_copy_atom( cute.nvgpu.warp.StMatrix8x8x16bOp(transpose=transpose, num_matrices=4), element_type, ) @cute.jit def warp_reduce( val: cute.TensorSSA | cute.Numeric, op: Callable, width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE, ) -> cute.TensorSSA | cute.Numeric: if const_expr(isinstance(val, cute.TensorSSA)): res = cute.make_fragment(val.shape, val.dtype) res.store(val) for i in cutlass.range_constexpr(cute.size(val.shape)): res[i] = warp_reduce(res[i], op, width) return res.load() else: for i in cutlass.range_constexpr(int(math.log2(width))): val = op(val, cute.arch.shuffle_sync_bfly(val, offset=1 << i)) return val @dsl_user_op def fmax( a: float | Float32, b: float | Float32, c: float | Float32 | None = None, *, loc=None, ip=None ) -> Float32: from cutlass import CUDA_VERSION # * NVVM call based on nvvm version if CUDA_VERSION.major == 12 and CUDA_VERSION.minor == 9: # Old API: requires explicit result type as first positional argument return Float32( nvvm.fmax( T.f32(), Float32(a).ir_value(loc=loc, ip=ip), Float32(b).ir_value(loc=loc, ip=ip), c=Float32(c).ir_value(loc=loc, ip=ip) if c is not None else None, loc=loc, ip=ip, ) ) else: # New API: infers result type automatically return Float32( nvvm.fmax( Float32(a).ir_value(loc=loc, ip=ip), Float32(b).ir_value(loc=loc, ip=ip), c=Float32(c).ir_value(loc=loc, ip=ip) if c is not None else None, loc=loc, ip=ip, ) ) @cute.jit def fmax_reduce( x: cute.TensorSSA, init_val: float | Float32 | None = None, arch: cutlass.Constexpr[int] = 80 ) -> Float32: if const_expr(arch < 100 or cute.size(x.shape) % 8 != 0): # if const_expr(init_val is None): # init_val = -cutlass.Float32.if # return x.reduce(cute.ReductionOp.MAX, init_val, 0) res = cute.make_fragment(x.shape, Float32) res.store(x) # local_max = [res[0], res[1]] # for i in cutlass.range_constexpr(2, cute.size(x.shape), 2): # local_max[0] = fmax(local_max[0], res[i + 0]) # local_max[1] = fmax(local_max[1], res[i + 1]) # local_max[0] = fmax(local_max[0], local_max[1]) # return local_max[0] if const_expr(init_val is None) else fmax(local_max[0], init_val) local_max = [res[0], res[1], res[2], res[3]] for i in cutlass.range_constexpr(4, cute.size(x.shape), 4): local_max[0] = fmax(local_max[0], res[i + 0]) local_max[1] = fmax(local_max[1], res[i + 1]) local_max[2] = fmax(local_max[2], res[i + 2]) local_max[3] = fmax(local_max[3], res[i + 3]) local_max[0] = fmax(local_max[0], local_max[1]) local_max[2] = fmax(local_max[2], local_max[3]) local_max[0] = fmax(local_max[0], local_max[2]) return local_max[0] if const_expr(init_val is None) else fmax(local_max[0], init_val) else: # [2025-06-15] x.reduce only seems to use 50% 3-input max and 50% 2-input max # We instead force the 3-input max. res = cute.make_fragment(x.shape, Float32) res.store(x) local_max_0 = ( fmax(init_val, res[0], res[1]) if const_expr(init_val is not None) else fmax(res[0], res[1]) ) local_max = [ local_max_0, fmax(res[2], res[3]), fmax(res[4], res[5]), fmax(res[6], res[7]), ] for i in cutlass.range_constexpr(8, cute.size(x.shape), 8): local_max[0] = fmax(local_max[0], res[i], res[i + 1]) local_max[1] = fmax(local_max[1], res[i + 2], res[i + 3]) local_max[2] = fmax(local_max[2], res[i + 4], res[i + 5]) local_max[3] = fmax(local_max[3], res[i + 6], res[i + 7]) local_max[0] = fmax(local_max[0], local_max[1]) return fmax(local_max[0], local_max[2], local_max[3]) @cute.jit def fadd_reduce( x: cute.TensorSSA, init_val: float | Float32 | None = None, arch: cutlass.Constexpr[int] = 80 ) -> Float32: if const_expr(arch < 100 or cute.size(x.shape) % 8 != 0): if const_expr(init_val is None): init_val = Float32.zero return x.reduce(cute.ReductionOp.ADD, init_val, 0) # res = cute.make_fragment(x.shape, Float32) # res.store(x) # local_sum = [res[0], res[1], res[2], res[3]] # for i in cutlass.range_constexpr(4, cute.size(x.shape), 4): # local_sum[0] += res[i + 0] # local_sum[1] += res[i + 1] # local_sum[2] += res[i + 2] # local_sum[3] += res[i + 3] # local_sum[0] += local_sum[1] # local_sum[2] += local_sum[3] # local_sum[0] += local_sum[2] # return local_sum[0] if const_expr(init_val is None) else local_sum[0] + init_val else: res = cute.make_fragment(x.shape, Float32) res.store(x) local_sum_0 = ( cute.arch.add_packed_f32x2((init_val, 0.0), (res[0], res[1])) # cute.arch.add_packed_f32x2((init_val / 2, init_val / 2), (res[0], res[1])) if const_expr(init_val is not None) else (res[0], res[1]) ) local_sum = [local_sum_0, (res[2], res[3]), (res[4], res[5]), (res[6], res[7])] for i in cutlass.range_constexpr(8, cute.size(x.shape), 8): local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], (res[i + 0], res[i + 1])) local_sum[1] = cute.arch.add_packed_f32x2(local_sum[1], (res[i + 2], res[i + 3])) local_sum[2] = cute.arch.add_packed_f32x2(local_sum[2], (res[i + 4], res[i + 5])) local_sum[3] = cute.arch.add_packed_f32x2(local_sum[3], (res[i + 6], res[i + 7])) local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], local_sum[1]) local_sum[2] = cute.arch.add_packed_f32x2(local_sum[2], local_sum[3]) local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], local_sum[2]) return local_sum[0][0] + local_sum[0][1] @dsl_user_op def atomic_add_fp32(a: float | Float32, gmem_ptr: cute.Pointer, *, loc=None, ip=None) -> None: # gmem_ptr_i64 = gmem_ptr.toint(loc=loc, ip=ip).ir_value() # # cache_hint = cutlass.Int64(0x12F0000000000000) # llvm.inline_asm( # None, # [gmem_ptr_i64, Float32(a).ir_value(loc=loc, ip=ip)], # # [gmem_ptr_i64, Float32(a).ir_value(loc=loc, ip=ip), cache_hint.ir_value()], # "red.global.add.f32 [$0], $1;", # # "red.global.add.L2::cache_hint.f32 [$0], $1, 0x12F0000000000000;", # # "red.global.add.L2::cache_hint.f32 [$0], $1, $2;", # "l,f", # # "l,f,l", # has_side_effects=True, # is_align_stack=False, # asm_dialect=llvm.AsmDialect.AD_ATT, # ) nvvm.atomicrmw( res=T.f32(), op=nvvm.AtomicOpKind.FADD, ptr=gmem_ptr.llvm_ptr, a=Float32(a).ir_value() ) @dsl_user_op def elem_pointer(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cute.Pointer: return x.iterator + cute.crd2idx(coord, x.layout, loc=loc, ip=ip) @cute.jit def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor: # Only compute predicates for the "k" dimension. For the mn dimension, we will use "if" tApA = cute.make_fragment( cute.make_layout( (cute.size(tAcA, mode=[0, 1]), cute.size(tAcA, mode=[1]), cute.size(tAcA, mode=[2])), stride=(cute.size(tAcA, mode=[2]), 0, 1), ), cutlass.Boolean, ) for rest_v in cutlass.range_constexpr(tApA.shape[0]): for rest_k in cutlass.range_constexpr(tApA.shape[2]): tApA[rest_v, 0, rest_k] = cute.elem_less(tAcA[(0, rest_v), 0, rest_k][1], limit) return tApA def canonical_warp_group_idx(sync: bool = True) -> cutlass.Int32: warp_group_idx = cute.arch.thread_idx()[0] // 128 if const_expr(sync): warp_group_idx = cute.arch.make_warp_uniform(warp_group_idx) return warp_group_idx # @dsl_user_op # def warp_vote_any_lt(a: float | Float32, b: float | Float32, *, loc=None, ip=None) -> cutlass.Boolean: # mask = cutlass.Int32(-1) # return cutlass.Boolean( # llvm.inline_asm( # T.i32(), # [Float32(a).ir_value(loc=loc, ip=ip), Float32(b).ir_value(loc=loc, ip=ip), mask.ir_value(loc=loc, ip=ip)], # ".pred p1, p2;\n" # "setp.lt.f32 p1, $1, $2;\n" # "vote.sync.any.pred p2, p1, $3;\n" # "selp.u32 $0, 1, 0, p2;", # # "selp.u32 $0, 1, 0, p1;", # "=r,f,f,r", # has_side_effects=False, # is_align_stack=False, # asm_dialect=llvm.AsmDialect.AD_ATT, # ) # ) @cute.jit def shuffle_sync( value: cute.Numeric, offset: cute.typing.Int, width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE, ) -> cute.Numeric: assert value.width % 32 == 0, "value type must be a multiple of 32 bits" # 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000 mask = cute.arch.WARP_SIZE - width clamp = cute.arch.WARP_SIZE - 1 mask_and_clamp = mask << 8 | clamp # important: need stride 1 and not 0 for recast_tensor to work val = cute.make_rmem_tensor(cute.make_layout((1,), stride=(1,)), type(value)) val[0] = value val_i32 = cute.recast_tensor(val, cutlass.Int32) for i in cutlass.range_constexpr(cute.size(val_i32)): val_i32[i] = cute.arch.shuffle_sync(val_i32[i], offset, mask_and_clamp=mask_and_clamp) return val[0] @dsl_user_op def shr_u32(val: cutlass.Uint32, shift: cutlass.Uint32, *, loc=None, ip=None) -> cutlass.Uint32: return cutlass.Uint32( llvm.inline_asm( T.i32(), [ cutlass.Uint32(val).ir_value(loc=loc, ip=ip), cutlass.Uint32(shift).ir_value(loc=loc, ip=ip), ], "shr.s32 $0, $1, $2;", "=r,r,r", has_side_effects=False, is_align_stack=False, asm_dialect=llvm.AsmDialect.AD_ATT, ) ) @cute.jit def warp_prefix_sum(val: cutlass.Int32, lane: Optional[cutlass.Int32] = None) -> cutlass.Int32: if const_expr(lane is None): lane = cute.arch.lane_idx() # if cute.arch.thread_idx()[0] >= 128 and cute.arch.thread_idx()[0] < 128 + 32 and cute.arch.block_idx()[0] == 0: cute.printf("tidx = %d, val = %d", cute.arch.thread_idx()[0] % 32, val) for i in cutlass.range_constexpr(int(math.log2(cute.arch.WARP_SIZE))): offset = 1 << i # Very important that we set mask_and_clamp to 0 partial_sum = cute.arch.shuffle_sync_up(val, offset=offset, mask_and_clamp=0) if lane >= offset: val += partial_sum # if cute.arch.thread_idx()[0] >= 128 and cute.arch.thread_idx()[0] < 128 + 32 and cute.arch.block_idx()[0] == 0: cute.printf("tidx = %d, partial_sum = %d, val = %d", cute.arch.thread_idx()[0] % 32, partial_sum, val) return val @dsl_user_op def cvt_f16x2_f32( a: float | Float32, b: float | Float32, to_dtype: Type, *, loc=None, ip=None ) -> cutlass.Int32: assert to_dtype in [cutlass.BFloat16, cutlass.Float16], "to_dtype must be BFloat16 or Float16" return cutlass.Int32( llvm.inline_asm( T.i32(), [Float32(a).ir_value(loc=loc, ip=ip), Float32(b).ir_value(loc=loc, ip=ip)], f"cvt.rn.{'bf16x2' if to_dtype is cutlass.BFloat16 else 'f16x2'}.f32 $0, $2, $1;", "=r,f,f", has_side_effects=False, is_align_stack=False, asm_dialect=llvm.AsmDialect.AD_ATT, ) ) @overload def cvt_f16(src: cute.Tensor, dst: cute.Tensor) -> None: ... @overload def cvt_f16(src: cute.Tensor, dtype: Type[cute.Numeric]) -> cute.Tensor: ... @cute.jit def cvt_f16(src: cute.Tensor, dst_or_dtype): """Convert Float32 tensor to Float16/BFloat16. Args: src: Source tensor with Float32 element type dst_or_dtype: Either a destination tensor or a dtype (Float16/BFloat16) Returns: None if dst is a tensor, or a new tensor if dtype is provided """ if const_expr(isinstance(dst_or_dtype, type)): # dtype variant: create new tensor and call the tensor variant dtype = dst_or_dtype dst = cute.make_fragment(src.shape, dtype) cvt_f16(src, dst) return dst else: # tensor variant: write to dst dst = dst_or_dtype assert cute.size(dst.shape) == cute.size(src.shape), "dst and src must have the same size" assert cute.size(src.shape) % 2 == 0, "src must have an even number of elements" assert dst.element_type in [cutlass.BFloat16, cutlass.Float16], ( "dst must be BFloat16 or Float16" ) assert src.element_type is Float32, "src must be Float32" dst_i32 = cute.recast_tensor(dst, cutlass.Int32) assert cute.size(dst_i32.shape) * 2 == cute.size(src.shape) for i in cutlass.range_constexpr(cute.size(dst_i32)): dst_i32[i] = cvt_f16x2_f32(src[2 * i], src[2 * i + 1], dst.element_type) @dsl_user_op @cute.jit def evaluate_polynomial(x: Float32, poly: Tuple[Float32, ...], *, loc=None, ip=None) -> Float32: deg = len(poly) - 1 out = poly[deg] for i in cutlass.range_constexpr(deg - 1, -1, -1): out = out * x + poly[i] return out @dsl_user_op @cute.jit def evaluate_polynomial_2( x: Float32, y: Float32, poly: Tuple[Float32, ...], *, loc=None, ip=None ) -> Tuple[Float32, Float32]: deg = len(poly) - 1 out = (poly[deg], poly[deg]) for i in cutlass.range_constexpr(deg - 1, -1, -1): out = cute.arch.fma_packed_f32x2(out, (x, y), (poly[i], poly[i])) return out @dsl_user_op def add_round_down(x: float | Float32, y: float | Float32, *, loc=None, ip=None) -> Float32: # There's probably a way to call llvm or nvvm to do this instead of ptx return cutlass.Float32( llvm.inline_asm( T.f32(), [Float32(x).ir_value(loc=loc, ip=ip), Float32(y).ir_value(loc=loc, ip=ip)], "add.rm.ftz.f32 $0, $1, $2;", "=f,f,f", has_side_effects=False, is_align_stack=False, asm_dialect=llvm.AsmDialect.AD_ATT, ) ) @dsl_user_op def combine_int_frac_ex2(x_rounded: Float32, frac_ex2: Float32, *, loc=None, ip=None) -> Float32: return cutlass.Float32( llvm.inline_asm( T.f32(), [ Float32(x_rounded).ir_value(loc=loc, ip=ip), Float32(frac_ex2).ir_value(loc=loc, ip=ip), ], "{\n\t" ".reg .s32 x_rounded_i, frac_ex_i, x_rounded_e, out_i;\n\t" "mov.b32 x_rounded_i, $1;\n\t" "mov.b32 frac_ex_i, $2;\n\t" "shl.b32 x_rounded_e, x_rounded_i, 23;\n\t" # add.u32 generates IMAD instruction and add.s32 generates LEA instruction # IMAD uses the FMA pipeline and LEA uses the ALU pipeline, afaik "add.s32 out_i, x_rounded_e, frac_ex_i;\n\t" "mov.b32 $0, out_i;\n\t" "}\n", "=f,f,f", has_side_effects=False, is_align_stack=False, asm_dialect=llvm.AsmDialect.AD_ATT, ) ) @dsl_user_op def ex2_emulation(x: Float32, *, poly_degree: int = 3, loc=None, ip=None) -> Float32: assert poly_degree in POLY_EX2, f"Polynomial degree {poly_degree} not supported" # We assume x <= 127.0 fp32_round_int = float(2**23 + 2**22) x_clamped = cute.arch.fmax(x, -127.0) # We want to round down here, so that the fractional part is in [0, 1) x_rounded = add_round_down(x_clamped, fp32_round_int, loc=loc, ip=ip) # The integer floor of x is now in the last 8 bits of x_rounded # We assume the next 2 ops round to nearest even. The rounding mode is important. x_rounded_back = x_rounded - fp32_round_int x_frac = x_clamped - x_rounded_back x_frac_ex2 = evaluate_polynomial(x_frac, POLY_EX2[poly_degree], loc=loc, ip=ip) return combine_int_frac_ex2(x_rounded, x_frac_ex2, loc=loc, ip=ip) # TODO: check that the ex2_emulation_2 produces the same SASS as the ptx version @dsl_user_op def ex2_emulation_2( x: Float32, y: Float32, *, poly_degree: int = 3, loc=None, ip=None ) -> Tuple[Float32, Float32]: # We assume x <= 127.0 and y <= 127.0 fp32_round_int = float(2**23 + 2**22) xy_clamped = (cute.arch.fmax(x, -127.0), cute.arch.fmax(y, -127.0)) # We want to round down here, so that the fractional part is in [0, 1) xy_rounded = cute.arch.add_packed_f32x2(xy_clamped, (fp32_round_int, fp32_round_int), rnd="rm") # The integer floor of x & y are now in the last 8 bits of xy_rounded # We want the next 2 ops to round to nearest even. The rounding mode is important. xy_rounded_back = activation.sub_packed_f32x2( xy_rounded, (fp32_round_int, fp32_round_int) ) xy_frac = activation.sub_packed_f32x2(xy_clamped, xy_rounded_back) xy_frac_ex2 = evaluate_polynomial_2(*xy_frac, POLY_EX2[poly_degree], loc=loc, ip=ip) x_out = combine_int_frac_ex2(xy_rounded[0], xy_frac_ex2[0], loc=loc, ip=ip) y_out = combine_int_frac_ex2(xy_rounded[1], xy_frac_ex2[1], loc=loc, ip=ip) return x_out, y_out @dsl_user_op def e2e_asm2(x: Float32, y: Float32, *, loc=None, ip=None) -> Tuple[Float32, Float32]: out_f32x2 = llvm.inline_asm( llvm.StructType.get_literal([T.f32(), T.f32()]), [Float32(x).ir_value(loc=loc, ip=ip), Float32(y, loc=loc, ip=ip).ir_value()], "{\n\t" ".reg .f32 f1, f2, f3, f4, f5, f6, f7;\n\t" ".reg .b64 l1, l2, l3, l4, l5, l6, l7, l8, l9, l10;\n\t" ".reg .s32 r1, r2, r3, r4, r5, r6, r7, r8;\n\t" "max.ftz.f32 f1, $2, 0fC2FE0000;\n\t" "max.ftz.f32 f2, $3, 0fC2FE0000;\n\t" "mov.b64 l1, {f1, f2};\n\t" "mov.f32 f3, 0f4B400000;\n\t" "mov.b64 l2, {f3, f3};\n\t" "add.rm.ftz.f32x2 l7, l1, l2;\n\t" "sub.rn.ftz.f32x2 l8, l7, l2;\n\t" "sub.rn.ftz.f32x2 l9, l1, l8;\n\t" "mov.f32 f7, 0f3D9DF09D;\n\t" "mov.b64 l6, {f7, f7};\n\t" "mov.f32 f6, 0f3E6906A4;\n\t" "mov.b64 l5, {f6, f6};\n\t" "mov.f32 f5, 0f3F31F519;\n\t" "mov.b64 l4, {f5, f5};\n\t" "mov.f32 f4, 0f3F800000;\n\t" "mov.b64 l3, {f4, f4};\n\t" "fma.rn.ftz.f32x2 l10, l9, l6, l5;\n\t" "fma.rn.ftz.f32x2 l10, l10, l9, l4;\n\t" "fma.rn.ftz.f32x2 l10, l10, l9, l3;\n\t" "mov.b64 {r1, r2}, l7;\n\t" "mov.b64 {r3, r4}, l10;\n\t" "shl.b32 r5, r1, 23;\n\t" "add.s32 r7, r5, r3;\n\t" "shl.b32 r6, r2, 23;\n\t" "add.s32 r8, r6, r4;\n\t" "mov.b32 $0, r7;\n\t" "mov.b32 $1, r8;\n\t" "}\n", "=r,=r,f,f", has_side_effects=False, is_align_stack=False, asm_dialect=llvm.AsmDialect.AD_ATT, ) out0 = Float32(llvm.extractvalue(T.f32(), out_f32x2, [0], loc=loc, ip=ip)) out1 = Float32(llvm.extractvalue(T.f32(), out_f32x2, [1], loc=loc, ip=ip)) return out0, out1 @dsl_user_op def domain_offset_aligned( coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None ) -> cute.Tensor: assert isinstance(tensor.iterator, cute.Pointer) # We assume that applying the offset does not change the pointer alignment new_ptr = cute.make_ptr( tensor.element_type, elem_pointer(tensor, coord).toint(), tensor.memspace, assumed_align=tensor.iterator.alignment, ) return cute.make_tensor(new_ptr, tensor.layout) @cute.jit def scalar_to_ssa(a: cute.Numeric, dtype) -> cute.TensorSSA: """Convert a scalar to a cute TensorSSA of shape (1,) and given dtype""" vec = cute.make_fragment(1, dtype) vec[0] = a return vec.load() def ssa_to_scalar(val): """Could inline but nice for reflecting the above api""" return val[0]