File size: 8,134 Bytes
b3a3b15 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
import torch
import argparse
from HybridTensor.utils.utils import _get_device
from HybridTensor.utils.activations import OPT_MODELS
from HybridTensor.models.opt import SparseConfig, build_sparse_opt
from HybridTensor.benchmarks.generation.gen_util import tokenize_dataset, get_random_batch
from HybridTensor.utils.activations import build_mlp_topk_lookup
from HybridTensor.routers.mlp.mlp_router_optim import load_router_dict_from_csv
from datasets import load_dataset
from transformers.models.opt import OPTConfig
from transformers import AutoTokenizer
from flash_attn.models.opt import opt_config_to_gpt2_config
from flash_attn.utils.generation import update_graph_cache
def arg_parser():
parser = argparse.ArgumentParser(description='Inference benchmarking')
parser.add_argument('--batch_size', type=int, default=16)
parser.add_argument('--model_index', type=int, default=5)
parser.add_argument('--seq_len', type=int, default=1024)
parser.add_argument('--index_size', type=int, default=8192)
parser.add_argument('--head_density', type=float, default=0.5)
parser.add_argument('--print_results', type=bool, default=True)
parser.add_argument('--iterations', type=int, default=1)
parser.add_argument('--check_results', type=bool, default=False)
parser.add_argument('--results_dir', type=str, default='results')
parser.add_argument('--gpu', type=int, default=0)
parser.add_argument('--attn_topk', type=float, default=0.5, help='Attention topk for sparse model')
parser.add_argument('--mlp_ckpt_dir', type=str, default='<PATH_TO_MLP_ROUTER_CHECKPOINTS>')
parser.add_argument('--attn_ckpt_dir', type=str, default='<PATH_TO_ATTENTION_CHECKPOINTS>')
parser.add_argument('--batch_stats_dir', type=str, default='configs/mlp_router/opt-6.7b')
parser.add_argument('--delta', type=int, default=256, help='Delta value for MLP topk calculation')
parser.add_argument('--use_cuda_graph', type=bool, default=False, help='Use CUDA graph for inference')
return parser.parse_args()
def update_router_config(model, num_layers, mlp_topk_lookup, attn_topk):
for i in range(num_layers):
model.transformer.layers[i].mlp_topk = mlp_topk_lookup[i]
# model.transformer.layers[i].mlp_topk = 512
model.transformer.layers[i].mha_router.topk = attn_topk
# model.transformer.layers[i].skip_mlp_router = True
model.transformer.layers[0].mha_router.topk = 1.0 # dense attention in layer 0
if __name__ == "__main__":
args = arg_parser()
model_name = OPT_MODELS[args.model_index-1]
dtype = torch.float16
device= _get_device(args.gpu)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# args.mlp_ckpt_dir = None
# args.attn_ckpt_dir = None
model = build_sparse_opt(args, model_name, args.mlp_ckpt_dir, args.attn_ckpt_dir, device = device, dtype=dtype)
model.eval()
print(model)
print("Model loaded with sparse routers")
# mlp_topk_lookup = build_mlp_topk_lookup("results/mlp_results/batch_activations/opt-6.7b", args.batch_size, args.delta)
mlp_topk_lookup = load_router_dict_from_csv(args.batch_stats_dir, args.batch_size)
print("MLP topk values updated: ", mlp_topk_lookup)
update_router_config(model, model.config.n_layer, mlp_topk_lookup, args.attn_topk) # this sets the router config for all layers using a single config
# update_router_config(model, model.config.n_layer, 2048, args.attn_topk)
print("Router config updated \n")
max_length = args.seq_len + 20
batch_size = args.batch_size
# input_texts = ["Hello, my dog is cute and", "The future of AI is", "In a distant galaxy, a spaceship", "The cat is sleeping on the "]
# input_texts = ["In a distant galaxy, a spaceship"]
# tokenized_inputs = tokenizer(input_texts, return_tensors="pt", padding=True, truncation=False).to(device)
# input_ids=tokenized_inputs["input_ids"]
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
tokens = tokenize_dataset(dataset, tokenizer)
input_ids = get_random_batch(tokens, args.batch_size, args.seq_len).to(device)
print("Input ids generated, starting inference")
# input_ids = tokenizer("Hello, my dog is cute and he", return_tensors="pt").input_ids.to(device)
position_ids = None
eos_token_id = tokenizer.eos_token_id
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
with torch.no_grad():
# warm up
_ = model.generate(
input_ids=input_ids,
max_length=max_length,
eos_token_id=eos_token_id,
return_dict_in_generate=True,
output_scores=True,
enable_timing=False,
cg=False,
)
print("Warm up done")
start_event.record()
for i in range(args.iterations):
out = model.generate(
input_ids=input_ids,
max_length=max_length,
eos_token_id=eos_token_id,
return_dict_in_generate=True,
output_scores=True,
enable_timing=False,
cg=False,
)
end_event.record()
torch.cuda.synchronize()
print("Without CUDA graph")
elapsed_time = start_event.elapsed_time(end_event) / args.iterations
print(f"Average time per genearation : {elapsed_time:.1f} ms")
# Compute throughput and latency per token
num_tokens_generated = out.sequences.shape[1] - input_ids.shape[1]
throughput = batch_size * num_tokens_generated / (elapsed_time / 1000) # tokens per second
latency_per_token = elapsed_time / num_tokens_generated # ms per token
print(f"Number of tokens generated: {num_tokens_generated}")
print(f"Throughput: {throughput:.1f} tokens/second")
print(f"Latency per token: {latency_per_token:.1f} ms")
# print(tokenizer.batch_decode(out.sequences.tolist()))
print("\n")
# print only the new tokens generated
print("New tokens generated:")
print(tokenizer.batch_decode(out.sequences[:, input_ids.shape[1]:].tolist()))
# ====================== With CUDA graph ======================
if args.use_cuda_graph:
batch_size, seqlen_og = input_ids.shape
model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
print("With CUDA graph")
torch.cuda.synchronize()
start_event.record()
for i in range(args.iterations):
out = model.generate(
input_ids=input_ids,
max_length=max_length,
cg=True,
eos_token_id=eos_token_id,
return_dict_in_generate=True,
output_scores=True,
enable_timing=False,
)
end_event.record()
torch.cuda.synchronize()
elapsed_time = start_event.elapsed_time(end_event) / args.iterations
print(f"Average time per genearation : {elapsed_time:.1f} ms")
# Compute throughput and latency per token
num_tokens_generated = out.sequences.shape[1] - input_ids.shape[1]
throughput = batch_size * num_tokens_generated / (elapsed_time / 1000) # tokens per second
latency_per_token = elapsed_time / num_tokens_generated # ms per token
print(f"Number of tokens generated: {num_tokens_generated}")
print(f"Throughput: {throughput:.1f} tokens/second")
print(f"Latency per token: {latency_per_token:.1f} ms")
# print(tokenizer.batch_decode(out.sequences.tolist()))
print("New tokens generated:")
print(tokenizer.batch_decode(out.sequences[:, input_ids.shape[1]:].tolist()))
|