| from typing import List | |
| import argparse | |
| import os | |
| import tvm | |
| from tvm import relax | |
| from tvm.runtime import ShapeTuple | |
| from tvm import rpc | |
| from tvm.relax.testing.lib_comparator import LibCompareVMInstrument | |
| import numpy as np | |
| import torch | |
| from transformers import AutoTokenizer | |
| from mlc_llm import utils | |
| class LibCompare(LibCompareVMInstrument): | |
| def __init__(self, mod, device, time_eval, skip_rounds=0): | |
| super().__init__(mod, device, True) | |
| self.time_eval = time_eval | |
| self.time_eval_results = {} | |
| self.visited = set([]) | |
| self.skip_rounds = skip_rounds | |
| self.atol = 1e-2 | |
| self.rtol = 1e-3 | |
| def skip_instrument(self, func, name, before_run, ret_val, *args): | |
| print(f"run {name}") | |
| if name.startswith("shape_func"): | |
| return True | |
| if self.counter < self.skip_rounds: | |
| self.counter += 1 | |
| print(f"[{self.counter}] Skip validating {name}..") | |
| return True | |
| if name in self.visited: | |
| if self.time_eval and name in self.time_eval_results: | |
| record = self.time_eval_results[name] | |
| self.time_eval_results[name] = (record[0], record[1] + 1) | |
| return True | |
| self.visited.add(name) | |
| return False | |
| def compare( | |
| self, | |
| name: str, | |
| ref_args: List[tvm.nd.NDArray], | |
| new_args: List[tvm.nd.NDArray], | |
| ret_indices: List[int], | |
| ): | |
| super().compare(name, ref_args, new_args, ret_indices) | |
| if self.time_eval and name not in self.time_eval_results: | |
| res = self.mod.time_evaluator(name, self.device)(*new_args) | |
| self.time_eval_results[name] = (res.mean, 1) | |
| print(f"Time-eval result {name} on {self.device}: {res}") | |
| def print_as_table(sorted_list): | |
| print( | |
| "Name".ljust(50) | |
| + "Time (ms)".ljust(12) | |
| + "Count".ljust(8) | |
| + "Total time (ms)".ljust(18) | |
| + "Percentage (%)" | |
| ) | |
| total_time = sum([record[1][0] * record[1][1] for record in sorted_list]) * 1000 | |
| for record in sorted_list: | |
| time = record[1][0] * 1000 | |
| weighted_time = time * record[1][1] | |
| percentage = weighted_time / total_time * 100 | |
| print( | |
| record[0].ljust(50) | |
| + "{:.4f}".format(time).ljust(12) | |
| + str(record[1][1]).ljust(8) | |
| + "{:.4f}".format(weighted_time).ljust(18) | |
| + "{:.2f}".format(percentage) | |
| ) | |
| print("Total time: {:.4f} ms".format(total_time)) | |
| print() | |
| class TestState: | |
| def __init__(self, args): | |
| self.primary_device = tvm.device(args.primary_device) | |
| ex = tvm.runtime.load_module( | |
| os.path.join( | |
| args.artifact_path, | |
| f"{args.model}_{args.primary_device}_{args.dtype}.so", | |
| ) | |
| ) | |
| self.vm = relax.VirtualMachine(ex, self.primary_device) | |
| if args.cmp_device == "iphone": | |
| lib_name = f"{args.model}_{args.cmp_device}_{args.dtype}.dylib" | |
| local_lib_path = os.path.join(args.artifact_path, lib_name) | |
| proxy_host = os.environ.get("TVM_RPC_PROXY_HOST", "127.0.0.1") | |
| proxy_port = int(os.environ.get("TVM_RPC_PROXY_PORT", "9090")) | |
| self.sess = rpc.connect(proxy_host, proxy_port, "iphone") | |
| self.sess.upload(local_lib_path) | |
| self.lib = self.sess.load_module(lib_name) | |
| self.cmp_device = self.sess.metal() | |
| elif args.cmp_device == "android": | |
| lib_name = f"{args.model}_{args.cmp_device}_{args.dtype}.so" | |
| local_lib_path = os.path.join(args.artifact_path, lib_name) | |
| tracker_host = os.environ.get("TVM_TRACKER_HOST", "0.0.0.0") | |
| tracker_port = int(os.environ.get("TVM_TRACKER_PORT", "9190")) | |
| tracker = rpc.connect_tracker(tracker_host, tracker_port) | |
| self.sess = tracker.request("android") | |
| self.sess.upload(local_lib_path) | |
| self.lib = self.sess.load_module(lib_name) | |
| self.cmp_device = self.sess.cl(0) | |
| else: | |
| self.sess = None | |
| self.lib = tvm.runtime.load_module( | |
| os.path.join( | |
| args.artifact_path, | |
| f"{args.model}_{args.cmp_device}_{args.dtype}.so", | |
| ) | |
| ) | |
| self.cmp_device = tvm.device(args.cmp_device) | |
| self.const_params_dict = utils.load_params( | |
| args.artifact_path, self.primary_device | |
| ) | |
| self.cmp_instrument = LibCompare( | |
| self.lib, | |
| self.cmp_device, | |
| time_eval=args.time_eval, | |
| skip_rounds=args.skip_rounds, | |
| ) | |
| self.vm.set_instrument(self.cmp_instrument) | |
| def deploy_to_pipeline(args) -> None: | |
| primary_device = tvm.device(args.primary_device) | |
| const_params = utils.load_params(args.artifact_path, primary_device) | |
| state = TestState(args) | |
| tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True) | |
| print("Tokenizing...") | |
| inputs = tvm.nd.array( | |
| tokenizer(args.prompt, return_tensors="pt").input_ids.to(torch.int32).numpy(), | |
| primary_device, | |
| ) | |
| first_sampled_token = tvm.nd.array( | |
| np.array([[6234]]).astype("int32"), primary_device | |
| ) | |
| seq_len_shape = tvm.runtime.ShapeTuple([inputs.shape[1]]) | |
| second_seq_len_shape = tvm.runtime.ShapeTuple([inputs.shape[1] + 1]) | |
| kv_caches = state.vm["create_kv_cache"]() | |
| print("Running inference...") | |
| print("======================= Starts Encoding =======================") | |
| logits, kv_caches = state.vm["encoding"]( | |
| inputs, seq_len_shape, kv_caches, const_params | |
| ) | |
| print_as_table( | |
| sorted( | |
| state.cmp_instrument.time_eval_results.items(), | |
| key=lambda x: -(x[1][0] * x[1][1]), | |
| ) | |
| ) | |
| state.cmp_instrument.time_eval_results.clear() | |
| state.cmp_instrument.visited.clear() | |
| print("======================= Starts Decoding =======================") | |
| logits, kv_caches = state.vm["decoding"]( | |
| first_sampled_token, second_seq_len_shape, kv_caches, const_params | |
| ) | |
| print_as_table( | |
| sorted( | |
| state.cmp_instrument.time_eval_results.items(), | |
| key=lambda x: -(x[1][0] * x[1][1]), | |
| ) | |
| ) | |
| state.cmp_instrument.time_eval_results.clear() | |
| def _parse_args(): | |
| args = argparse.ArgumentParser() | |
| args.add_argument("--artifact-path", type=str, default="dist") | |
| args.add_argument("--primary-device", type=str, default="auto") | |
| args.add_argument("--cmp-device", type=str, required=True) | |
| args.add_argument("--prompt", type=str, default="The capital of Canada is") | |
| args.add_argument("--model", type=str, default="vicuna-v1-7b") | |
| args.add_argument( | |
| "--dtype", type=str, choices=["float32", "float16"], default="float16" | |
| ) | |
| args.add_argument("--time-eval", default=False, action="store_true") | |
| args.add_argument("--skip-rounds", type=int, default=0) | |
| parsed = args.parse_args() | |
| parsed.model_path = os.path.join(parsed.artifact_path, "models", parsed.model) | |
| parsed.artifact_path = os.path.join( | |
| parsed.artifact_path, parsed.model, parsed.dtype | |
| ) | |
| if parsed.primary_device == "auto": | |
| if tvm.cuda().exist: | |
| parsed.primary_device = "cuda" | |
| elif tvm.metal().exist: | |
| parsed.primary_device = "metal" | |
| else: | |
| raise ValueError("Cannot auto deduce device-name, please set it") | |
| return parsed | |
| if __name__ == "__main__": | |
| args = _parse_args() | |
| deploy_to_pipeline(args) | |