| |
| |
|
|
| from tvm import DataType |
| from typing import Literal |
| from .mma_layout import ( |
| ldmatrix_32x8_to_shared_16x16_layout, |
| ldmatrix_trans_32x8_to_shared_16x16_layout, |
| ldmatrix_16x32_to_shared_16x32_layout_a, |
| ldmatrix_16x32_to_shared_16x32_layout_b, |
| mma_store_32x8_to_shared_16x16_layout, |
| ) |
| from .mfma_layout import (thread_id_shared_access_64x4_to_16x16_layout_C_n_m) |
|
|
| from .mma_layout import get_swizzle_layout |
| from .mma_layout import make_mma_swizzle_layout |
| from .mfma_layout import make_mfma_swizzle_layout |
|
|
|
|
| |
| |
| def get_ldmatrix_offset( |
| matrix: Literal["A", "B"], |
| row_idx, |
| col_idx, |
| stride, |
| dtype: Literal["float16", "int8"] = "float16", |
| transposed: bool = False, |
| ): |
| assert matrix in ["A", "B"], "matrix should be either A or B" |
| dtype_bits = DataType(dtype).bits |
| if dtype_bits == 16: |
| transform_func = ldmatrix_32x8_to_shared_16x16_layout |
| transform_func_trans = ldmatrix_trans_32x8_to_shared_16x16_layout |
| if transposed: |
| new_row_idx, new_col_idx = transform_func_trans(row_idx, col_idx) |
| return new_row_idx * stride + new_col_idx |
| else: |
| new_row_idx, new_col_idx = transform_func(row_idx, col_idx) |
| return new_row_idx * stride + new_col_idx |
| elif dtype_bits == 8: |
| if matrix == "B" and transposed: |
| transform_func = ldmatrix_16x32_to_shared_16x32_layout_b |
| new_row_idx, new_col_idx = transform_func(row_idx, col_idx) |
| return new_row_idx * stride + new_col_idx |
| elif matrix == "A" and not transposed: |
| transform_func = ldmatrix_16x32_to_shared_16x32_layout_a |
| new_row_idx, new_col_idx = transform_func(row_idx, col_idx) |
| return new_row_idx * stride + new_col_idx |
| else: |
| raise ValueError("ldmatrix only supports B transposed and A non-transposed for int8") |
| else: |
| raise ValueError(f"Unsupported dtype {dtype}") |
|
|
|
|
| def mma_store_index_map(*args, **kwargs): |
| return mma_store_32x8_to_shared_16x16_layout(*args, **kwargs) |
|
|
|
|
| def mfma_store_index_map(*args, **kwargs): |
| return thread_id_shared_access_64x4_to_16x16_layout_C_n_m(*args, **kwargs) |
|
|
|
|
| def get_mma_micro_size(dtype: Literal["float16", "int8"]): |
| |
| |
| micro_size_x = micro_size_y = 16 |
| micro_size_k = 16 |
| if dtype == "int8": |
| micro_size_k = 32 |
| return micro_size_x, micro_size_y, micro_size_k |
|
|
|
|
| def index_to_coordinates(index, shape): |
| ''' |
| General Implementation of: |
| vjj = index % (micro_size_k // num_elems_per_byte) |
| coordinates[-1] = index % shape[-1]; |
| vii = index // (micro_size_k // num_elems_per_byte) % micro_size_y |
| index = index // shape[-1]; coordinates[-2] = index % shape[-2]; |
| vj = index // (micro_size_k // num_elems_per_byte * micro_size_y) % block_K // (micro_size_k // num_elems_per_byte) |
| index = index // shape[-2]; coordinates[-3] = index % shape[-3]; |
| vi = index // (micro_size_k // num_elems_per_byte * micro_size_y * (block_K // (micro_size_k // num_elems_per_byte))) % block_N // micro_size_y |
| index = index // shape[-3]; coordinates[-4] = index % shape[-4]; |
| ''' |
| coordinates = [] |
| dims = len(shape) |
| for i in range(dims): |
| coordinates.append(index % shape[dims - i - 1]) |
| index = index // shape[dims - i - 1] |
| coordinates.reverse() |
| return coordinates |
|
|