| | import torch |
| | import argparse |
| |
|
| | from HybridTensor.utils.utils import _get_device |
| | from HybridTensor.utils.activations import OPT_MODELS |
| | from HybridTensor.models.opt import SparseConfig, build_sparse_opt |
| | from HybridTensor.benchmarks.generation.gen_util import tokenize_dataset, get_random_batch |
| | from HybridTensor.utils.activations import build_mlp_topk_lookup |
| | from HybridTensor.routers.mlp.mlp_router_optim import load_router_dict_from_csv |
| |
|
| | from datasets import load_dataset |
| |
|
| | from transformers.models.opt import OPTConfig |
| | from transformers import AutoTokenizer |
| | from flash_attn.models.opt import opt_config_to_gpt2_config |
| | from flash_attn.utils.generation import update_graph_cache |
| |
|
| | def arg_parser(): |
| | parser = argparse.ArgumentParser(description='Inference benchmarking') |
| | parser.add_argument('--batch_size', type=int, default=16) |
| | parser.add_argument('--model_index', type=int, default=5) |
| | parser.add_argument('--seq_len', type=int, default=1024) |
| | parser.add_argument('--index_size', type=int, default=8192) |
| | parser.add_argument('--head_density', type=float, default=0.5) |
| | parser.add_argument('--print_results', type=bool, default=True) |
| | parser.add_argument('--iterations', type=int, default=1) |
| | parser.add_argument('--check_results', type=bool, default=False) |
| | parser.add_argument('--results_dir', type=str, default='results') |
| | parser.add_argument('--gpu', type=int, default=0) |
| | parser.add_argument('--attn_topk', type=float, default=0.5, help='Attention topk for sparse model') |
| | parser.add_argument('--mlp_ckpt_dir', type=str, default='<PATH_TO_MLP_ROUTER_CHECKPOINTS>') |
| | parser.add_argument('--attn_ckpt_dir', type=str, default='<PATH_TO_ATTENTION_CHECKPOINTS>') |
| | parser.add_argument('--batch_stats_dir', type=str, default='configs/mlp_router/opt-6.7b') |
| | parser.add_argument('--delta', type=int, default=256, help='Delta value for MLP topk calculation') |
| | parser.add_argument('--use_cuda_graph', type=bool, default=False, help='Use CUDA graph for inference') |
| |
|
| | return parser.parse_args() |
| |
|
| | def update_router_config(model, num_layers, mlp_topk_lookup, attn_topk): |
| | for i in range(num_layers): |
| | model.transformer.layers[i].mlp_topk = mlp_topk_lookup[i] |
| | |
| | model.transformer.layers[i].mha_router.topk = attn_topk |
| | |
| | |
| | model.transformer.layers[0].mha_router.topk = 1.0 |
| |
|
| | if __name__ == "__main__": |
| | args = arg_parser() |
| | model_name = OPT_MODELS[args.model_index-1] |
| | dtype = torch.float16 |
| | device= _get_device(args.gpu) |
| | |
| | tokenizer = AutoTokenizer.from_pretrained(model_name) |
| | |
| | |
| | |
| | |
| | model = build_sparse_opt(args, model_name, args.mlp_ckpt_dir, args.attn_ckpt_dir, device = device, dtype=dtype) |
| | model.eval() |
| | print(model) |
| | print("Model loaded with sparse routers") |
| | |
| | |
| | mlp_topk_lookup = load_router_dict_from_csv(args.batch_stats_dir, args.batch_size) |
| | print("MLP topk values updated: ", mlp_topk_lookup) |
| | update_router_config(model, model.config.n_layer, mlp_topk_lookup, args.attn_topk) |
| | |
| | print("Router config updated \n") |
| | |
| | |
| | max_length = args.seq_len + 20 |
| | batch_size = args.batch_size |
| | |
| | |
| | |
| | |
| | |
| | |
| | dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") |
| | tokens = tokenize_dataset(dataset, tokenizer) |
| | input_ids = get_random_batch(tokens, args.batch_size, args.seq_len).to(device) |
| | |
| | print("Input ids generated, starting inference") |
| | |
| | |
| | position_ids = None |
| | eos_token_id = tokenizer.eos_token_id |
| | |
| | start_event = torch.cuda.Event(enable_timing=True) |
| | end_event = torch.cuda.Event(enable_timing=True) |
| | |
| | with torch.no_grad(): |
| | |
| | _ = model.generate( |
| | input_ids=input_ids, |
| | max_length=max_length, |
| | eos_token_id=eos_token_id, |
| | return_dict_in_generate=True, |
| | output_scores=True, |
| | enable_timing=False, |
| | cg=False, |
| | ) |
| | |
| | print("Warm up done") |
| | |
| | start_event.record() |
| | for i in range(args.iterations): |
| | out = model.generate( |
| | input_ids=input_ids, |
| | max_length=max_length, |
| | eos_token_id=eos_token_id, |
| | return_dict_in_generate=True, |
| | output_scores=True, |
| | enable_timing=False, |
| | cg=False, |
| | ) |
| | |
| | end_event.record() |
| | |
| | torch.cuda.synchronize() |
| | print("Without CUDA graph") |
| | elapsed_time = start_event.elapsed_time(end_event) / args.iterations |
| | print(f"Average time per genearation : {elapsed_time:.1f} ms") |
| | |
| | |
| | num_tokens_generated = out.sequences.shape[1] - input_ids.shape[1] |
| | throughput = batch_size * num_tokens_generated / (elapsed_time / 1000) |
| | latency_per_token = elapsed_time / num_tokens_generated |
| |
|
| | print(f"Number of tokens generated: {num_tokens_generated}") |
| | print(f"Throughput: {throughput:.1f} tokens/second") |
| | print(f"Latency per token: {latency_per_token:.1f} ms") |
| |
|
| | |
| | print("\n") |
| | |
| | |
| | print("New tokens generated:") |
| | print(tokenizer.batch_decode(out.sequences[:, input_ids.shape[1]:].tolist())) |
| | |
| | |
| | if args.use_cuda_graph: |
| | batch_size, seqlen_og = input_ids.shape |
| | model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length) |
| | print("With CUDA graph") |
| | torch.cuda.synchronize() |
| | |
| | start_event.record() |
| | |
| | for i in range(args.iterations): |
| | out = model.generate( |
| | input_ids=input_ids, |
| | max_length=max_length, |
| | cg=True, |
| | eos_token_id=eos_token_id, |
| | return_dict_in_generate=True, |
| | output_scores=True, |
| | enable_timing=False, |
| | ) |
| | |
| | end_event.record() |
| | |
| | torch.cuda.synchronize() |
| | |
| | |
| | elapsed_time = start_event.elapsed_time(end_event) / args.iterations |
| | print(f"Average time per genearation : {elapsed_time:.1f} ms") |
| | |
| | |
| | num_tokens_generated = out.sequences.shape[1] - input_ids.shape[1] |
| | throughput = batch_size * num_tokens_generated / (elapsed_time / 1000) |
| | latency_per_token = elapsed_time / num_tokens_generated |
| |
|
| | print(f"Number of tokens generated: {num_tokens_generated}") |
| | print(f"Throughput: {throughput:.1f} tokens/second") |
| | print(f"Latency per token: {latency_per_token:.1f} ms") |
| |
|
| | |
| | print("New tokens generated:") |
| | print(tokenizer.batch_decode(out.sequences[:, input_ids.shape[1]:].tolist())) |
| | |