blingBillie's picture
Upload 138 files
d443007
# Used as reference
import argparse
import os
import time
from typing import List, Tuple
import numpy as np
import torch
import tvm
from transformers import AutoTokenizer # type: ignore[import]
from tvm import relax
from tvm.relax.testing.lib_comparator import LibCompareVMInstrument
from tvm.runtime import ShapeTuple
from mlc_llm import utils
def _parse_args():
args = argparse.ArgumentParser()
utils.argparse_add_common(args)
args.add_argument("--device-name", type=str, default="auto")
args.add_argument("--debug-dump", action="store_true", default=False)
args.add_argument("--artifact-path", type=str, default="dist")
args.add_argument("--prompt", type=str, default="The capital of Canada is")
args.add_argument("--profile", action="store_true", default=False)
parsed = args.parse_args()
parsed.model_path = os.path.join(parsed.artifact_path, "models", parsed.model)
parsed.artifact_path = os.path.join(
parsed.artifact_path, parsed.model, parsed.dtype
)
utils.argparse_postproc_common(parsed)
return parsed
class LibCompare(LibCompareVMInstrument):
def __init__(self, mod, device):
super().__init__(mod, device, verbose=False)
self.time_eval_results = {}
def compare(
self,
name: str,
ref_args: List[tvm.nd.NDArray],
new_args: List[tvm.nd.NDArray],
ret_indices: List[int],
):
if name.startswith("shape_func"):
return
if name not in self.time_eval_results:
super().compare(name, ref_args, new_args, ret_indices)
res = self.mod.time_evaluator(name, dev=self.device)(*new_args).mean
self.time_eval_results[name] = (res, 1)
else:
record = self.time_eval_results[name]
self.time_eval_results[name] = (record[0], record[1] + 1)
def print_as_table(sorted_list: List[Tuple[str, Tuple[float, int]]]):
print(
"Name".ljust(50)
+ "Time (ms)".ljust(12)
+ "Count".ljust(8)
+ "Total time (ms)".ljust(18)
+ "Percentage (%)"
)
total_time = sum([record[1][0] * record[1][1] for record in sorted_list]) * 1000
for record in sorted_list:
time = record[1][0] * 1000
weighted_time = time * record[1][1]
percentage = weighted_time / total_time * 100
print(
record[0].ljust(50)
+ "{:.4f}".format(time).ljust(12)
+ str(record[1][1]).ljust(8)
+ "{:.4f}".format(weighted_time).ljust(18)
+ "{:.2f}".format(percentage)
)
print("Total time: {:.4f} ms".format(total_time))
print()
def deploy_to_pipeline(args) -> None:
device = tvm.device(args.device_name)
const_params = utils.load_params(args.artifact_path, device)
ex = tvm.runtime.load_module(
os.path.join(
args.artifact_path, f"{args.model}_{args.device_name}_{args.dtype}.so"
)
)
vm = relax.VirtualMachine(ex, device)
tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
print("Tokenizing...")
inputs = tvm.nd.array(
tokenizer(args.prompt, return_tensors="pt").input_ids.to(torch.int32).numpy(),
device,
)
first_sampled_token = tvm.nd.array(np.array([[6234]]).astype("int32"), device)
seq_len_shape = tvm.runtime.ShapeTuple([inputs.shape[1]])
second_seq_len_shape = tvm.runtime.ShapeTuple([inputs.shape[1] + 1])
kv_caches = vm["create_kv_cache"]()
# skip warm up
logits, kv_caches = vm["encoding"](inputs, seq_len_shape, kv_caches, const_params)
logits, kv_caches = vm["decoding"](
first_sampled_token, second_seq_len_shape, kv_caches, const_params
)
device.sync()
kv_caches = vm["create_kv_cache"]()
print("Running inference...")
start = time.time()
logits, kv_caches = vm["encoding"](inputs, seq_len_shape, kv_caches, const_params)
device.sync()
encoding_end = time.time()
logits, kv_caches = vm["decoding"](
first_sampled_token, second_seq_len_shape, kv_caches, const_params
)
device.sync()
end = time.time()
fcache_view = tvm.get_global_func("vm.builtin.attention_kv_cache_view")
first_k_cache = fcache_view(kv_caches[0], ShapeTuple([7, 32, 128]))
if args.debug_dump:
print(f"output kv_cache[0]:\n{first_k_cache.numpy().transpose(1, 0, 2)}")
print(f"output logits:\n{logits.numpy()}")
print(
f"Time elapsed: encoding {(encoding_end - start)} seconds, decoding {end - encoding_end} secs"
)
if args.profile:
cmp_instrument = LibCompare(ex, device)
vm.set_instrument(cmp_instrument)
print("Profiling...")
kv_caches = vm["create_kv_cache"]()
logits, kv_caches = vm["encoding"](
inputs, seq_len_shape, kv_caches, const_params
)
print("======================= Encoding Profiling =======================")
print_as_table(
sorted(
cmp_instrument.time_eval_results.items(),
key=lambda x: -(x[1][0] * x[1][1]),
)
)
cmp_instrument.time_eval_results.clear()
logits, kv_caches = vm["decoding"](
first_sampled_token, second_seq_len_shape, kv_caches, const_params
)
print("======================= Decoding Profiling =======================")
print_as_table(
sorted(
cmp_instrument.time_eval_results.items(),
key=lambda x: -(x[1][0] * x[1][1]),
)
)
if __name__ == "__main__":
ARGS = _parse_args()
deploy_to_pipeline(ARGS)