| |
| |
| from bitblas import tvm |
| from typing import Optional, List, Dict, Union |
| from tvm import IRModule |
| from bitblas.base.arch import TileDevice, is_cuda_arch, is_cdna_arch |
| from bitblas.utils import match_global_kernel |
| from bitblas.utils.rtmod_analysis import get_annotated_device_mod |
| import re |
| import logging |
|
|
| from .base import (BaseWrapper, PREDEF_ARRTIBUTE_SET_DYNAMIC_MEMORY, PREDEF_INIT_FUNC, |
| PREDEF_HOST_FUNC) |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class TLCUDASourceWrapper(object): |
| _TYPE_MAP = { |
| "float32": "float", |
| "float16": "half_t", |
| "bfloat16": "bfloat16_t", |
| "e4m3_float8": "float_e4m3_t", |
| "e5m2_float8": "float_e5m2_t", |
| "float64": "double", |
| "int64": "int64_t", |
| "int32": "int", |
| "uint32": "unsigned int", |
| "bool": "int8_t", |
| "int8": "int8_t", |
| "uint8": "uint8_t", |
| "int16": "int16_t", |
| "uchar": "uint8_t", |
| } |
|
|
| backend = "tl" |
|
|
| def __init__(self, scheduled_ir_module: IRModule, source: str, arch: TileDevice): |
| self.mod = scheduled_ir_module |
| self.arch = arch |
| self.source = source |
| self.function_name: Optional[str] = None |
| self.dynamic_smem_buf: Optional[int] = None |
| self.block_info: Union[List[int], Dict] = [1, 1, 1] |
| self.grid_info: Union[List[int], Dict] = [1, 1, 1] |
| self.parse_source_information() |
| self.srcpath: Optional[str] = None |
| self.libpath: Optional[str] = None |
| self.lib_code: Optional[str] = self.update_lib_code(source) |
|
|
| def parse_source_information(self): |
| device_mod = get_annotated_device_mod(self.mod, self.arch.target, backend=self.backend) |
| assert (len(device_mod.functions) == 1 |
| ), "Only support one function in the module for static shape kernel." |
| for g_var, func in device_mod.functions.items(): |
| self.function_name = g_var.name_hint |
| attrs = func.attrs |
| if "dyn_shared_memory_buf" in attrs: |
| self.dynamic_smem_buf = int(attrs["dyn_shared_memory_buf"]) |
| if "thread_extent" in attrs: |
| thread_extent = attrs["thread_extent"] |
| for tag, extent in thread_extent.items(): |
| if "threadIdx" in tag: |
| self.block_info["xyz".index(tag[-1])] = extent |
| elif "blockIdx" in tag: |
| self.grid_info["xyz".index(tag[-1])] = extent |
|
|
| def get_dynamic_symbolic_set(self, prim_func): |
| |
| dynamic_symbolic_set = set() |
| for param in prim_func.params: |
| buffer = prim_func.buffer_map[param] |
| for dim in buffer.shape: |
| if isinstance(dim, tvm.tir.Var): |
| dynamic_symbolic_set.add(dim.name) |
| return dynamic_symbolic_set |
|
|
| def get_cuda_init_func(self): |
| |
| call_str = """""" |
| |
| if self.dynamic_smem_buf is not None: |
| call_str = ( |
| PREDEF_ARRTIBUTE_SET_DYNAMIC_MEMORY.format(self.function_name, |
| self.dynamic_smem_buf)) |
| |
| init_funcs = PREDEF_INIT_FUNC.format(call_str) |
| return init_funcs |
|
|
| def update_lib_code(self, code: str): |
| |
| self.lib_code = code |
| |
| index = match_global_kernel(code) |
| |
| declaration = code[index:].split(";")[0] |
|
|
| function_name = self.function_name |
| |
| init_func = self.get_cuda_init_func() |
|
|
| |
| index = code.index("{", index) |
| function_args = [] |
| |
| for param in self.prim_func.params: |
| buffer = self.prim_func.buffer_map[param] |
| function_args.append({ |
| "name": buffer.name, |
| "type": self._TYPE_MAP[buffer.dtype] + "* __restrict__", |
| }) |
|
|
| dynamic_symbolic_set = self.get_dynamic_symbolic_set(self.prim_func) |
| |
| for dyn_sym in dynamic_symbolic_set: |
| function_args.append({"name": dyn_sym, "type": "int"}) |
|
|
| function_args.append({"name": "stream=cudaStreamDefault", "type": "cudaStream_t"},) |
| |
| def_args = ", ".join([f"{arg['type']} {arg['name']}" for arg in function_args]) |
|
|
| def func_call_args(s, function_args): |
| |
| pattern = r"[,\s]*(?:\w+\s*\*+\s*__restrict__\s+)?(\w+)" |
| matches = re.findall(pattern, s) |
| call_args = [] |
| for match in matches: |
| for arg in function_args: |
| if arg["name"] == match: |
| call_args.append(match) |
| return call_args |
|
|
| call_args = ", ".join(func_call_args(declaration, function_args)) |
| block_info, grid_info = self.block_info, self.grid_info |
|
|
| def legalize_c(p): |
| |
| |
| |
| |
| if isinstance(p, tvm.tir.IntImm): |
| p = int(p) |
| return str(p).replace("//", "/") |
|
|
| |
| block_str = "dim3({}, {}, {})".format( |
| legalize_c(block_info[0]), |
| legalize_c(block_info[1]), |
| legalize_c(block_info[2]), |
| ) |
| grid_str = "dim3({}, {}, {})".format( |
| legalize_c(grid_info[0]), legalize_c(grid_info[1]), legalize_c(grid_info[2])) |
| |
| smem_str = 0 if self.dynamic_smem_buf is None else self.dynamic_smem_buf |
| |
| if len(dynamic_symbolic_set) != 0: |
| call_str = "if ({} == 0) return; \n\t\t".format(list(dynamic_symbolic_set)[0]) |
| else: |
| call_str = "" |
| call_str += "{}<<<{}, {}, {}, stream>>>({});".format(function_name, grid_str, block_str, |
| smem_str, call_args) |
| |
| host_func = PREDEF_HOST_FUNC.format(def_args, call_str) |
| |
| lib_code = self.source + init_func + host_func |
| return lib_code |
|
|
| @property |
| def prim_func(self): |
| if len(self.mod.get_global_vars()) == 1: |
| return self.mod[self.mod.get_global_vars()[0]] |
| elif "main" in self.mod: |
| return self.mod["main"] |
| else: |
| for _, function in self.mod.functions_items(): |
| attr = function.attrs |
| if "tir.is_global_func" in attr and attr["tir.is_global_func"]: |
| return function |
| raise ValueError("Cannot find primary function in the module.") |
|
|
|
|
| class TLCUDASourceWrapperWithDynamic(TLCUDASourceWrapper): |
|
|
| def __init__(self, scheduled_ir_module: IRModule, source: str, arch: TileDevice): |
| super().__init__(scheduled_ir_module, source, arch) |
|
|
| def get_cuda_init_func(self): |
| |
| call_str = """""" |
| |
| for function_name, dynamic_smem_buf in self.dynamic_smem_buf.items(): |
| if dynamic_smem_buf is not None: |
| |
| call_str += PREDEF_ARRTIBUTE_SET_DYNAMIC_MEMORY.format( |
| function_name, dynamic_smem_buf) |
| |
| init_funcs = PREDEF_INIT_FUNC.format(call_str) |
| return init_funcs |
|
|
| def create_dispatch_func(self, code, function_informations): |
| |
| dynamic_symbolic_set = self.get_dynamic_symbolic_set(self.prim_func) |
|
|
| |
| index = match_global_kernel(code) |
|
|
| |
| dummy_declaration = code[index:].split(";")[0] |
|
|
| function_name = self.function_name |
|
|
| |
| index = code.index("{", index) |
| function_args = [] |
| |
| for param in self.prim_func.params: |
| buffer = self.prim_func.buffer_map[param] |
| function_args.append({ |
| "name": buffer.name, |
| "type": self._TYPE_MAP[buffer.dtype] + "* __restrict__", |
| }) |
| |
| for dyn_sym in dynamic_symbolic_set: |
| function_args.append({"name": dyn_sym, "type": "int"}) |
|
|
| function_args.append({"name": "stream=cudaStreamDefault", "type": "cudaStream_t"},) |
|
|
| |
| def_args = ", ".join([f"{arg['type']} {arg['name']}" for arg in function_args]) |
|
|
| def func_call_args(s: str, function_args): |
| |
| pattern = r"[,\s]*(?:\w+\s*\*+\s*__restrict__\s+)?(\w+)" |
| matches = re.findall(pattern, s) |
| call_args = [] |
| for match in matches: |
| match = re.sub(r"\d+", "", match) |
| match = re.sub(r"_", "", match) |
| for arg in function_args: |
| if arg["name"] == match: |
| call_args.append(match) |
| return call_args |
|
|
| call_args = ", ".join(func_call_args(dummy_declaration, function_args)) |
|
|
| def legalize_c(p): |
| |
| |
| |
| |
| if isinstance(p, tvm.tir.IntImm): |
| p = int(p) |
| return str(p).replace("//", "/") |
|
|
| last_range = 0 |
| num_items = len(function_informations) |
| _call_str = """""" |
| for last_range, (function_name, info) in enumerate(function_informations.items()): |
| |
| block_info, grid_info = info["block_info"], info["grid_info"] |
| block_str = "dim3({}, {}, {})".format( |
| legalize_c(block_info[0]), |
| legalize_c(block_info[1]), |
| legalize_c(block_info[2]), |
| ) |
| grid_str = "dim3({}, {}, {})".format( |
| legalize_c(grid_info[0]), |
| legalize_c(grid_info[1]), |
| legalize_c(grid_info[2]), |
| ) |
| |
| smem_str = (0 if info["dynamic_smem_buf"] is None else info["dynamic_smem_buf"]) |
| opt_shapes = info["opt_shapes"] |
| |
| (symbolic,) = list(dynamic_symbolic_set) |
| range_str = opt_shapes[symbolic] |
| if last_range == 0: |
| call_str = " if ({} == 0) return; \n".format(symbolic,) |
| call_str += " if ({} <= {}) {{\n {}<<<{}, {}, {}, stream>>>({}); \n }}\n".format( |
| symbolic, |
| range_str, |
| function_name, |
| grid_str, |
| block_str, |
| smem_str, |
| call_args, |
| ) |
| else: |
| call_str = " else if ({} <= {}) {{\n {}<<<{}, {}, {}, stream>>>({}); \n }}\n".format( |
| symbolic, |
| range_str, |
| function_name, |
| grid_str, |
| block_str, |
| smem_str, |
| call_args, |
| ) |
| if last_range == num_items - 1: |
| call_str += " else {{\n {}<<<{}, {}, {}, stream>>>({}); \n }}\n".format( |
| function_name, grid_str, block_str, smem_str, call_args) |
| _call_str += call_str |
|
|
| |
| host_func = PREDEF_HOST_FUNC.format(def_args, _call_str) |
| return host_func |
|
|
| def parse_source_information(self): |
| |
| device_mod = get_annotated_device_mod(self.mod, self.arch.target, backend=self.backend) |
| block_info_map = {} |
| grid_info_map = {} |
| dynamic_smem_buf_map = {} |
| for g_var, func in device_mod.functions.items(): |
| |
| block_info = [1, 1, 1] |
| grid_info = [1, 1, 1] |
| function_name = g_var.name_hint |
| attrs = func.attrs |
| dynamic_smem_buf = None |
| if "dyn_shared_memory_buf" in attrs: |
| dynamic_smem_buf = int(attrs["dyn_shared_memory_buf"]) |
| if "thread_extent" in attrs: |
| |
| thread_extent = attrs["thread_extent"] |
| for tag, extent in thread_extent.items(): |
| if "threadIdx" in tag: |
| block_info["xyz".index(tag[-1])] = extent |
| elif "blockIdx" in tag: |
| grid_info["xyz".index(tag[-1])] = extent |
| |
| block_info_map[function_name] = block_info |
| grid_info_map[function_name] = grid_info |
| dynamic_smem_buf_map[function_name] = dynamic_smem_buf |
| |
| self.block_info = block_info_map |
| self.grid_info = grid_info_map |
| self.dynamic_smem_buf = dynamic_smem_buf_map |
|
|
| def update_lib_code(self, code: str): |
| |
| function_informations = {} |
| for g_var, func in self.mod.functions.items(): |
| function_name = g_var.name_hint |
| |
| if (function_name not in self.block_info) or (function_name not in self.grid_info): |
| continue |
|
|
| attrs = func.attrs |
| assert "opt_shapes" in attrs |
| opt_shapes = attrs["opt_shapes"] |
| function_informations[function_name] = { |
| "function_name": function_name, |
| "opt_shapes": opt_shapes, |
| "block_info": self.block_info[function_name], |
| "grid_info": self.grid_info[function_name], |
| "dynamic_smem_buf": self.dynamic_smem_buf[function_name], |
| } |
|
|
| def compare_map_objects(map_obj): |
| comparable_representation = list(map_obj.values()) |
| return comparable_representation |
|
|
| function_informations = dict( |
| sorted( |
| function_informations.items(), |
| key=lambda item: compare_map_objects(item[1]["opt_shapes"]), |
| )) |
|
|
| self.lib_code = code |
|
|
| |
| init_func = self.get_cuda_init_func() |
| host_func = self.create_dispatch_func(code, function_informations) |
| |
| lib_code = self.source + init_func + host_func |
| return lib_code |
|
|
|
|
| class TLHIPSourceWrapper(TLCUDASourceWrapper): |
|
|
| def __init__(self, scheduled_ir_module: IRModule, source: str, arch: TileDevice): |
| super().__init__(scheduled_ir_module, source, arch) |
|
|
| def get_hip_init_func(self): |
| |
| call_str = """""" |
| |
| if self.dynamic_smem_buf is not None: |
| call_str = PREDEF_ARRTIBUTE_SET_DYNAMIC_MEMORY.format(self.function_name, |
| self.dynamic_smem_buf) |
| |
| init_funcs = PREDEF_INIT_FUNC.format(call_str) |
| return init_funcs |
|
|
| def get_stream_type(self, function_args): |
| function_args.append({"name": "stream=hipStreamDefault", "type": "hipStream_t"},) |
|
|
|
|
| class TLWrapper(BaseWrapper): |
|
|
| def __init__(self, arch: TileDevice): |
| super().__init__() |
| self.scheduled_ir_module = None |
| self.arch = arch |
| self.lib = None |
|
|
| def assign_optimized_module(self, scheduled_ir_module: IRModule): |
| self.scheduled_ir_module = scheduled_ir_module |
|
|
| |
| def wrap(self, c_source: str, is_dynamic: bool = False): |
| assert self.scheduled_ir_module is not None, "Please assign optimized module first." |
| if is_cuda_arch(self.arch): |
| wrapper_class = ( |
| TLCUDASourceWrapper if not is_dynamic else TLCUDASourceWrapperWithDynamic) |
| elif is_cdna_arch(self.arch): |
| wrapper_class = TLHIPSourceWrapper |
| else: |
| raise ValueError(f"Unsupported platform: {self.arch.platform}") |
| wrapper = wrapper_class(self.scheduled_ir_module, c_source, self.arch) |
| return wrapper.lib_code |
|
|