| | |
| | |
| |
|
| |
|
| |
|
| | from .base_wrapper import ONNXModel |
| | from pathlib import Path |
| |
|
| |
|
| | try: |
| | from .base_wrapper import TRTWrapper, TRTWrapperSelf |
| | except: |
| | pass |
| |
|
| |
|
| | |
| |
|
| | class ModelBase: |
| | def __init__(self, model_info, provider): |
| | self.model_path = model_info['model_path'] |
| |
|
| | if 'input_dynamic_shape' in model_info.keys(): |
| | self.input_dynamic_shape = model_info['input_dynamic_shape'] |
| | else: |
| | self.input_dynamic_shape = None |
| |
|
| | if 'picklable' in model_info.keys(): |
| | picklable = model_info['picklable'] |
| | else: |
| | picklable = False |
| |
|
| | if 'trt_wrapper_self' in model_info.keys(): |
| | TRTWrapper = TRTWrapperSelf |
| |
|
| | |
| | if Path(self.model_path).suffix == '.engine': |
| | self.model_type = 'trt' |
| | self.model = TRTWrapper(self.model_path) |
| | elif Path(self.model_path).suffix == '.tjm': |
| | self.model_type = 'tjm' |
| | self.model = TJMWrapper(self.model_path, provider=provider) |
| | elif Path(self.model_path).suffix in ['.onnx', '.bin']: |
| | self.model_type = 'onnx' |
| | if not picklable: |
| | if 'encrypt' in model_info.keys(): |
| | self.model_path = load_encrypt_model(self.model_path, key=model_info['encrypt']) |
| | self.model = ONNXModel(self.model_path, provider=provider, input_dynamic_shape=self.input_dynamic_shape) |
| | else: |
| | self.model = OnnxModelPickable(self.model_path, provider=provider, ) |
| | else: |
| | raise 'check model suffix , support engine/tjm/onnx now.' |
| |
|