| |
| |
|
|
| from typing import Union |
| from tvm import arith, DataType |
| import tvm.tl.language as T |
|
|
|
|
| def ldmatrix_32x8_to_shared_16x16_layout(thread_id, local_id): |
| row = thread_id % 16 |
| col = 8 * (thread_id // 16) + local_id % 8 |
| return row, col |
|
|
|
|
| def ldmatrix_trans_32x8_to_shared_16x16_layout(thread_id, local_id): |
| row = 8 * (thread_id // 16) + (thread_id % 8) |
| col = 8 * ((thread_id % 16) // 8) + local_id % 8 |
| return row, col |
|
|
|
|
| def ldmatrix_16x32_to_shared_16x32_layout_a(thread_id, local_id): |
| row = thread_id % 16 |
| col = 16 * (thread_id // 16) + local_id % 16 |
| return row, col |
|
|
|
|
| def ldmatrix_16x32_to_shared_16x32_layout_b(thread_id, local_id): |
| row = 8 * (thread_id // 16) + (thread_id % 8) |
| col = 16 * ((thread_id % 16) // 8) + local_id % 16 |
| return row, col |
|
|
|
|
| def ldmatrix_32x16_to_shared_16x32_layout_a(thread_id, local_id): |
| row = thread_id % 16 |
| col = local_id + (thread_id // 16) * 16 |
| return row, col |
|
|
|
|
| def ldmatrix_32x16_to_shared_16x32_layout_b(thread_id, local_id): |
| row = (thread_id // 16) * 8 + (thread_id % 8) |
| col = local_id + 16 * ((thread_id % 16) // 8) |
| return row, col |
|
|
|
|
| def mma_store_32x8_to_shared_16x16_layout(thread_id, local_id): |
| row = 8 * (local_id % 4 // 2) + (thread_id // 4) |
| col = 8 * (local_id // 4) + (thread_id % 4) * 2 + (local_id % 2) |
| return row, col |
|
|
|
|
| def shared_16x16_to_mma_32x8_smoothlayout(i, j): |
| return (i * 2 + j // 8, j % 8) |
|
|
|
|
| def shared_16x32_to_mma_32x16_smoothlayout(i, j): |
| return (i * 2 + j // 16, j % 16) |
|
|
|
|
| def shared_32x16_to_mma_32x16_smoothlayout(i, j): |
| return (i * 2 + j // 16, j % 16) |
|
|
|
|
| def get_swizzle_layout(row_idx, col_idx, row_size, dtype: Union[DataType, str]): |
| ana = arith.Analyzer() |
| BANK_SIZE_BYTES = 128 |
| if isinstance(dtype, str): |
| dtype = DataType(dtype) |
| col_idx_outer, col_idx_inner = col_idx // (BANK_SIZE_BYTES // dtype.bits), col_idx % ( |
| BANK_SIZE_BYTES // dtype.bits) |
| |
| |
| |
| coalescent_bits = dtype.bits * row_size |
| |
| bank_elems = BANK_SIZE_BYTES // dtype.bits |
| new_col_idx_outer = None |
|
|
| if coalescent_bits % 1024 == 0: |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| row_idx_sub = row_idx % bank_elems |
| new_col_idx_outer = col_idx_outer ^ row_idx_sub |
| else: |
| assert coalescent_bits % 512 == 0 |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| row_idx_sub = row_idx % bank_elems |
| |
| interleave_elems = 32 // dtype.bits |
| new_col_idx_outer = col_idx_outer ^ (row_idx_sub // interleave_elems) |
|
|
| assert (new_col_idx_outer is not None), f"Unsupported dtype {dtype} with {coalescent_bits} bits" |
| return row_idx, ana.simplify(new_col_idx_outer * bank_elems + col_idx_inner) |
|
|
|
|
| def make_mma_swizzle_layout(shared_buf, is_smooth: bool = False): |
| dtype = shared_buf.dtype |
| shape = shared_buf.shape |
|
|
| can_swizzle = shape[-1] * DataType(dtype).bits % 512 == 0 |
| if is_smooth or (not can_swizzle): |
| return T.Layout(shape, lambda *args: args) |
|
|
| def transform_func(*args): |
| i, j = args[-2:] |
| new_warp_i, new_warp_j = get_swizzle_layout(i, j, shape[-1], dtype) |
| return [*args[:-2], new_warp_i, new_warp_j] |
|
|
| return T.Layout(shape, transform_func) |
|
|