File size: 8,134 Bytes
b3a3b15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
import torch
import argparse

from HybridTensor.utils.utils import _get_device
from HybridTensor.utils.activations import OPT_MODELS
from HybridTensor.models.opt import SparseConfig, build_sparse_opt 
from HybridTensor.benchmarks.generation.gen_util import tokenize_dataset, get_random_batch
from HybridTensor.utils.activations import build_mlp_topk_lookup
from HybridTensor.routers.mlp.mlp_router_optim import load_router_dict_from_csv

from datasets import load_dataset

from transformers.models.opt import OPTConfig
from transformers import AutoTokenizer
from flash_attn.models.opt import opt_config_to_gpt2_config
from flash_attn.utils.generation import update_graph_cache

def arg_parser():
    parser = argparse.ArgumentParser(description='Inference benchmarking')
    parser.add_argument('--batch_size', type=int, default=16)
    parser.add_argument('--model_index', type=int, default=5)
    parser.add_argument('--seq_len', type=int, default=1024)
    parser.add_argument('--index_size', type=int, default=8192)
    parser.add_argument('--head_density', type=float, default=0.5)
    parser.add_argument('--print_results', type=bool, default=True)
    parser.add_argument('--iterations', type=int, default=1)
    parser.add_argument('--check_results', type=bool, default=False)
    parser.add_argument('--results_dir', type=str, default='results')
    parser.add_argument('--gpu', type=int, default=0)
    parser.add_argument('--attn_topk', type=float, default=0.5, help='Attention topk for sparse model')
    parser.add_argument('--mlp_ckpt_dir', type=str, default='<PATH_TO_MLP_ROUTER_CHECKPOINTS>')
    parser.add_argument('--attn_ckpt_dir', type=str, default='<PATH_TO_ATTENTION_CHECKPOINTS>')
    parser.add_argument('--batch_stats_dir', type=str, default='configs/mlp_router/opt-6.7b')
    parser.add_argument('--delta', type=int, default=256, help='Delta value for MLP topk calculation')
    parser.add_argument('--use_cuda_graph', type=bool, default=False, help='Use CUDA graph for inference')

    return parser.parse_args()

def update_router_config(model, num_layers, mlp_topk_lookup, attn_topk):
    for i in range(num_layers):
        model.transformer.layers[i].mlp_topk = mlp_topk_lookup[i]
        # model.transformer.layers[i].mlp_topk = 512
        model.transformer.layers[i].mha_router.topk = attn_topk
        
        # model.transformer.layers[i].skip_mlp_router = True
    model.transformer.layers[0].mha_router.topk = 1.0  # dense attention in layer 0

if __name__ == "__main__":
    args = arg_parser()
    model_name = OPT_MODELS[args.model_index-1]
    dtype = torch.float16
    device= _get_device(args.gpu)
    
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    # args.mlp_ckpt_dir = None
    # args.attn_ckpt_dir = None
    
    model = build_sparse_opt(args, model_name, args.mlp_ckpt_dir, args.attn_ckpt_dir, device = device, dtype=dtype)
    model.eval()
    print(model)
    print("Model loaded with sparse routers")
    
    # mlp_topk_lookup = build_mlp_topk_lookup("results/mlp_results/batch_activations/opt-6.7b", args.batch_size, args.delta)
    mlp_topk_lookup = load_router_dict_from_csv(args.batch_stats_dir, args.batch_size)
    print("MLP topk values updated: ", mlp_topk_lookup)
    update_router_config(model, model.config.n_layer, mlp_topk_lookup, args.attn_topk)  # this sets the router config for all layers using a single config
    # update_router_config(model, model.config.n_layer, 2048, args.attn_topk)
    print("Router config updated \n")
    
    
    max_length = args.seq_len + 20
    batch_size = args.batch_size
    
    # input_texts = ["Hello, my dog is cute and", "The future of AI is", "In a distant galaxy, a spaceship", "The cat is sleeping on the "]
    # input_texts = ["In a distant galaxy, a spaceship"]
    # tokenized_inputs = tokenizer(input_texts, return_tensors="pt", padding=True, truncation=False).to(device)
    # input_ids=tokenized_inputs["input_ids"]
    
    dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
    tokens = tokenize_dataset(dataset, tokenizer)
    input_ids = get_random_batch(tokens, args.batch_size, args.seq_len).to(device)
    
    print("Input ids generated, starting inference")
    
    # input_ids = tokenizer("Hello, my dog is cute and he", return_tensors="pt").input_ids.to(device)
    position_ids = None
    eos_token_id = tokenizer.eos_token_id
    
    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
    
    with torch.no_grad():
        # warm up
        _ = model.generate(
            input_ids=input_ids,
            max_length=max_length,
            eos_token_id=eos_token_id,
            return_dict_in_generate=True,
            output_scores=True,
            enable_timing=False,
            cg=False,
            )
        
        print("Warm up done")
        
        start_event.record()
        for i in range(args.iterations):
            out = model.generate(
                input_ids=input_ids,
                max_length=max_length,
                eos_token_id=eos_token_id,
                return_dict_in_generate=True,
                output_scores=True,
                enable_timing=False,
                cg=False,
                )
        
        end_event.record()
        
        torch.cuda.synchronize()
        print("Without CUDA graph")
        elapsed_time = start_event.elapsed_time(end_event) / args.iterations
        print(f"Average time per genearation : {elapsed_time:.1f} ms")
        
        # Compute throughput and latency per token
        num_tokens_generated = out.sequences.shape[1] - input_ids.shape[1]
        throughput = batch_size * num_tokens_generated / (elapsed_time / 1000)  # tokens per second
        latency_per_token = elapsed_time / num_tokens_generated  # ms per token

        print(f"Number of tokens generated: {num_tokens_generated}")
        print(f"Throughput: {throughput:.1f} tokens/second")
        print(f"Latency per token: {latency_per_token:.1f} ms")

        # print(tokenizer.batch_decode(out.sequences.tolist()))
        print("\n")
        
        # print only the new tokens generated 
        print("New tokens generated:")
        print(tokenizer.batch_decode(out.sequences[:, input_ids.shape[1]:].tolist()))
        
        # ====================== With CUDA graph ======================
        if args.use_cuda_graph:
            batch_size, seqlen_og = input_ids.shape
            model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
            print("With CUDA graph")
            torch.cuda.synchronize()
            
            start_event.record()
            
            for i in range(args.iterations):
                out = model.generate(
                    input_ids=input_ids,
                    max_length=max_length,
                    cg=True,
                    eos_token_id=eos_token_id,
                    return_dict_in_generate=True,
                    output_scores=True,
                    enable_timing=False,
                    )
            
            end_event.record()
            
            torch.cuda.synchronize()
            
            
            elapsed_time = start_event.elapsed_time(end_event) / args.iterations
            print(f"Average time per genearation : {elapsed_time:.1f} ms")
            
            # Compute throughput and latency per token
            num_tokens_generated = out.sequences.shape[1] - input_ids.shape[1]
            throughput = batch_size * num_tokens_generated / (elapsed_time / 1000)  # tokens per second
            latency_per_token = elapsed_time / num_tokens_generated  # ms per token

            print(f"Number of tokens generated: {num_tokens_generated}")
            print(f"Throughput: {throughput:.1f} tokens/second")
            print(f"Latency per token: {latency_per_token:.1f} ms")

            # print(tokenizer.batch_decode(out.sequences.tolist()))
            print("New tokens generated:")
            print(tokenizer.batch_decode(out.sequences[:, input_ids.shape[1]:].tolist()))