| import argparse | |
| import os | |
| import pickle | |
| from platform import system | |
| from typing import List | |
| import tvm | |
| import tvm.testing | |
| from tvm import relax | |
| import mlc_llm | |
| from mlc_llm import utils | |
| from mlc_llm.relax_model import gpt_neox, llama, moss | |
| def _parse_args(): | |
| args = argparse.ArgumentParser() | |
| utils.argparse_add_common(args) | |
| args.add_argument("--quantization-sym", action="store_true", default=False) | |
| args.add_argument( | |
| "--quantization-mode", type=str, choices=["int4", "int3", "fp4"], default="int4" | |
| ) | |
| args.add_argument( | |
| "--quantization-storage-nbit", type=int, choices=[32, 16, 8], default=32 | |
| ) | |
| args.add_argument("--no-quantize", action="store_true", default=False) | |
| args.add_argument("--max-seq-len", type=int, default=-1) | |
| args.add_argument("--target", type=str, default="auto") | |
| args.add_argument( | |
| "--db-path", | |
| type=str, | |
| default=None, | |
| help="Path to log database. Default: ./log_db/{model}", | |
| ) | |
| args.add_argument("--artifact-path", type=str, default="dist") | |
| args.add_argument( | |
| "--use-cache", | |
| type=int, | |
| default=1, | |
| help="Whether to use previously pickled IRModule and skip trace.", | |
| ) | |
| args.add_argument("--debug-dump", action="store_true", default=False) | |
| args.add_argument("--debug-load-script", action="store_true", default=False) | |
| args.add_argument( | |
| "--llvm-mingw", | |
| type=str, | |
| default="", | |
| help="/path/to/llvm-mingw-root, use llvm-mingw to cross compile to windows", | |
| ) | |
| args.add_argument("--system-lib", action="store_true", default=False) | |
| parsed = args.parse_args() | |
| assert parsed.max_seq_len == -1 or parsed.max_seq_len > 0 | |
| parsed.model_path = os.path.join(parsed.artifact_path, "models", parsed.model) | |
| parsed.artifact_path = os.path.join( | |
| parsed.artifact_path, parsed.model, parsed.dtype | |
| ) | |
| parsed.export_kwargs = {} | |
| parsed.lib_format = "so" | |
| parsed.db_path = parsed.db_path or os.path.join("log_db", parsed.model) | |
| utils.parse_target(parsed) | |
| utils.argparse_postproc_common(parsed) | |
| return parsed | |
| def debug_dump_script(mod, name, args): | |
| """Debug dump mode""" | |
| if not args.debug_dump: | |
| return | |
| dump_path = os.path.join(args.artifact_path, "debug", name) | |
| with open(dump_path, "w") as outfile: | |
| outfile.write(mod.script(show_meta=True)) | |
| print(f"Dump mod to {dump_path}") | |
| def debug_load_script(name, args): | |
| input_path = os.path.join(args.artifact_path, "debug", name) | |
| lib = {"__file__": input_path} | |
| exec(compile(open(input_path, "rb").read(), input_path, "exec"), lib, lib) | |
| return lib["Module"] | |
| def debug_dump_shader(ex, name, args): | |
| """Debug dump mode""" | |
| if not args.debug_dump: | |
| return | |
| target_kind = args.target.kind.default_keys[0] | |
| suffix_map = { | |
| "webgpu": ".wgsl", | |
| "cuda": ".cu", | |
| "metal": ".mtl", | |
| "opencl": ".cl", | |
| } | |
| suffix = suffix_map.get(target_kind, ".txt") | |
| dump_path = os.path.join(args.artifact_path, "debug", name + suffix) | |
| source = ex.mod.imported_modules[0].imported_modules[0].get_source() | |
| with open(dump_path, "w") as outfile: | |
| outfile.write(source) | |
| print(f"Dump shader to {dump_path}") | |
| def mod_transform_before_build( | |
| mod: tvm.IRModule, | |
| model_params: List[tvm.nd.NDArray], | |
| args: argparse.Namespace, | |
| ) -> tvm.IRModule: | |
| """First-stage: Legalize ops and trace""" | |
| model_names = ["encoding", "decoding", "create_kv_cache", "softmax_with_temperature"] | |
| if not args.no_quantize: | |
| mod = mlc_llm.transform.GroupQuantize( | |
| group_size=40 if args.quantization_mode.endswith("3") else 32, | |
| sym=args.quantization_sym, | |
| mode=args.quantization_mode, | |
| storage_nbit=args.quantization_storage_nbit, | |
| dtype=args.dtype, | |
| )(mod) | |
| mod = mlc_llm.transform.FuseTransposeMatmul()(mod) | |
| mod = relax.pipeline.get_pipeline()(mod) | |
| mod = mlc_llm.transform.FuseDecodeMatmulEwise(args.dtype, args.target_kind)(mod) | |
| mod = relax.transform.DeadCodeElimination(model_names)(mod) | |
| mod = relax.transform.LiftTransformParams()(mod) | |
| mod_transform, mod_deploy = utils.split_transform_deploy_mod(mod, model_names) | |
| debug_dump_script(mod_transform, "mod_lift_params.py", args) | |
| new_params = utils.transform_params(mod_transform, model_params) | |
| utils.save_params(new_params, args.artifact_path) | |
| return mod_deploy | |
| def build(mod_deploy: tvm.IRModule, args: argparse.Namespace) -> None: | |
| target_kind = args.target_kind | |
| debug_dump_script(mod_deploy, "mod_before_build.py", args) | |
| if target_kind != "cpu": | |
| from tvm import meta_schedule as ms | |
| if os.path.exists(args.db_path): | |
| db = ms.database.create(work_dir=args.db_path) | |
| else: | |
| db = ms.database.MemoryDatabase() | |
| with db, tvm.target.Target("apple/m1-gpu-restricted"): | |
| mod_deploy = relax.transform.MetaScheduleApplyDatabase()(mod_deploy) | |
| if args.target_kind == "android": | |
| mod_deploy = mlc_llm.dispatch.DispatchTIROperatorAdreno()(mod_deploy) | |
| mod_deploy = mlc_llm.dispatch.DispatchTIROperator(args.model_category)( | |
| mod_deploy | |
| ) | |
| mod_deploy = tvm.tir.transform.DefaultGPUSchedule()(mod_deploy) | |
| mod_deploy = tvm.tir.transform.ForceNarrowIndexToInt32()(mod_deploy) | |
| if args.debug_load_script: | |
| mod_deploy = debug_load_script("mod_build_stage_debug.py", args) | |
| debug_dump_script(mod_deploy, "mod_build_stage.py", args) | |
| ex = relax.build(mod_deploy, args.target, system_lib=args.system_lib) | |
| output_filename = f"{args.model}_{target_kind}_{args.dtype}.{args.lib_format}" | |
| debug_dump_shader(ex, f"{args.model}_{target_kind}_{args.dtype}", args) | |
| lib_path = os.path.join(args.artifact_path, output_filename) | |
| ex.export_library(lib_path, **args.export_kwargs) | |
| print(f"Finish exporting to {lib_path}") | |
| def dump_split_tir(mod: tvm.IRModule): | |
| template = """ | |
| from tvm.script import ir as I | |
| from tvm.script import tir as T | |
| # fmt: off | |
| {content} | |
| # fmt: on | |
| """ | |
| mod_static, mod_dynamic = utils.split_static_dynamic_tir(mod) | |
| static_path = os.path.join(ARGS.artifact_path, "mod_tir_static.py") | |
| dynamic_path = os.path.join(ARGS.artifact_path, "mod_tir_dynamic.py") | |
| print(f"Dump static shape TIR to {static_path}") | |
| with open(static_path, "w") as o_f: | |
| o_f.write(template.format(content=mod_static.script())) | |
| print(f"Dump dynamic shape TIR to {dynamic_path}") | |
| with open(dynamic_path, "w") as o_f: | |
| o_f.write(template.format(content=mod_dynamic.script())) | |
| if __name__ == "__main__": | |
| ARGS = _parse_args() | |
| os.makedirs(ARGS.artifact_path, exist_ok=True) | |
| os.makedirs(os.path.join(ARGS.artifact_path, "debug"), exist_ok=True) | |
| cache_path = os.path.join( | |
| ARGS.artifact_path, f"mod_cache_before_build_{ARGS.dtype}.pkl" | |
| ) | |
| use_cache = ARGS.use_cache and os.path.isfile(cache_path) | |
| if not use_cache: | |
| if ARGS.model_category == "llama": | |
| mod, params = llama.get_model(ARGS) | |
| elif ARGS.model_category == "gpt_neox": | |
| mod, params = gpt_neox.get_model(ARGS.model, ARGS.model_path, ARGS.dtype) | |
| elif ARGS.model_category == "moss": | |
| mod, params = moss.get_model(ARGS) | |
| else: | |
| raise ValueError(f"Model {ARGS.model} not supported") | |
| mod = mod_transform_before_build(mod, params, ARGS) | |
| with open(cache_path, "wb") as outfile: | |
| pickle.dump(mod, outfile) | |
| print(f"Save a cached module to {cache_path}.") | |
| else: | |
| print( | |
| f"Load cached module from {cache_path} and skip tracing. " | |
| "You can use --use-cache=0 to retrace" | |
| ) | |
| mod = pickle.load(open(cache_path, "rb")) | |
| dump_split_tir(mod) | |
| build(mod, ARGS) | |