Luigi
commited on
Commit
·
2f2f685
1
Parent(s):
634d4ff
Add 'get_model_format_and_input_shape' helper function
Browse files- rtmo_gpu.py +16 -12
rtmo_gpu.py
CHANGED
|
@@ -299,6 +299,19 @@ def get_trt_input_shapes(model_path):
|
|
| 299 |
input_shapes[binding] = engine.get_binding_shape(binding)
|
| 300 |
return input_shapes
|
| 301 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 302 |
class RTMO_GPU(object):
|
| 303 |
|
| 304 |
def preprocess(self, img: np.ndarray):
|
|
@@ -465,18 +478,7 @@ class RTMO_GPU(object):
|
|
| 465 |
raise FileNotFoundError(f"The specified ONNX model file was not found: {model}")
|
| 466 |
|
| 467 |
self.model = model
|
| 468 |
-
|
| 469 |
-
self.model_format = 'onnx'
|
| 470 |
-
self.input_shape = get_onnx_input_shapes(self.model)['input']
|
| 471 |
-
elif is_trt_engine(model):
|
| 472 |
-
self.model_format = 'engine'
|
| 473 |
-
from polygraphy.backend.common import BytesFromPath
|
| 474 |
-
from polygraphy.backend.trt import EngineFromBytes, TrtRunner, load_plugins
|
| 475 |
-
load_plugins(plugins=['libmmdeploy_tensorrt_ops.so'])
|
| 476 |
-
self.input_shape = get_trt_input_shapes(self.model)['input']
|
| 477 |
-
else:
|
| 478 |
-
raise TypeError("Your model is neither ONNX nor Engine !")
|
| 479 |
-
|
| 480 |
|
| 481 |
if self.model_format == 'onnx':
|
| 482 |
|
|
@@ -497,6 +499,8 @@ class RTMO_GPU(object):
|
|
| 497 |
providers=providers[device])
|
| 498 |
|
| 499 |
else: # 'engine'
|
|
|
|
|
|
|
| 500 |
engine = EngineFromBytes(BytesFromPath(model))
|
| 501 |
self.session = TrtRunner(engine)
|
| 502 |
|
|
|
|
| 299 |
input_shapes[binding] = engine.get_binding_shape(binding)
|
| 300 |
return input_shapes
|
| 301 |
|
| 302 |
+
def get_model_format_and_input_shape(model):
|
| 303 |
+
if is_onnx_model(model):
|
| 304 |
+
model_format = 'onnx'
|
| 305 |
+
input_shape = get_onnx_input_shapes(model)['input']
|
| 306 |
+
elif is_trt_engine(model):
|
| 307 |
+
model_format = 'engine'
|
| 308 |
+
from polygraphy.backend.trt import load_plugins
|
| 309 |
+
load_plugins(plugins=['libmmdeploy_tensorrt_ops.so'])
|
| 310 |
+
input_shape = get_trt_input_shapes(model)['input']
|
| 311 |
+
else:
|
| 312 |
+
raise TypeError("Your model is neither ONNX nor Engine !")
|
| 313 |
+
return model_format, input_shape
|
| 314 |
+
|
| 315 |
class RTMO_GPU(object):
|
| 316 |
|
| 317 |
def preprocess(self, img: np.ndarray):
|
|
|
|
| 478 |
raise FileNotFoundError(f"The specified ONNX model file was not found: {model}")
|
| 479 |
|
| 480 |
self.model = model
|
| 481 |
+
self.model_format, self.input_shape = get_model_format_and_input_shape(self.model)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 482 |
|
| 483 |
if self.model_format == 'onnx':
|
| 484 |
|
|
|
|
| 499 |
providers=providers[device])
|
| 500 |
|
| 501 |
else: # 'engine'
|
| 502 |
+
from polygraphy.backend.common import BytesFromPath
|
| 503 |
+
from polygraphy.backend.trt import EngineFromBytes, TrtRunner
|
| 504 |
engine = EngineFromBytes(BytesFromPath(model))
|
| 505 |
self.session = TrtRunner(engine)
|
| 506 |
|