|
|
|
|
|
|
|
|
import argparse |
|
|
import os |
|
|
import time |
|
|
from typing import List, Tuple |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import tvm |
|
|
from transformers import AutoTokenizer |
|
|
from tvm import relax |
|
|
from tvm.relax.testing.lib_comparator import LibCompareVMInstrument |
|
|
from tvm.runtime import ShapeTuple |
|
|
|
|
|
from mlc_llm import utils |
|
|
|
|
|
|
|
|
def _parse_args(): |
|
|
args = argparse.ArgumentParser() |
|
|
utils.argparse_add_common(args) |
|
|
args.add_argument("--device-name", type=str, default="auto") |
|
|
args.add_argument("--debug-dump", action="store_true", default=False) |
|
|
args.add_argument("--artifact-path", type=str, default="dist") |
|
|
args.add_argument("--prompt", type=str, default="The capital of Canada is") |
|
|
args.add_argument("--profile", action="store_true", default=False) |
|
|
parsed = args.parse_args() |
|
|
parsed.model_path = os.path.join(parsed.artifact_path, "models", parsed.model) |
|
|
parsed.artifact_path = os.path.join( |
|
|
parsed.artifact_path, parsed.model, parsed.dtype |
|
|
) |
|
|
utils.argparse_postproc_common(parsed) |
|
|
return parsed |
|
|
|
|
|
|
|
|
class LibCompare(LibCompareVMInstrument): |
|
|
def __init__(self, mod, device): |
|
|
super().__init__(mod, device, verbose=False) |
|
|
self.time_eval_results = {} |
|
|
|
|
|
def compare( |
|
|
self, |
|
|
name: str, |
|
|
ref_args: List[tvm.nd.NDArray], |
|
|
new_args: List[tvm.nd.NDArray], |
|
|
ret_indices: List[int], |
|
|
): |
|
|
if name.startswith("shape_func"): |
|
|
return |
|
|
if name not in self.time_eval_results: |
|
|
super().compare(name, ref_args, new_args, ret_indices) |
|
|
res = self.mod.time_evaluator(name, dev=self.device)(*new_args).mean |
|
|
self.time_eval_results[name] = (res, 1) |
|
|
else: |
|
|
record = self.time_eval_results[name] |
|
|
self.time_eval_results[name] = (record[0], record[1] + 1) |
|
|
|
|
|
|
|
|
def print_as_table(sorted_list: List[Tuple[str, Tuple[float, int]]]): |
|
|
print( |
|
|
"Name".ljust(50) |
|
|
+ "Time (ms)".ljust(12) |
|
|
+ "Count".ljust(8) |
|
|
+ "Total time (ms)".ljust(18) |
|
|
+ "Percentage (%)" |
|
|
) |
|
|
total_time = sum([record[1][0] * record[1][1] for record in sorted_list]) * 1000 |
|
|
for record in sorted_list: |
|
|
time = record[1][0] * 1000 |
|
|
weighted_time = time * record[1][1] |
|
|
percentage = weighted_time / total_time * 100 |
|
|
print( |
|
|
record[0].ljust(50) |
|
|
+ "{:.4f}".format(time).ljust(12) |
|
|
+ str(record[1][1]).ljust(8) |
|
|
+ "{:.4f}".format(weighted_time).ljust(18) |
|
|
+ "{:.2f}".format(percentage) |
|
|
) |
|
|
print("Total time: {:.4f} ms".format(total_time)) |
|
|
print() |
|
|
|
|
|
|
|
|
def deploy_to_pipeline(args) -> None: |
|
|
device = tvm.device(args.device_name) |
|
|
const_params = utils.load_params(args.artifact_path, device) |
|
|
ex = tvm.runtime.load_module( |
|
|
os.path.join( |
|
|
args.artifact_path, f"{args.model}_{args.device_name}_{args.dtype}.so" |
|
|
) |
|
|
) |
|
|
vm = relax.VirtualMachine(ex, device) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True) |
|
|
|
|
|
print("Tokenizing...") |
|
|
inputs = tvm.nd.array( |
|
|
tokenizer(args.prompt, return_tensors="pt").input_ids.to(torch.int32).numpy(), |
|
|
device, |
|
|
) |
|
|
first_sampled_token = tvm.nd.array(np.array([[6234]]).astype("int32"), device) |
|
|
seq_len_shape = tvm.runtime.ShapeTuple([inputs.shape[1]]) |
|
|
second_seq_len_shape = tvm.runtime.ShapeTuple([inputs.shape[1] + 1]) |
|
|
kv_caches = vm["create_kv_cache"]() |
|
|
|
|
|
|
|
|
logits, kv_caches = vm["encoding"](inputs, seq_len_shape, kv_caches, const_params) |
|
|
logits, kv_caches = vm["decoding"]( |
|
|
first_sampled_token, second_seq_len_shape, kv_caches, const_params |
|
|
) |
|
|
device.sync() |
|
|
|
|
|
kv_caches = vm["create_kv_cache"]() |
|
|
print("Running inference...") |
|
|
start = time.time() |
|
|
logits, kv_caches = vm["encoding"](inputs, seq_len_shape, kv_caches, const_params) |
|
|
device.sync() |
|
|
encoding_end = time.time() |
|
|
logits, kv_caches = vm["decoding"]( |
|
|
first_sampled_token, second_seq_len_shape, kv_caches, const_params |
|
|
) |
|
|
device.sync() |
|
|
end = time.time() |
|
|
fcache_view = tvm.get_global_func("vm.builtin.attention_kv_cache_view") |
|
|
first_k_cache = fcache_view(kv_caches[0], ShapeTuple([7, 32, 128])) |
|
|
if args.debug_dump: |
|
|
print(f"output kv_cache[0]:\n{first_k_cache.numpy().transpose(1, 0, 2)}") |
|
|
print(f"output logits:\n{logits.numpy()}") |
|
|
print( |
|
|
f"Time elapsed: encoding {(encoding_end - start)} seconds, decoding {end - encoding_end} secs" |
|
|
) |
|
|
|
|
|
if args.profile: |
|
|
cmp_instrument = LibCompare(ex, device) |
|
|
vm.set_instrument(cmp_instrument) |
|
|
|
|
|
print("Profiling...") |
|
|
kv_caches = vm["create_kv_cache"]() |
|
|
|
|
|
logits, kv_caches = vm["encoding"]( |
|
|
inputs, seq_len_shape, kv_caches, const_params |
|
|
) |
|
|
print("======================= Encoding Profiling =======================") |
|
|
print_as_table( |
|
|
sorted( |
|
|
cmp_instrument.time_eval_results.items(), |
|
|
key=lambda x: -(x[1][0] * x[1][1]), |
|
|
) |
|
|
) |
|
|
cmp_instrument.time_eval_results.clear() |
|
|
|
|
|
logits, kv_caches = vm["decoding"]( |
|
|
first_sampled_token, second_seq_len_shape, kv_caches, const_params |
|
|
) |
|
|
print("======================= Decoding Profiling =======================") |
|
|
print_as_table( |
|
|
sorted( |
|
|
cmp_instrument.time_eval_results.items(), |
|
|
key=lambda x: -(x[1][0] * x[1][1]), |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
ARGS = _parse_args() |
|
|
deploy_to_pipeline(ARGS) |
|
|
|