Matt300209's picture
Add files using upload-large-folder tool
a3f2e4d verified
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# pylint: disable=missing-docstring, invalid-name
"""A GEMM schedule rule for GPU operators."""
from typing import Literal, Optional, List
from tvm import tir, DataType
from tvm.target import Target
from bitblas.base.operator_common import TransformKind
from ..base.roller import Hint
from ..base.roller.rasterization import NoRasterization
from ..base import analysis
from .base import GPUScheduleRule
from .matmul_mma_dequantize import MatmulTensorizationMMAWithDequantizeInfo
from ..base.analysis import get_coalesced_veclen
from .matmul_analysis import (
auto_inline_consumer_chain,
is_transpose_block,
is_identity_block,
_collect_producers,
inline_transpose_block,
auto_inline_producers,
get_index_map,
get_reduction_blocks,
get_dequantize_block,
normalize_to_matmul,
get_propagate_map,
)
def get_index_map_3d(index_map, l=16, r=16): # noqa: E741
def index_map_3d(b, i, j):
return (
b,
i // l,
j // r,
*index_map(i % l, j % r),
)
return index_map_3d
def get_index_map_5d(index_map):
"""
for layout transformed gemm, the index map should be 5d
"""
def index_map_5d(b, i, j, ii, jj):
return (
b,
i,
j,
*index_map(ii, jj),
)
return index_map_5d
def get_warp_index_map(index_map, l=16, r=16, is_5d=False): # noqa: E741
if is_5d:
return get_index_map_5d(index_map)
return get_index_map_3d(index_map, l, r)
class MatmulTensorizationMMA(GPUScheduleRule):
"""
The schedule rule for float16 tensor core matmul computation.
func with attr 'dlight.do_not_tensorize' will not be tensorized.
"""
def apply( # pylint: disable=too-many-locals,missing-docstring
self,
func: tir.PrimFunc,
target: Target,
_: bool,
) -> Optional[tir.Schedule]:
if "dequantize_info" in func.attrs:
dequantize_rule = MatmulTensorizationMMAWithDequantizeInfo()
return dequantize_rule.apply(func, target, False)
sch = tir.Schedule(func)
root_block = analysis.get_root_block(sch)
blocks = sch.get_child_blocks(root_block)
if func.attrs is not None and "dlight.do_not_tensorize" in func.attrs.keys():
return None
# We first inline all transpose blocks for later analysis of transposed A and B
blocks = inline_transpose_block(sch, blocks)
reduction_blocks = get_reduction_blocks(sch, blocks)
if reduction_blocks is None:
return None
dequantize_block = get_dequantize_block(sch, blocks)
main_block = reduction_blocks[0]
main_block_stmt = sch.get(main_block)
# Supported data types:
# fp16, fp16, fp16: fp16 precision
# fp16, fp16, fp32: fp16 mixed precision
dtype_a = main_block_stmt.reads[0].buffer.dtype
dtype_b = main_block_stmt.reads[1].buffer.dtype
dtype_c = main_block_stmt.writes[0].buffer.dtype
if dtype_a != dtype_b:
return None
# Get index maps
index_maps = get_index_map(main_block_stmt)
if index_maps is None:
return None
matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps
# Start Schedule
# Step 0. Get schedule config.
# NOTE: we can analyze the config by the hardware spec in the future
# Tensorization by hardware intrinsics
from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel
get_mma_intrin_group, shared_16x16_to_mma_32x8_layout,
)
# tile size
block_m, block_n, block_k = 128, 128, 32
# tensor core intrinsic size
micro_size_m, micro_size_n, micro_size_k = 16, 16, 16
# thread size
# thread_x == warp_size
thread_z, thread_y, thread_x = 2, 2, 32
vector_size = 8
unroll_depth = 4 # noqa: F841
# Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K]
block = sch.reindex(main_block, ("read", 0))
sch.transform_layout(block, ("write", 0), a_index_map)
is_transpose_a = is_transpose_block(sch.get(block))
block = sch.reindex(main_block, ("read", 1))
sch.transform_layout(block, ("write", 0), b_index_map)
is_transpose_b = is_identity_block(sch.get(block))
block = sch.reindex(main_block, ("write", 0))
sch.transform_layout(block, ("read", 0), c_index_map)
sch.transform_block_layout(main_block, matmul_index_map)
batch, i, j, k = sch.get_loops(main_block)
swizzle_factor_for_l2_m = [1, None]
swizzle_factor_for_l2_n = [1, None]
# Step 2. Padding for dynamic shape kernels
sch.pad_einsum(
main_block,
[
1,
swizzle_factor_for_l2_m[0] * block_m,
swizzle_factor_for_l2_n[0] * block_n,
block_k,
],
)
# Step 3. Reorder loops for tiling
# Step 3.1 inner loops for tensor core computation
i, i_inner = sch.split(i, factors=[None, micro_size_m])
j, j_inner = sch.split(j, factors=[None, micro_size_n])
k, k_inner = sch.split(k, factors=[None, micro_size_k])
sch.reorder(i, j, k, i_inner, j_inner, k_inner)
block_inner = main_block
block_outer = sch.blockize(i_inner)
# Step 3.2 outer loops for tiling
# split factors for i, j, and k
micro_block_cnt_in_warp_m = block_m // thread_z // micro_size_m
micro_block_cnt_in_warp_n = block_n // thread_y // micro_size_n
micro_block_cnt_in_warp_k = block_k // micro_size_k
i_factors = swizzle_factor_for_l2_m + [thread_z, micro_block_cnt_in_warp_m]
j_factors = swizzle_factor_for_l2_n + [thread_y, micro_block_cnt_in_warp_n]
k_factors = [None, micro_block_cnt_in_warp_k]
i0, i1, i2, i3 = sch.split(i, factors=i_factors)
j0, j1, j2, j3 = sch.split(j, factors=j_factors)
k0, k1 = sch.split(k, factors=k_factors)
sch.reorder(i0, j0, i1, j1, i2, j2, k0, k1, i3, j3)
block_axis = sch.fuse(batch, i0, j0, i1, j1)
sch.bind(block_axis, "blockIdx.x")
sch.bind(i2, "threadIdx.z")
sch.bind(j2, "threadIdx.y")
# Step 4. Read/write to shared mem and register
def fetch_input(block_outer, read_buffer_idx, tensor_name: Literal["A", "B"], is_transpose):
# 1) Read to shared memory
block_read_smem = sch.cache_read(block_outer, read_buffer_idx, "shared.dyn")
sch.compute_at(block_read_smem, k0)
auto_inline_producers(sch, block_read_smem,
[dequantize_block] if dequantize_block else [])
# For transposed read, we directly load transposed tensor from global
# Then use ldmatrix.trans to handle transpose later
if (tensor_name == "A" and is_transpose) or (tensor_name == "B" and not is_transpose):
# specifical handle transpose read (for NN matmul or TT matmul)
v0, v1 = sch.get_loops(block_read_smem)[-2:]
sch.reorder(v1, v0)
sch.transform_layout(block_read_smem, ("write", 0), lambda b, i, j: (b, j, i))
# bind loops
fused = sch.fuse(*sch.get_loops(block_read_smem)[-2:])
f0, f1, f2, f3, f4 = sch.split(fused, [None, thread_z, thread_y, thread_x, vector_size])
sch.bind(f1, "threadIdx.z")
sch.bind(f2, "threadIdx.y")
sch.bind(f3, "threadIdx.x")
sch.vectorize(f4)
# swizzling
sch.annotate(block_read_smem, ann_key="permuted_layout", ann_val=1)
# 2) Read to register
block_read_reg = sch.cache_read(block_outer, read_buffer_idx, "warp")
sch.compute_at(block_read_reg, k1)
# bind_loops
micro_size_spatial = micro_size_m if tensor_name == "A" else micro_size_n
micro_size_1, micro_size_2 = ((micro_size_spatial,
micro_size_k) if not is_transpose else
(micro_size_k, micro_size_spatial))
v00, v01 = sch.split(sch.get_loops(block_read_reg)[-2], [None, micro_size_1])
v10, v11 = sch.split(sch.get_loops(block_read_reg)[-1], [None, micro_size_2])
sch.reorder(v00, v10, v01, v11)
# reorder read axis to match the layout of ldmatrix
sch.transform_layout(
block_read_reg,
("write", 0),
lambda v0, v1, v2: (
v0,
v1 // micro_size_1,
v2 // micro_size_2,
*shared_16x16_to_mma_32x8_layout(v1 % micro_size_1, v2 % micro_size_2),
),
)
# swizzling
mma_read_block = sch.blockize(sch.get_loops(block_read_reg)[-2])
sch.annotate(mma_read_block, ann_key="permuted_layout", ann_val=1)
return block_read_smem, block_read_reg
block_read_a, block_read_reg_a = fetch_input(block_outer, 0, "A", is_transpose_a)
block_read_b, block_read_reg_b = fetch_input(block_outer, 1, "B", is_transpose_b)
# Write to register, and then smem
def store_output(block_outer, write_buffer_idx):
# 1) Write to shared memory
block_write_smem = sch.cache_write(block_outer, write_buffer_idx, "shared.dyn")
sch.reverse_compute_at(block_write_smem, block_axis)
auto_inline_consumer_chain(sch, block_write_smem)
# bind loops
write_vec_len = get_coalesced_veclen(sch.get(block_write_smem))
fused = sch.fuse(*sch.get_loops(block_write_smem)[-2:])
f0, f1, f2 = sch.split(fused, [None, thread_x, write_vec_len])
sch.bind(f1, "threadIdx.x")
sch.vectorize(f2)
# 2) Write to register
block_write_reg = sch.cache_write(block_outer, write_buffer_idx, "warp")
# bind loops
v0, v1, v2 = sch.get_loops(block_write_reg)[-3:]
v11, v12, v13 = sch.split(v1, factors=[thread_z, None, micro_size_m])
v21, v22, v23 = sch.split(v2, factors=[thread_y, None, micro_size_n])
sch.reorder(v11, v21, v12, v22, v13, v23)
sch.bind(v11, "threadIdx.z")
sch.bind(v21, "threadIdx.y")
# reorder write axis to match the layout of ldmatrix
sch.transform_layout(
block_write_reg,
("read", 0),
lambda v0, v1, v2: (
v0,
v1 // micro_size_m,
v2 // micro_size_n,
*shared_16x16_to_mma_32x8_layout(v1 % micro_size_m, v2 % micro_size_n),
),
)
return block_write_smem, block_write_reg
_, block_write_reg = store_output(block_outer, 0)
# Step 5. Schedule tensor core computation
block_init = sch.decompose_reduction(block_outer, k0)
block_init_inner = sch.get_child_blocks(block_init)[0]
intrin_group = get_mma_intrin_group(
load_scope="shared.dyn",
store_scope="shared.dyn",
a_dtype=str(dtype_a),
b_dtype=str(dtype_b),
out_dtype=str(dtype_c),
trans_a=is_transpose_a,
trans_b=is_transpose_b,
not_use_mma_store_intrinic=False,
)
sch.tensorize(sch.get_loops(block_init_inner)[-2], intrin_group["init"])
sch.tensorize(sch.get_loops(block_read_reg_a)[-2], intrin_group["load_a"])
weight_transform_kind = 0
if hasattr(func, "attrs") and "weight_transform_kind" in func.attrs:
weight_transform_kind = func.attrs["weight_transform_kind"]
if weight_transform_kind >= TransformKind.LDMatrixTransform:
fused = sch.fuse(sch.get_loops(block_read_reg_b)[-2:])
vec_len = get_coalesced_veclen(sch.get(block_read_reg_b))
f0, f1, f2 = sch.split(fused, factors=[None, 32, vec_len])
sch.bind(f1, "threadIdx.x")
sch.vectorize(f2)
else:
sch.tensorize(sch.get_loops(block_read_reg_b)[-2], intrin_group["load_b"])
sch.tensorize(sch.get_loops(block_inner)[-3], intrin_group["compute"])
sch.tensorize(sch.get_loops(block_write_reg)[-2], intrin_group["store"])
# Step 6. Async pipeline
sch.annotate(k0, ann_key="software_pipeline_stage", ann_val=[0, 0, 3])
sch.annotate(k0, ann_key="software_pipeline_order", ann_val=[0, 1, 2])
sch.annotate(k0, ann_key="software_pipeline_async_stages", ann_val=[0])
# Step 7. Handle dequantize block
# Now we just add a dummy kernel to compute dequantize
if dequantize_block is not None:
auto_inline_producers(sch, dequantize_block)
loops = sch.get_loops(dequantize_block)
loop = sch.fuse(*loops)
v0, v1, v2, v3 = sch.split(loop, [None, 128, 2, 4])
sch.bind(v0, "blockIdx.x")
sch.bind(v1, "threadIdx.x")
sch.unroll(v2)
sch.vectorize(v3)
return sch
def apply_config( # pylint: disable=too-many-locals,missing-docstring
self,
func: tir.PrimFunc,
config: Hint,
) -> Optional[tir.Schedule]:
if "dequantize_info" in func.attrs:
dequantize_rule = MatmulTensorizationMMAWithDequantizeInfo()
return dequantize_rule.apply_config(func, config)
is_cross_thread_reduce = (
hasattr(config, "block_reduction_depth") and config.block_reduction_depth is not None)
block_reduction_depth = config.block_reduction_depth if is_cross_thread_reduce else 1
from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel
get_mma_intrin_group,)
import_source: List[str] = []
sch = tir.Schedule(func)
root_block = analysis.get_root_block(sch)
blocks = sch.get_child_blocks(root_block)
if func.attrs is not None and "dlight.do_not_tensorize" in func.attrs.keys():
return None
reduction_blocks = get_reduction_blocks(sch, blocks)
if reduction_blocks is None:
return None
main_block = reduction_blocks[0]
output_blocks = [sch.get(block) for block in sch.get_output_blocks(root_block)]
def check_require_cache(func: tir.PrimFunc, config):
conditions: List[bool] = []
# check if has dynamic symbolic
def check_has_dynamic(func: tir.PrimFunc):
for param in func.params:
if param not in func.buffer_map:
continue
arg = func.buffer_map[param]
for i in arg.shape:
if isinstance(i, tir.Var):
return True
return False
conditions.append(check_has_dynamic(func))
# check if has post process
conditions.append(sch.get(main_block) not in output_blocks)
# check if not use async copy
conditions.append(config.use_async is False)
return any(conditions)
# cache_write_required = check_require_cache(func, config=config)
cache_write_required = True
# Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K]/B[S, K, J]
if not (func.attrs is not None and "dlight.tensorcore_prenormlized" in func.attrs.keys()):
sch = normalize_to_matmul(sch, main_block, ["a", "a", "a"])
shared_scope = config.shared_scope
intrin_info = config.intrin_info
input_transform_kind = intrin_info.input_transform_kind
weight_transform_kind = intrin_info.weight_transform_kind
assert input_transform_kind <= TransformKind.IntraWarpTransform, "Only support up to intra-warp transform"
intrin_group = get_mma_intrin_group(
load_scope=shared_scope,
store_scope=shared_scope if cache_write_required else "global",
a_dtype=intrin_info.in_dtype,
b_dtype=intrin_info.in_dtype,
out_dtype=intrin_info.out_dtype,
trans_a=intrin_info.trans_a,
trans_b=intrin_info.trans_b,
smooth_a=intrin_info.smooth_a,
smooth_b=intrin_info.smooth_b,
not_use_mma_store_intrinic=False,
)
# Start Schedule
# Step 0. Get schedule config.
warp_row_tiles = config.warp[0]
warp_col_tiles = config.warp[1]
block_row_warps = config.block[0] // warp_row_tiles
block_col_warps = config.block[1] // warp_col_tiles
stage = config.pipeline_stage
use_async = config.use_async
reduce_k = block_reduction_depth
chunk = config.rstep[0] // reduce_k
# tensor core intrinsic size
micro_size_x, micro_size_y, micro_size_k = intrin_group["micro_kernel"]
# get the axis for layout transform
def get_axis(l, r, trans): # noqa: E741
return (r, l) if trans else (l, r) # noqa: E741
a_lr = get_axis(micro_size_x, micro_size_k, intrin_info.trans_a)
b_lr = get_axis(micro_size_k, micro_size_y, intrin_info.trans_b)
def can_enable_swizzle(dtype: str, smooth: bool):
# inject_permuted_layout only support float16 currently
if dtype == "float16" or dtype == "int8":
if (chunk * reduce_k) * DataType(dtype).bits != (512):
# currently the swizzle rule only support 512 bit.
return False
# if we use smooth layout, we don't need to do swizzling
return not smooth
return False
can_swizzle_a = can_enable_swizzle(intrin_info.in_dtype, intrin_info.inter_transform_a)
can_swizzle_b = can_enable_swizzle(intrin_info.in_dtype, intrin_info.inter_transform_b)
warp_size = 32
i_factors, j_factors, k_factors = (
[None, 1, block_row_warps, warp_row_tiles // micro_size_x],
[1, None, block_col_warps, warp_col_tiles // micro_size_y],
[None, chunk // micro_size_k],
)
num_ty = i_factors[2]
num_tz = j_factors[2]
x_pad_factor = i_factors[2] * i_factors[3]
y_pad_factor = j_factors[2] * j_factors[3]
k_pad_factor = k_factors[1]
# Step 2. Padding for dynamic shape kernels
sch.pad_einsum(
main_block,
[
1,
micro_size_x * x_pad_factor,
micro_size_y * y_pad_factor,
micro_size_k * k_pad_factor,
],
)
# Step 3. Schedule matmul to use tensor core
block = main_block
batch, i, j, k = sch.get_loops(block)
# inner loops for tensor core computation
i, i_inner = sch.split(i, factors=[None, micro_size_x])
j, j_inner = sch.split(j, factors=[None, micro_size_y])
k, k_inner = sch.split(k, factors=[None, micro_size_k])
sch.reorder(i, j, k, i_inner, j_inner, k_inner)
block_inner = block
block_outer = sch.blockize(i_inner)
i0, i1, i2, i3 = sch.split(i, factors=i_factors)
j0, j1, j2, j3 = sch.split(j, factors=j_factors)
k0, k1 = sch.split(k, k_factors)
if reduce_k > 1:
k0, kr = sch.split(k0, [None, reduce_k])
sch.reorder(i0, j0, i1, j1, i2, j2, kr, k0, k1, i3, j3)
else:
sch.reorder(i0, j0, i1, j1, i2, j2, k0, k1, i3, j3)
block_idy = sch.fuse(i0, j0)
block_idx = sch.fuse(i1, j1)
thread_idy = i2
thread_idz = j2
sch.bind(batch, "blockIdx.z")
sch.bind(block_idx, "blockIdx.x")
sch.bind(block_idy, "blockIdx.y")
if reduce_k > 1:
thread_idz = j2 = thread_idy = sch.fuse(thread_idy, thread_idz)
sch.bind(thread_idy, "threadIdx.y")
sch.bind(kr, "threadIdx.z")
else:
sch.bind(thread_idy, "threadIdx.y")
sch.bind(thread_idz, "threadIdx.z")
# rewrite smooth layout of shared memory
# enable_store_rewrite is a hack as in some cases
# lower vectorization factor may failed to generate
# expected high performance code
enable_store_rewrite = not intrin_info.is_input_8bit()
def smooth_smem_layout_rewrite(block, scope, l=16, r=16, enable=True): # noqa: E741
if not enable:
return
sch.transform_layout(
block,
scope,
lambda b, i, j: (
b,
i // l,
j // r,
i % l,
j % r,
),
)
smooth_smem_layout_rewrite(
block_outer, ("read", 0), *a_lr, enable=intrin_info.inter_transform_a)
smooth_smem_layout_rewrite(
block_outer, ("read", 1), *b_lr, enable=intrin_info.inter_transform_b)
smooth_smem_layout_rewrite(block_outer, ("write", 0), enable=enable_store_rewrite)
def fetch_to_shared(block, idx, vec_len, can_swizzle=False, is_smooth=False, reduce_k=1):
block_read = sch.cache_read(block, idx, shared_scope)
sch.compute_at(block_read, k0, preserve_unit_loops=True)
ndim = len(sch.get(block_read).iter_vars)
fused = sch.fuse(*sch.get_loops(block_read)[-ndim:])
if reduce_k > 1:
f_r, f_0, f_1, f_2, f_3, f_4 = sch.split(
fused, factors=[reduce_k, num_ty, num_tz, None, warp_size, vec_len])
sch.bind(f_3, "threadIdx.x")
f_0 = f_1 = sch.fuse(f_0, f_1)
sch.bind(f_0, "threadIdx.y")
sch.bind(f_r, "threadIdx.z")
else:
f_0, f_1, f_2, f_3, f_4 = sch.split(
fused, factors=[num_ty, num_tz, None, warp_size, vec_len])
sch.bind(f_3, "threadIdx.x")
sch.bind(f_1, "threadIdx.z")
sch.bind(f_0, "threadIdx.y")
sch.vectorize(f_4)
sch.unroll(f_2)
# Apply Swizzling
sch.annotate(block_read, ann_key="permuted_layout", ann_val=can_swizzle)
# if not, apply padding to alleviate bank conflict
if not (can_swizzle or is_smooth):
pad_offset = 8 if intrin_info.in_dtype == "float16" else 16
sch.storage_align(block_read, 0, axis=-2, factor=16, offset=pad_offset)
sch.annotate(f_2, "pragma_unroll_explicit", False)
return block_read
if len(config.vectorize.values()) < 2:
return None
a_g2s = fetch_to_shared(
block_outer,
0,
vec_len=list(config.vectorize.values())[0],
can_swizzle=can_swizzle_a,
is_smooth=intrin_info.smooth_a,
reduce_k=reduce_k,
)
b_g2s = fetch_to_shared(
block_outer,
1,
vec_len=list(config.vectorize.values())[1],
can_swizzle=can_swizzle_b,
is_smooth=intrin_info.smooth_b,
reduce_k=reduce_k,
)
# rewrite global smooth layout
def smooth_gmem_layout_rewrite(sch, block, enable=True, trans=False, matrix_name="A"):
if not enable:
return
# step1: find the first producer block
# Notes: we assume the layout propagate happens in the first producer block
# otherwise, the layout transform will have no effect as it will transform both
# read and write buffer
producers = _collect_producers(sch, block)
g2s_block = a_g2s if matrix_name == "A" else b_g2s
propagate_block: tir.Block = (producers[-1] if len(producers) > 0 else g2s_block)
# step2: transform the layout with inverse permutation
intra_indexmap, _ = get_propagate_map(
trans=trans, dtype=intrin_info.in_dtype, matrix_name=matrix_name)
def inverse_permutation(i, j, ii, jj):
return (i, j, *intra_indexmap.map_indices([ii, jj]))
sch.transform_layout(propagate_block, ("read", 0), inverse_permutation)
smooth_gmem_layout_rewrite(
sch, a_g2s, intrin_info.smooth_a, intrin_info.trans_a, matrix_name="A")
smooth_gmem_layout_rewrite(
sch, b_g2s, intrin_info.smooth_b, intrin_info.trans_b, matrix_name="B")
auto_inline_producers(sch, a_g2s)
auto_inline_producers(sch, b_g2s)
# create read cache to load matrix from shared memory to wmma fragments
A_mat = sch.cache_read(block_outer, 0, "warp")
B_mat = sch.cache_read(block_outer, 1, "warp")
sch.compute_at(A_mat, k1)
sch.compute_at(B_mat, k1)
# create write cache to store matrix from wmma fragments to shared memory and global memory
if cache_write_required:
accumulator_shared_to_global = sch.cache_write(block_outer, 0, shared_scope)
store = sch.cache_write(block_outer, 0, "warp")
sch.reverse_compute_at(store, j2)
# split the store loop to match hardware intrinsic pattern
i, j = sch.get_loops(store)[-2:]
i0, i1 = sch.split(i, factors=[None, micro_size_x], preserve_unit_iters=False)
j0, j1 = sch.split(j, factors=[None, micro_size_y], preserve_unit_iters=False)
sch.reorder(i0, j0, i1, j1)
if cache_write_required:
auto_inline_consumer_chain(sch, accumulator_shared_to_global)
sch.reverse_compute_at(
accumulator_shared_to_global,
sch.get_loops(store)[-6],
preserve_unit_loops=True,
)
vec_len = get_coalesced_veclen(sch.get(accumulator_shared_to_global))
fuse_iters = 5 if enable_store_rewrite else 3
fused = sch.fuse(*sch.get_loops(accumulator_shared_to_global)[-fuse_iters:])
f0, f1, f2 = sch.split(fused, factors=[None, warp_size, vec_len])
sch.bind(f1, "threadIdx.x")
sch.vectorize(f2)
sch.unroll(f0)
sch.annotate(f0, "pragma_unroll_explicit", False)
else:
auto_inline_consumer_chain(sch, store)
block_init_c = sch.decompose_reduction(block_outer, k0)
block_init_c_inner = sch.get_child_blocks(block_init_c)[0]
# Tensorization by hardware intrinsics
index_map_a, index_map_b, index_map_c = intrin_group["index_map"]
sch.transform_layout(
A_mat,
("write", 0),
get_warp_index_map(index_map_a, *a_lr, intrin_info.inter_transform_a),
)
sch.transform_layout(
B_mat,
("write", 0),
get_warp_index_map(index_map_b, *b_lr, intrin_info.inter_transform_b),
)
sch.transform_layout(
store,
("read", 0),
get_warp_index_map(index_map_c, is_5d=enable_store_rewrite),
)
i, j = sch.get_loops(A_mat)[-2:]
i0, i1 = sch.split(i, factors=[None, a_lr[0]])
j0, j1 = sch.split(j, factors=[None, a_lr[1]])
sch.reorder(i0, j0, i1, j1)
ba = sch.blockize(i1)
sch.annotate(ba, ann_key="permuted_layout", ann_val=can_swizzle_a)
sch.tensorize(ba, intrin_group["load_a"])
i, j = sch.get_loops(B_mat)[-2:]
i0, i1 = sch.split(i, factors=[None, b_lr[0]])
j0, j1 = sch.split(j, factors=[None, b_lr[1]])
sch.reorder(i0, j0, i1, j1)
if weight_transform_kind >= TransformKind.LDMatrixTransform:
fused = sch.fuse(i1, j1)
vec_len = get_coalesced_veclen(sch.get(B_mat))
f0, f1, f2 = sch.split(fused, factors=[None, warp_size, vec_len])
sch.bind(f1, "threadIdx.x")
sch.vectorize(f2)
else:
bb = sch.blockize(i1)
sch.annotate(bb, ann_key="permuted_layout", ann_val=can_swizzle_b)
sch.tensorize(bb, intrin_group["load_b"])
def tensorize_init_store_compute():
sch.tensorize(sch.get_loops(block_init_c_inner)[-2], intrin_group["init"])
sch.tensorize(sch.get_loops(store)[-2], intrin_group["store"])
sch.tensorize(sch.get_loops(block_inner)[-3], intrin_group["compute"])
tensorize_init_store_compute()
if stage > 1:
sch.annotate(k0, ann_key="software_pipeline_stage", ann_val=[0, 0, stage - 1])
sch.annotate(k0, ann_key="software_pipeline_order", ann_val=[0, 1, 2])
if use_async:
sch.annotate(k0, "software_pipeline_async_stages", [0])
# plan rasteration
if not isinstance(config.rasterization_plan, NoRasterization):
device_func, invoke_func = config.rasterization_plan.get_code()
import_source.append(device_func)
sch.annotate(
sch.get_loops(block_init_c)[-2],
ann_key="inject_customized_code_prepend",
ann_val=invoke_func,
)
# plan import source
if len(import_source) > 0:
sch.annotate(
thread_idz,
ann_key="pragma_import_c",
ann_val=("\n").join(import_source),
)
return sch