| |
| |
|
|
| |
| """A GEMM schedule rule for GPU operators.""" |
| from dataclasses import dataclass |
| from typing import Optional, List |
| from contextlib import suppress |
| from tvm import tir |
| from tvm.target import Target |
| from tvm.tir.stmt import ForKind |
|
|
| from ..base import analysis |
| from ..base.analysis import get_coalesced_veclen |
| from .base import GPUScheduleRule |
| from . import utils |
| from .matmul_analysis import ( |
| auto_inline_consumer_chain, |
| auto_inline_producers, |
| get_in_out_dtypes, |
| get_index_map, |
| normalize_to_matmul, |
| _collect_producers, |
| get_reduction_blocks, |
| ) |
| from .matmul_mma import MatmulTensorizationMMA |
| from .matmul_wmma import ( |
| MatmulInt8Tensorization, |
| MatmulTensorizationWMMA, |
| ) |
| from .matmul_mfma import MatmulTensorizationMFMA |
| from functools import reduce |
| import logging |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class Matmul(GPUScheduleRule): |
| """The schedule rule for matmul-like computation""" |
|
|
| @dataclass |
| class Config: |
| block_size_x: int = 8 |
| block_size_y: int = 8 |
| vthread_x: int = 1 |
| vthread_y: int = 1 |
| micro_size_x: int = 4 |
| micro_size_y: int = 4 |
| micro_size_k: int = 8 |
| vector_size: int = 1 |
| unroll: int = 256 |
| use_shared: bool = True |
| storage_align: bool = False |
| inner_x: bool = False |
|
|
| def get_configs(self, target: Target) -> Config: |
| """Get the schedule config for the target""" |
| if target.kind.name in {"cuda", "rocm", "hip"}: |
| return Matmul.Config( |
| block_size_x=8, |
| block_size_y=16, |
| vthread_x=1, |
| vthread_y=1, |
| micro_size_x=4, |
| micro_size_y=4, |
| micro_size_k=16, |
| vector_size=2, |
| unroll=256, |
| use_shared=True, |
| storage_align=True, |
| inner_x=False, |
| ) |
| elif target.kind.name == "opencl" and "android" in str(target.host): |
| return Matmul.Config( |
| block_size_x=8, |
| block_size_y=8, |
| vthread_x=1, |
| vthread_y=1, |
| micro_size_x=8, |
| micro_size_y=2, |
| micro_size_k=16, |
| vector_size=8, |
| unroll=64, |
| use_shared=False, |
| storage_align=False, |
| inner_x=True, |
| ) |
| else: |
| return Matmul.Config() |
|
|
| def apply( |
| self, |
| func: tir.PrimFunc, |
| target: Target, |
| _: bool, |
| ) -> Optional[tir.Schedule]: |
| if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target): |
| return None |
| sch = tir.Schedule(func) |
| root_block = analysis.get_root_block(sch) |
| blocks = sch.get_child_blocks(root_block) |
|
|
| reduction_blocks = get_reduction_blocks(sch, blocks) |
| if reduction_blocks is None: |
| return None |
|
|
| main_block = reduction_blocks[0] |
| block_stmt = sch.get(main_block) |
| sch = normalize_to_matmul(sch, main_block) |
| if sch is None: |
| return None |
|
|
| |
| |
| |
| |
| |
| minimal_tensorize_threshold = 64 |
| block_stmt = sch.get(main_block) |
| if target.kind.name == "cuda" and utils.get_sm_version(target) >= 70: |
| apply_tensorization: bool = True |
| |
| |
| in_dtype, out_dtype = get_in_out_dtypes(block_stmt) |
| if in_dtype not in ["int8", "float16"]: |
| apply_tensorization = False |
| for item_var in block_stmt.iter_vars[1:]: |
| extent = item_var.dom.extent |
| if isinstance(extent, |
| tir.expr.IntImm) and extent.value <= minimal_tensorize_threshold: |
| apply_tensorization = False |
| if apply_tensorization: |
| if in_dtype == "int8" and out_dtype == "int32": |
| tensorize_sch = MatmulInt8Tensorization().apply(func, target, _) |
| elif utils.get_sm_version(target) >= 80: |
| |
| tensorize_sch = MatmulTensorizationMMA().apply(func, target, _) |
| else: |
| |
| tensorize_sch = MatmulTensorizationWMMA().apply(func, target, _) |
| if tensorize_sch is not None: |
| return tensorize_sch |
| elif target.kind.name == "hip": |
| apply_tensorization: bool = True |
| |
| |
| in_dtype, out_dtype = get_in_out_dtypes(block_stmt) |
| if in_dtype not in ["int8", "float16"]: |
| apply_tensorization = False |
| for item_var in block_stmt.iter_vars[1:]: |
| extent = item_var.dom.extent |
| if isinstance(extent, |
| tir.expr.IntImm) and extent.value <= minimal_tensorize_threshold: |
| apply_tensorization = False |
| if apply_tensorization: |
| |
| tensorize_sch = MatmulTensorizationMFMA().apply(func, target, _) |
| if tensorize_sch is not None: |
| return tensorize_sch |
|
|
| |
| config = self.get_configs(target) |
|
|
| |
| y_kernel_size = config.vthread_y * config.block_size_y * config.micro_size_y |
| x_kernel_size = config.vthread_x * config.block_size_x * config.micro_size_x |
| if config.inner_x: |
| sch.pad_einsum( |
| main_block, |
| [1, y_kernel_size, x_kernel_size, config.micro_size_k], |
| ) |
| batch, y, x, k = sch.get_loops(main_block) |
| else: |
| sch.pad_einsum( |
| main_block, |
| [1, x_kernel_size, y_kernel_size, config.micro_size_k], |
| ) |
| batch, x, y, k = sch.get_loops(main_block) |
| by, vy, ty, yi = sch.split( |
| y, [None, config.vthread_y, config.block_size_y, config.micro_size_y]) |
| bx, vx, tx, xi = sch.split( |
| x, [None, config.vthread_x, config.block_size_x, config.micro_size_x]) |
| ko, ki = sch.split(k, factors=[None, config.micro_size_k]) |
| sch.reorder(by, bx, vy, vx, ty, tx, ko, ki, yi, xi) |
| by = sch.fuse(batch, by) |
| sch.bind(bx, "blockIdx.x") |
| sch.bind(by, "blockIdx.y") |
| sch.bind(vy, "vthread.y") |
| sch.bind(vx, "vthread.x") |
| sch.bind(ty, "threadIdx.y") |
| sch.bind(tx, "threadIdx.x") |
| inner_loop = config.micro_size_x if config.inner_x else config.micro_size_y |
| if inner_loop % config.vector_size == 0: |
| _, v = sch.split(xi, [None, config.vector_size]) |
| sch.vectorize(v) |
|
|
| if config.unroll > 0: |
| sch.annotate(tx, ann_key="pragma_auto_unroll_max_step", ann_val=config.unroll) |
| sch.annotate(tx, ann_key="pragma_unroll_explicit", ann_val=1) |
|
|
| l2g = sch.cache_write(main_block, 0, "local") |
| sch.reverse_compute_at(l2g, tx, preserve_unit_loops=True) |
| if config.micro_size_x % config.vector_size == 0: |
| _, v = sch.split(sch.get_loops(l2g)[-1], [None, config.vector_size]) |
| sch.vectorize(v) |
|
|
| if config.use_shared: |
|
|
| def _cooperative_fetch(index, vec_len): |
| block = sch.cache_read(main_block, index, "shared") |
| num_loops = len(sch.get_loops(block)) |
| sch.compute_at(block, ko, preserve_unit_loops=True) |
| loops = sch.get_loops(block)[-num_loops:] |
| ty, tx, _, vec = sch.split( |
| sch.fuse(*loops), |
| factors=[config.block_size_y, config.block_size_x, None, vec_len], |
| ) |
| sch.vectorize(vec) |
| sch.bind(ty, "threadIdx.y") |
| sch.bind(tx, "threadIdx.x") |
| if config.storage_align: |
| sch.storage_align(block, 0, axis=1, factor=8, offset=vec_len) |
| return block |
|
|
| a_g2s = _cooperative_fetch(0, vec_len=config.vector_size) |
| b_g2s = _cooperative_fetch(1, vec_len=config.vector_size) |
|
|
| auto_inline_producers(sch, a_g2s) |
| auto_inline_producers(sch, b_g2s) |
| else: |
| auto_inline_producers(sch, main_block) |
|
|
| auto_inline_consumer_chain(sch, l2g) |
| sch.decompose_reduction(main_block, ko) |
|
|
| |
| def is_scheduled(block: tir.schedule.BlockRV) -> bool: |
| loops = sch.get_loops(block) |
| loop_kinds = {sch.get(loop).kind for loop in loops} |
| return loop_kinds != {ForKind.SERIAL} |
|
|
| blocks = sch.get_child_blocks(root_block) |
| max_threads_per_block = utils.max_threads_per_block(target) |
| for block in blocks: |
| if is_scheduled(block): |
| continue |
| |
| s_loops = sch.get_loops(block) |
| bx, tx = sch.split( |
| sch.fuse(*s_loops), |
| factors=[ |
| None, |
| 256, |
| ], |
| ) |
| sch.bind(bx, "blockIdx.x") |
| sch.bind(tx, "threadIdx.x") |
|
|
| return sch |
|
|
| def apply_config( |
| self, |
| func: tir.PrimFunc, |
| config, |
| ) -> tir.Schedule: |
| if "dequantize_info" in func.attrs: |
| return self.sch_dequantize_in_register_with_config(func, config) |
| sch = tir.Schedule(func) |
| root_block = analysis.get_root_block(sch) |
| blocks = sch.get_child_blocks(root_block) |
|
|
| reduction_blocks = get_reduction_blocks(sch, blocks) |
| if reduction_blocks is None: |
| return None |
|
|
| |
| |
| if len(config.block) != 2: |
| logger.debug(f"Warning: block config {config.block} is not valid for matmul, skip.") |
| return None |
|
|
| main_block = reduction_blocks[0] |
|
|
| block_stmt = sch.get(main_block) |
|
|
| |
| index_maps = get_index_map(block_stmt, ["n", "n", "n"]) |
| if index_maps is None: |
| return None |
| matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps |
|
|
| |
| block = sch.reindex(main_block, ("read", 0)) |
| sch.transform_layout(block, ("write", 0), a_index_map) |
| block = sch.reindex(main_block, ("read", 1)) |
| sch.transform_layout(block, ("write", 0), b_index_map) |
| block = sch.reindex(main_block, ("write", 0)) |
| sch.transform_layout(block, ("read", 0), c_index_map) |
| sch.transform_block_layout(main_block, matmul_index_map) |
|
|
| |
| block_row_warps = config.block[0] // (config.thread[0] * config.step[0]) |
| block_col_warps = config.block[1] // (config.thread[1] * config.step[1]) |
| thread_row_tiles = config.thread[0] // (config.step[0] * 2) |
| thread_col_tiles = config.thread[1] // (config.step[1] * 2) |
| vthread_row_tiles = (config.step[0] * 2) |
| vthread_col_tiles = (config.step[1] * 2) |
| chunk = config.rstep[0] |
|
|
| |
| BM = block_row_warps * vthread_row_tiles * thread_row_tiles |
| BN = block_col_warps * vthread_col_tiles * thread_col_tiles |
| BK = chunk |
|
|
| sch.pad_einsum( |
| main_block, |
| [1, BM, BN, BK], |
| ) |
| batch, y, x, k = sch.get_loops(main_block) |
| by, vy, ty, yi = sch.split(y, [None, vthread_row_tiles, block_row_warps, thread_row_tiles]) |
| bx, vx, tx, xi = sch.split(x, [None, vthread_col_tiles, block_col_warps, thread_col_tiles]) |
| ko, ki = sch.split(k, factors=[None, BK]) |
| sch.reorder(by, bx, vy, vx, ty, tx, ko, ki, yi, xi) |
| by = sch.fuse(batch, by) |
| sch.bind(bx, "blockIdx.x") |
| sch.bind(by, "blockIdx.y") |
| sch.bind(vy, "vthread.y") |
| sch.bind(vx, "vthread.x") |
| sch.bind(ty, "threadIdx.y") |
| sch.bind(tx, "threadIdx.x") |
|
|
| def prod(iterable): |
| return reduce(lambda x, y: x * y, iterable, 1) |
|
|
| l2g = sch.cache_write(main_block, 0, "local") |
| sch.reverse_compute_at(l2g, tx, preserve_unit_loops=True) |
|
|
| def _cooperative_fetch(index, vec_len): |
| block = sch.cache_read(main_block, index, "shared") |
| num_loops = len(sch.get_loops(block)) |
| block_local = sch.cache_read(main_block, index, "local") |
| sch.compute_at(block_local, ki, preserve_unit_loops=True) |
| sch.compute_at(block, ko, preserve_unit_loops=True) |
| loops = sch.get_loops(block)[-num_loops:] |
| _, ty, tx, vec = sch.split( |
| sch.fuse(*loops), |
| factors=[None, block_row_warps, block_col_warps, vec_len], |
| ) |
|
|
| auto_inline_producers(sch, block) |
|
|
| def is_trivial_load(block): |
| |
| reads = sch.get(block).reads |
| writes = sch.get(block).writes |
| if len(reads) != 1 or len(writes) != 1: |
| return False |
| return all( |
| read.region[-1] == write.region[-1] for read, write in zip(reads, writes)) |
|
|
| if is_trivial_load(block): |
| sch.vectorize(vec) |
|
|
| sch.bind(ty, "threadIdx.y") |
| sch.bind(tx, "threadIdx.x") |
|
|
| _, vec = sch.split( |
| sch.fuse(*sch.get_loops(block_local)[-2:]), |
| [None, vec_len // prod(config.step)], |
| ) |
| sch.vectorize(vec) |
|
|
| return block |
|
|
| for i, input_region in enumerate(sch.get(main_block).reads): |
| _buffer_name = input_region.buffer.name.replace("_reindex", "").replace("_pad", "") |
| if _buffer_name not in config.cached_tensors: |
| logger.warning( |
| f"Warning: {_buffer_name} is not in cached_tensors {config.cached_tensors}, skip." |
| ) |
| continue |
|
|
| |
| vectorize = config.vectorize.get(_buffer_name, 1) |
|
|
| _cooperative_fetch(i, vec_len=vectorize) |
|
|
| auto_inline_consumer_chain(sch, l2g) |
|
|
| _, vec = sch.split( |
| sch.fuse(*sch.get_loops(l2g)[-2:]), [None, vectorize // prod(config.step)]) |
| sch.vectorize(vec) |
|
|
| sch.decompose_reduction(main_block, ko) |
| return sch |
|
|
| def sch_dequantize_in_register_with_config( |
| self, |
| func: tir.PrimFunc, |
| config, |
| ) -> tir.Schedule: |
| ''' |
| For devices without async copy, we can use a simple dequantize schedule without shared memory prefetch. |
| quantized weight |
| | |
| V |
| dequantized in register |
| | |
| V |
| save into shared memory |
| | |
| V |
| compute |
| ''' |
| from .intrin import get_lop3_intrin_group |
|
|
| import_source: List[str] = [] |
|
|
| sch = tir.Schedule(func) |
| root_block = analysis.get_root_block(sch) |
| blocks = sch.get_child_blocks(root_block) |
|
|
| reduction_blocks = get_reduction_blocks(sch, blocks) |
| if reduction_blocks is None: |
| return None |
|
|
| |
| |
| if len(config.block) != 2: |
| logger.debug(f"Warning: block config {config.block} is not valid for matmul, skip.") |
| return None |
|
|
| |
| dequantize_info = func.attrs["dequantize_info"] |
|
|
| def check_dequantize_info(dequantize_info): |
| conditions = [] |
| |
| conditions.append(len(dequantize_info) == 1) |
| |
| return all(conditions) |
|
|
| assert check_dequantize_info(dequantize_info) |
|
|
| (weight_decode_info,) = list(dequantize_info.values()) |
|
|
| def check_weight_decode_info(weight_decode_info): |
| conditions = [] |
| |
| conditions.append("source_format" in weight_decode_info) |
| conditions.append(weight_decode_info["source_format"]["format"] in |
| ["uint", "int", "fp", "nf", "fp_e4m3"]) |
| |
| conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8]) |
| |
| conditions.append("target_format" in weight_decode_info) |
| conditions.append(weight_decode_info["target_format"] in ["float16", "int8"]) |
| return all(conditions) |
|
|
| assert check_weight_decode_info(weight_decode_info), "Invalid Weight Decode Info" |
|
|
| main_block = reduction_blocks[0] |
|
|
| block_stmt = sch.get(main_block) |
|
|
| |
| index_maps = get_index_map(block_stmt, ["n", "t", "n"]) |
| if index_maps is None: |
| return None |
| matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps |
|
|
| |
| block = sch.reindex(main_block, ("read", 0)) |
| sch.transform_layout(block, ("write", 0), a_index_map) |
| block = sch.reindex(main_block, ("read", 1)) |
| sch.transform_layout(block, ("write", 0), b_index_map) |
| block = sch.reindex(main_block, ("write", 0)) |
| sch.transform_layout(block, ("read", 0), c_index_map) |
| sch.transform_block_layout(main_block, matmul_index_map) |
|
|
| |
| block_row_warps = config.block[0] // (config.thread[0] * config.step[0]) |
| block_col_warps = config.block[1] // (config.thread[1] * config.step[1]) |
| thread_row_tiles = config.thread[0] // (config.step[0]) |
| thread_col_tiles = config.thread[1] // (config.step[1]) |
| vthread_row_tiles = (config.step[0]) |
| vthread_col_tiles = (config.step[1]) |
| chunk = config.rstep[0] |
| shared_scope = config.shared_scope |
|
|
| num_ty = block_row_warps |
| num_tx = block_col_warps |
|
|
| |
| BM = block_row_warps * vthread_row_tiles * thread_row_tiles |
| BN = block_col_warps * vthread_col_tiles * thread_col_tiles |
|
|
| |
| def find_valid_number(k, chunk, magic=16): |
| |
| num = (chunk // magic) * magic |
| |
| while num > 0: |
| if k % num == 0: |
| return num |
| num -= magic |
|
|
| return None |
|
|
| K = func.buffer_map[func.params[0]].shape[-1] |
| |
| BK = find_valid_number(K, chunk) |
| |
| align_factor = 4 |
| sch.pad_einsum( |
| main_block, |
| [1, BM, BN, BK], |
| ) |
| batch, y, x, k = sch.get_loops(main_block) |
| by, vy, ty, yi = sch.split(y, [None, vthread_row_tiles, block_row_warps, thread_row_tiles]) |
| bx, vx, tx, xi = sch.split(x, [None, vthread_col_tiles, block_col_warps, thread_col_tiles]) |
| ko, ki, kii = sch.split(k, factors=[None, (BK // align_factor), align_factor]) |
| sch.reorder(by, bx, vy, vx, ty, tx, ko, ki, kii, yi, xi) |
| sch.bind(ty, "threadIdx.y") |
| sch.bind(tx, "threadIdx.x") |
|
|
| def prod(iterable): |
| return reduce(lambda x, y: x * y, iterable, 1) |
|
|
| l2g = sch.cache_write(main_block, 0, "local") |
| sch.reverse_compute_at(l2g, tx, preserve_unit_loops=True) |
|
|
| def _cooperative_fetch(index, vec_len, align_factor=2): |
| block = sch.cache_read(main_block, index, "shared") |
| num_loops = len(sch.get_loops(block)) |
| block_local = sch.cache_read(main_block, index, "local") |
| sch.compute_at(block_local, ki, preserve_unit_loops=True) |
| sch.compute_at(block, ko, preserve_unit_loops=True) |
| loops = sch.get_loops(block)[-num_loops:] |
| _, ty, tx, vec = sch.split( |
| sch.fuse(*loops), |
| factors=[None, block_row_warps, block_col_warps, vec_len], |
| ) |
|
|
| auto_inline_producers(sch, block) |
|
|
| def is_trivial_load(block): |
| |
| reads = sch.get(block).reads |
| writes = sch.get(block).writes |
| if len(reads) != 1 or len(writes) != 1: |
| return False |
| return all( |
| read.region[-1] == write.region[-1] for read, write in zip(reads, writes)) |
|
|
| if is_trivial_load(block): |
| sch.vectorize(vec) |
|
|
| sch.bind(ty, "threadIdx.y") |
| sch.bind(tx, "threadIdx.x") |
|
|
| fused = sch.fuse(*sch.get_loops(block_local)[-2:]) |
| _, vec = sch.split( |
| fused, |
| [None, align_factor], |
| ) |
| sch.vectorize(vec) |
|
|
| return block |
|
|
| for i, input_region in enumerate(sch.get(main_block).reads[:1]): |
| _buffer_name = input_region.buffer.name.replace("_reindex", "").replace("_pad", "") |
| if _buffer_name not in config.cached_tensors: |
| logger.warning( |
| f"Warning: {_buffer_name} is not in cached_tensors {config.cached_tensors}, skip." |
| ) |
| continue |
|
|
| |
| vectorize = config.vectorize.get(_buffer_name, 1) |
|
|
| _cooperative_fetch(i, vec_len=vectorize, align_factor=align_factor) |
|
|
| def decode_fetch_to_shared(block, idx): |
| |
| |
| block_shared = sch.cache_read(block, idx, shared_scope) |
| sch.compute_at(block_shared, ko, preserve_unit_loops=True) |
|
|
| decode_factor = get_coalesced_veclen(sch.get(block_shared)) |
| _, B_shared_vi, _ = sch.split( |
| sch.get_loops(block_shared)[-1], factors=[None, 1, decode_factor]) |
| block_shared_local = sch.cache_read(block_shared, 0, "local") |
| |
| |
| weight_dequantize_block = sch.get_block(weight_decode_info["decode_block"]) |
| weight_producers = _collect_producers(sch, weight_dequantize_block) |
| auto_inline_producers(sch, block_shared_local, weight_producers) |
|
|
| |
| def get_idx(): |
| |
| |
| |
| if weight_decode_info["source_format"]["format"] == "nf": |
| return 1 |
| return 0 |
|
|
| b_idx = get_idx() |
| |
| block_shared_local_local = sch.cache_read(block_shared_local, b_idx, "local") |
|
|
| sch.compute_at(block_shared_local, B_shared_vi, preserve_unit_loops=True) |
| sch.compute_at(block_shared_local_local, B_shared_vi, preserve_unit_loops=True) |
|
|
| dequantize_block_local = block_shared_local |
| if ("with_scaling" in weight_decode_info and weight_decode_info["with_scaling"]): |
| block_local_scales = sch.cache_read(dequantize_block_local, b_idx + 1, "local") |
| sch.compute_at(block_local_scales, B_shared_vi, preserve_unit_loops=True) |
| |
| auto_inline_producers(sch, block_local_scales) |
|
|
| if ("with_zeros" in weight_decode_info and weight_decode_info["with_zeros"]): |
| block_local_zeros = sch.cache_read(dequantize_block_local, b_idx + 2, "local") |
| sch.compute_at(block_local_zeros, B_shared_vi, preserve_unit_loops=True) |
| auto_inline_producers(sch, block_local_zeros) |
|
|
| for producer in weight_producers: |
| with suppress(Exception): |
| auto_inline_producers(sch, producer) |
| sch.compute_inline(producer) |
|
|
| |
| if ("fast_decoding" in weight_decode_info and weight_decode_info["fast_decoding"]): |
| source_bit = weight_decode_info["source_format"]["bits"] |
| out_dtype = weight_decode_info["target_format"] |
| lop3_intrin_info = get_lop3_intrin_group( |
| out_dtype=out_dtype, |
| storage_dtype=weight_decode_info["storage_dtype"], |
| source_format=weight_decode_info["source_format"]["format"], |
| source_bit=source_bit, |
| with_scaling=weight_decode_info["with_scaling"], |
| with_zeros=weight_decode_info["with_zeros"], |
| zeros_mode=weight_decode_info["zeros_mode"], |
| ) |
| sch.tensorize( |
| sch.get_loops(dequantize_block_local)[-1], |
| lop3_intrin_info["compute"], |
| ) |
| import_source.append(lop3_intrin_info["c_source"]) |
|
|
| union_len = (2 + 2) |
| B_shared_fused = sch.fuse(*sch.get_loops(block_shared)[-union_len:-2]) |
|
|
| _, B_shared_ty, B_shared_tx = sch.split(B_shared_fused, factors=[None, num_ty, num_tx]) |
| sch.bind(B_shared_tx, "threadIdx.x") |
| sch.bind(B_shared_ty, "threadIdx.y") |
| sch.vectorize(sch.get_loops(block_shared)[-1]) |
| sch.vectorize(sch.get_loops(block_shared_local_local)[-1]) |
|
|
| |
| if b_idx: |
| block_shared_lut = sch.cache_read(dequantize_block_local, 0, shared_scope) |
| sch.reverse_compute_at(block_shared_lut, bx) |
| _, B_shared_tx = sch.split( |
| sch.get_loops(block_shared_lut)[-1], factors=[None, num_tx]) |
| sch.bind(B_shared_tx, "threadIdx.x") |
| return block_shared_local |
|
|
| _ = decode_fetch_to_shared(main_block, 1) |
|
|
| def fetch_to_local(block, index, align_factor=2): |
| |
| block_local = sch.cache_read(block, index, "local") |
| sch.compute_at(block_local, ki, preserve_unit_loops=True) |
| fused = sch.fuse(*sch.get_loops(block_local)[-2:]) |
| _, vec = sch.split( |
| fused, |
| [None, align_factor], |
| ) |
| sch.vectorize(vec) |
| return block_local |
|
|
| fetch_to_local(main_block, 1, align_factor=align_factor) |
|
|
| auto_inline_consumer_chain(sch, l2g) |
|
|
| l2g_vec = get_coalesced_veclen(sch.get(l2g)) |
| _, vec = sch.split(sch.fuse(*sch.get_loops(l2g)[-2:]), [None, l2g_vec]) |
| sch.vectorize(vec) |
|
|
| sch.decompose_reduction(main_block, ko) |
| |
| if len(import_source) > 0: |
| sch.annotate( |
| ty, |
| ann_key="pragma_import_c", |
| ann_val=("\n").join(import_source), |
| ) |
| return sch |
|
|