Instructions to use kernels-community/flash-attn4 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Kernels
How to use kernels-community/flash-attn4 with Kernels:
# !pip install kernels from kernels import get_kernel kernel = get_kernel("kernels-community/flash-attn4") - Notebooks
- Google Colab
- Kaggle
| # Copyright (c) 2025, Tri Dao. | |
| import math | |
| import hashlib | |
| import inspect | |
| from typing import Type, Callable, Optional, Tuple, overload | |
| import cutlass | |
| import cutlass.cute as cute | |
| from cutlass import Float32, const_expr | |
| from cutlass.cutlass_dsl import T, dsl_user_op | |
| from cutlass._mlir.dialects import nvvm, llvm | |
| from cutlass.cute.runtime import from_dlpack | |
| from .quack import activation | |
| _MIXER_ATTRS = ("__vec_size__",) | |
| # Obtained from sollya: | |
| # fpminimax(exp(x * log(2.0)), 1, [|1,24...|],[0;1],relative); | |
| POLY_EX2 = { | |
| 0: (1.0), | |
| 1: ( | |
| 1.0, | |
| 0.922497093677520751953125, | |
| ), | |
| 2: ( | |
| 1.0, | |
| 0.6657850742340087890625, | |
| 0.330107033252716064453125, | |
| ), | |
| 3: ( | |
| 1.0, | |
| 0.695146143436431884765625, | |
| 0.227564394474029541015625, | |
| 0.077119089663028717041015625, | |
| ), | |
| 4: ( | |
| 1.0, | |
| 0.693042695522308349609375, | |
| 0.2412912547588348388671875, | |
| 5.2225358784198760986328125e-2, | |
| 1.3434938155114650726318359375e-2, | |
| ), | |
| 5: ( | |
| 1.0, | |
| 0.693151414394378662109375, | |
| 0.24016360938549041748046875, | |
| 5.5802188813686370849609375e-2, | |
| 9.01452265679836273193359375e-3, | |
| 1.86810153536498546600341796875e-3, | |
| ), | |
| } | |
| def _compute_base_hash(func: Callable) -> str: | |
| """Compute hash from source code or bytecode and closure values.""" | |
| try: | |
| data = inspect.getsource(func).encode() | |
| except (OSError, TypeError): | |
| if hasattr(func, "__code__") and func.__code__ is not None: | |
| data = func.__code__.co_code | |
| else: | |
| data = repr(func).encode() | |
| hasher = hashlib.sha256(data) | |
| if hasattr(func, "__closure__") and func.__closure__ is not None: | |
| for cell in func.__closure__: | |
| hasher.update(repr(cell.cell_contents).encode()) | |
| return hasher.hexdigest() | |
| def hash_callable( | |
| func: Callable, mixer_attrs: Tuple[str] = _MIXER_ATTRS, set_cute_hash: bool = True | |
| ) -> str: | |
| """Hash a callable based on the source code or bytecode and closure values. | |
| Fast-path: if the callable (or its __wrapped__ base) has a ``__cute_hash__`` | |
| attribute, that value is returned immediately as the base hash, then | |
| metadata dunders are mixed in to produce the final dict-key hash. | |
| set_cute_hash: whether or not to set func.__cute_hash__ | |
| """ | |
| # Resolve base hash | |
| if hasattr(func, "__cute_hash__"): | |
| base_hash = func.__cute_hash__ | |
| else: | |
| # Unwrap decorated functions (e.g., cute.jit wrappers). | |
| base_func = getattr(func, "__wrapped__", func) | |
| if hasattr(base_func, "__cute_hash__"): | |
| base_hash = base_func.__cute_hash__ | |
| else: | |
| base_hash = _compute_base_hash(base_func) | |
| if set_cute_hash: | |
| base_func.__cute_hash__ = base_hash | |
| # Mix in mutable metadata dunders | |
| mixer_values = tuple(getattr(func, attr, None) for attr in mixer_attrs) | |
| if all(v is None for v in mixer_values): | |
| return base_hash | |
| hasher = hashlib.sha256(base_hash.encode()) | |
| for attr, val in zip(_MIXER_ATTRS, mixer_values): | |
| hasher.update(f"{attr}={val!r}".encode()) | |
| return hasher.hexdigest() | |
| def create_softcap_scoremod(softcap_val): | |
| inv_softcap = 1.0 / softcap_val | |
| def scoremod_premask_fn(acc_S_SSA, batch_idx, head_idx, q_idx, kv_idx, aux_tensors): | |
| scores = acc_S_SSA * inv_softcap | |
| return scores * cute.math.tanh(scores, fastmath=True) | |
| return scoremod_premask_fn | |
| def convert_from_dlpack(x, leading_dim, alignment=16, divisibility=1) -> cute.Tensor: | |
| return ( | |
| from_dlpack(x, assumed_align=alignment) | |
| .mark_layout_dynamic(leading_dim=leading_dim) | |
| .mark_compact_shape_dynamic( | |
| mode=leading_dim, stride_order=x.dim_order(), divisibility=divisibility | |
| ) | |
| ) | |
| def convert_from_dlpack_leading_static( | |
| x, leading_dim, alignment=16, static_modes=None, stride_order=None | |
| ) -> cute.Tensor: | |
| if stride_order is None: | |
| stride_order = x.dim_order() | |
| x_ = from_dlpack(x, assumed_align=alignment) | |
| for i in range(x.ndim): | |
| if i != leading_dim and (static_modes is None or i not in static_modes): | |
| x_ = x_.mark_compact_shape_dynamic(mode=i, stride_order=stride_order) | |
| return x_ | |
| def make_tiled_copy_A( | |
| copy_atom: cute.CopyAtom, tiled_mma: cute.TiledMma, swapAB: cutlass.Constexpr[bool] = False | |
| ) -> cute.TiledCopy: | |
| if const_expr(swapAB): | |
| return cute.make_tiled_copy_B(copy_atom, tiled_mma) | |
| else: | |
| return cute.make_tiled_copy_A(copy_atom, tiled_mma) | |
| def make_tiled_copy_B( | |
| copy_atom: cute.CopyAtom, tiled_mma: cute.TiledMma, swapAB: cutlass.Constexpr[bool] = False | |
| ) -> cute.TiledCopy: | |
| if const_expr(swapAB): | |
| return cute.make_tiled_copy_A(copy_atom, tiled_mma) | |
| else: | |
| return cute.make_tiled_copy_B(copy_atom, tiled_mma) | |
| def mma_make_fragment_A( | |
| smem: cute.Tensor, thr_mma: cute.core.ThrMma, swapAB: cutlass.Constexpr[bool] = False | |
| ) -> cute.Tensor: | |
| if const_expr(swapAB): | |
| return mma_make_fragment_B(smem, thr_mma) | |
| else: | |
| return thr_mma.make_fragment_A(thr_mma.partition_A(smem)) | |
| def mma_make_fragment_B( | |
| smem: cute.Tensor, thr_mma: cute.core.ThrMma, swapAB: cutlass.Constexpr[bool] = False | |
| ) -> cute.Tensor: | |
| if const_expr(swapAB): | |
| return mma_make_fragment_A(smem, thr_mma) | |
| else: | |
| return thr_mma.make_fragment_B(thr_mma.partition_B(smem)) | |
| def get_smem_store_atom( | |
| arch: cutlass.Constexpr[int], element_type: Type[cute.Numeric], transpose: bool = False | |
| ) -> cute.CopyAtom: | |
| if const_expr(arch < 90 or element_type.width != 16): | |
| return cute.make_copy_atom( | |
| cute.nvgpu.CopyUniversalOp(), | |
| element_type, | |
| num_bits_per_copy=2 * element_type.width, | |
| ) | |
| else: | |
| return cute.make_copy_atom( | |
| cute.nvgpu.warp.StMatrix8x8x16bOp(transpose=transpose, num_matrices=4), | |
| element_type, | |
| ) | |
| def warp_reduce( | |
| val: cute.TensorSSA | cute.Numeric, | |
| op: Callable, | |
| width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE, | |
| ) -> cute.TensorSSA | cute.Numeric: | |
| if const_expr(isinstance(val, cute.TensorSSA)): | |
| res = cute.make_fragment(val.shape, val.dtype) | |
| res.store(val) | |
| for i in cutlass.range_constexpr(cute.size(val.shape)): | |
| res[i] = warp_reduce(res[i], op, width) | |
| return res.load() | |
| else: | |
| for i in cutlass.range_constexpr(int(math.log2(width))): | |
| val = op(val, cute.arch.shuffle_sync_bfly(val, offset=1 << i)) | |
| return val | |
| def fmax( | |
| a: float | Float32, b: float | Float32, c: float | Float32 | None = None, *, loc=None, ip=None | |
| ) -> Float32: | |
| from cutlass import CUDA_VERSION | |
| # * NVVM call based on nvvm version | |
| if CUDA_VERSION.major == 12 and CUDA_VERSION.minor == 9: | |
| # Old API: requires explicit result type as first positional argument | |
| return Float32( | |
| nvvm.fmax( | |
| T.f32(), | |
| Float32(a).ir_value(loc=loc, ip=ip), | |
| Float32(b).ir_value(loc=loc, ip=ip), | |
| c=Float32(c).ir_value(loc=loc, ip=ip) if c is not None else None, | |
| loc=loc, | |
| ip=ip, | |
| ) | |
| ) | |
| else: | |
| # New API: infers result type automatically | |
| return Float32( | |
| nvvm.fmax( | |
| Float32(a).ir_value(loc=loc, ip=ip), | |
| Float32(b).ir_value(loc=loc, ip=ip), | |
| c=Float32(c).ir_value(loc=loc, ip=ip) if c is not None else None, | |
| loc=loc, | |
| ip=ip, | |
| ) | |
| ) | |
| def fmax_reduce( | |
| x: cute.TensorSSA, init_val: float | Float32 | None = None, arch: cutlass.Constexpr[int] = 80 | |
| ) -> Float32: | |
| if const_expr(arch < 100 or cute.size(x.shape) % 8 != 0): | |
| # if const_expr(init_val is None): | |
| # init_val = -cutlass.Float32.if | |
| # return x.reduce(cute.ReductionOp.MAX, init_val, 0) | |
| res = cute.make_fragment(x.shape, Float32) | |
| res.store(x) | |
| # local_max = [res[0], res[1]] | |
| # for i in cutlass.range_constexpr(2, cute.size(x.shape), 2): | |
| # local_max[0] = fmax(local_max[0], res[i + 0]) | |
| # local_max[1] = fmax(local_max[1], res[i + 1]) | |
| # local_max[0] = fmax(local_max[0], local_max[1]) | |
| # return local_max[0] if const_expr(init_val is None) else fmax(local_max[0], init_val) | |
| local_max = [res[0], res[1], res[2], res[3]] | |
| for i in cutlass.range_constexpr(4, cute.size(x.shape), 4): | |
| local_max[0] = fmax(local_max[0], res[i + 0]) | |
| local_max[1] = fmax(local_max[1], res[i + 1]) | |
| local_max[2] = fmax(local_max[2], res[i + 2]) | |
| local_max[3] = fmax(local_max[3], res[i + 3]) | |
| local_max[0] = fmax(local_max[0], local_max[1]) | |
| local_max[2] = fmax(local_max[2], local_max[3]) | |
| local_max[0] = fmax(local_max[0], local_max[2]) | |
| return local_max[0] if const_expr(init_val is None) else fmax(local_max[0], init_val) | |
| else: | |
| # [2025-06-15] x.reduce only seems to use 50% 3-input max and 50% 2-input max | |
| # We instead force the 3-input max. | |
| res = cute.make_fragment(x.shape, Float32) | |
| res.store(x) | |
| local_max_0 = ( | |
| fmax(init_val, res[0], res[1]) | |
| if const_expr(init_val is not None) | |
| else fmax(res[0], res[1]) | |
| ) | |
| local_max = [ | |
| local_max_0, | |
| fmax(res[2], res[3]), | |
| fmax(res[4], res[5]), | |
| fmax(res[6], res[7]), | |
| ] | |
| for i in cutlass.range_constexpr(8, cute.size(x.shape), 8): | |
| local_max[0] = fmax(local_max[0], res[i], res[i + 1]) | |
| local_max[1] = fmax(local_max[1], res[i + 2], res[i + 3]) | |
| local_max[2] = fmax(local_max[2], res[i + 4], res[i + 5]) | |
| local_max[3] = fmax(local_max[3], res[i + 6], res[i + 7]) | |
| local_max[0] = fmax(local_max[0], local_max[1]) | |
| return fmax(local_max[0], local_max[2], local_max[3]) | |
| def fadd_reduce( | |
| x: cute.TensorSSA, init_val: float | Float32 | None = None, arch: cutlass.Constexpr[int] = 80 | |
| ) -> Float32: | |
| if const_expr(arch < 100 or cute.size(x.shape) % 8 != 0): | |
| if const_expr(init_val is None): | |
| init_val = Float32.zero | |
| return x.reduce(cute.ReductionOp.ADD, init_val, 0) | |
| # res = cute.make_fragment(x.shape, Float32) | |
| # res.store(x) | |
| # local_sum = [res[0], res[1], res[2], res[3]] | |
| # for i in cutlass.range_constexpr(4, cute.size(x.shape), 4): | |
| # local_sum[0] += res[i + 0] | |
| # local_sum[1] += res[i + 1] | |
| # local_sum[2] += res[i + 2] | |
| # local_sum[3] += res[i + 3] | |
| # local_sum[0] += local_sum[1] | |
| # local_sum[2] += local_sum[3] | |
| # local_sum[0] += local_sum[2] | |
| # return local_sum[0] if const_expr(init_val is None) else local_sum[0] + init_val | |
| else: | |
| res = cute.make_fragment(x.shape, Float32) | |
| res.store(x) | |
| local_sum_0 = ( | |
| cute.arch.add_packed_f32x2((init_val, 0.0), (res[0], res[1])) | |
| # cute.arch.add_packed_f32x2((init_val / 2, init_val / 2), (res[0], res[1])) | |
| if const_expr(init_val is not None) | |
| else (res[0], res[1]) | |
| ) | |
| local_sum = [local_sum_0, (res[2], res[3]), (res[4], res[5]), (res[6], res[7])] | |
| for i in cutlass.range_constexpr(8, cute.size(x.shape), 8): | |
| local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], (res[i + 0], res[i + 1])) | |
| local_sum[1] = cute.arch.add_packed_f32x2(local_sum[1], (res[i + 2], res[i + 3])) | |
| local_sum[2] = cute.arch.add_packed_f32x2(local_sum[2], (res[i + 4], res[i + 5])) | |
| local_sum[3] = cute.arch.add_packed_f32x2(local_sum[3], (res[i + 6], res[i + 7])) | |
| local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], local_sum[1]) | |
| local_sum[2] = cute.arch.add_packed_f32x2(local_sum[2], local_sum[3]) | |
| local_sum[0] = cute.arch.add_packed_f32x2(local_sum[0], local_sum[2]) | |
| return local_sum[0][0] + local_sum[0][1] | |
| def atomic_add_fp32(a: float | Float32, gmem_ptr: cute.Pointer, *, loc=None, ip=None) -> None: | |
| # gmem_ptr_i64 = gmem_ptr.toint(loc=loc, ip=ip).ir_value() | |
| # # cache_hint = cutlass.Int64(0x12F0000000000000) | |
| # llvm.inline_asm( | |
| # None, | |
| # [gmem_ptr_i64, Float32(a).ir_value(loc=loc, ip=ip)], | |
| # # [gmem_ptr_i64, Float32(a).ir_value(loc=loc, ip=ip), cache_hint.ir_value()], | |
| # "red.global.add.f32 [$0], $1;", | |
| # # "red.global.add.L2::cache_hint.f32 [$0], $1, 0x12F0000000000000;", | |
| # # "red.global.add.L2::cache_hint.f32 [$0], $1, $2;", | |
| # "l,f", | |
| # # "l,f,l", | |
| # has_side_effects=True, | |
| # is_align_stack=False, | |
| # asm_dialect=llvm.AsmDialect.AD_ATT, | |
| # ) | |
| nvvm.atomicrmw( | |
| res=T.f32(), op=nvvm.AtomicOpKind.FADD, ptr=gmem_ptr.llvm_ptr, a=Float32(a).ir_value() | |
| ) | |
| def elem_pointer(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cute.Pointer: | |
| return x.iterator + cute.crd2idx(coord, x.layout, loc=loc, ip=ip) | |
| def predicate_k(tAcA: cute.Tensor, limit: cutlass.Int32) -> cute.Tensor: | |
| # Only compute predicates for the "k" dimension. For the mn dimension, we will use "if" | |
| tApA = cute.make_fragment( | |
| cute.make_layout( | |
| (cute.size(tAcA, mode=[0, 1]), cute.size(tAcA, mode=[1]), cute.size(tAcA, mode=[2])), | |
| stride=(cute.size(tAcA, mode=[2]), 0, 1), | |
| ), | |
| cutlass.Boolean, | |
| ) | |
| for rest_v in cutlass.range_constexpr(tApA.shape[0]): | |
| for rest_k in cutlass.range_constexpr(tApA.shape[2]): | |
| tApA[rest_v, 0, rest_k] = cute.elem_less(tAcA[(0, rest_v), 0, rest_k][1], limit) | |
| return tApA | |
| def canonical_warp_group_idx(sync: bool = True) -> cutlass.Int32: | |
| warp_group_idx = cute.arch.thread_idx()[0] // 128 | |
| if const_expr(sync): | |
| warp_group_idx = cute.arch.make_warp_uniform(warp_group_idx) | |
| return warp_group_idx | |
| # @dsl_user_op | |
| # def warp_vote_any_lt(a: float | Float32, b: float | Float32, *, loc=None, ip=None) -> cutlass.Boolean: | |
| # mask = cutlass.Int32(-1) | |
| # return cutlass.Boolean( | |
| # llvm.inline_asm( | |
| # T.i32(), | |
| # [Float32(a).ir_value(loc=loc, ip=ip), Float32(b).ir_value(loc=loc, ip=ip), mask.ir_value(loc=loc, ip=ip)], | |
| # ".pred p1, p2;\n" | |
| # "setp.lt.f32 p1, $1, $2;\n" | |
| # "vote.sync.any.pred p2, p1, $3;\n" | |
| # "selp.u32 $0, 1, 0, p2;", | |
| # # "selp.u32 $0, 1, 0, p1;", | |
| # "=r,f,f,r", | |
| # has_side_effects=False, | |
| # is_align_stack=False, | |
| # asm_dialect=llvm.AsmDialect.AD_ATT, | |
| # ) | |
| # ) | |
| def shuffle_sync( | |
| value: cute.Numeric, | |
| offset: cute.typing.Int, | |
| width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE, | |
| ) -> cute.Numeric: | |
| assert value.width % 32 == 0, "value type must be a multiple of 32 bits" | |
| # 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000 | |
| mask = cute.arch.WARP_SIZE - width | |
| clamp = cute.arch.WARP_SIZE - 1 | |
| mask_and_clamp = mask << 8 | clamp | |
| # important: need stride 1 and not 0 for recast_tensor to work | |
| val = cute.make_rmem_tensor(cute.make_layout((1,), stride=(1,)), type(value)) | |
| val[0] = value | |
| val_i32 = cute.recast_tensor(val, cutlass.Int32) | |
| for i in cutlass.range_constexpr(cute.size(val_i32)): | |
| val_i32[i] = cute.arch.shuffle_sync(val_i32[i], offset, mask_and_clamp=mask_and_clamp) | |
| return val[0] | |
| def shr_u32(val: cutlass.Uint32, shift: cutlass.Uint32, *, loc=None, ip=None) -> cutlass.Uint32: | |
| return cutlass.Uint32( | |
| llvm.inline_asm( | |
| T.i32(), | |
| [ | |
| cutlass.Uint32(val).ir_value(loc=loc, ip=ip), | |
| cutlass.Uint32(shift).ir_value(loc=loc, ip=ip), | |
| ], | |
| "shr.s32 $0, $1, $2;", | |
| "=r,r,r", | |
| has_side_effects=False, | |
| is_align_stack=False, | |
| asm_dialect=llvm.AsmDialect.AD_ATT, | |
| ) | |
| ) | |
| def warp_prefix_sum(val: cutlass.Int32, lane: Optional[cutlass.Int32] = None) -> cutlass.Int32: | |
| if const_expr(lane is None): | |
| lane = cute.arch.lane_idx() | |
| # if cute.arch.thread_idx()[0] >= 128 and cute.arch.thread_idx()[0] < 128 + 32 and cute.arch.block_idx()[0] == 0: cute.printf("tidx = %d, val = %d", cute.arch.thread_idx()[0] % 32, val) | |
| for i in cutlass.range_constexpr(int(math.log2(cute.arch.WARP_SIZE))): | |
| offset = 1 << i | |
| # Very important that we set mask_and_clamp to 0 | |
| partial_sum = cute.arch.shuffle_sync_up(val, offset=offset, mask_and_clamp=0) | |
| if lane >= offset: | |
| val += partial_sum | |
| # if cute.arch.thread_idx()[0] >= 128 and cute.arch.thread_idx()[0] < 128 + 32 and cute.arch.block_idx()[0] == 0: cute.printf("tidx = %d, partial_sum = %d, val = %d", cute.arch.thread_idx()[0] % 32, partial_sum, val) | |
| return val | |
| def cvt_f16x2_f32( | |
| a: float | Float32, b: float | Float32, to_dtype: Type, *, loc=None, ip=None | |
| ) -> cutlass.Int32: | |
| assert to_dtype in [cutlass.BFloat16, cutlass.Float16], "to_dtype must be BFloat16 or Float16" | |
| return cutlass.Int32( | |
| llvm.inline_asm( | |
| T.i32(), | |
| [Float32(a).ir_value(loc=loc, ip=ip), Float32(b).ir_value(loc=loc, ip=ip)], | |
| f"cvt.rn.{'bf16x2' if to_dtype is cutlass.BFloat16 else 'f16x2'}.f32 $0, $2, $1;", | |
| "=r,f,f", | |
| has_side_effects=False, | |
| is_align_stack=False, | |
| asm_dialect=llvm.AsmDialect.AD_ATT, | |
| ) | |
| ) | |
| def cvt_f16(src: cute.Tensor, dst: cute.Tensor) -> None: ... | |
| def cvt_f16(src: cute.Tensor, dtype: Type[cute.Numeric]) -> cute.Tensor: ... | |
| def cvt_f16(src: cute.Tensor, dst_or_dtype): | |
| """Convert Float32 tensor to Float16/BFloat16. | |
| Args: | |
| src: Source tensor with Float32 element type | |
| dst_or_dtype: Either a destination tensor or a dtype (Float16/BFloat16) | |
| Returns: | |
| None if dst is a tensor, or a new tensor if dtype is provided | |
| """ | |
| if const_expr(isinstance(dst_or_dtype, type)): | |
| # dtype variant: create new tensor and call the tensor variant | |
| dtype = dst_or_dtype | |
| dst = cute.make_fragment(src.shape, dtype) | |
| cvt_f16(src, dst) | |
| return dst | |
| else: | |
| # tensor variant: write to dst | |
| dst = dst_or_dtype | |
| assert cute.size(dst.shape) == cute.size(src.shape), "dst and src must have the same size" | |
| assert cute.size(src.shape) % 2 == 0, "src must have an even number of elements" | |
| assert dst.element_type in [cutlass.BFloat16, cutlass.Float16], ( | |
| "dst must be BFloat16 or Float16" | |
| ) | |
| assert src.element_type is Float32, "src must be Float32" | |
| dst_i32 = cute.recast_tensor(dst, cutlass.Int32) | |
| assert cute.size(dst_i32.shape) * 2 == cute.size(src.shape) | |
| for i in cutlass.range_constexpr(cute.size(dst_i32)): | |
| dst_i32[i] = cvt_f16x2_f32(src[2 * i], src[2 * i + 1], dst.element_type) | |
| def evaluate_polynomial(x: Float32, poly: Tuple[Float32, ...], *, loc=None, ip=None) -> Float32: | |
| deg = len(poly) - 1 | |
| out = poly[deg] | |
| for i in cutlass.range_constexpr(deg - 1, -1, -1): | |
| out = out * x + poly[i] | |
| return out | |
| def evaluate_polynomial_2( | |
| x: Float32, y: Float32, poly: Tuple[Float32, ...], *, loc=None, ip=None | |
| ) -> Tuple[Float32, Float32]: | |
| deg = len(poly) - 1 | |
| out = (poly[deg], poly[deg]) | |
| for i in cutlass.range_constexpr(deg - 1, -1, -1): | |
| out = cute.arch.fma_packed_f32x2(out, (x, y), (poly[i], poly[i])) | |
| return out | |
| def add_round_down(x: float | Float32, y: float | Float32, *, loc=None, ip=None) -> Float32: | |
| # There's probably a way to call llvm or nvvm to do this instead of ptx | |
| return cutlass.Float32( | |
| llvm.inline_asm( | |
| T.f32(), | |
| [Float32(x).ir_value(loc=loc, ip=ip), Float32(y).ir_value(loc=loc, ip=ip)], | |
| "add.rm.ftz.f32 $0, $1, $2;", | |
| "=f,f,f", | |
| has_side_effects=False, | |
| is_align_stack=False, | |
| asm_dialect=llvm.AsmDialect.AD_ATT, | |
| ) | |
| ) | |
| def combine_int_frac_ex2(x_rounded: Float32, frac_ex2: Float32, *, loc=None, ip=None) -> Float32: | |
| return cutlass.Float32( | |
| llvm.inline_asm( | |
| T.f32(), | |
| [ | |
| Float32(x_rounded).ir_value(loc=loc, ip=ip), | |
| Float32(frac_ex2).ir_value(loc=loc, ip=ip), | |
| ], | |
| "{\n\t" | |
| ".reg .s32 x_rounded_i, frac_ex_i, x_rounded_e, out_i;\n\t" | |
| "mov.b32 x_rounded_i, $1;\n\t" | |
| "mov.b32 frac_ex_i, $2;\n\t" | |
| "shl.b32 x_rounded_e, x_rounded_i, 23;\n\t" | |
| # add.u32 generates IMAD instruction and add.s32 generates LEA instruction | |
| # IMAD uses the FMA pipeline and LEA uses the ALU pipeline, afaik | |
| "add.s32 out_i, x_rounded_e, frac_ex_i;\n\t" | |
| "mov.b32 $0, out_i;\n\t" | |
| "}\n", | |
| "=f,f,f", | |
| has_side_effects=False, | |
| is_align_stack=False, | |
| asm_dialect=llvm.AsmDialect.AD_ATT, | |
| ) | |
| ) | |
| def ex2_emulation(x: Float32, *, poly_degree: int = 3, loc=None, ip=None) -> Float32: | |
| assert poly_degree in POLY_EX2, f"Polynomial degree {poly_degree} not supported" | |
| # We assume x <= 127.0 | |
| fp32_round_int = float(2**23 + 2**22) | |
| x_clamped = cute.arch.fmax(x, -127.0) | |
| # We want to round down here, so that the fractional part is in [0, 1) | |
| x_rounded = add_round_down(x_clamped, fp32_round_int, loc=loc, ip=ip) | |
| # The integer floor of x is now in the last 8 bits of x_rounded | |
| # We assume the next 2 ops round to nearest even. The rounding mode is important. | |
| x_rounded_back = x_rounded - fp32_round_int | |
| x_frac = x_clamped - x_rounded_back | |
| x_frac_ex2 = evaluate_polynomial(x_frac, POLY_EX2[poly_degree], loc=loc, ip=ip) | |
| return combine_int_frac_ex2(x_rounded, x_frac_ex2, loc=loc, ip=ip) | |
| # TODO: check that the ex2_emulation_2 produces the same SASS as the ptx version | |
| def ex2_emulation_2( | |
| x: Float32, y: Float32, *, poly_degree: int = 3, loc=None, ip=None | |
| ) -> Tuple[Float32, Float32]: | |
| # We assume x <= 127.0 and y <= 127.0 | |
| fp32_round_int = float(2**23 + 2**22) | |
| xy_clamped = (cute.arch.fmax(x, -127.0), cute.arch.fmax(y, -127.0)) | |
| # We want to round down here, so that the fractional part is in [0, 1) | |
| xy_rounded = cute.arch.add_packed_f32x2(xy_clamped, (fp32_round_int, fp32_round_int), rnd="rm") | |
| # The integer floor of x & y are now in the last 8 bits of xy_rounded | |
| # We want the next 2 ops to round to nearest even. The rounding mode is important. | |
| xy_rounded_back = activation.sub_packed_f32x2( | |
| xy_rounded, (fp32_round_int, fp32_round_int) | |
| ) | |
| xy_frac = activation.sub_packed_f32x2(xy_clamped, xy_rounded_back) | |
| xy_frac_ex2 = evaluate_polynomial_2(*xy_frac, POLY_EX2[poly_degree], loc=loc, ip=ip) | |
| x_out = combine_int_frac_ex2(xy_rounded[0], xy_frac_ex2[0], loc=loc, ip=ip) | |
| y_out = combine_int_frac_ex2(xy_rounded[1], xy_frac_ex2[1], loc=loc, ip=ip) | |
| return x_out, y_out | |
| def e2e_asm2(x: Float32, y: Float32, *, loc=None, ip=None) -> Tuple[Float32, Float32]: | |
| out_f32x2 = llvm.inline_asm( | |
| llvm.StructType.get_literal([T.f32(), T.f32()]), | |
| [Float32(x).ir_value(loc=loc, ip=ip), Float32(y, loc=loc, ip=ip).ir_value()], | |
| "{\n\t" | |
| ".reg .f32 f1, f2, f3, f4, f5, f6, f7;\n\t" | |
| ".reg .b64 l1, l2, l3, l4, l5, l6, l7, l8, l9, l10;\n\t" | |
| ".reg .s32 r1, r2, r3, r4, r5, r6, r7, r8;\n\t" | |
| "max.ftz.f32 f1, $2, 0fC2FE0000;\n\t" | |
| "max.ftz.f32 f2, $3, 0fC2FE0000;\n\t" | |
| "mov.b64 l1, {f1, f2};\n\t" | |
| "mov.f32 f3, 0f4B400000;\n\t" | |
| "mov.b64 l2, {f3, f3};\n\t" | |
| "add.rm.ftz.f32x2 l7, l1, l2;\n\t" | |
| "sub.rn.ftz.f32x2 l8, l7, l2;\n\t" | |
| "sub.rn.ftz.f32x2 l9, l1, l8;\n\t" | |
| "mov.f32 f7, 0f3D9DF09D;\n\t" | |
| "mov.b64 l6, {f7, f7};\n\t" | |
| "mov.f32 f6, 0f3E6906A4;\n\t" | |
| "mov.b64 l5, {f6, f6};\n\t" | |
| "mov.f32 f5, 0f3F31F519;\n\t" | |
| "mov.b64 l4, {f5, f5};\n\t" | |
| "mov.f32 f4, 0f3F800000;\n\t" | |
| "mov.b64 l3, {f4, f4};\n\t" | |
| "fma.rn.ftz.f32x2 l10, l9, l6, l5;\n\t" | |
| "fma.rn.ftz.f32x2 l10, l10, l9, l4;\n\t" | |
| "fma.rn.ftz.f32x2 l10, l10, l9, l3;\n\t" | |
| "mov.b64 {r1, r2}, l7;\n\t" | |
| "mov.b64 {r3, r4}, l10;\n\t" | |
| "shl.b32 r5, r1, 23;\n\t" | |
| "add.s32 r7, r5, r3;\n\t" | |
| "shl.b32 r6, r2, 23;\n\t" | |
| "add.s32 r8, r6, r4;\n\t" | |
| "mov.b32 $0, r7;\n\t" | |
| "mov.b32 $1, r8;\n\t" | |
| "}\n", | |
| "=r,=r,f,f", | |
| has_side_effects=False, | |
| is_align_stack=False, | |
| asm_dialect=llvm.AsmDialect.AD_ATT, | |
| ) | |
| out0 = Float32(llvm.extractvalue(T.f32(), out_f32x2, [0], loc=loc, ip=ip)) | |
| out1 = Float32(llvm.extractvalue(T.f32(), out_f32x2, [1], loc=loc, ip=ip)) | |
| return out0, out1 | |
| def domain_offset_aligned( | |
| coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None | |
| ) -> cute.Tensor: | |
| assert isinstance(tensor.iterator, cute.Pointer) | |
| # We assume that applying the offset does not change the pointer alignment | |
| new_ptr = cute.make_ptr( | |
| tensor.element_type, | |
| elem_pointer(tensor, coord).toint(), | |
| tensor.memspace, | |
| assumed_align=tensor.iterator.alignment, | |
| ) | |
| return cute.make_tensor(new_ptr, tensor.layout) | |
| def scalar_to_ssa(a: cute.Numeric, dtype) -> cute.TensorSSA: | |
| """Convert a scalar to a cute TensorSSA of shape (1,) and given dtype""" | |
| vec = cute.make_fragment(1, dtype) | |
| vec[0] = a | |
| return vec.load() | |
| def ssa_to_scalar(val): | |
| """Could inline but nice for reflecting the above api""" | |
| return val[0] | |