File size: 1,677 Bytes
3d3198b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
# -- coding: utf-8 --
# @Time : 2022/7/29



from .base_wrapper import ONNXModel
from pathlib import Path


try:
    from .base_wrapper import TRTWrapper, TRTWrapperSelf
except:
    pass


# from cv2box.utils import try_import

class ModelBase:
    def __init__(self, model_info, provider):
        self.model_path = model_info['model_path']

        if 'input_dynamic_shape' in model_info.keys():
            self.input_dynamic_shape = model_info['input_dynamic_shape']
        else:
            self.input_dynamic_shape = None

        if 'picklable' in model_info.keys():
            picklable = model_info['picklable']
        else:
            picklable = False

        if 'trt_wrapper_self' in model_info.keys():
            TRTWrapper = TRTWrapperSelf

        # init model
        if Path(self.model_path).suffix == '.engine':
            self.model_type = 'trt'
            self.model = TRTWrapper(self.model_path)
        elif Path(self.model_path).suffix == '.tjm':
            self.model_type = 'tjm'
            self.model = TJMWrapper(self.model_path, provider=provider)
        elif Path(self.model_path).suffix in ['.onnx', '.bin']:
            self.model_type = 'onnx'
            if not picklable:
                if 'encrypt' in model_info.keys():
                    self.model_path = load_encrypt_model(self.model_path, key=model_info['encrypt'])
                self.model = ONNXModel(self.model_path, provider=provider, input_dynamic_shape=self.input_dynamic_shape)
            else:
                self.model = OnnxModelPickable(self.model_path, provider=provider, )
        else:
            raise 'check model suffix , support engine/tjm/onnx now.'