File size: 3,372 Bytes
866ee56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# --------------------------------------------------------
# InternVL
# Copyright (c) 2022 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------

import argparse
import os
import time

import torch
from config import get_config
from models import build_model
from tqdm import tqdm


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name', type=str,
                        default='internimage_t_1k_224')
    parser.add_argument('--ckpt_dir', type=str,
                        default='/mnt/petrelfs/share_data/huangzhenhang/code/internimage/checkpoint_dir/new/cls')
    parser.add_argument('--onnx', default=False, action='store_true')
    parser.add_argument('--trt', default=False, action='store_true')

    args = parser.parse_args()
    args.cfg = os.path.join('./configs', f'{args.model_name}.yaml')
    args.ckpt = os.path.join(args.ckpt_dir, f'{args.model_name}.pth')
    args.size = int(args.model_name.split('.')[0].split('_')[-1])

    cfg = get_config(args)
    return args, cfg


def get_model(args, cfg):
    model = build_model(cfg)
    ckpt = torch.load(args.ckpt, map_location='cpu')['model']

    model.load_state_dict(ckpt)
    return model


def speed_test(model, input):
    # warm-up
    for _ in tqdm(range(100)):
        _ = model(input)

    # speed test
    torch.cuda.synchronize()
    start = time.time()
    for _ in tqdm(range(100)):
        _ = model(input)
    end = time.time()
    th = 100 / (end - start)
    print(f'using time: {end - start}, throughput {th}')


def torch2onnx(args, cfg):
    model = get_model(args, cfg).cuda()

    # speed_test(model)

    onnx_name = f'{args.model_name}.onnx'
    torch.onnx.export(model,
                      torch.rand(1, 3, args.size, args.size).cuda(),
                      onnx_name,
                      input_names=['input'],
                      output_names=['output'])

    return model


def onnx2trt(args):
    from mmdeploy.backend.tensorrt import from_onnx

    onnx_name = f'{args.model_name}.onnx'
    from_onnx(
        onnx_name,
        args.model_name,
        dict(
            input=dict(
                min_shape=[1, 3, args.size, args.size],
                opt_shape=[1, 3, args.size, args.size],
                max_shape=[1, 3, args.size, args.size],
            )
        ),
        max_workspace_size=2 ** 30,
    )


def check(args, cfg):
    from mmdeploy.backend.tensorrt.wrapper import TRTWrapper

    model = get_model(args, cfg).cuda()
    model.eval()
    trt_model = TRTWrapper(f'{args.model_name}.engine',
                           ['output'])

    x = torch.randn(1, 3, args.size, args.size).cuda()

    torch_out = model(x)
    trt_out = trt_model(dict(input=x))['output']

    print('torch out shape:', torch_out.shape)
    print('trt out shape:', trt_out.shape)

    print('max delta:', (torch_out - trt_out).abs().max())
    print('mean delta:', (torch_out - trt_out).abs().mean())

    speed_test(model, x)
    speed_test(trt_model, dict(input=x))


def main():
    args, cfg = get_args()

    if args.onnx or args.trt:
        torch2onnx(args, cfg)
        print('torch -> onnx: succeess')

    if args.trt:
        onnx2trt(args)
        print('onnx -> trt: success')
        check(args, cfg)


if __name__ == '__main__':
    main()