| |
|
|
| import math |
| import hashlib |
| import inspect |
| from typing import Type, Callable, Optional, Tuple, overload |
|
|
| import cutlass |
| import cutlass.cute as cute |
|
|
| from cutlass import Float32, const_expr |
| from cutlass.cutlass_dsl import T, dsl_user_op |
| from cutlass._mlir.dialects import nvvm, llvm |
| from cutlass.cute.runtime import from_dlpack |
|
|
|
|
| from .quack import activation |
|
|
| _MIXER_ATTRS = ("__vec_size__",) |
|
|
| |
| |
| POLY_EX2 = { |
| 0: (1.0), |
| 1: ( |
| 1.0, |
| 0.922497093677520751953125, |
| ), |
| 2: ( |
| 1.0, |
| 0.6657850742340087890625, |
| 0.330107033252716064453125, |
| ), |
| 3: ( |
| 1.0, |
| 0.695146143436431884765625, |
| 0.227564394474029541015625, |
| 0.077119089663028717041015625, |
| ), |
| 4: ( |
| 1.0, |
| 0.693042695522308349609375, |
| 0.2412912547588348388671875, |
| 5.2225358784198760986328125e-2, |
| 1.3434938155114650726318359375e-2, |
| ), |
| 5: ( |
| 1.0, |
| 0.693151414394378662109375, |
| 0.24016360938549041748046875, |
| 5.5802188813686370849609375e-2, |
| 9.01452265679836273193359375e-3, |
| 1.86810153536498546600341796875e-3, |
| ), |
| } |
|
|
|
|
| def _compute_base_hash(func: Callable) -> str: |
| """Compute hash from source code or bytecode and closure values.""" |
| try: |
| data = inspect.getsource(func).encode() |
| except (OSError, TypeError): |
| if hasattr(func, "__code__") and func.__code__ is not None: |
| data = func.__code__.co_code |
| else: |
| data = repr(func).encode() |
|
|
| hasher = hashlib.sha256(data) |
|
|
| if hasattr(func, "__closure__") and func.__closure__ is not None: |
| for cell in func.__closure__: |
| hasher.update(repr(cell.cell_contents).encode()) |
|
|
| return hasher.hexdigest() |
|
|
|
|
| def hash_callable( |
| func: Callable, mixer_attrs: Tuple[str] = _MIXER_ATTRS, set_cute_hash: bool = True |
| ) -> str: |
| """Hash a callable based on the source code or bytecode and closure values. |
| Fast-path: if the callable (or its __wrapped__ base) has a ``__cute_hash__`` |
| attribute, that value is returned immediately as the base hash, then |
| metadata dunders are mixed in to produce the final dict-key hash. |
| set_cute_hash: whether or not to set func.__cute_hash__ |
| """ |
| |
| if hasattr(func, "__cute_hash__"): |
| base_hash = func.__cute_hash__ |
| else: |
| |
| base_func = getattr(func, "__wrapped__", func) |
|
|
| if hasattr(base_func, "__cute_hash__"): |
| base_hash = base_func.__cute_hash__ |
| else: |
| base_hash = _compute_base_hash(base_func) |
|
|
| if set_cute_hash: |
| base_func.__cute_hash__ = base_hash |
|
|
| |
| mixer_values = tuple(getattr(func, attr, None) for attr in mixer_attrs) |
|
|
| if all(v is None for v in mixer_values): |
| return base_hash |
|
|
| hasher = hashlib.sha256(base_hash.encode()) |
|
|
| for attr, val in zip(_MIXER_ATTRS, mixer_values): |
| hasher.update(f"{attr}={val!r}".encode()) |
|
|
| return hasher.hexdigest() |
|
|
|
|
| def create_softcap_scoremod(softcap_val): |
| inv_softcap = 1.0 / softcap_val |
|
|
| @cute.jit |
| def scoremod_premask_fn(acc_S_SSA, batch_idx, head_idx, q_idx, kv_idx, aux_tensors): |
| scores = acc_S_SSA * inv_softcap |
| return scores * cute.math.tanh(scores, fastmath=True) |
|
|
| return scoremod_premask_fn |
|
|
|
|
| def convert_from_dlpack(x, leading_dim, alignment=16, divisibility=1) -> cute.Tensor: |
| return ( |
| from_dlpack(x, assumed_align=alignment) |
| .mark_layout_dynamic(leading_dim=leading_dim) |
| .mark_compact_shape_dynamic( |
| mode=leading_dim, stride_order=x.dim_order(), divisibility=divisibility |
| ) |
| ) |
|
|
|
|
| def convert_from_dlpack_leading_static( |
| x, leading_dim, alignment=16, static_modes=None, stride_order=None |
| ) -> cute.Tensor: |
| if stride_order is None: |
| stride_order = x.dim_order() |
| x_ = from_dlpack(x, assumed_align=alignment) |
| for i in range(x.ndim): |
| if i != leading_dim and (static_modes is None or i not in static_modes): |
| x_ = x_.mark_compact_shape_dynamic(mode=i, stride_order=stride_order) |
| return x_ |
|
|
|
|
| def make_tiled_copy_A( |
| copy_atom: cute.CopyAtom, tiled_mma: cute.TiledMma, swapAB: cutlass.Constexpr[bool] = False |
| ) -> cute.TiledCopy: |
| if const_expr(swapAB): |
| return cute.make_tiled_copy_B(copy_atom, tiled_mma) |
| else: |
| return cute.make_tiled_copy_A(copy_atom, tiled_mma) |
|
|
|
|
| def make_tiled_copy_B( |
| copy_atom: cute.CopyAtom, tiled_mma: cute.TiledMma, swapAB: cutlass.Constexpr[bool] = False |
| ) -> cute.TiledCopy: |
| if const_expr(swapAB): |
| return cute.make_tiled_copy_A(copy_atom, tiled_mma) |
| else: |
| return cute.make_tiled_copy_B(copy_atom, tiled_mma) |
|
|
|
|
| def mma_make_fragment_A( |
| smem: cute.Tensor, thr_mma: cute.core.ThrMma, swapAB: cutlass.Constexpr[bool] = False |
| ) -> cute.Tensor: |
| if const_expr(swapAB): |
| return mma_make_fragment_B(smem, thr_mma) |
| else: |
| return thr_mma.make_fragment_A(thr_mma.partition_A(smem)) |
|
|
|
|
| def mma_make_fragment_B( |
| smem: cute.Tensor, thr_mma: cute.core.ThrMma, swapAB: cutlass.Constexpr[bool] = False |
| ) -> cute.Tensor: |
| if const_expr(swapAB): |
| return mma_make_fragment_A(smem, thr_mma) |
| else: |
| return thr_mma.make_fragment_B(thr_mma.partition_B(smem)) |
|
|
|
|
| def get_smem_store_atom( |
| arch: cutlass.Constexpr[int], element_type: Type[cute.Numeric], transpose: bool = False |
| ) -> cute.CopyAtom: |
| if const_expr(arch < 90 or element_type.width != 16): |
| return cute.make_copy_atom( |
| cute.nvgpu.CopyUniversalOp(), |
| element_type, |
| num_bits_per_copy=2 * element_type.width, |
| ) |
| else: |
| return cute.make_copy_atom( |
| cute.nvgpu.warp.StMatrix8x8x16bOp(transpose=transpose, num_matrices=4), |
| element_type, |
| ) |
|
|
|
|
| @cute.jit |
| def warp_reduce( |
| val: cute.TensorSSA | cute.Numeric, |
| op: Callable, |
| width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE, |
| ) -> cute.TensorSSA | cute.Numeric: |
| if const_expr(isinstance(val, cute.TensorSSA)): |
| res = cute.make_fragment(val.shape, val.dtype) |
| res.store(val) |
| for i in cutlass.range_constexpr(cute.size(val.shape)): |
| res[i] = warp_reduce(res[i], op, width) |
| return res.load() |
| else: |
| for i in cutlass.range_constexpr(int(math.log2(width))): |
| val = op(val, cute.arch.shuffle_sync_bfly(val, offset=1 << i)) |
| return val |
|
|
|
|
| @dsl_user_op |
| def fmax( |
| a: float | Float32, b: float | Float32, c: float | Float32 | None = None, *, loc=None, ip=None |
| ) -> Float32: |
| from cutlass import CUDA_VERSION |
|
|
| |
| if CUDA_VERSION.major == 12 and CUDA_VERSION.minor == 9: |
| |
| return Float32( |
| nvvm.fmax( |
| T.f32(), |
| Float32(a).ir_value(loc=loc, ip=ip), |
| Float32(b).ir_value(loc=loc, ip=ip), |
| c=Float32(c).ir_value(loc=loc, ip=ip) if c is not None else None, |
| loc=loc, |
| ip=ip, |
| ) |
| ) |
| else: |
| |
| return Float32( |
| nvvm.fmax( |
| Float32(a).ir_value(loc=loc, ip=ip), |
| Float32(b).ir_value(loc=loc, ip=ip), |
| c=Float32(c).ir_value(loc=loc, ip=ip) if c is not None else None, |
| loc=loc, |
| ip=ip, |
| ) |
| ) |
|
|
|
|
| @cute.jit |
| def fmax_reduce( |
| x: cute.TensorSSA, init_val: float | Float32 | None = None, arch: cutlass.Constexpr[int] = 80 |
| ) -> Float32: |
| if const_expr(arch < 100 or cute.size(x.shape) % 8 != 0): |
| |
| |
| |
| res = cute.make_fragment(x.shape, Float32) |
| res.store(x) |
| |
| |
| |
| |
| |
| |
| local_max = [res[0], res[1], res[2], res[3]] |
| for i in cutlass.range_constexpr(4, cute.size(x.shape), 4): |
| local_max[0] = fmax(local_max[0], res[i + 0]) |
| local_max[1] = fmax(local_max[1], res[i + 1]) |
| local_max[2] = fmax(local_max[2], res[i + 2]) |
| local_max[3] = fmax(local_max[3], res[i + 3]) |
| local_max[0] = fmax(local_max[0], local_max[1]) |
| local_max[2] = fmax(local_max[2], local_max[3]) |
| local_max[0] = fmax(local_max[0], local_max[2]) |
| return local_max[0] if const_expr(init_val is None) else fmax(local_max[0], init_val) |
| else: |
| |
| |
| res = cute.make_fragment(x.shape, Float32) |
| res.store(x) |
| local_max_0 = ( |
| fmax(init_val, res[0], res[1]) |
| if const_expr(init_val is not None) |
| else fmax(res[0], res[1]) |
| ) |
| local_max = [ |
| local_max_0, |
| fmax(res[2], res[3]), |
| fmax(res[4], res[5]), |
| fmax(res[6], res[7]), |
| ] |
| for i in cutlass.range_constexpr(8, cute.size(x.shape), 8): |
| local_max[0] = fmax(local_max[0], res[i], res[i + 1]) |
| local_max[1] = fmax(local_max[1], res[i + 2], res[i + 3]) |
| local_max[2] = fmax(local_max[2], res[i + 4], res[i + 5]) |
| local_max[3] = fmax(local_max[3], res[i + 6], res[i + 7]) |
| local_max[0] = fmax(local_max[0], local_max[1]) |
| return fmax(local_max[0], local_max[2], local_max[3]) |
|
|
|
|
| @cute.jit |
| def fadd_reduce( |
| x: cute.TensorSSA, init_val: float | Float32 | None = None, arch: cutlass.Constexpr[int] = 80 |
| ) -> Float32: |
| if const_expr(arch < 100 or cute.size(x.shape) % 8 != 0): |
| if const_expr(init_val is None): |
| init_val = Float32.zero |
| return x.reduce(cute.ReductionOp.ADD, init_val, 0) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| else: |
| res = cute.make_fragment(x.shape, Float32) |
| res.store(x) |
| local_sum_0 = ( |
| cute.arch.add_packed_f32x2((init_val, 0.0), (res[0], res[1])) |
| |
| if const_expr(init_val is not None) |
| else (res[0], res[1]) |
| ) |
| local_sum = [local_sum_0, (res[2], res[3]), (res[4], res[5]), (res[6], res[7])] |
| for i in cutlass.range_constexpr(8, cute.size(x.shape), 8): |
| local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], (res[i + 0], res[i + 1])) |
| local_sum[1] = cute.arch.add_packed_f32x2(local_sum[1], (res[i + 2], res[i + 3])) |
| local_sum[2] = cute.arch.add_packed_f32x2(local_sum[2], (res[i + 4], res[i + 5])) |
| local_sum[3] = cute.arch.add_packed_f32x2(local_sum[3], (res[i + 6], res[i + 7])) |
| local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], local_sum[1]) |
| local_sum[2] = cute.arch.add_packed_f32x2(local_sum[2], local_sum[3]) |
| local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], local_sum[2]) |
| return local_sum[0][0] + local_sum[0][1] |
|
|
|
|
| @dsl_user_op |
| def atomic_add_fp32(a: float | Float32, gmem_ptr: cute.Pointer, *, loc=None, ip=None) -> None: |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| nvvm.atomicrmw( |
| res=T.f32(), op=nvvm.AtomicOpKind.FADD, ptr=gmem_ptr.llvm_ptr, a=Float32(a).ir_value() |
| ) |
|
|
|
|
| @dsl_user_op |
| def elem_pointer(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cute.Pointer: |
| return x.iterator + cute.crd2idx(coord, x.layout, loc=loc, ip=ip) |
|
|
|
|
| @cute.jit |
| def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor: |
| |
| tApA = cute.make_fragment( |
| cute.make_layout( |
| (cute.size(tAcA, mode=[0, 1]), cute.size(tAcA, mode=[1]), cute.size(tAcA, mode=[2])), |
| stride=(cute.size(tAcA, mode=[2]), 0, 1), |
| ), |
| cutlass.Boolean, |
| ) |
| for rest_v in cutlass.range_constexpr(tApA.shape[0]): |
| for rest_k in cutlass.range_constexpr(tApA.shape[2]): |
| tApA[rest_v, 0, rest_k] = cute.elem_less(tAcA[(0, rest_v), 0, rest_k][1], limit) |
| return tApA |
|
|
|
|
| def canonical_warp_group_idx(sync: bool = True) -> cutlass.Int32: |
| warp_group_idx = cute.arch.thread_idx()[0] // 128 |
| if const_expr(sync): |
| warp_group_idx = cute.arch.make_warp_uniform(warp_group_idx) |
| return warp_group_idx |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| @cute.jit |
| def shuffle_sync( |
| value: cute.Numeric, |
| offset: cute.typing.Int, |
| width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE, |
| ) -> cute.Numeric: |
| assert value.width % 32 == 0, "value type must be a multiple of 32 bits" |
| |
| mask = cute.arch.WARP_SIZE - width |
| clamp = cute.arch.WARP_SIZE - 1 |
| mask_and_clamp = mask << 8 | clamp |
| |
| val = cute.make_rmem_tensor(cute.make_layout((1,), stride=(1,)), type(value)) |
| val[0] = value |
| val_i32 = cute.recast_tensor(val, cutlass.Int32) |
| for i in cutlass.range_constexpr(cute.size(val_i32)): |
| val_i32[i] = cute.arch.shuffle_sync(val_i32[i], offset, mask_and_clamp=mask_and_clamp) |
| return val[0] |
|
|
|
|
| @dsl_user_op |
| def shr_u32(val: cutlass.Uint32, shift: cutlass.Uint32, *, loc=None, ip=None) -> cutlass.Uint32: |
| return cutlass.Uint32( |
| llvm.inline_asm( |
| T.i32(), |
| [ |
| cutlass.Uint32(val).ir_value(loc=loc, ip=ip), |
| cutlass.Uint32(shift).ir_value(loc=loc, ip=ip), |
| ], |
| "shr.s32 $0, $1, $2;", |
| "=r,r,r", |
| has_side_effects=False, |
| is_align_stack=False, |
| asm_dialect=llvm.AsmDialect.AD_ATT, |
| ) |
| ) |
|
|
|
|
| @cute.jit |
| def warp_prefix_sum(val: cutlass.Int32, lane: Optional[cutlass.Int32] = None) -> cutlass.Int32: |
| if const_expr(lane is None): |
| lane = cute.arch.lane_idx() |
| |
| for i in cutlass.range_constexpr(int(math.log2(cute.arch.WARP_SIZE))): |
| offset = 1 << i |
| |
| partial_sum = cute.arch.shuffle_sync_up(val, offset=offset, mask_and_clamp=0) |
| if lane >= offset: |
| val += partial_sum |
| |
| return val |
|
|
|
|
| @dsl_user_op |
| def cvt_f16x2_f32( |
| a: float | Float32, b: float | Float32, to_dtype: Type, *, loc=None, ip=None |
| ) -> cutlass.Int32: |
| assert to_dtype in [cutlass.BFloat16, cutlass.Float16], "to_dtype must be BFloat16 or Float16" |
| return cutlass.Int32( |
| llvm.inline_asm( |
| T.i32(), |
| [Float32(a).ir_value(loc=loc, ip=ip), Float32(b).ir_value(loc=loc, ip=ip)], |
| f"cvt.rn.{'bf16x2' if to_dtype is cutlass.BFloat16 else 'f16x2'}.f32 $0, $2, $1;", |
| "=r,f,f", |
| has_side_effects=False, |
| is_align_stack=False, |
| asm_dialect=llvm.AsmDialect.AD_ATT, |
| ) |
| ) |
|
|
|
|
| @overload |
| def cvt_f16(src: cute.Tensor, dst: cute.Tensor) -> None: ... |
|
|
|
|
| @overload |
| def cvt_f16(src: cute.Tensor, dtype: Type[cute.Numeric]) -> cute.Tensor: ... |
|
|
|
|
| @cute.jit |
| def cvt_f16(src: cute.Tensor, dst_or_dtype): |
| """Convert Float32 tensor to Float16/BFloat16. |
| |
| Args: |
| src: Source tensor with Float32 element type |
| dst_or_dtype: Either a destination tensor or a dtype (Float16/BFloat16) |
| |
| Returns: |
| None if dst is a tensor, or a new tensor if dtype is provided |
| """ |
| if const_expr(isinstance(dst_or_dtype, type)): |
| |
| dtype = dst_or_dtype |
| dst = cute.make_fragment(src.shape, dtype) |
| cvt_f16(src, dst) |
| return dst |
| else: |
| |
| dst = dst_or_dtype |
| assert cute.size(dst.shape) == cute.size(src.shape), "dst and src must have the same size" |
| assert cute.size(src.shape) % 2 == 0, "src must have an even number of elements" |
| assert dst.element_type in [cutlass.BFloat16, cutlass.Float16], ( |
| "dst must be BFloat16 or Float16" |
| ) |
| assert src.element_type is Float32, "src must be Float32" |
| dst_i32 = cute.recast_tensor(dst, cutlass.Int32) |
| assert cute.size(dst_i32.shape) * 2 == cute.size(src.shape) |
| for i in cutlass.range_constexpr(cute.size(dst_i32)): |
| dst_i32[i] = cvt_f16x2_f32(src[2 * i], src[2 * i + 1], dst.element_type) |
|
|
|
|
| @dsl_user_op |
| @cute.jit |
| def evaluate_polynomial(x: Float32, poly: Tuple[Float32, ...], *, loc=None, ip=None) -> Float32: |
| deg = len(poly) - 1 |
| out = poly[deg] |
| for i in cutlass.range_constexpr(deg - 1, -1, -1): |
| out = out * x + poly[i] |
| return out |
|
|
|
|
| @dsl_user_op |
| @cute.jit |
| def evaluate_polynomial_2( |
| x: Float32, y: Float32, poly: Tuple[Float32, ...], *, loc=None, ip=None |
| ) -> Tuple[Float32, Float32]: |
| deg = len(poly) - 1 |
| out = (poly[deg], poly[deg]) |
| for i in cutlass.range_constexpr(deg - 1, -1, -1): |
| out = cute.arch.fma_packed_f32x2(out, (x, y), (poly[i], poly[i])) |
| return out |
|
|
|
|
| @dsl_user_op |
| def add_round_down(x: float | Float32, y: float | Float32, *, loc=None, ip=None) -> Float32: |
| |
| return cutlass.Float32( |
| llvm.inline_asm( |
| T.f32(), |
| [Float32(x).ir_value(loc=loc, ip=ip), Float32(y).ir_value(loc=loc, ip=ip)], |
| "add.rm.ftz.f32 $0, $1, $2;", |
| "=f,f,f", |
| has_side_effects=False, |
| is_align_stack=False, |
| asm_dialect=llvm.AsmDialect.AD_ATT, |
| ) |
| ) |
|
|
|
|
| @dsl_user_op |
| def combine_int_frac_ex2(x_rounded: Float32, frac_ex2: Float32, *, loc=None, ip=None) -> Float32: |
| return cutlass.Float32( |
| llvm.inline_asm( |
| T.f32(), |
| [ |
| Float32(x_rounded).ir_value(loc=loc, ip=ip), |
| Float32(frac_ex2).ir_value(loc=loc, ip=ip), |
| ], |
| "{\n\t" |
| ".reg .s32 x_rounded_i, frac_ex_i, x_rounded_e, out_i;\n\t" |
| "mov.b32 x_rounded_i, $1;\n\t" |
| "mov.b32 frac_ex_i, $2;\n\t" |
| "shl.b32 x_rounded_e, x_rounded_i, 23;\n\t" |
| |
| |
| "add.s32 out_i, x_rounded_e, frac_ex_i;\n\t" |
| "mov.b32 $0, out_i;\n\t" |
| "}\n", |
| "=f,f,f", |
| has_side_effects=False, |
| is_align_stack=False, |
| asm_dialect=llvm.AsmDialect.AD_ATT, |
| ) |
| ) |
|
|
|
|
| @dsl_user_op |
| def ex2_emulation(x: Float32, *, poly_degree: int = 3, loc=None, ip=None) -> Float32: |
| assert poly_degree in POLY_EX2, f"Polynomial degree {poly_degree} not supported" |
| |
| fp32_round_int = float(2**23 + 2**22) |
| x_clamped = cute.arch.fmax(x, -127.0) |
| |
| x_rounded = add_round_down(x_clamped, fp32_round_int, loc=loc, ip=ip) |
| |
| |
| x_rounded_back = x_rounded - fp32_round_int |
| x_frac = x_clamped - x_rounded_back |
| x_frac_ex2 = evaluate_polynomial(x_frac, POLY_EX2[poly_degree], loc=loc, ip=ip) |
| return combine_int_frac_ex2(x_rounded, x_frac_ex2, loc=loc, ip=ip) |
|
|
|
|
| |
| @dsl_user_op |
| def ex2_emulation_2( |
| x: Float32, y: Float32, *, poly_degree: int = 3, loc=None, ip=None |
| ) -> Tuple[Float32, Float32]: |
| |
| fp32_round_int = float(2**23 + 2**22) |
| xy_clamped = (cute.arch.fmax(x, -127.0), cute.arch.fmax(y, -127.0)) |
| |
| xy_rounded = cute.arch.add_packed_f32x2(xy_clamped, (fp32_round_int, fp32_round_int), rnd="rm") |
| |
| |
| xy_rounded_back = activation.sub_packed_f32x2( |
| xy_rounded, (fp32_round_int, fp32_round_int) |
| ) |
| xy_frac = activation.sub_packed_f32x2(xy_clamped, xy_rounded_back) |
| xy_frac_ex2 = evaluate_polynomial_2(*xy_frac, POLY_EX2[poly_degree], loc=loc, ip=ip) |
| x_out = combine_int_frac_ex2(xy_rounded[0], xy_frac_ex2[0], loc=loc, ip=ip) |
| y_out = combine_int_frac_ex2(xy_rounded[1], xy_frac_ex2[1], loc=loc, ip=ip) |
| return x_out, y_out |
|
|
|
|
| @dsl_user_op |
| def e2e_asm2(x: Float32, y: Float32, *, loc=None, ip=None) -> Tuple[Float32, Float32]: |
| out_f32x2 = llvm.inline_asm( |
| llvm.StructType.get_literal([T.f32(), T.f32()]), |
| [Float32(x).ir_value(loc=loc, ip=ip), Float32(y, loc=loc, ip=ip).ir_value()], |
| "{\n\t" |
| ".reg .f32 f1, f2, f3, f4, f5, f6, f7;\n\t" |
| ".reg .b64 l1, l2, l3, l4, l5, l6, l7, l8, l9, l10;\n\t" |
| ".reg .s32 r1, r2, r3, r4, r5, r6, r7, r8;\n\t" |
| "max.ftz.f32 f1, $2, 0fC2FE0000;\n\t" |
| "max.ftz.f32 f2, $3, 0fC2FE0000;\n\t" |
| "mov.b64 l1, {f1, f2};\n\t" |
| "mov.f32 f3, 0f4B400000;\n\t" |
| "mov.b64 l2, {f3, f3};\n\t" |
| "add.rm.ftz.f32x2 l7, l1, l2;\n\t" |
| "sub.rn.ftz.f32x2 l8, l7, l2;\n\t" |
| "sub.rn.ftz.f32x2 l9, l1, l8;\n\t" |
| "mov.f32 f7, 0f3D9DF09D;\n\t" |
| "mov.b64 l6, {f7, f7};\n\t" |
| "mov.f32 f6, 0f3E6906A4;\n\t" |
| "mov.b64 l5, {f6, f6};\n\t" |
| "mov.f32 f5, 0f3F31F519;\n\t" |
| "mov.b64 l4, {f5, f5};\n\t" |
| "mov.f32 f4, 0f3F800000;\n\t" |
| "mov.b64 l3, {f4, f4};\n\t" |
| "fma.rn.ftz.f32x2 l10, l9, l6, l5;\n\t" |
| "fma.rn.ftz.f32x2 l10, l10, l9, l4;\n\t" |
| "fma.rn.ftz.f32x2 l10, l10, l9, l3;\n\t" |
| "mov.b64 {r1, r2}, l7;\n\t" |
| "mov.b64 {r3, r4}, l10;\n\t" |
| "shl.b32 r5, r1, 23;\n\t" |
| "add.s32 r7, r5, r3;\n\t" |
| "shl.b32 r6, r2, 23;\n\t" |
| "add.s32 r8, r6, r4;\n\t" |
| "mov.b32 $0, r7;\n\t" |
| "mov.b32 $1, r8;\n\t" |
| "}\n", |
| "=r,=r,f,f", |
| has_side_effects=False, |
| is_align_stack=False, |
| asm_dialect=llvm.AsmDialect.AD_ATT, |
| ) |
| out0 = Float32(llvm.extractvalue(T.f32(), out_f32x2, [0], loc=loc, ip=ip)) |
| out1 = Float32(llvm.extractvalue(T.f32(), out_f32x2, [1], loc=loc, ip=ip)) |
| return out0, out1 |
|
|
|
|
| @dsl_user_op |
| def domain_offset_aligned( |
| coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None |
| ) -> cute.Tensor: |
| assert isinstance(tensor.iterator, cute.Pointer) |
| |
| new_ptr = cute.make_ptr( |
| tensor.element_type, |
| elem_pointer(tensor, coord).toint(), |
| tensor.memspace, |
| assumed_align=tensor.iterator.alignment, |
| ) |
| return cute.make_tensor(new_ptr, tensor.layout) |
|
|
|
|
| @cute.jit |
| def scalar_to_ssa(a: cute.Numeric, dtype) -> cute.TensorSSA: |
| """Convert a scalar to a cute TensorSSA of shape (1,) and given dtype""" |
| vec = cute.make_fragment(1, dtype) |
| vec[0] = a |
| return vec.load() |
|
|
|
|
| def ssa_to_scalar(val): |
| """Could inline but nice for reflecting the above api""" |
| return val[0] |
|
|