| |
| |
| |
| from bitblas import tvm |
| from tvm import te |
| from bitblas.gpu.matmul_analysis import get_propagate_map |
| from bitblas.base.operator_common import TransformKind |
|
|
|
|
| def matmul_nn( |
| M, |
| N, |
| K, |
| in_dtype="float16", |
| out_dtype="float16", |
| accum_dtype="float16", |
| with_bias=False, |
| ): |
| if not isinstance(M, int): |
| M = tvm.te.var("m") |
| A = te.placeholder((M, K), name="A", dtype=in_dtype) |
| B = te.placeholder((K, N), name="B", dtype=in_dtype) |
| Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) |
|
|
| |
| k = te.reduce_axis((0, K), name="k") |
| C = te.compute( |
| (M, N), |
| lambda i, j: te.sum(A[i, k].astype(accum_dtype) * B[k, j].astype(accum_dtype), axis=k), |
| name="C", |
| ) |
| last_output = C |
| if accum_dtype != out_dtype: |
| D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") |
| last_output = D |
|
|
| if with_bias: |
| E = te.compute((M, N), lambda i, j: last_output[i, j] + Bias[j], name="E") |
| last_output = E |
|
|
| args = [A, B, Bias, last_output] if with_bias else [A, B, last_output] |
|
|
| func = te.create_prim_func(args) |
|
|
| return tvm.IRModule.from_expr(func) |
|
|
|
|
| def matmul_nt( |
| M, |
| N, |
| K, |
| in_dtype="float16", |
| out_dtype="float16", |
| accum_dtype="float16", |
| with_bias=False, |
| ): |
| if not isinstance(M, int): |
| M = tvm.te.var("m") |
| A = te.placeholder((M, K), name="A", dtype=in_dtype) |
| B = te.placeholder((N, K), name="B", dtype=in_dtype) |
| Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) |
|
|
| |
| k = te.reduce_axis((0, K), name="k") |
| C = te.compute( |
| (M, N), |
| lambda i, j: te.sum(A[i, k].astype(accum_dtype) * B[j, k].astype(accum_dtype), axis=k), |
| name="C", |
| ) |
| last_output = C |
| if accum_dtype != out_dtype: |
| D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") |
| last_output = D |
|
|
| if with_bias: |
| E = te.compute((M, N), lambda i, j: last_output[i, j] + Bias[j], name="E") |
| last_output = E |
|
|
| args = [A, B, Bias, last_output] if with_bias else [A, B, last_output] |
|
|
| func = te.create_prim_func(args) |
|
|
| return tvm.IRModule.from_expr(func) |
|
|
|
|
| def matmul( |
| M, |
| N, |
| K, |
| in_dtype="float16", |
| out_dtype="float16", |
| accum_dtype="float16", |
| with_bias=False, |
| layout="nt", |
| ): |
| if layout == "nn": |
| return matmul_nn(M, N, K, in_dtype, out_dtype, accum_dtype, with_bias) |
| return matmul_nt(M, N, K, in_dtype, out_dtype, accum_dtype, with_bias) |
|
|
|
|
| def matmul_nt_propagate_a( |
| M, |
| N, |
| K, |
| in_dtype="float16", |
| out_dtype="float16", |
| accum_dtype="float16", |
| with_bias=False, |
| transform_kind: TransformKind = TransformKind.IntraWarpTransform, |
| ): |
| if not isinstance(M, int): |
| M = tvm.te.var("m") |
| l = r = 16 |
| if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]: |
| l, r = 16, 32 |
|
|
| _, inversed_index_map = get_propagate_map(trans=False, dtype=in_dtype, matrix_name="A") |
|
|
| A = te.placeholder((M // l, K // r, l, r), name="A", dtype=in_dtype) |
| B = te.placeholder((N, K), name="B", dtype=in_dtype) |
| Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) |
|
|
| def fcompute(i, j): |
| warp_i, warp_j = i % l, j % r |
| spatial_args = i // l, j // r |
| if transform_kind >= TransformKind.IntraWarpTransform: |
| warp_i, warp_j = inversed_index_map.map_indices([warp_i, warp_j]) |
| new_index = (*spatial_args, warp_i, warp_j) |
| return A[new_index] |
|
|
| A_reindex = te.compute( |
| (M, K), |
| fcompute, |
| name="A_reindex", |
| ) |
| |
| k = te.reduce_axis((0, K), name="k") |
| C = te.compute( |
| (M, N), |
| lambda i, j: te.sum( |
| A_reindex[i, k].astype(accum_dtype) * B[j, k].astype(accum_dtype), axis=k), |
| name="C", |
| ) |
| last_output = C |
| if accum_dtype != out_dtype: |
| D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") |
| last_output = D |
|
|
| if with_bias: |
| E = te.compute((M, N), lambda i, j: last_output[i, j] + Bias[j], name="E") |
| last_output = E |
|
|
| args = [A, B, Bias, last_output] if with_bias else [A, B, last_output] |
|
|
| func = te.create_prim_func(args) |
| func = func.with_attr("input_transform_kind", transform_kind.value) |
|
|
| return tvm.IRModule.from_expr(func) |
|
|
|
|
| def matmul_nt_propagate_b( |
| M, |
| N, |
| K, |
| in_dtype="float16", |
| out_dtype="float16", |
| accum_dtype="float16", |
| with_bias=False, |
| transform_kind: TransformKind = TransformKind.IntraWarpTransform, |
| ): |
| if isinstance(transform_kind, int): |
| transform_kind = TransformKind(transform_kind) |
| if not isinstance(M, int): |
| M = tvm.te.var("m") |
| l = r = 16 |
| if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]: |
| l, r = 16, 32 |
|
|
| _, inversed_index_map = get_propagate_map(trans=True, dtype=in_dtype, matrix_name="B") |
|
|
| A = te.placeholder((M, K), name="A", dtype=in_dtype) |
| B = te.placeholder((N // l, K // r, l, r), name="B", dtype=in_dtype) |
| Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) |
|
|
| def fcompute(i, j): |
| warp_i, warp_j = i % l, j % r |
| spatial_args = i // l, j // r |
| if transform_kind >= TransformKind.IntraWarpTransform: |
| warp_i, warp_j = inversed_index_map.map_indices([warp_i, warp_j]) |
| new_index = (*spatial_args, warp_i, warp_j) |
| return B[new_index] |
|
|
| B_reindex = te.compute( |
| (N, K), |
| fcompute, |
| name="B_reindex", |
| ) |
| |
| k = te.reduce_axis((0, K), name="k") |
| C = te.compute( |
| (M, N), |
| lambda i, j: te.sum( |
| A[i, k].astype(accum_dtype) * B_reindex[j, k].astype(accum_dtype), axis=k), |
| name="C", |
| ) |
| last_output = C |
| if accum_dtype != out_dtype: |
| D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") |
| last_output = D |
|
|
| if with_bias: |
| E = te.compute((M, N), lambda i, j: last_output[i, j] + Bias[j], name="E") |
| last_output = E |
|
|
| args = [A, B, Bias, last_output] if with_bias else [A, B, last_output] |
|
|
| func = te.create_prim_func(args) |
| func = func.with_attr("weight_transform_kind", transform_kind.value) |
|
|
| return tvm.IRModule.from_expr(func) |
|
|
|
|
| def matmul_nt_propagate_a_propagate_b( |
| M, |
| N, |
| K, |
| in_dtype="float16", |
| out_dtype="float16", |
| accum_dtype="float16", |
| with_bias=False, |
| transform_kind_input: TransformKind = TransformKind.IntraWarpTransform, |
| transform_kind_weight: TransformKind = TransformKind.IntraWarpTransform, |
| ): |
| if not isinstance(M, int): |
| M = tvm.te.var("m") |
| l = r = 16 |
| if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]: |
| l, r = 16, 32 |
|
|
| A = te.placeholder((M // l, K // r, l, r), name="A", dtype=in_dtype) |
| B = te.placeholder((N // l, K // r, l, r), name="B", dtype=in_dtype) |
| Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) |
|
|
| _, inversed_index_map = get_propagate_map(trans=False, dtype=in_dtype, matrix_name="A") |
|
|
| def fcompute(i, j): |
| warp_i, warp_j = i % l, j % r |
| spatial_args = i // l, j // r |
| if transform_kind_input >= TransformKind.IntraWarpTransform: |
| warp_i, warp_j = inversed_index_map.map_indices([warp_i, warp_j]) |
| new_index = (*spatial_args, warp_i, warp_j) |
| return A[new_index] |
|
|
| A_reindex = te.compute( |
| (M, K), |
| fcompute, |
| name="A_reindex", |
| ) |
|
|
| _, inversed_index_map = get_propagate_map(trans=True, dtype=in_dtype, matrix_name="B") |
|
|
| def fcompute(i, j): |
| warp_i, warp_j = i % l, j % r |
| spatial_args = i // l, j // r |
| if transform_kind_weight >= TransformKind.IntraWarpTransform: |
| warp_i, warp_j = inversed_index_map.map_indices([warp_i, warp_j]) |
| new_index = (*spatial_args, warp_i, warp_j) |
| return B[new_index] |
|
|
| B_reindex = te.compute( |
| (N, K), |
| fcompute, |
| name="B_reindex", |
| ) |
| |
| k = te.reduce_axis((0, K), name="k") |
| C = te.compute( |
| (M, N), |
| lambda i, j: te.sum( |
| A_reindex[i, k].astype(accum_dtype) * B_reindex[j, k].astype(accum_dtype), |
| axis=k, |
| ), |
| name="C", |
| ) |
| last_output = C |
| if accum_dtype != out_dtype: |
| D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") |
| last_output = D |
|
|
| if with_bias: |
| E = te.compute((M, N), lambda i, j: last_output[i, j] + Bias[j], name="E") |
| last_output = E |
|
|
| args = [A, B, Bias, last_output] if with_bias else [A, B, last_output] |
|
|
| func = te.create_prim_func(args) |
| func = func.with_attr("input_transform_kind", transform_kind_input.value) |
| func = func.with_attr("weight_transform_kind", transform_kind_weight.value) |
|
|
| return tvm.IRModule.from_expr(func) |
|
|
|
|
| def select_implementation( |
| M=None, |
| N=16384, |
| K=16384, |
| in_dtype="float16", |
| out_dtype="float16", |
| accum_dtype="float16", |
| with_bias=False, |
| layout="nt", |
| propagate_a: TransformKind = TransformKind.NonTransform, |
| propagate_b: TransformKind = TransformKind.NonTransform, |
| ): |
| if layout == "nn": |
| if propagate_a or propagate_b: |
| raise ValueError( |
| "Currently only support propagate_a=False and propagate_b=False for layout=nn") |
| return matmul(M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout) |
| elif layout == "nt": |
| if propagate_a and propagate_b: |
| return matmul_nt_propagate_a_propagate_b( |
| M, |
| N, |
| K, |
| in_dtype, |
| out_dtype, |
| accum_dtype, |
| with_bias, |
| transform_kind_input=propagate_a, |
| transform_kind_weight=propagate_b, |
| ) |
| elif propagate_a: |
| return matmul_nt_propagate_a( |
| M, |
| N, |
| K, |
| in_dtype, |
| out_dtype, |
| accum_dtype, |
| with_bias, |
| transform_kind=propagate_a, |
| ) |
| elif propagate_b: |
| return matmul_nt_propagate_b( |
| M, |
| N, |
| K, |
| in_dtype, |
| out_dtype, |
| accum_dtype, |
| with_bias, |
| transform_kind=propagate_b, |
| ) |
| else: |
| return matmul(M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout) |
| else: |
| raise ValueError(f"Unsupported layout: {layout}") |
|
|