| |
| |
|
|
| |
| """A GEMM schedule rule for GPU operators.""" |
| from typing import Optional, List |
| from contextlib import suppress |
|
|
| from tvm import tir, DataType |
| from tvm.target import Target |
|
|
| from bitblas.base.operator_common import TransformKind |
| from ..base.roller.hint import Hint, IntrinInfo |
| from ..base.roller.rasterization import NoRasterization |
| from ..base import analysis |
| from .base import GPUScheduleRule |
| from ..base.analysis import get_coalesced_veclen |
| from .matmul_analysis import ( |
| auto_inline_consumer_chain, |
| auto_inline_producers, |
| get_reduction_blocks, |
| normalize_to_matmul, |
| get_propagate_map, |
| get_ladder_stage3_map, |
| layout_propagate_chain, |
| find_last_producer_from_buffer, |
| _collect_producers, |
| get_in_out_dtypes, |
| ) |
|
|
|
|
| def _bind_thread_based_on_config(sch, block, block_row_warps, block_col_warps, warp_size): |
| |
| last_loop = sch.get_loops(block)[-1] |
| loop_extent = sch.get(last_loop).extent |
| vec = get_coalesced_veclen(sch.get(block)) |
|
|
| if loop_extent // (vec * warp_size) == 0: |
| last_loop, B_shared_tx, B_shared_vi = sch.split(last_loop, factors=[1, warp_size, None]) |
| sch.bind(B_shared_tx, "threadIdx.x") |
| sch.vectorize(B_shared_vi) |
| loop_extent = sch.get(last_loop).extent |
|
|
| |
| if loop_extent // vec >= 1 and loop_extent & (vec - 1) == 0 and (loop_extent // |
| vec) & (warp_size - 1) == 0: |
| last_loop, B_shared_vi = sch.split(last_loop, factors=[None, vec]) |
| sch.vectorize(B_shared_vi) |
| loop_extent = sch.get(last_loop).extent |
|
|
| if loop_extent // warp_size >= 1 and loop_extent % warp_size == 0: |
| last_loop, B_shared_tx = sch.split(last_loop, factors=[None, warp_size]) |
| sch.bind(B_shared_tx, "threadIdx.x") |
| loop_extent = sch.get(last_loop).extent |
|
|
| if loop_extent // block_row_warps >= 1 and loop_extent % block_row_warps == 0: |
| last_loop, B_shared_ty = sch.split(last_loop, factors=[None, block_row_warps]) |
| sch.bind(B_shared_ty, "threadIdx.y") |
| loop_extent = sch.get(last_loop).extent |
|
|
| if loop_extent // block_col_warps >= 1 and loop_extent % block_col_warps == 0: |
| last_loop, B_shared_tz = sch.split(last_loop, factors=[None, block_col_warps]) |
| sch.bind(B_shared_tz, "threadIdx.z") |
| loop_extent = sch.get(last_loop).extent |
|
|
| sch.unroll(last_loop) |
| sch.annotate(last_loop, "pragma_unroll_explicit", False) |
|
|
|
|
| def _bind_thread_based_with_block_reduce_on_config(sch, block, block_row_warps, block_col_warps, |
| warp_size, reduce_k): |
| |
| last_loop = sch.get_loops(block)[-1] |
| loop_extent = sch.get(last_loop).extent |
| vec = get_coalesced_veclen(sch.get(block)) |
|
|
| if loop_extent // (vec * warp_size) == 0: |
| last_loop, B_shared_tx, B_shared_vi = sch.split(last_loop, factors=[1, warp_size, None]) |
| sch.bind(B_shared_tx, "threadIdx.x") |
| sch.vectorize(B_shared_vi) |
| loop_extent = sch.get(last_loop).extent |
|
|
| |
| if loop_extent // vec >= 1 and loop_extent & (vec - 1) == 0 and (loop_extent // |
| vec) & (warp_size - 1) == 0: |
| last_loop, B_shared_vi = sch.split(last_loop, factors=[None, vec]) |
| sch.vectorize(B_shared_vi) |
| loop_extent = sch.get(last_loop).extent |
|
|
| if loop_extent // warp_size >= 1 and loop_extent % warp_size == 0: |
| last_loop, B_shared_tx = sch.split(last_loop, factors=[None, warp_size]) |
| sch.bind(B_shared_tx, "threadIdx.x") |
| loop_extent = sch.get(last_loop).extent |
|
|
| B_shared_ty = None |
| if loop_extent // block_row_warps >= 1 and loop_extent % block_row_warps == 0: |
| last_loop, B_shared_ty = sch.split(last_loop, factors=[None, block_row_warps]) |
| loop_extent = sch.get(last_loop).extent |
|
|
| B_shared_tz = None |
| if loop_extent // block_col_warps >= 1 and loop_extent % block_col_warps == 0: |
| last_loop, B_shared_tz = sch.split(last_loop, factors=[None, block_col_warps]) |
| loop_extent = sch.get(last_loop).extent |
|
|
| if B_shared_ty and B_shared_tz: |
| B_shared_ty = sch.fuse(B_shared_tz, B_shared_ty) |
| sch.bind(B_shared_ty, "threadIdx.y") |
| elif B_shared_ty: |
| sch.bind(B_shared_ty, "threadIdx.y") |
|
|
| if loop_extent // reduce_k >= 1 and loop_extent % reduce_k == 0: |
| last_loop, B_shared_tk = sch.split(last_loop, factors=[None, reduce_k]) |
| sch.bind(B_shared_tk, "threadIdx.z") |
| loop_extent = sch.get(last_loop).extent |
|
|
| sch.unroll(last_loop) |
| sch.annotate(last_loop, "pragma_unroll_explicit", False) |
|
|
|
|
| def get_index_map_3d(index_map, l=16, r=16): |
|
|
| def index_map_3d(b, i, j): |
| return ( |
| b, |
| i // l, |
| j // r, |
| *index_map(i % l, j % r), |
| ) |
|
|
| return index_map_3d |
|
|
|
|
| def get_index_map_5d(index_map): |
| """ |
| for layout transformed gemm, the index map should be 5d |
| """ |
|
|
| def index_map_5d(b, i, j, ii, jj): |
| return ( |
| b, |
| i, |
| j, |
| *index_map(ii, jj), |
| ) |
|
|
| return index_map_5d |
|
|
|
|
| def get_index_map(index_map, l=16, r=16, is_5d=False): |
| if is_5d: |
| return get_index_map_5d(index_map) |
| return get_index_map_3d(index_map, l, r) |
|
|
|
|
| def check_weight_decode_info(weight_decode_info): |
| conditions = [] |
| |
| conditions.append("source_format" in weight_decode_info) |
| conditions.append( |
| weight_decode_info["source_format"]["format"] in ["uint", "int", "fp", "nf", "fp_e4m3"]) |
| |
| conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8]) |
| |
| conditions.append("target_format" in weight_decode_info) |
| conditions.append(weight_decode_info["target_format"] in ["bfloat16", "float16", "int8"]) |
| return all(conditions) |
|
|
|
|
| class MatmulTensorizationMMAWithDequantizeInfo(GPUScheduleRule): |
| """ |
| The schedule rule for float16 tensor core matmul computation. |
| func with attr 'dlight.do_not_tensorize' will not be tensorized. |
| """ |
|
|
| def apply( |
| self, |
| func: tir.PrimFunc, |
| target: Target, |
| _: bool, |
| ): |
| """ |
| For devices without async copy, we can use a simple dequantize schedule without shared memory prefetch. |
| quantized weight |
| | |
| V |
| dequantized in register |
| | |
| V |
| save into shared memory |
| | |
| V |
| compute |
| """ |
| from tvm.tir.tensor_intrin.cuda import ( |
| get_mma_intrin_group,) |
| from .intrin import get_lop3_intrin_group |
|
|
| import_source: List[str] = [] |
|
|
| sch = tir.Schedule(func) |
| root_block = analysis.get_root_block(sch) |
| blocks = sch.get_child_blocks(root_block) |
|
|
| if func.attrs is not None and "dlight.do_not_tensorize" in func.attrs.keys(): |
| return None |
|
|
| reduction_blocks = get_reduction_blocks(sch, blocks) |
| if reduction_blocks is None: |
| return None |
|
|
| main_block = reduction_blocks[0] |
| |
| cache_write_required = True |
|
|
| |
| dequantize_info = func.attrs["dequantize_info"] |
|
|
| def check_dequantize_info(dequantize_info): |
| conditions = [] |
| |
| conditions.append(len(dequantize_info) == 1) |
| |
| return all(conditions) |
|
|
| assert check_dequantize_info(dequantize_info) |
|
|
| (weight_decode_info,) = list(dequantize_info.values()) |
|
|
| assert check_weight_decode_info(weight_decode_info), "Invalid Weight Decode Info" |
|
|
| |
| |
|
|
| |
| in_dtype, out_dtype = get_in_out_dtypes(sch.get(main_block)) |
| intrin_info = IntrinInfo( |
| in_dtype=in_dtype, |
| out_dtype=out_dtype, |
| trans_b=True, |
| ) |
| if "weight_transform_kind" in func.attrs: |
| intrin_info.weight_transform_kind = int(func.attrs["weight_transform_kind"]) |
|
|
| if "input_transform_kind" in func.attrs: |
| intrin_info.input_transform_kind = int(func.attrs["input_transform_kind"]) |
| |
| config = Hint().from_dict({ |
| "block": [128, 128], |
| "warp": [64, 64], |
| "rstep": [32], |
| "pipeline_stage": 1, |
| "use_async": False, |
| "intrin_info": intrin_info, |
| "shared_scope": "shared.dyn", |
| }) |
| shared_scope = config.shared_scope |
|
|
| intrin_group = get_mma_intrin_group( |
| load_scope=shared_scope, |
| store_scope=shared_scope if cache_write_required else "global", |
| a_dtype=intrin_info.in_dtype, |
| b_dtype=intrin_info.in_dtype, |
| out_dtype=intrin_info.out_dtype, |
| trans_a=intrin_info.trans_a, |
| trans_b=intrin_info.trans_b, |
| smooth_a=intrin_info.smooth_a, |
| smooth_b=intrin_info.smooth_b, |
| not_use_mma_store_intrinic=False, |
| ) |
|
|
| warp_row_tiles = config.warp[0] |
| warp_col_tiles = config.warp[1] |
| block_row_warps = config.block[0] // warp_row_tiles |
| block_col_warps = config.block[1] // warp_col_tiles |
| stage = config.pipeline_stage |
| use_async = config.use_async |
| chunk = config.rstep[0] |
|
|
| micro_size_x, micro_size_y, micro_size_k = intrin_group["micro_kernel"] |
|
|
| |
| def get_axis(l, r, trans): |
| return (r, l) if trans else (l, r) |
|
|
| a_lr = get_axis(micro_size_x, micro_size_k, intrin_info.trans_a) |
| b_lr = get_axis(micro_size_k, micro_size_y, intrin_info.trans_b) |
|
|
| def can_enable_swizzle(dtype: str, smooth: bool): |
| |
| if dtype == "float16" or dtype == "int8": |
| if chunk * DataType(dtype).bits != 512: |
| |
| return False |
| |
| return not smooth |
| return False |
|
|
| can_swizzle_a = can_enable_swizzle(intrin_info.in_dtype, intrin_info.inter_transform_a) |
| can_swizzle_b = can_enable_swizzle(intrin_info.in_dtype, intrin_info.inter_transform_b) |
|
|
| |
| def smooth_gmem_layout_rewrite(sch, main_block, enable=True, trans=False, matrix_name="A"): |
| if not enable: |
| return |
|
|
| |
| buffer_offset = (1 if sch.get(main_block).reads[0].buffer |
| == sch.get(main_block).writes[0].buffer else 0) |
| buffer_idx = 0 if matrix_name == "A" else 1 |
| source_buffer = sch.get(main_block).reads[buffer_offset + buffer_idx].buffer |
|
|
| |
| |
| |
| |
| propagate_block: tir.Block = find_last_producer_from_buffer( |
| sch, main_block, source_buffer) |
| |
| (weight_dequantize_info,) = dequantize_info.values() |
| if (sch.get(propagate_block).name_hint == weight_dequantize_info["decode_block"]): |
| return |
|
|
| |
| intra_indexmap, _ = get_propagate_map( |
| trans=trans, dtype=intrin_info.in_dtype, matrix_name=matrix_name) |
|
|
| |
|
|
| intra_indexmap = layout_propagate_chain( |
| sch, |
| start_block=main_block, |
| start_buffer=source_buffer, |
| end_block=propagate_block, |
| index_map=intra_indexmap, |
| ) |
|
|
| def inverse_permutation(i, j, ii, jj): |
| return (i, j, *intra_indexmap.map_indices([ii, jj])) |
|
|
| sch.transform_layout(propagate_block, ("read", 0), inverse_permutation) |
|
|
| smooth_gmem_layout_rewrite( |
| sch, main_block, intrin_info.smooth_a, intrin_info.trans_a, matrix_name="A") |
|
|
| smooth_gmem_layout_rewrite( |
| sch, main_block, intrin_info.smooth_b, intrin_info.trans_b, matrix_name="B") |
|
|
| warp_size = 32 |
|
|
| i_factors, j_factors, k_factors = ( |
| [None, 1, block_row_warps, warp_row_tiles // micro_size_x], |
| [1, None, block_col_warps, warp_col_tiles // micro_size_y], |
| [None, chunk // micro_size_k], |
| ) |
|
|
| num_ty = i_factors[2] |
| num_tz = j_factors[2] |
| x_pad_factor = i_factors[2] * i_factors[3] |
| y_pad_factor = j_factors[2] * j_factors[3] |
| k_pad_factor = k_factors[1] |
|
|
| |
| if not (func.attrs is not None and "dlight.tensorcore_prenormlized" in func.attrs.keys()): |
| sch = normalize_to_matmul(sch, main_block, ["n", "t", "n"]) |
|
|
| |
| sch.pad_einsum( |
| main_block, |
| [ |
| 1, |
| micro_size_x * x_pad_factor, |
| micro_size_y * y_pad_factor, |
| micro_size_k * k_pad_factor, |
| ], |
| ) |
|
|
| |
| block = main_block |
|
|
| batch, i, j, k = sch.get_loops(block) |
|
|
| |
| i, i_inner = sch.split(i, factors=[None, micro_size_x]) |
| j, j_inner = sch.split(j, factors=[None, micro_size_y]) |
| k, k_inner = sch.split(k, factors=[None, micro_size_k]) |
|
|
| sch.reorder(i, j, k, i_inner, j_inner, k_inner) |
|
|
| block_inner = block |
| block_outer = sch.blockize(i_inner) |
|
|
| i0, i1, i2, i3 = sch.split(i, factors=i_factors) |
| j0, j1, j2, j3 = sch.split(j, factors=j_factors) |
| k0, k1 = sch.split(k, k_factors) |
|
|
| sch.reorder(i0, j0, i1, j1, i2, j2, k0, k1, i3, j3) |
|
|
| block_idy = sch.fuse(i0, j0) |
| block_idx = sch.fuse(i1, j1) |
| thread_idy = i2 |
| thread_idz = j2 |
|
|
| sch.bind(batch, "blockIdx.z") |
| sch.bind(block_idx, "blockIdx.x") |
| sch.bind(block_idy, "blockIdx.y") |
| sch.bind(thread_idy, "threadIdx.y") |
| sch.bind(thread_idz, "threadIdx.z") |
|
|
| def smooth_layout_recover(block, scope, l=16, r=16, enable=True): |
| if not enable: |
| return |
| sch.transform_layout( |
| block, |
| scope, |
| lambda b, i, j: ( |
| b, |
| i // l, |
| j // r, |
| i % l, |
| j % r, |
| ), |
| ) |
|
|
| smooth_layout_recover(block_outer, ("read", 0), *a_lr, enable=intrin_info.inter_transform_a) |
| smooth_layout_recover( |
| block_outer, |
| ("read", 1), |
| *b_lr, |
| enable=intrin_info.inter_transform_b, |
| ) |
| smooth_layout_recover(block_outer, ("write", 0), enable=True) |
|
|
| def fetch_to_shared(block, idx, vec_len, can_swizzle=False, is_smooth=False): |
| block_read = sch.cache_read(block, idx, shared_scope) |
| sch.compute_at(block_read, k0, preserve_unit_loops=True) |
| ndim = len(sch.get(block_read).iter_vars) |
| fused = sch.fuse(*sch.get_loops(block_read)[-ndim:]) |
|
|
| f_0, f_1, f_2, f_3, f_4 = sch.split( |
| fused, factors=[None, num_ty, num_tz, warp_size, vec_len]) |
|
|
| sch.bind(f_3, "threadIdx.x") |
| sch.bind(f_2, "threadIdx.z") |
| sch.bind(f_1, "threadIdx.y") |
| sch.vectorize(f_4) |
| sch.unroll(f_0) |
| |
| sch.annotate(block_read, ann_key="permuted_layout", ann_val=can_swizzle) |
| |
| if not (can_swizzle or is_smooth): |
| pad_offset = 8 if intrin_info.in_dtype == "float16" else 16 |
| sch.storage_align(block_read, 0, axis=-2, factor=16, offset=pad_offset) |
| sch.annotate(f_0, "pragma_unroll_explicit", False) |
| return block_read |
|
|
| a_g2s = fetch_to_shared( |
| block_outer, |
| 0, |
| vec_len=4, |
| can_swizzle=can_swizzle_a, |
| is_smooth=intrin_info.smooth_a, |
| ) |
|
|
| auto_inline_producers(sch, a_g2s) |
|
|
| def decode_fetch_to_shared(block, idx): |
| |
| |
| block_shared = sch.cache_read(block, idx, shared_scope) |
| sch.compute_at(block_shared, k0, preserve_unit_loops=True) |
|
|
| decode_factor = get_coalesced_veclen(sch.get(block_shared)) |
| _, B_shared_vi, _ = sch.split( |
| sch.get_loops(block_shared)[-1], factors=[None, 1, decode_factor]) |
| block_shared_local = sch.cache_read(block_shared, 0, "local") |
| |
| |
| weight_dequantize_block = sch.get_block(weight_decode_info["decode_block"]) |
| weight_producers = _collect_producers(sch, weight_dequantize_block) |
| auto_inline_producers(sch, block_shared_local, weight_producers) |
|
|
| |
| def get_idx(): |
| |
| |
| |
| if weight_decode_info["source_format"]["format"] == "nf": |
| return 1 |
| return 0 |
|
|
| b_idx = get_idx() |
| |
| block_shared_local_local = sch.cache_read(block_shared_local, b_idx, "local") |
|
|
| sch.compute_at(block_shared_local, B_shared_vi, preserve_unit_loops=True) |
| sch.compute_at(block_shared_local_local, B_shared_vi, preserve_unit_loops=True) |
|
|
| dequantize_block_local = block_shared_local |
| if ("with_scaling" in weight_decode_info and weight_decode_info["with_scaling"]): |
| block_local_scales = sch.cache_read(dequantize_block_local, b_idx + 1, "local") |
| sch.compute_at(block_local_scales, B_shared_vi, preserve_unit_loops=True) |
| |
| auto_inline_producers(sch, block_local_scales) |
|
|
| if ("with_zeros" in weight_decode_info and weight_decode_info["with_zeros"]): |
| block_local_zeros = sch.cache_read(dequantize_block_local, b_idx + 2, "local") |
| sch.compute_at(block_local_zeros, B_shared_vi, preserve_unit_loops=True) |
| auto_inline_producers(sch, block_local_zeros) |
|
|
| for producer in weight_producers: |
| with suppress(Exception): |
| auto_inline_producers(sch, producer) |
| sch.compute_inline(producer) |
|
|
| |
| if ("fast_decoding" in weight_decode_info and weight_decode_info["fast_decoding"]): |
| source_bit = weight_decode_info["source_format"]["bits"] |
| out_dtype = weight_decode_info["target_format"] |
| lop3_intrin_info = get_lop3_intrin_group( |
| out_dtype=out_dtype, |
| storage_dtype=weight_decode_info["storage_dtype"], |
| source_format=weight_decode_info["source_format"]["format"], |
| source_bit=source_bit, |
| with_scaling=weight_decode_info["with_scaling"], |
| with_zeros=weight_decode_info["with_zeros"], |
| zeros_mode=weight_decode_info["zeros_mode"], |
| ) |
| sch.tensorize( |
| sch.get_loops(dequantize_block_local)[-1], |
| lop3_intrin_info["compute"], |
| ) |
| import_source.append(lop3_intrin_info["c_source"]) |
|
|
| sch.annotate(block_shared, ann_key="permuted_layout", ann_val=can_swizzle_b) |
| union_len = (2 + 4) if intrin_info.smooth_b else (2 + 2) |
| B_shared_fused = sch.fuse(*sch.get_loops(block_shared)[-union_len:-2]) |
| _, B_shared_ty, B_shared_tz, B_shared_tx = sch.split( |
| B_shared_fused, factors=[None, num_ty, num_tz, warp_size]) |
| if not (can_swizzle_b or intrin_info.smooth_b): |
| pad_offset = 8 if intrin_info.in_dtype == "float16" else 16 |
| sch.storage_align(block_shared, 0, axis=-2, factor=16, offset=pad_offset) |
| sch.bind(B_shared_tx, "threadIdx.x") |
| sch.bind(B_shared_ty, "threadIdx.y") |
| sch.bind(B_shared_tz, "threadIdx.z") |
| sch.vectorize(sch.get_loops(block_shared)[-1]) |
| sch.vectorize(sch.get_loops(block_shared_local_local)[-1]) |
|
|
| |
| if b_idx: |
| block_shared_lut = sch.cache_read(dequantize_block_local, 0, shared_scope) |
| sch.reverse_compute_at(block_shared_lut, j2) |
| _, B_shared_tx = sch.split( |
| sch.get_loops(block_shared_lut)[-1], factors=[None, warp_size]) |
| sch.bind(B_shared_tx, "threadIdx.x") |
| return block_shared_local |
|
|
| _ = decode_fetch_to_shared(block_outer, 1) |
|
|
| |
| A_mat = sch.cache_read(block_outer, 0, "warp") |
| B_mat = sch.cache_read(block_outer, 1, "warp") |
| sch.compute_at(A_mat, k1, preserve_unit_loops=True) |
| sch.compute_at(B_mat, k1, preserve_unit_loops=True) |
|
|
| |
| if cache_write_required: |
| accumulator_shared_to_global = sch.cache_write(block_outer, 0, shared_scope) |
|
|
| store = sch.cache_write(block_outer, 0, "warp") |
| sch.reverse_compute_at(store, j2) |
|
|
| |
| i, j = sch.get_loops(store)[-2:] |
| i0, i1 = sch.split(i, factors=[None, micro_size_x]) |
| j0, j1 = sch.split(j, factors=[None, micro_size_y]) |
| sch.reorder(i0, j0, i1, j1) |
|
|
| if cache_write_required: |
| auto_inline_consumer_chain(sch, accumulator_shared_to_global) |
| sch.reverse_compute_at( |
| accumulator_shared_to_global, |
| sch.get_loops(store)[-6], |
| preserve_unit_loops=True, |
| ) |
| vec_len = get_coalesced_veclen(sch.get(accumulator_shared_to_global)) |
| fused = sch.fuse(*sch.get_loops(accumulator_shared_to_global)[-5:]) |
| f0, f1, f2 = sch.split(fused, factors=[None, warp_size, vec_len]) |
| sch.bind(f1, "threadIdx.x") |
| sch.vectorize(f2) |
| sch.unroll(f0) |
| sch.annotate(f0, "pragma_unroll_explicit", False) |
| else: |
| auto_inline_consumer_chain(sch, store) |
|
|
| block_init_c = sch.decompose_reduction(block_outer, k0) |
| block_init_c_inner = sch.get_child_blocks(block_init_c)[0] |
|
|
| |
|
|
| index_map_a, index_map_b, index_map_c = intrin_group["index_map"] |
|
|
| sch.transform_layout( |
| A_mat, |
| ("write", 0), |
| get_index_map(index_map_a, *a_lr, intrin_info.inter_transform_a), |
| ) |
| sch.transform_layout( |
| B_mat, |
| ("write", 0), |
| get_index_map(index_map_b, *b_lr, intrin_info.inter_transform_b), |
| ) |
| sch.transform_layout( |
| store, |
| ("read", 0), |
| get_index_map(index_map_c, is_5d=True), |
| ) |
|
|
| i, j = sch.get_loops(A_mat)[-2:] |
| i0, i1 = sch.split(i, factors=[None, a_lr[0]]) |
| j0, j1 = sch.split(j, factors=[None, a_lr[1]]) |
| sch.reorder(i0, j0, i1, j1) |
| ba = sch.blockize(i1) |
| sch.annotate(ba, ann_key="permuted_layout", ann_val=can_swizzle_a) |
| sch.tensorize(ba, intrin_group["load_a"]) |
|
|
| i, j = sch.get_loops(B_mat)[-2:] |
| i0, i1 = sch.split(i, factors=[None, b_lr[0]]) |
| j0, j1 = sch.split(j, factors=[None, b_lr[1]]) |
| sch.reorder(i0, j0, i1, j1) |
| weight_transform_kind = 0 |
| if hasattr(func, "attrs") and "weight_transform_kind" in func.attrs: |
| weight_transform_kind = func.attrs["weight_transform_kind"] |
| if weight_transform_kind >= TransformKind.LDMatrixTransform: |
| fused = sch.fuse(i1, j1) |
| vec_len = get_coalesced_veclen(sch.get(B_mat)) |
| f0, f1, f2 = sch.split(fused, factors=[None, warp_size, vec_len]) |
| sch.bind(f1, "threadIdx.x") |
| sch.vectorize(f2) |
| else: |
| bb = sch.blockize(i1) |
| sch.annotate(bb, ann_key="permuted_layout", ann_val=can_swizzle_b) |
| sch.tensorize(bb, intrin_group["load_b"]) |
|
|
| def tensorize_init_store_compute(): |
| sch.tensorize(sch.get_loops(block_init_c_inner)[-2], intrin_group["init"]) |
| sch.tensorize(sch.get_loops(store)[-2], intrin_group["store"]) |
| sch.tensorize(sch.get_loops(block_inner)[-3], intrin_group["compute"]) |
|
|
| tensorize_init_store_compute() |
|
|
| if stage > 1: |
| sch.annotate( |
| k0, |
| ann_key="software_pipeline_stage", |
| ann_val=[0, 0, stage - 1], |
| ) |
| sch.annotate(k0, ann_key="software_pipeline_order", ann_val=[0, 1, 2]) |
| if use_async: |
| sch.annotate(k0, "software_pipeline_async_stages", [0]) |
| |
| if not isinstance(config.rasterization_plan, NoRasterization): |
| device_func, invoke_func = config.rasterization_plan.get_code() |
| import_source.append(device_func) |
| sch.annotate( |
| sch.get_loops(block_init_c)[-2], |
| ann_key="inject_customized_code_prepend", |
| ann_val=invoke_func, |
| ) |
| |
| if len(import_source) > 0: |
| sch.annotate( |
| thread_idz, |
| ann_key="pragma_import_c", |
| ann_val=("\n").join(import_source), |
| ) |
| return sch |
|
|
| def sch_dequantize_in_register_with_config( |
| self, |
| func: tir.PrimFunc, |
| config: Hint, |
| ): |
| """ |
| For devices without async copy, we can use a simple dequantize schedule without shared memory prefetch. |
| quantized weight |
| | |
| V |
| dequantized in register |
| | |
| V |
| save into shared memory |
| | |
| V |
| compute |
| """ |
| weight_transform_kind = config.intrin_info.weight_transform_kind |
| if weight_transform_kind == TransformKind.LDMatrixTransform: |
| return self.sch_warp_memory_prefetch_with_config(func, config) |
|
|
| from tvm.tir.tensor_intrin.cuda import ( |
| get_mma_intrin_group,) |
| from .intrin import get_lop3_intrin_group |
|
|
| import_source: List[str] = [] |
|
|
| sch = tir.Schedule(func) |
| root_block = analysis.get_root_block(sch) |
| blocks = sch.get_child_blocks(root_block) |
|
|
| if func.attrs is not None and "dlight.do_not_tensorize" in func.attrs.keys(): |
| return None |
|
|
| reduction_blocks = get_reduction_blocks(sch, blocks) |
| if reduction_blocks is None: |
| return None |
|
|
| main_block = reduction_blocks[0] |
| |
| cache_write_required = True |
|
|
| |
| dequantize_info = func.attrs["dequantize_info"] |
|
|
| def check_dequantize_info(dequantize_info): |
| conditions = [] |
| |
| conditions.append(len(dequantize_info) == 1) |
| |
| return all(conditions) |
|
|
| assert check_dequantize_info(dequantize_info) |
|
|
| (weight_decode_info,) = list(dequantize_info.values()) |
|
|
| assert check_weight_decode_info(weight_decode_info), "Invalid Weight Decode Info" |
|
|
| |
| |
| |
|
|
| |
| intrin_info = config.intrin_info |
| shared_scope = config.shared_scope |
|
|
| intrin_group = get_mma_intrin_group( |
| load_scope=shared_scope, |
| store_scope=shared_scope if cache_write_required else "global", |
| a_dtype=intrin_info.in_dtype, |
| b_dtype=intrin_info.in_dtype, |
| out_dtype=intrin_info.out_dtype, |
| trans_a=intrin_info.trans_a, |
| trans_b=intrin_info.trans_b, |
| smooth_a=intrin_info.smooth_a, |
| smooth_b=intrin_info.smooth_b, |
| not_use_mma_store_intrinic=False, |
| ) |
|
|
| warp_row_tiles = config.warp[0] |
| warp_col_tiles = config.warp[1] |
| block_row_warps = config.block[0] // warp_row_tiles |
| block_col_warps = config.block[1] // warp_col_tiles |
| stage = config.pipeline_stage |
| use_async = config.use_async |
| chunk = config.rstep[0] |
|
|
| micro_size_x, micro_size_y, micro_size_k = intrin_group["micro_kernel"] |
|
|
| |
| def get_axis(l, r, trans): |
| return (r, l) if trans else (l, r) |
|
|
| a_lr = get_axis(micro_size_x, micro_size_k, intrin_info.trans_a) |
| b_lr = get_axis(micro_size_k, micro_size_y, intrin_info.trans_b) |
|
|
| def can_enable_swizzle(dtype: str, smooth: bool): |
| |
| if dtype == "float16" or dtype == "int8": |
| if chunk * DataType(dtype).bits != 512: |
| |
| return False |
| |
| return not smooth |
| return False |
|
|
| can_swizzle_a = can_enable_swizzle(intrin_info.in_dtype, intrin_info.inter_transform_a) |
| can_swizzle_b = can_enable_swizzle(intrin_info.in_dtype, intrin_info.inter_transform_b) |
|
|
| |
| def smooth_gmem_layout_rewrite(sch, main_block, enable=True, trans=False, matrix_name="A"): |
| if not enable: |
| return |
|
|
| |
| buffer_offset = (1 if sch.get(main_block).reads[0].buffer |
| == sch.get(main_block).writes[0].buffer else 0) |
| buffer_idx = 0 if matrix_name == "A" else 1 |
| source_buffer = sch.get(main_block).reads[buffer_offset + buffer_idx].buffer |
|
|
| |
| |
| |
| |
| propagate_block: tir.Block = find_last_producer_from_buffer( |
| sch, main_block, source_buffer) |
| |
| (weight_dequantize_info,) = dequantize_info.values() |
| if (sch.get(propagate_block).name_hint == weight_dequantize_info["decode_block"]): |
| return |
|
|
| |
| intra_indexmap, _ = get_propagate_map( |
| trans=trans, dtype=intrin_info.in_dtype, matrix_name=matrix_name) |
|
|
| |
|
|
| intra_indexmap = layout_propagate_chain( |
| sch, |
| start_block=main_block, |
| start_buffer=source_buffer, |
| end_block=propagate_block, |
| index_map=intra_indexmap, |
| ) |
|
|
| def inverse_permutation(i, j, ii, jj): |
| return (i, j, *intra_indexmap.map_indices([ii, jj])) |
|
|
| sch.transform_layout(propagate_block, ("read", 0), inverse_permutation) |
|
|
| smooth_gmem_layout_rewrite( |
| sch, main_block, intrin_info.smooth_a, intrin_info.trans_a, matrix_name="A") |
|
|
| smooth_gmem_layout_rewrite( |
| sch, main_block, intrin_info.smooth_b, intrin_info.trans_b, matrix_name="B") |
|
|
| warp_size = 32 |
|
|
| i_factors, j_factors, k_factors = ( |
| [None, 1, block_row_warps, warp_row_tiles // micro_size_x], |
| [1, None, block_col_warps, warp_col_tiles // micro_size_y], |
| [None, chunk // micro_size_k], |
| ) |
|
|
| num_ty = i_factors[2] |
| num_tz = j_factors[2] |
| x_pad_factor = i_factors[2] * i_factors[3] |
| y_pad_factor = j_factors[2] * j_factors[3] |
| k_pad_factor = k_factors[1] |
|
|
| |
| if not (func.attrs is not None and "dlight.tensorcore_prenormlized" in func.attrs.keys()): |
| sch = normalize_to_matmul(sch, main_block, ["a", "a", "a"]) |
|
|
| |
| sch.pad_einsum( |
| main_block, |
| [ |
| 1, |
| micro_size_x * x_pad_factor, |
| micro_size_y * y_pad_factor, |
| micro_size_k * k_pad_factor, |
| ], |
| ) |
|
|
| |
| block = main_block |
|
|
| batch, i, j, k = sch.get_loops(block) |
|
|
| |
| i, i_inner = sch.split(i, factors=[None, micro_size_x]) |
| j, j_inner = sch.split(j, factors=[None, micro_size_y]) |
| k, k_inner = sch.split(k, factors=[None, micro_size_k]) |
|
|
| sch.reorder(i, j, k, i_inner, j_inner, k_inner) |
|
|
| block_inner = block |
| block_outer = sch.blockize(i_inner) |
|
|
| i0, i1, i2, i3 = sch.split(i, factors=i_factors) |
| j0, j1, j2, j3 = sch.split(j, factors=j_factors) |
| k0, k1 = sch.split(k, k_factors) |
|
|
| sch.reorder(i0, j0, i1, j1, i2, j2, k0, k1, i3, j3) |
|
|
| block_idy = sch.fuse(i0, j0) |
| block_idx = sch.fuse(i1, j1) |
| thread_idy = i2 |
| thread_idz = j2 |
|
|
| sch.bind(batch, "blockIdx.z") |
| sch.bind(block_idx, "blockIdx.x") |
| sch.bind(block_idy, "blockIdx.y") |
| sch.bind(thread_idy, "threadIdx.y") |
| sch.bind(thread_idz, "threadIdx.z") |
|
|
| |
| |
| |
| |
| enable_store_rewrite = not intrin_info.is_input_8bit() |
|
|
| def smooth_layout_recover(block, scope, l=16, r=16, enable=True): |
| if not enable: |
| return |
| sch.transform_layout( |
| block, |
| scope, |
| lambda b, i, j: ( |
| b, |
| i // l, |
| j // r, |
| i % l, |
| j % r, |
| ), |
| ) |
|
|
| smooth_layout_recover(block_outer, ("read", 0), *a_lr, enable=intrin_info.inter_transform_a) |
| smooth_layout_recover( |
| block_outer, |
| ("read", 1), |
| *b_lr, |
| enable=intrin_info.inter_transform_b, |
| ) |
| smooth_layout_recover(block_outer, ("write", 0), enable=enable_store_rewrite) |
|
|
| def fetch_to_shared(block, idx, vec_len, can_swizzle=False, is_smooth=False): |
| block_read = sch.cache_read(block, idx, shared_scope) |
| sch.compute_at(block_read, k0, preserve_unit_loops=True) |
| ndim = len(sch.get(block_read).iter_vars) |
| fused = sch.fuse(*sch.get_loops(block_read)[-ndim:]) |
|
|
| f_0, f_1, f_2, f_3, f_4 = sch.split( |
| fused, factors=[None, num_ty, num_tz, warp_size, vec_len]) |
|
|
| sch.bind(f_3, "threadIdx.x") |
| sch.bind(f_2, "threadIdx.z") |
| sch.bind(f_1, "threadIdx.y") |
| sch.vectorize(f_4) |
| sch.unroll(f_0) |
| sch.annotate(f_0, "pragma_unroll_explicit", False) |
|
|
| |
| sch.annotate(block_read, ann_key="permuted_layout", ann_val=can_swizzle) |
| |
| if not (can_swizzle or is_smooth): |
| pad_offset = 8 if intrin_info.in_dtype == "float16" else 16 |
| sch.storage_align(block_read, 0, axis=-2, factor=16, offset=pad_offset) |
| return block_read |
|
|
| a_g2s = fetch_to_shared( |
| block_outer, |
| 0, |
| vec_len=list(config.vectorize.values())[0], |
| can_swizzle=can_swizzle_a, |
| is_smooth=intrin_info.smooth_a, |
| ) |
|
|
| auto_inline_producers(sch, a_g2s) |
|
|
| def decode_fetch_to_shared(block, idx): |
| |
| |
| block_shared = sch.cache_read(block, idx, shared_scope) |
| sch.compute_at(block_shared, k0, preserve_unit_loops=True) |
|
|
| decode_factor = get_coalesced_veclen(sch.get(block_shared)) |
| _, B_shared_vi, _ = sch.split( |
| sch.get_loops(block_shared)[-1], factors=[None, 1, decode_factor]) |
| block_shared_local = sch.cache_read(block_shared, 0, "local") |
| |
| |
| weight_dequantize_block = sch.get_block(weight_decode_info["decode_block"]) |
| weight_producers = _collect_producers(sch, weight_dequantize_block) |
| auto_inline_producers(sch, block_shared_local, weight_producers) |
|
|
| |
| def get_idx(): |
| |
| |
| |
| if weight_decode_info["source_format"]["format"] == "nf": |
| return 1 |
| return 0 |
|
|
| b_idx = get_idx() |
| |
| block_shared_local_local = sch.cache_read(block_shared_local, b_idx, "local") |
|
|
| sch.compute_at(block_shared_local, B_shared_vi, preserve_unit_loops=True) |
| sch.compute_at(block_shared_local_local, B_shared_vi, preserve_unit_loops=True) |
|
|
| dequantize_block_local = block_shared_local |
| if ("with_scaling" in weight_decode_info and weight_decode_info["with_scaling"]): |
| block_local_scales = sch.cache_read(dequantize_block_local, b_idx + 1, "local") |
| sch.compute_at(block_local_scales, B_shared_vi, preserve_unit_loops=True) |
| |
| auto_inline_producers(sch, block_local_scales) |
|
|
| if ("with_zeros" in weight_decode_info and weight_decode_info["with_zeros"]): |
| block_local_zeros = sch.cache_read(dequantize_block_local, b_idx + 2, "local") |
| sch.compute_at(block_local_zeros, B_shared_vi, preserve_unit_loops=True) |
| auto_inline_producers(sch, block_local_zeros) |
|
|
| for producer in weight_producers: |
| with suppress(Exception): |
| auto_inline_producers(sch, producer) |
| sch.compute_inline(producer) |
|
|
| |
| if ("fast_decoding" in weight_decode_info and weight_decode_info["fast_decoding"]): |
| source_bit = weight_decode_info["source_format"]["bits"] |
| out_dtype = weight_decode_info["target_format"] |
| lop3_intrin_info = get_lop3_intrin_group( |
| out_dtype=out_dtype, |
| storage_dtype=weight_decode_info["storage_dtype"], |
| source_format=weight_decode_info["source_format"]["format"], |
| source_bit=source_bit, |
| with_scaling=weight_decode_info["with_scaling"], |
| with_zeros=weight_decode_info["with_zeros"], |
| zeros_mode=weight_decode_info["zeros_mode"], |
| ) |
| sch.tensorize( |
| sch.get_loops(dequantize_block_local)[-1], |
| lop3_intrin_info["compute"], |
| ) |
| import_source.append(lop3_intrin_info["c_source"]) |
|
|
| sch.annotate(block_shared, ann_key="permuted_layout", ann_val=can_swizzle_b) |
| union_len = (2 + 4) if intrin_info.smooth_b else (2 + 2) |
| B_shared_fused = sch.fuse(*sch.get_loops(block_shared)[-union_len:-2]) |
| _, B_shared_ty, B_shared_tz, B_shared_tx = sch.split( |
| B_shared_fused, factors=[None, num_ty, num_tz, warp_size]) |
| if not (can_swizzle_b or intrin_info.smooth_b): |
| pad_offset = 8 if intrin_info.in_dtype == "float16" else 16 |
| sch.storage_align(block_shared, 0, axis=-2, factor=16, offset=pad_offset) |
| sch.bind(B_shared_tx, "threadIdx.x") |
| sch.bind(B_shared_ty, "threadIdx.y") |
| sch.bind(B_shared_tz, "threadIdx.z") |
| sch.vectorize(sch.get_loops(block_shared)[-1]) |
| sch.vectorize(sch.get_loops(block_shared_local_local)[-1]) |
|
|
| |
| if b_idx: |
| block_shared_lut = sch.cache_read(dequantize_block_local, 0, shared_scope) |
| sch.reverse_compute_at(block_shared_lut, j2) |
| _, B_shared_tx = sch.split( |
| sch.get_loops(block_shared_lut)[-1], factors=[None, warp_size]) |
| sch.bind(B_shared_tx, "threadIdx.x") |
| return block_shared_local |
|
|
| _ = decode_fetch_to_shared(block_outer, 1) |
|
|
| |
| A_mat = sch.cache_read(block_outer, 0, "warp") |
| B_mat = sch.cache_read(block_outer, 1, "warp") |
| sch.compute_at(A_mat, k1, preserve_unit_loops=True) |
| sch.compute_at(B_mat, k1, preserve_unit_loops=True) |
|
|
| |
| if cache_write_required: |
| accumulator_shared_to_global = sch.cache_write(block_outer, 0, shared_scope) |
|
|
| store = sch.cache_write(block_outer, 0, "warp") |
| sch.reverse_compute_at(store, j2) |
|
|
| |
| i, j = sch.get_loops(store)[-2:] |
| i0, i1 = sch.split(i, factors=[None, micro_size_x]) |
| j0, j1 = sch.split(j, factors=[None, micro_size_y]) |
| sch.reorder(i0, j0, i1, j1) |
|
|
| if cache_write_required: |
| auto_inline_consumer_chain(sch, accumulator_shared_to_global) |
| sch.reverse_compute_at( |
| accumulator_shared_to_global, |
| sch.get_loops(store)[-6], |
| preserve_unit_loops=True, |
| ) |
| vec_len = get_coalesced_veclen(sch.get(accumulator_shared_to_global)) |
| fuse_iters = 5 if enable_store_rewrite else 3 |
| fused = sch.fuse(*sch.get_loops(accumulator_shared_to_global)[-fuse_iters:]) |
| f0, f1, f2 = sch.split(fused, factors=[None, warp_size, vec_len]) |
| sch.bind(f1, "threadIdx.x") |
| sch.vectorize(f2) |
| sch.unroll(f0) |
| sch.annotate(f0, "pragma_unroll_explicit", False) |
| else: |
| auto_inline_consumer_chain(sch, store) |
|
|
| block_init_c = sch.decompose_reduction(block_outer, k0) |
| block_init_c_inner = sch.get_child_blocks(block_init_c)[0] |
|
|
| |
|
|
| index_map_a, index_map_b, index_map_c = intrin_group["index_map"] |
|
|
| sch.transform_layout( |
| A_mat, |
| ("write", 0), |
| get_index_map(index_map_a, *a_lr, intrin_info.inter_transform_a), |
| ) |
| sch.transform_layout( |
| B_mat, |
| ("write", 0), |
| get_index_map(index_map_b, *b_lr, intrin_info.inter_transform_b), |
| ) |
| sch.transform_layout( |
| store, |
| ("read", 0), |
| get_index_map(index_map_c, is_5d=enable_store_rewrite), |
| ) |
|
|
| i, j = sch.get_loops(A_mat)[-2:] |
| i0, i1 = sch.split(i, factors=[None, a_lr[0]]) |
| j0, j1 = sch.split(j, factors=[None, a_lr[1]]) |
| sch.reorder(i0, j0, i1, j1) |
| ba = sch.blockize(i1) |
| sch.annotate(ba, ann_key="permuted_layout", ann_val=can_swizzle_a) |
| sch.tensorize(ba, intrin_group["load_a"]) |
|
|
| i, j = sch.get_loops(B_mat)[-2:] |
| i0, i1 = sch.split(i, factors=[None, b_lr[0]]) |
| j0, j1 = sch.split(j, factors=[None, b_lr[1]]) |
| sch.reorder(i0, j0, i1, j1) |
| bb = sch.blockize(i1) |
| sch.annotate(bb, ann_key="permuted_layout", ann_val=can_swizzle_b) |
| sch.tensorize(bb, intrin_group["load_b"]) |
|
|
| def tensorize_init_store_compute(): |
| sch.tensorize(sch.get_loops(block_init_c_inner)[-2], intrin_group["init"]) |
| sch.tensorize(sch.get_loops(store)[-2], intrin_group["store"]) |
| sch.tensorize(sch.get_loops(block_inner)[-3], intrin_group["compute"]) |
|
|
| tensorize_init_store_compute() |
|
|
| if stage > 1: |
| sch.annotate( |
| k0, |
| ann_key="software_pipeline_stage", |
| ann_val=[0, 0, stage - 1], |
| ) |
| sch.annotate(k0, ann_key="software_pipeline_order", ann_val=[0, 1, 2]) |
| if use_async: |
| sch.annotate(k0, "software_pipeline_async_stages", [0]) |
| |
| if not isinstance(config.rasterization_plan, NoRasterization): |
| device_func, invoke_func = config.rasterization_plan.get_code() |
| import_source.append(device_func) |
| sch.annotate( |
| sch.get_loops(block_init_c)[-2], |
| ann_key="inject_customized_code_prepend", |
| ann_val=invoke_func, |
| ) |
| |
| if len(import_source) > 0: |
| sch.annotate( |
| thread_idz, |
| ann_key="pragma_import_c", |
| ann_val=("\n").join(import_source), |
| ) |
| return sch |
|
|
| def sch_shared_memory_prefetch_with_config( |
| self, |
| func: tir.PrimFunc, |
| config: Hint, |
| ): |
| """ |
| For A100 Like devices, the shared memory prefetch(async) is required |
| to achieve optimal performance. |
| quantized weight |
| | |
| V |
| shared memory prefetch (with async copy) |
| | |
| V |
| dequantized into shared memory |
| | |
| V |
| compute |
| """ |
|
|
| weight_transform_kind = config.intrin_info.weight_transform_kind |
| if weight_transform_kind == TransformKind.LDMatrixTransform: |
| return self.sch_warp_memory_prefetch_with_config(func, config) |
|
|
| is_cross_thread_reduce = ( |
| hasattr(config, "block_reduction_depth") and config.block_reduction_depth is not None) |
| block_reduction_depth = config.block_reduction_depth if is_cross_thread_reduce else 1 |
|
|
| from tvm.tir.tensor_intrin.cuda import ( |
| get_mma_intrin_group,) |
| from .intrin import get_lop3_intrin_group |
|
|
| import_source: List[str] = [] |
|
|
| sch = tir.Schedule(func) |
| root_block = analysis.get_root_block(sch) |
| blocks = sch.get_child_blocks(root_block) |
|
|
| if func.attrs is not None and "dlight.do_not_tensorize" in func.attrs.keys(): |
| return None |
|
|
| reduction_blocks = get_reduction_blocks(sch, blocks) |
| if reduction_blocks is None: |
| return None |
|
|
| main_block = reduction_blocks[0] |
| |
| cache_write_required = True |
|
|
| |
| |
| dequantize_info = func.attrs["dequantize_info"] |
|
|
| def check_dequantize_info(dequantize_info): |
| conditions = [] |
| |
| conditions.append(len(dequantize_info) == 1) |
| |
| return all(conditions) |
|
|
| assert check_dequantize_info(dequantize_info) |
|
|
| (weight_decode_info,) = list(dequantize_info.values()) |
|
|
| assert check_weight_decode_info(weight_decode_info), "Invalid B_decode_info" |
|
|
| |
| |
| |
| shared_scope = config.shared_scope |
|
|
| intrin_info = config.intrin_info |
| intrin_group = get_mma_intrin_group( |
| load_scope=shared_scope, |
| store_scope=shared_scope if cache_write_required else "global", |
| a_dtype=intrin_info.in_dtype, |
| b_dtype=intrin_info.in_dtype, |
| out_dtype=intrin_info.out_dtype, |
| trans_a=intrin_info.trans_a, |
| trans_b=intrin_info.trans_b, |
| smooth_a=intrin_info.smooth_a, |
| smooth_b=intrin_info.smooth_b, |
| not_use_mma_store_intrinic=False, |
| ) |
|
|
| warp_row_tiles = config.warp[0] |
| warp_col_tiles = config.warp[1] |
| block_row_warps = config.block[0] // warp_row_tiles |
| block_col_warps = config.block[1] // warp_col_tiles |
| stage = config.pipeline_stage |
| use_async = config.use_async |
| reduce_k = block_reduction_depth |
| chunk = config.rstep[0] |
|
|
| micro_size_x, micro_size_y, micro_size_k = intrin_group["micro_kernel"] |
|
|
| |
| def get_axis(l, r, trans): |
| return (r, l) if trans else (l, r) |
|
|
| a_lr = get_axis(micro_size_x, micro_size_k, intrin_info.trans_a) |
| b_lr = get_axis(micro_size_k, micro_size_y, intrin_info.trans_b) |
|
|
| def can_enable_swizzle(dtype: str, smooth: bool): |
| |
| if dtype == "float16" or dtype == "int8": |
| if (chunk * reduce_k) * DataType(dtype).bits != 512: |
| |
| return False |
| |
| return not smooth |
| return False |
|
|
| can_swizzle_a = can_enable_swizzle(intrin_info.in_dtype, intrin_info.inter_transform_a) |
| can_swizzle_b = can_enable_swizzle(intrin_info.in_dtype, intrin_info.inter_transform_b) |
|
|
| |
| def smooth_gmem_layout_rewrite( |
| sch, |
| main_block, |
| enable=True, |
| trans=False, |
| matrix_name="A", |
| intrin_group=intrin_group, |
| ): |
| if not enable: |
| return |
|
|
| |
| buffer_offset = (1 if sch.get(main_block).reads[0].buffer |
| == sch.get(main_block).writes[0].buffer else 0) |
| buffer_idx = 0 if matrix_name == "A" else 1 |
| source_buffer = sch.get(main_block).reads[buffer_offset + buffer_idx].buffer |
|
|
| |
| |
| |
| |
| propagate_block: tir.Block = find_last_producer_from_buffer( |
| sch, main_block, source_buffer) |
| |
| (weight_dequantize_info,) = dequantize_info.values() |
| if (sch.get(propagate_block).name_hint == weight_dequantize_info["decode_block"]): |
| return |
| |
| intra_indexmap, _ = get_propagate_map( |
| trans=trans, dtype=intrin_info.in_dtype, matrix_name=matrix_name) |
|
|
| |
|
|
| intra_indexmap = layout_propagate_chain( |
| sch, |
| start_block=main_block, |
| start_buffer=source_buffer, |
| end_block=propagate_block, |
| index_map=intra_indexmap, |
| ) |
|
|
| def inverse_permutation(i, j, ii, jj): |
| return (i, j, *intra_indexmap.map_indices([ii, jj])) |
|
|
| sch.transform_layout(propagate_block, ("read", 0), inverse_permutation) |
|
|
| intra_index_map, _ = get_propagate_map( |
| trans=trans, dtype=intrin_info.in_dtype, matrix_name=matrix_name) |
|
|
| |
| def get_offset(): |
| |
| |
| |
| if weight_dequantize_info["source_format"]["format"] == "nf": |
| return 1 |
| return 0 |
|
|
| offset = get_offset() |
| dequantize_block = sch.get_block(weight_dequantize_info["decode_block"]) |
| group_size = weight_dequantize_info["group_size"] |
|
|
| _, mn, mk = intrin_group["micro_kernel"] |
|
|
| def get_param_indices( |
| indexmap, |
| l=mn, |
| r=mk, |
| group_size=group_size |
| ): |
| |
| rl, rr = [x.var for x in sch.get(dequantize_block).iter_vars] |
| warp_i, warp_j = rl % l, rr % r |
| spatial_i, spatial_j = rl // l, rr // r |
| warp_i, warp_j = indexmap.map_indices([warp_i, warp_j]) |
| new_indices = ( |
| spatial_i * l + warp_i, |
| (spatial_j * r + warp_j) // group_size, |
| ) |
| return new_indices |
|
|
| with_scaling = bool(weight_dequantize_info["with_scaling"]) |
| if with_scaling: |
| sch.unsafe_rewrite_buffer_region( |
| dequantize_block, |
| ("read", offset + 1), |
| get_param_indices(intra_index_map), |
| ) |
| with_zeros = bool(weight_dequantize_info["with_zeros"]) |
| if with_zeros: |
| sch.unsafe_rewrite_buffer_region( |
| dequantize_block, |
| ("read", offset + 2), |
| get_param_indices(intra_index_map), |
| ) |
|
|
| smooth_gmem_layout_rewrite( |
| sch, main_block, intrin_info.smooth_a, intrin_info.trans_a, matrix_name="A") |
|
|
| smooth_gmem_layout_rewrite( |
| sch, main_block, intrin_info.smooth_b, intrin_info.trans_b, matrix_name="B") |
|
|
| warp_size = 32 |
|
|
| i_factors, j_factors, k_factors = ( |
| [None, 1, block_row_warps, warp_row_tiles // micro_size_x], |
| [1, None, block_col_warps, warp_col_tiles // micro_size_y], |
| [None, chunk // micro_size_k], |
| ) |
|
|
| num_ty = i_factors[2] |
| num_tz = j_factors[2] |
| x_pad_factor = i_factors[2] * i_factors[3] |
| y_pad_factor = j_factors[2] * j_factors[3] |
| k_pad_factor = k_factors[1] |
|
|
| |
| if not (func.attrs is not None and "dlight.tensorcore_prenormlized" in func.attrs.keys()): |
| sch = normalize_to_matmul(sch, main_block, ["a", "a", "a"]) |
|
|
| |
| sch.pad_einsum( |
| main_block, |
| [ |
| 1, |
| micro_size_x * x_pad_factor, |
| micro_size_y * y_pad_factor, |
| micro_size_k * k_pad_factor, |
| ], |
| ) |
|
|
| |
| block = main_block |
|
|
| batch, i, j, k = sch.get_loops(block) |
|
|
| |
| i, i_inner = sch.split(i, factors=[None, micro_size_x]) |
| j, j_inner = sch.split(j, factors=[None, micro_size_y]) |
| k, k_inner = sch.split(k, factors=[None, micro_size_k]) |
|
|
| sch.reorder(i, j, k, i_inner, j_inner, k_inner) |
|
|
| block_inner = block |
| block_outer = sch.blockize(i_inner) |
|
|
| i0, i1, i2, i3 = sch.split(i, factors=i_factors) |
| j0, j1, j2, j3 = sch.split(j, factors=j_factors) |
| k0, k1 = sch.split(k, k_factors) |
| if reduce_k > 1: |
| k0, kr = sch.split(k0, [None, reduce_k]) |
|
|
| sch.reorder(i0, j0, i1, j1, i2, j2, kr, k0, k1, i3, j3) |
| else: |
| sch.reorder(i0, j0, i1, j1, i2, j2, k0, k1, i3, j3) |
|
|
| block_idy = sch.fuse(i0, j0) |
| block_idx = sch.fuse(i1, j1) |
| thread_idy = i2 |
| thread_idz = j2 |
|
|
| sch.bind(batch, "blockIdx.z") |
| sch.bind(block_idx, "blockIdx.x") |
| sch.bind(block_idy, "blockIdx.y") |
| if reduce_k > 1: |
| thread_idz = j2 = thread_idy = sch.fuse(thread_idy, thread_idz) |
| sch.bind(thread_idy, "threadIdx.y") |
| sch.bind(kr, "threadIdx.z") |
| else: |
| sch.bind(thread_idy, "threadIdx.y") |
| sch.bind(thread_idz, "threadIdx.z") |
|
|
| |
| |
| |
| |
| enable_store_rewrite = not intrin_info.is_input_8bit() |
|
|
| def smooth_layout_recover(block, scope, l=16, r=16, enable=True): |
| if not enable: |
| return |
| sch.transform_layout( |
| block, |
| scope, |
| lambda b, i, j: ( |
| b, |
| i // l, |
| j // r, |
| i % l, |
| j % r, |
| ), |
| ) |
|
|
| smooth_layout_recover(block_outer, ("read", 0), *a_lr, enable=intrin_info.inter_transform_a) |
| smooth_layout_recover( |
| block_outer, |
| ("read", 1), |
| *b_lr, |
| enable=intrin_info.inter_transform_b, |
| ) |
| smooth_layout_recover(block_outer, ("write", 0), enable=enable_store_rewrite) |
|
|
| def fetch_to_shared(block, idx, vec_len, can_swizzle=False, is_smooth=False): |
| block_read = sch.cache_read(block, idx, shared_scope) |
| sch.compute_at(block_read, k0, preserve_unit_loops=True) |
| ndim = len(sch.get(block_read).iter_vars) |
| fused = sch.fuse(*sch.get_loops(block_read)[-ndim:]) |
|
|
| if reduce_k > 1: |
| f_r, f_0, f_1, f_2, f_3, f_4 = sch.split( |
| fused, factors=[reduce_k, num_ty, num_tz, None, warp_size, vec_len]) |
| sch.bind(f_3, "threadIdx.x") |
| f_0 = f_1 = sch.fuse(f_0, f_1) |
| sch.bind(f_0, "threadIdx.y") |
| sch.bind(f_r, "threadIdx.z") |
| else: |
| f_0, f_1, f_2, f_3, f_4 = sch.split( |
| fused, factors=[num_ty, num_tz, None, warp_size, vec_len]) |
| sch.bind(f_3, "threadIdx.x") |
| sch.bind(f_1, "threadIdx.z") |
| sch.bind(f_0, "threadIdx.y") |
|
|
| sch.vectorize(f_4) |
| sch.unroll(f_0) |
| sch.annotate(f_0, "pragma_unroll_explicit", False) |
|
|
| |
| sch.annotate(block_read, ann_key="permuted_layout", ann_val=can_swizzle) |
| |
| if not (can_swizzle or is_smooth): |
| pad_offset = 8 if intrin_info.in_dtype == "float16" else 16 |
| sch.storage_align(block_read, 0, axis=-2, factor=16, offset=pad_offset) |
| return block_read |
|
|
| a_g2s = fetch_to_shared( |
| block_outer, |
| 0, |
| vec_len=list(config.vectorize.values())[0], |
| can_swizzle=can_swizzle_a, |
| is_smooth=intrin_info.smooth_a, |
| ) |
|
|
| auto_inline_producers(sch, a_g2s) |
|
|
| def decode_fetch_to_shared(block, idx): |
| |
| |
| block_shared = sch.cache_read(block, idx, shared_scope) |
| sch.compute_at(block_shared, k0, preserve_unit_loops=True) |
|
|
| |
| decode_factor = get_coalesced_veclen(sch.get(block_shared)) |
| _, B_shared_vi, _ = sch.split( |
| sch.get_loops(block_shared)[-1], factors=[None, 1, decode_factor]) |
| block_shared_local = sch.cache_read(block_shared, 0, "local") |
| |
| |
| is_qzeros = ("with_zeros" in weight_decode_info and weight_decode_info["with_zeros"] and |
| weight_decode_info["zeros_mode"] == "quantized") |
| weight_dequantize_block = sch.get_block(weight_decode_info["decode_block"]) |
| weight_producers = ( |
| _collect_producers(sch, weight_dequantize_block) if is_qzeros else []) |
| auto_inline_producers(sch, block_shared_local, weight_producers) |
|
|
| |
| def get_idx(): |
| |
| |
| |
| if weight_decode_info["source_format"]["format"] == "nf": |
| return 1 |
| return 0 |
|
|
| b_idx = get_idx() |
| |
| block_shared_local_local = sch.cache_read(block_shared_local, b_idx, "local") |
| |
| block_shared_local_local_shared = sch.cache_read(block_shared_local_local, 0, |
| shared_scope) |
| sch.compute_at(block_shared_local, B_shared_vi, preserve_unit_loops=True) |
| sch.compute_at(block_shared_local_local, B_shared_vi, preserve_unit_loops=True) |
|
|
| dequantize_block_local = block_shared_local |
|
|
| if ("with_scaling" in weight_decode_info and weight_decode_info["with_scaling"]): |
| block_local_scales = sch.cache_read(dequantize_block_local, b_idx + 1, "local") |
| sch.compute_at(block_local_scales, B_shared_vi, preserve_unit_loops=True) |
| auto_inline_producers(sch, block_local_scales) |
|
|
| if ("with_zeros" in weight_decode_info and weight_decode_info["with_zeros"]): |
| block_local_zeros = sch.cache_read(dequantize_block_local, b_idx + 2, "local") |
| sch.compute_at(block_local_zeros, B_shared_vi, preserve_unit_loops=True) |
| auto_inline_producers(sch, block_local_zeros) |
|
|
| for producer in weight_producers: |
| with suppress(Exception): |
| auto_inline_producers(sch, producer) |
| sch.compute_inline(producer) |
|
|
| |
| if ("fast_decoding" in weight_decode_info and weight_decode_info["fast_decoding"]): |
| source_bit = weight_decode_info["source_format"]["bits"] |
| out_dtype = weight_decode_info["target_format"] |
| lop3_intrin_info = get_lop3_intrin_group( |
| out_dtype=out_dtype, |
| storage_dtype=weight_decode_info["storage_dtype"], |
| source_format=weight_decode_info["source_format"]["format"], |
| source_bit=source_bit, |
| with_scaling=weight_decode_info["with_scaling"], |
| with_zeros=weight_decode_info["with_zeros"], |
| zeros_mode=weight_decode_info["zeros_mode"], |
| ) |
| sch.tensorize( |
| sch.get_loops(dequantize_block_local)[-1], |
| lop3_intrin_info["compute"], |
| ) |
| import_source.append(lop3_intrin_info["c_source"]) |
|
|
| sch.annotate(block_shared, ann_key="permuted_layout", ann_val=can_swizzle_b) |
| union_len = (2 + 4) if intrin_info.smooth_b else (2 + 2) |
| B_shared_fused = sch.fuse(*sch.get_loops(block_shared)[-union_len:-2]) |
| if reduce_k > 1: |
| _, B_shared_rk, B_shared_ty, B_shared_tz, B_shared_tx = sch.split( |
| B_shared_fused, factors=[None, reduce_k, num_ty, num_tz, warp_size]) |
| sch.bind(B_shared_tx, "threadIdx.x") |
| B_shared_tz = B_shared_ty = sch.fuse(B_shared_ty, B_shared_tz) |
| sch.bind(B_shared_ty, "threadIdx.y") |
| sch.bind(B_shared_rk, "threadIdx.z") |
| else: |
| _, B_shared_ty, B_shared_tz, B_shared_tx = sch.split( |
| B_shared_fused, factors=[None, num_ty, num_tz, warp_size]) |
| sch.bind(B_shared_tx, "threadIdx.x") |
| sch.bind(B_shared_ty, "threadIdx.y") |
| sch.bind(B_shared_tz, "threadIdx.z") |
|
|
| if not (can_swizzle_b or intrin_info.smooth_b): |
| pad_offset = 8 if intrin_info.in_dtype == "float16" else 16 |
| sch.storage_align(block_shared, 0, axis=-2, factor=16, offset=pad_offset) |
| sch.vectorize(sch.get_loops(block_shared)[-1]) |
| sch.vectorize(sch.get_loops(block_shared_local_local)[-1]) |
|
|
| sch.compute_at(block_shared_local_local_shared, k0, preserve_unit_loops=True) |
| ndim = len(sch.get(block_shared_local_local_shared).iter_vars) |
| _ = sch.fuse(*sch.get_loops(block_shared_local_local_shared)[-ndim:]) |
|
|
| if reduce_k > 1: |
| _bind_thread_based_with_block_reduce_on_config(sch, block_shared_local_local_shared, |
| num_ty, num_tz, warp_size, reduce_k) |
| else: |
| _bind_thread_based_on_config(sch, block_shared_local_local_shared, num_ty, num_tz, |
| warp_size) |
|
|
| |
| if b_idx: |
| block_shared_lut = sch.cache_read(dequantize_block_local, 0, shared_scope) |
| sch.reverse_compute_at(block_shared_lut, j2) |
| _, B_shared_tx = sch.split( |
| sch.get_loops(block_shared_lut)[-1], factors=[None, warp_size]) |
| sch.bind(B_shared_tx, "threadIdx.x") |
| return block_shared_local |
|
|
| _ = decode_fetch_to_shared(block_outer, 1) |
|
|
| |
| |
| if reduce_k > 1: |
| sch.bind(kr, "threadIdx.z") |
| |
| A_mat = sch.cache_read(block_outer, 0, "warp") |
| B_mat = sch.cache_read(block_outer, 1, "warp") |
| sch.compute_at(A_mat, k1, preserve_unit_loops=True) |
| sch.compute_at(B_mat, k1, preserve_unit_loops=True) |
|
|
| |
| if cache_write_required: |
| accumulator_shared_to_global = sch.cache_write(block_outer, 0, shared_scope) |
|
|
| store = sch.cache_write(block_outer, 0, "warp") |
| sch.reverse_compute_at(store, j2) |
|
|
| |
| i, j = sch.get_loops(store)[-2:] |
| i0, i1 = sch.split(i, factors=[None, micro_size_x]) |
| j0, j1 = sch.split(j, factors=[None, micro_size_y]) |
| sch.reorder(i0, j0, i1, j1) |
|
|
| if cache_write_required: |
| auto_inline_consumer_chain(sch, accumulator_shared_to_global) |
| sch.reverse_compute_at( |
| accumulator_shared_to_global, |
| sch.get_loops(store)[-6], |
| preserve_unit_loops=True, |
| ) |
| vec_len = get_coalesced_veclen(sch.get(accumulator_shared_to_global)) |
| fuse_iters = 5 if enable_store_rewrite else 3 |
| fused = sch.fuse(*sch.get_loops(accumulator_shared_to_global)[-fuse_iters:]) |
| f0, f1, f2 = sch.split(fused, factors=[None, warp_size, vec_len]) |
| sch.bind(f1, "threadIdx.x") |
| sch.vectorize(f2) |
| sch.unroll(f0) |
| sch.annotate(f0, "pragma_unroll_explicit", False) |
| else: |
| auto_inline_consumer_chain(sch, store) |
|
|
| block_init_c = sch.decompose_reduction(block_outer, k0) |
| block_init_c_inner = sch.get_child_blocks(block_init_c)[0] |
|
|
| |
|
|
| index_map_a, index_map_b, index_map_c = intrin_group["index_map"] |
|
|
| sch.transform_layout( |
| A_mat, |
| ("write", 0), |
| get_index_map(index_map_a, *a_lr, intrin_info.inter_transform_a), |
| ) |
| sch.transform_layout( |
| B_mat, |
| ("write", 0), |
| get_index_map(index_map_b, *b_lr, intrin_info.inter_transform_b), |
| ) |
| sch.transform_layout( |
| store, |
| ("read", 0), |
| get_index_map(index_map_c, is_5d=enable_store_rewrite), |
| ) |
|
|
| i, j = sch.get_loops(A_mat)[-2:] |
| i0, i1 = sch.split(i, factors=[None, a_lr[0]]) |
| j0, j1 = sch.split(j, factors=[None, a_lr[1]]) |
| sch.reorder(i0, j0, i1, j1) |
| ba = sch.blockize(i1) |
| sch.annotate(ba, ann_key="permuted_layout", ann_val=can_swizzle_a) |
| sch.tensorize(ba, intrin_group["load_a"]) |
|
|
| i, j = sch.get_loops(B_mat)[-2:] |
| i0, i1 = sch.split(i, factors=[None, b_lr[0]]) |
| j0, j1 = sch.split(j, factors=[None, b_lr[1]]) |
| sch.reorder(i0, j0, i1, j1) |
| bb = sch.blockize(i1) |
| sch.annotate(bb, ann_key="permuted_layout", ann_val=can_swizzle_b) |
| sch.tensorize(bb, intrin_group["load_b"]) |
|
|
| def tensorize_init_store_compute(): |
| sch.tensorize(sch.get_loops(block_init_c_inner)[-2], intrin_group["init"]) |
| sch.tensorize(sch.get_loops(store)[-2], intrin_group["store"]) |
| sch.tensorize(sch.get_loops(block_inner)[-3], intrin_group["compute"]) |
|
|
| tensorize_init_store_compute() |
|
|
| if stage > 1: |
| sch.annotate( |
| k0, |
| ann_key="software_pipeline_stage", |
| ann_val=[0, 0, stage - 1, stage - 1], |
| ) |
| sch.annotate(k0, ann_key="software_pipeline_order", ann_val=[0, 1, 2, 3]) |
| if use_async: |
| sch.annotate(k0, "software_pipeline_async_stages", [0]) |
| |
| if not isinstance(config.rasterization_plan, NoRasterization): |
| device_func, invoke_func = config.rasterization_plan.get_code() |
| import_source.append(device_func) |
| sch.annotate( |
| sch.get_loops(block_init_c)[-2], |
| ann_key="inject_customized_code_prepend", |
| ann_val=invoke_func, |
| ) |
| |
| if len(import_source) > 0: |
| sch.annotate( |
| j2, |
| ann_key="pragma_import_c", |
| ann_val=("\n").join(import_source), |
| ) |
| return sch |
|
|
| def sch_warp_memory_prefetch_with_config( |
| self, |
| func: tir.PrimFunc, |
| config: Hint, |
| ): |
| """ |
| For A100 Like devices, the shared memory prefetch(async) is required |
| to achieve optimal performance. |
| quantized weight |
| | |
| V |
| shared memory prefetch (with async copy) |
| | |
| V |
| LDMatrix into warp memory |
| | |
| V |
| Dequantize |
| | |
| V |
| Compute |
| """ |
| from tvm.tir.tensor_intrin.cuda import ( |
| get_mma_intrin_group,) |
| from bitblas.gpu.intrin import get_lop3_intrin_group |
|
|
| import_source: List[str] = [] |
|
|
| sch = tir.Schedule(func) |
| root_block = analysis.get_root_block(sch) |
| blocks = sch.get_child_blocks(root_block) |
|
|
| if func.attrs is not None and "dlight.do_not_tensorize" in func.attrs.keys(): |
| return None |
|
|
| reduction_blocks = get_reduction_blocks(sch, blocks) |
| if reduction_blocks is None: |
| return None |
|
|
| main_block = reduction_blocks[0] |
| |
| cache_write_required = True |
|
|
| |
| |
| dequantize_info = func.attrs["dequantize_info"] |
|
|
| def check_dequantize_info(dequantize_info): |
| conditions = [] |
| |
| conditions.append(len(dequantize_info) == 1) |
| |
| return all(conditions) |
|
|
| assert check_dequantize_info(dequantize_info) |
|
|
| (weight_decode_info,) = list(dequantize_info.values()) |
|
|
| assert check_weight_decode_info(weight_decode_info), "Invalid B_decode_info" |
|
|
| |
| |
| |
| shared_scope = config.shared_scope |
|
|
| intrin_info = config.intrin_info |
| input_transform_kind = intrin_info.input_transform_kind |
| weight_transform_kind = intrin_info.weight_transform_kind |
| assert input_transform_kind <= TransformKind.IntraWarpTransform, "Invalid input_transform_kind" |
| assert weight_transform_kind == TransformKind.LDMatrixTransform, "Invalid weight_transform_kind" |
|
|
| intrin_group = get_mma_intrin_group( |
| load_scope=shared_scope, |
| store_scope=shared_scope if cache_write_required else "global", |
| a_dtype=intrin_info.in_dtype, |
| b_dtype=intrin_info.in_dtype, |
| out_dtype=intrin_info.out_dtype, |
| trans_a=intrin_info.trans_a, |
| trans_b=intrin_info.trans_b, |
| smooth_a=intrin_info.smooth_a, |
| smooth_b=intrin_info.smooth_b, |
| not_use_mma_store_intrinic=False, |
| ) |
|
|
| warp_row_tiles = config.warp[0] |
| warp_col_tiles = config.warp[1] |
| block_row_warps = config.block[0] // warp_row_tiles |
| block_col_warps = config.block[1] // warp_col_tiles |
| stage = config.pipeline_stage |
| use_async = config.use_async |
| reduce_k = config.block_reduction_depth if config.block_reduction_depth is not None else 1 |
| chunk = config.rstep[0] // reduce_k |
|
|
| micro_size_x, micro_size_y, micro_size_k = intrin_group["micro_kernel"] |
|
|
| |
| def get_axis(l, r, trans): |
| return (r, l) if trans else (l, r) |
|
|
| a_lr = get_axis(micro_size_x, micro_size_k, intrin_info.trans_a) |
| b_lr = get_axis(micro_size_k, micro_size_y, intrin_info.trans_b) |
|
|
| def can_enable_swizzle(dtype: str, smooth: bool): |
| |
| if dtype == "float16" or dtype == "int8": |
| if (chunk * reduce_k) * DataType(dtype).bits != 512: |
| |
| return False |
| |
| return not smooth |
| return False |
|
|
| can_swizzle_a = can_enable_swizzle(intrin_info.in_dtype, intrin_info.inter_transform_a) |
|
|
| |
| def smooth_gmem_layout_rewrite( |
| sch, |
| main_block, |
| enable=True, |
| trans=False, |
| matrix_name="A", |
| intrin_group=intrin_group, |
| ): |
| if not enable: |
| return |
|
|
| |
| buffer_offset = (1 if sch.get(main_block).reads[0].buffer |
| == sch.get(main_block).writes[0].buffer else 0) |
| buffer_idx = 0 if matrix_name == "A" else 1 |
| source_buffer = sch.get(main_block).reads[buffer_offset + buffer_idx].buffer |
|
|
| |
| |
| |
| |
| propagate_block: tir.Block = find_last_producer_from_buffer( |
| sch, main_block, source_buffer) |
| |
| (weight_dequantize_info,) = dequantize_info.values() |
| if (sch.get(propagate_block).name_hint == weight_dequantize_info["decode_block"]): |
| return |
| |
| intra_indexmap, _ = get_propagate_map( |
| trans=trans, dtype=intrin_info.in_dtype, matrix_name=matrix_name) |
|
|
| |
|
|
| intra_indexmap = layout_propagate_chain( |
| sch, |
| start_block=main_block, |
| start_buffer=source_buffer, |
| end_block=propagate_block, |
| index_map=intra_indexmap, |
| ) |
|
|
| def inverse_permutation(i, j, ii, jj): |
| return (i, j, *intra_indexmap.map_indices([ii, jj])) |
|
|
| sch.transform_layout(propagate_block, ("read", 0), inverse_permutation) |
|
|
| intra_index_map, _ = get_propagate_map( |
| trans=trans, dtype=intrin_info.in_dtype, matrix_name=matrix_name) |
| ladder_stage3_index_map, ladder_stage3_inverse_index_map = get_ladder_stage3_map( |
| dtype=intrin_info.in_dtype) |
|
|
| |
| def get_offset(): |
| |
| |
| |
| if weight_dequantize_info["source_format"]["format"] == "nf": |
| return 1 |
| return 0 |
|
|
| offset = get_offset() |
| dequantize_block = sch.get_block(weight_dequantize_info["decode_block"]) |
| group_size = weight_dequantize_info["group_size"] |
|
|
| _, mn, mk = intrin_group["micro_kernel"] |
|
|
| def get_param_indices( |
| l=mn, |
| r=mk, |
| transform_kind=intrin_info.weight_transform_kind |
| ): |
| |
| rl, rr = [x.var for x in sch.get(dequantize_block).iter_vars] |
| warp_i, warp_j = rl % l, rr % r |
| spatial_i, spatial_j = rl // l, rr // r |
| |
| if transform_kind > 2: |
| warp_i, warp_j = ladder_stage3_inverse_index_map.map_indices([warp_i, warp_j]) |
| warp_i, warp_j = intra_index_map.map_indices([warp_i, warp_j]) |
| new_indices = ( |
| spatial_i * l + warp_i, |
| (spatial_j * r + warp_j) // group_size, |
| ) |
| return new_indices |
|
|
| with_scaling = bool(weight_dequantize_info["with_scaling"]) |
| if with_scaling: |
| sch.unsafe_rewrite_buffer_region( |
| dequantize_block, |
| ("read", offset + 1), |
| get_param_indices(transform_kind=intrin_info.weight_transform_kind), |
| ) |
| with_zeros = bool(weight_dequantize_info["with_zeros"]) |
| if with_zeros: |
| sch.unsafe_rewrite_buffer_region( |
| dequantize_block, |
| ("read", offset + 2), |
| get_param_indices(transform_kind=intrin_info.weight_transform_kind), |
| ) |
|
|
| smooth_gmem_layout_rewrite( |
| sch, main_block, intrin_info.smooth_a, intrin_info.trans_a, matrix_name="A") |
|
|
| smooth_gmem_layout_rewrite( |
| sch, main_block, intrin_info.smooth_b, intrin_info.trans_b, matrix_name="B") |
|
|
| warp_size = 32 |
|
|
| i_factors, j_factors, k_factors = ( |
| [None, 1, block_row_warps, warp_row_tiles // micro_size_x], |
| [1, None, block_col_warps, warp_col_tiles // micro_size_y], |
| [None, chunk // micro_size_k], |
| ) |
|
|
| num_ty = i_factors[2] |
| num_tz = j_factors[2] |
| x_pad_factor = i_factors[2] * i_factors[3] |
| y_pad_factor = j_factors[2] * j_factors[3] |
| k_pad_factor = k_factors[1] |
|
|
| |
| if not (func.attrs is not None and "dlight.tensorcore_prenormlized" in func.attrs.keys()): |
| sch = normalize_to_matmul(sch, main_block, ["a", "a", "a"]) |
|
|
| |
| sch.pad_einsum( |
| main_block, |
| [ |
| 1, |
| micro_size_x * x_pad_factor, |
| micro_size_y * y_pad_factor, |
| micro_size_k * k_pad_factor, |
| ], |
| ) |
|
|
| |
| block = main_block |
|
|
| batch, i, j, k = sch.get_loops(block) |
|
|
| |
| i, i_inner = sch.split(i, factors=[None, micro_size_x]) |
| j, j_inner = sch.split(j, factors=[None, micro_size_y]) |
| k, k_inner = sch.split(k, factors=[None, micro_size_k]) |
|
|
| sch.reorder(i, j, k, i_inner, j_inner, k_inner) |
|
|
| block_inner = block |
| block_outer = sch.blockize(i_inner) |
|
|
| i0, i1, i2, i3 = sch.split(i, factors=i_factors) |
| j0, j1, j2, j3 = sch.split(j, factors=j_factors) |
| k0, k1 = sch.split(k, k_factors) |
| if reduce_k > 1: |
| k0, kr = sch.split(k0, [None, reduce_k]) |
|
|
| sch.reorder(i0, j0, i1, j1, i2, j2, kr, k0, k1, i3, j3) |
| else: |
| sch.reorder(i0, j0, i1, j1, i2, j2, k0, k1, i3, j3) |
|
|
| block_idy = sch.fuse(i0, j0) |
| block_idx = sch.fuse(i1, j1) |
| thread_idy = i2 |
| thread_idz = j2 |
|
|
| sch.bind(batch, "blockIdx.z") |
| sch.bind(block_idx, "blockIdx.x") |
| sch.bind(block_idy, "blockIdx.y") |
| if reduce_k > 1: |
| thread_idz = j2 = thread_idy = sch.fuse(thread_idy, thread_idz) |
| sch.bind(thread_idy, "threadIdx.y") |
| sch.bind(kr, "threadIdx.z") |
| else: |
| sch.bind(thread_idy, "threadIdx.y") |
| sch.bind(thread_idz, "threadIdx.z") |
|
|
| |
| |
| |
| |
| enable_store_rewrite = not intrin_info.is_input_8bit() |
|
|
| def smooth_layout_recover(block, scope, l=16, r=16, enable=True): |
| if not enable: |
| return |
| sch.transform_layout( |
| block, |
| scope, |
| lambda b, i, j: ( |
| b, |
| i // l, |
| j // r, |
| i % l, |
| j % r, |
| ), |
| ) |
|
|
| smooth_layout_recover(block_outer, ("read", 0), *a_lr, enable=intrin_info.inter_transform_a) |
| smooth_layout_recover( |
| block_outer, |
| ("read", 1), |
| *b_lr, |
| enable=intrin_info.inter_transform_b, |
| ) |
| smooth_layout_recover(block_outer, ("write", 0), enable=enable_store_rewrite) |
|
|
| def fetch_to_shared(block, idx, vec_len, can_swizzle=False, is_smooth=False): |
| block_read = sch.cache_read(block, idx, shared_scope) |
| sch.compute_at(block_read, k0, preserve_unit_loops=True) |
| ndim = len(sch.get(block_read).iter_vars) |
| fused = sch.fuse(*sch.get_loops(block_read)[-ndim:]) |
|
|
| if reduce_k > 1: |
| f_r, f_0, f_1, f_2, f_3, f_4 = sch.split( |
| fused, factors=[reduce_k, num_ty, num_tz, None, warp_size, vec_len]) |
| sch.bind(f_3, "threadIdx.x") |
| f_0 = f_1 = sch.fuse(f_0, f_1) |
| sch.bind(f_0, "threadIdx.y") |
| sch.bind(f_r, "threadIdx.z") |
| else: |
| f_0, f_1, f_2, f_3, f_4 = sch.split( |
| fused, factors=[num_ty, num_tz, None, warp_size, vec_len]) |
| sch.bind(f_3, "threadIdx.x") |
| sch.bind(f_1, "threadIdx.z") |
| sch.bind(f_0, "threadIdx.y") |
|
|
| sch.vectorize(f_4) |
| sch.unroll(f_0) |
| sch.annotate(f_0, "pragma_unroll_explicit", False) |
|
|
| |
| sch.annotate(block_read, ann_key="permuted_layout", ann_val=can_swizzle) |
| |
| if not (can_swizzle or is_smooth): |
| pad_offset = 8 if intrin_info.in_dtype == "float16" else 16 |
| sch.storage_align(block_read, 0, axis=-2, factor=16, offset=pad_offset) |
| return block_read |
|
|
| a_g2s = fetch_to_shared( |
| block_outer, |
| 0, |
| vec_len=list(config.vectorize.values())[0], |
| can_swizzle=can_swizzle_a, |
| is_smooth=intrin_info.smooth_a, |
| ) |
|
|
| auto_inline_producers(sch, a_g2s) |
|
|
| |
| |
| if reduce_k > 1: |
| sch.bind(kr, "threadIdx.z") |
| |
| A_mat = sch.cache_read(block_outer, 0, "warp") |
| sch.compute_at(A_mat, k1, preserve_unit_loops=True) |
|
|
| |
| def warp_memory_dequantize(): |
| B_dequantized_mat = sch.cache_read(block_outer, 1, "warp") |
| is_qzeros = ("with_zeros" in weight_decode_info and weight_decode_info["with_zeros"] and |
| weight_decode_info["zeros_mode"] == "quantized") |
| weight_dequantize_block = sch.get_block(weight_decode_info["decode_block"]) |
| weight_producers = ( |
| _collect_producers(sch, weight_dequantize_block) if is_qzeros else []) |
| auto_inline_producers(sch, B_dequantized_mat, weight_producers) |
|
|
| sch.compute_at(B_dequantized_mat, k1, preserve_unit_loops=False) |
|
|
| |
| def get_idx(): |
| |
| |
| |
| if weight_decode_info["source_format"]["format"] == "nf": |
| return 1 |
| return 0 |
|
|
| b_idx = get_idx() |
| B_quant_mat = sch.cache_read(B_dequantized_mat, b_idx, "warp") |
| B_shared = sch.cache_read(B_quant_mat, 0, shared_scope) |
| sch.compute_at(B_quant_mat, k1, preserve_unit_loops=True) |
| ndim = len(sch.get(B_quant_mat).iter_vars) |
| f0 = sch.fuse(*sch.get_loops(B_quant_mat)[-ndim:]) |
| vec_len = 4 |
| |
| f1, f2 = sch.split( |
| f0, factors=[None, warp_size, vec_len], preserve_unit_iters=False)[-2:] |
| sch.bind(f1, "threadIdx.x") |
| sch.vectorize(f2) |
| sch.compute_at(B_shared, k0, preserve_unit_loops=True) |
| ndim = len(sch.get(B_shared).iter_vars) |
| _ = sch.fuse(*sch.get_loops(B_shared)[-ndim:]) |
|
|
| if reduce_k > 1: |
| _bind_thread_based_with_block_reduce_on_config( |
| sch, |
| B_shared, |
| num_ty, |
| num_tz, |
| warp_size, |
| reduce_k, |
| ) |
| else: |
| _bind_thread_based_on_config(sch, B_shared, num_ty, num_tz, warp_size) |
| return B_dequantized_mat |
|
|
| B_dequantized_mat = warp_memory_dequantize() |
| |
| |
| if reduce_k > 1: |
| sch.bind(kr, "threadIdx.z") |
| |
| if cache_write_required: |
| accumulator_shared_to_global = sch.cache_write(block_outer, 0, shared_scope) |
|
|
| store = sch.cache_write(block_outer, 0, "warp") |
| sch.reverse_compute_at(store, j2) |
|
|
| |
| i, j = sch.get_loops(store)[-2:] |
| i0, i1 = sch.split(i, factors=[None, micro_size_x]) |
| j0, j1 = sch.split(j, factors=[None, micro_size_y]) |
| sch.reorder(i0, j0, i1, j1) |
|
|
| if cache_write_required: |
| auto_inline_consumer_chain(sch, accumulator_shared_to_global) |
| sch.reverse_compute_at( |
| accumulator_shared_to_global, |
| sch.get_loops(store)[-6], |
| preserve_unit_loops=True, |
| ) |
| fuse_iters = 5 if enable_store_rewrite else 3 |
| fused = sch.fuse(*sch.get_loops(accumulator_shared_to_global)[-fuse_iters:]) |
| f0, f1, f2 = sch.split( |
| fused, |
| factors=[ |
| None, warp_size, |
| get_coalesced_veclen(sch.get(accumulator_shared_to_global)) |
| ]) |
| sch.bind(f1, "threadIdx.x") |
| sch.vectorize(f2) |
| sch.unroll(f0) |
| sch.annotate(f0, "pragma_unroll_explicit", False) |
| else: |
| auto_inline_consumer_chain(sch, store) |
|
|
| block_init_c = sch.decompose_reduction(block_outer, k0) |
| block_init_c_inner = sch.get_child_blocks(block_init_c)[0] |
|
|
| |
|
|
| index_map_a, index_map_b, index_map_c = intrin_group["index_map"] |
|
|
| sch.transform_layout( |
| A_mat, |
| ("write", 0), |
| get_index_map(index_map_a, *a_lr, intrin_info.inter_transform_a), |
| ) |
| sch.transform_layout( |
| B_dequantized_mat, |
| ("write", 0), |
| get_index_map(index_map_b, *b_lr, intrin_info.inter_transform_b), |
| ) |
| sch.transform_layout( |
| store, |
| ("read", 0), |
| get_index_map(index_map_c, is_5d=enable_store_rewrite), |
| ) |
|
|
| i, j = sch.get_loops(A_mat)[-2:] |
| i0, i1 = sch.split(i, factors=[None, a_lr[0]]) |
| j0, j1 = sch.split(j, factors=[None, a_lr[1]]) |
| sch.reorder(i0, j0, i1, j1) |
| ba = sch.blockize(i1) |
| sch.annotate(ba, ann_key="permuted_layout", ann_val=can_swizzle_a) |
| sch.tensorize(ba, intrin_group["load_a"]) |
|
|
| i, j = sch.get_loops(B_dequantized_mat)[-2:] |
| i0, i1 = sch.split(i, factors=[None, b_lr[0]]) |
| j0, j1 = sch.split(j, factors=[None, b_lr[1]]) |
| sch.reorder(i0, j0, i1, j1) |
| _ = sch.blockize(i1) |
| vec_len = get_coalesced_veclen(sch.get(B_dequantized_mat)) |
| sch.transform_block_layout( |
| B_dequantized_mat, lambda i, j: ((i * b_lr[1] + j) // vec_len, |
| (i * b_lr[1] + j) % vec_len)) |
| i1, j1 = sch.get_loops(B_dequantized_mat)[-2:] |
| fused = sch.fuse(i1, j1) |
| f1, f2 = sch.split(fused, factors=[warp_size, vec_len], preserve_unit_iters=False)[-2:] |
| sch.bind(f1, "threadIdx.x") |
| |
| if ("fast_decoding" in weight_decode_info and weight_decode_info["fast_decoding"]): |
| source_bit = weight_decode_info["source_format"]["bits"] |
| out_dtype = weight_decode_info["target_format"] |
| lop3_intrin_info = get_lop3_intrin_group( |
| out_dtype=out_dtype, |
| storage_dtype=weight_decode_info["storage_dtype"], |
| source_format=weight_decode_info["source_format"]["format"], |
| source_bit=source_bit, |
| with_scaling=weight_decode_info["with_scaling"], |
| with_zeros=weight_decode_info["with_zeros"], |
| zeros_mode=weight_decode_info["zeros_mode"], |
| storage_scope="warp", |
| ) |
| bf = sch.blockize(f2) |
| sch.tensorize( |
| bf, |
| lop3_intrin_info["compute"], |
| ) |
| |
| if "with_scaling" in weight_decode_info and weight_decode_info["with_scaling"]: |
| grouped_k = sch.get(bf).reads[1].buffer.shape[-1] |
| |
| loop_extent = 8 if out_dtype == "float16" else 16 |
| sch.unsafe_inject_call_argument(bf, -2, loop_extent * grouped_k) |
| import_source.append(lop3_intrin_info["c_source"]) |
|
|
| def tensorize_init_store_compute(): |
| sch.tensorize(sch.get_loops(block_init_c_inner)[-2], intrin_group["init"]) |
| sch.tensorize(sch.get_loops(store)[-2], intrin_group["store"]) |
| sch.tensorize(sch.get_loops(block_inner)[-3], intrin_group["compute"]) |
|
|
| tensorize_init_store_compute() |
|
|
| if stage > 1: |
| sch.annotate( |
| k0, |
| ann_key="software_pipeline_stage", |
| ann_val=[0, 0, stage - 1], |
| ) |
| sch.annotate(k0, ann_key="software_pipeline_order", ann_val=[0, 1, 2]) |
| if use_async: |
| sch.annotate(k0, "software_pipeline_async_stages", [0]) |
| |
| if not isinstance(config.rasterization_plan, NoRasterization): |
| device_func, invoke_func = config.rasterization_plan.get_code() |
| import_source.append(device_func) |
| sch.annotate( |
| sch.get_loops(block_init_c)[-2], |
| ann_key="inject_customized_code_prepend", |
| ann_val=invoke_func, |
| ) |
| |
| if len(import_source) > 0: |
| sch.annotate( |
| j2, |
| ann_key="pragma_import_c", |
| ann_val=("\n").join(import_source), |
| ) |
| return sch |
|
|
| def apply_config( |
| self, |
| func: tir.PrimFunc, |
| config: Hint, |
| ) -> Optional[tir.Schedule]: |
|
|
| def check_sm_version(arch: str) -> int: |
| sm_version = arch.replace("sm_", "") |
| return int(sm_version) if sm_version.isdigit() else -1 |
|
|
| if check_sm_version(config.arch.target.arch) < 80: |
| """MMA Template only support sm_80 and above""" |
| return None |
|
|
| if (config.arch.target.kind.name == "cuda" and |
| check_sm_version(config.arch.target.arch) == 80): |
| return self.sch_shared_memory_prefetch_with_config(func, config) |
| else: |
| return self.sch_dequantize_in_register_with_config(func, config) |
|
|