Luigi commited on
Commit
2f2f685
·
1 Parent(s): 634d4ff

Add 'get_model_format_and_input_shape' helper function

Browse files
Files changed (1) hide show
  1. rtmo_gpu.py +16 -12
rtmo_gpu.py CHANGED
@@ -299,6 +299,19 @@ def get_trt_input_shapes(model_path):
299
  input_shapes[binding] = engine.get_binding_shape(binding)
300
  return input_shapes
301
 
 
 
 
 
 
 
 
 
 
 
 
 
 
302
  class RTMO_GPU(object):
303
 
304
  def preprocess(self, img: np.ndarray):
@@ -465,18 +478,7 @@ class RTMO_GPU(object):
465
  raise FileNotFoundError(f"The specified ONNX model file was not found: {model}")
466
 
467
  self.model = model
468
- if is_onnx_model(model):
469
- self.model_format = 'onnx'
470
- self.input_shape = get_onnx_input_shapes(self.model)['input']
471
- elif is_trt_engine(model):
472
- self.model_format = 'engine'
473
- from polygraphy.backend.common import BytesFromPath
474
- from polygraphy.backend.trt import EngineFromBytes, TrtRunner, load_plugins
475
- load_plugins(plugins=['libmmdeploy_tensorrt_ops.so'])
476
- self.input_shape = get_trt_input_shapes(self.model)['input']
477
- else:
478
- raise TypeError("Your model is neither ONNX nor Engine !")
479
-
480
 
481
  if self.model_format == 'onnx':
482
 
@@ -497,6 +499,8 @@ class RTMO_GPU(object):
497
  providers=providers[device])
498
 
499
  else: # 'engine'
 
 
500
  engine = EngineFromBytes(BytesFromPath(model))
501
  self.session = TrtRunner(engine)
502
 
 
299
  input_shapes[binding] = engine.get_binding_shape(binding)
300
  return input_shapes
301
 
302
+ def get_model_format_and_input_shape(model):
303
+ if is_onnx_model(model):
304
+ model_format = 'onnx'
305
+ input_shape = get_onnx_input_shapes(model)['input']
306
+ elif is_trt_engine(model):
307
+ model_format = 'engine'
308
+ from polygraphy.backend.trt import load_plugins
309
+ load_plugins(plugins=['libmmdeploy_tensorrt_ops.so'])
310
+ input_shape = get_trt_input_shapes(model)['input']
311
+ else:
312
+ raise TypeError("Your model is neither ONNX nor Engine !")
313
+ return model_format, input_shape
314
+
315
  class RTMO_GPU(object):
316
 
317
  def preprocess(self, img: np.ndarray):
 
478
  raise FileNotFoundError(f"The specified ONNX model file was not found: {model}")
479
 
480
  self.model = model
481
+ self.model_format, self.input_shape = get_model_format_and_input_shape(self.model)
 
 
 
 
 
 
 
 
 
 
 
482
 
483
  if self.model_format == 'onnx':
484
 
 
499
  providers=providers[device])
500
 
501
  else: # 'engine'
502
+ from polygraphy.backend.common import BytesFromPath
503
+ from polygraphy.backend.trt import EngineFromBytes, TrtRunner
504
  engine = EngineFromBytes(BytesFromPath(model))
505
  self.session = TrtRunner(engine)
506