from typing import List import argparse import os import tvm from tvm import relax from tvm.runtime import ShapeTuple from tvm import rpc from tvm.relax.testing.lib_comparator import LibCompareVMInstrument import numpy as np import torch from transformers import AutoTokenizer from mlc_llm import utils class LibCompare(LibCompareVMInstrument): def __init__(self, mod, device, time_eval, skip_rounds=0): super().__init__(mod, device, True) self.time_eval = time_eval self.time_eval_results = {} self.visited = set([]) self.skip_rounds = skip_rounds self.atol = 1e-2 self.rtol = 1e-3 def skip_instrument(self, func, name, before_run, ret_val, *args): print(f"run {name}") if name.startswith("shape_func"): return True if self.counter < self.skip_rounds: self.counter += 1 print(f"[{self.counter}] Skip validating {name}..") return True if name in self.visited: if self.time_eval and name in self.time_eval_results: record = self.time_eval_results[name] self.time_eval_results[name] = (record[0], record[1] + 1) return True self.visited.add(name) return False def compare( self, name: str, ref_args: List[tvm.nd.NDArray], new_args: List[tvm.nd.NDArray], ret_indices: List[int], ): super().compare(name, ref_args, new_args, ret_indices) if self.time_eval and name not in self.time_eval_results: res = self.mod.time_evaluator(name, self.device)(*new_args) self.time_eval_results[name] = (res.mean, 1) print(f"Time-eval result {name} on {self.device}: {res}") def print_as_table(sorted_list): print( "Name".ljust(50) + "Time (ms)".ljust(12) + "Count".ljust(8) + "Total time (ms)".ljust(18) + "Percentage (%)" ) total_time = sum([record[1][0] * record[1][1] for record in sorted_list]) * 1000 for record in sorted_list: time = record[1][0] * 1000 weighted_time = time * record[1][1] percentage = weighted_time / total_time * 100 print( record[0].ljust(50) + "{:.4f}".format(time).ljust(12) + str(record[1][1]).ljust(8) + "{:.4f}".format(weighted_time).ljust(18) + "{:.2f}".format(percentage) ) print("Total time: {:.4f} ms".format(total_time)) print() class TestState: def __init__(self, args): self.primary_device = tvm.device(args.primary_device) ex = tvm.runtime.load_module( os.path.join( args.artifact_path, f"{args.model}_{args.primary_device}_{args.dtype}.so", ) ) self.vm = relax.VirtualMachine(ex, self.primary_device) if args.cmp_device == "iphone": lib_name = f"{args.model}_{args.cmp_device}_{args.dtype}.dylib" local_lib_path = os.path.join(args.artifact_path, lib_name) proxy_host = os.environ.get("TVM_RPC_PROXY_HOST", "127.0.0.1") proxy_port = int(os.environ.get("TVM_RPC_PROXY_PORT", "9090")) self.sess = rpc.connect(proxy_host, proxy_port, "iphone") self.sess.upload(local_lib_path) self.lib = self.sess.load_module(lib_name) self.cmp_device = self.sess.metal() elif args.cmp_device == "android": lib_name = f"{args.model}_{args.cmp_device}_{args.dtype}.so" local_lib_path = os.path.join(args.artifact_path, lib_name) tracker_host = os.environ.get("TVM_TRACKER_HOST", "0.0.0.0") tracker_port = int(os.environ.get("TVM_TRACKER_PORT", "9190")) tracker = rpc.connect_tracker(tracker_host, tracker_port) self.sess = tracker.request("android") self.sess.upload(local_lib_path) self.lib = self.sess.load_module(lib_name) self.cmp_device = self.sess.cl(0) else: self.sess = None self.lib = tvm.runtime.load_module( os.path.join( args.artifact_path, f"{args.model}_{args.cmp_device}_{args.dtype}.so", ) ) self.cmp_device = tvm.device(args.cmp_device) self.const_params_dict = utils.load_params( args.artifact_path, self.primary_device ) self.cmp_instrument = LibCompare( self.lib, self.cmp_device, time_eval=args.time_eval, skip_rounds=args.skip_rounds, ) self.vm.set_instrument(self.cmp_instrument) def deploy_to_pipeline(args) -> None: primary_device = tvm.device(args.primary_device) const_params = utils.load_params(args.artifact_path, primary_device) state = TestState(args) tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True) print("Tokenizing...") inputs = tvm.nd.array( tokenizer(args.prompt, return_tensors="pt").input_ids.to(torch.int32).numpy(), primary_device, ) first_sampled_token = tvm.nd.array( np.array([[6234]]).astype("int32"), primary_device ) seq_len_shape = tvm.runtime.ShapeTuple([inputs.shape[1]]) second_seq_len_shape = tvm.runtime.ShapeTuple([inputs.shape[1] + 1]) kv_caches = state.vm["create_kv_cache"]() print("Running inference...") print("======================= Starts Encoding =======================") logits, kv_caches = state.vm["encoding"]( inputs, seq_len_shape, kv_caches, const_params ) print_as_table( sorted( state.cmp_instrument.time_eval_results.items(), key=lambda x: -(x[1][0] * x[1][1]), ) ) state.cmp_instrument.time_eval_results.clear() state.cmp_instrument.visited.clear() print("======================= Starts Decoding =======================") logits, kv_caches = state.vm["decoding"]( first_sampled_token, second_seq_len_shape, kv_caches, const_params ) print_as_table( sorted( state.cmp_instrument.time_eval_results.items(), key=lambda x: -(x[1][0] * x[1][1]), ) ) state.cmp_instrument.time_eval_results.clear() def _parse_args(): args = argparse.ArgumentParser() args.add_argument("--artifact-path", type=str, default="dist") args.add_argument("--primary-device", type=str, default="auto") args.add_argument("--cmp-device", type=str, required=True) args.add_argument("--prompt", type=str, default="The capital of Canada is") args.add_argument("--model", type=str, default="vicuna-v1-7b") args.add_argument( "--dtype", type=str, choices=["float32", "float16"], default="float16" ) args.add_argument("--time-eval", default=False, action="store_true") args.add_argument("--skip-rounds", type=int, default=0) parsed = args.parse_args() parsed.model_path = os.path.join(parsed.artifact_path, "models", parsed.model) parsed.artifact_path = os.path.join( parsed.artifact_path, parsed.model, parsed.dtype ) if parsed.primary_device == "auto": if tvm.cuda().exist: parsed.primary_device = "cuda" elif tvm.metal().exist: parsed.primary_device = "metal" else: raise ValueError("Cannot auto deduce device-name, please set it") return parsed if __name__ == "__main__": args = _parse_args() deploy_to_pipeline(args)