Support Inference over batch with TensorRT Engine Model
Browse files- rtmo_gpu.py +26 -12
rtmo_gpu.py
CHANGED
|
@@ -242,7 +242,6 @@ def is_onnx_model(model_path):
|
|
| 242 |
ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
|
| 243 |
return True
|
| 244 |
except Exception as e:
|
| 245 |
-
print('Error:', type(e))
|
| 246 |
return False
|
| 247 |
|
| 248 |
def is_trt_engine(model_path):
|
|
@@ -513,20 +512,35 @@ class RTMO_GPU_Batch(RTMO_GPU):
|
|
| 513 |
|
| 514 |
input = batch_img
|
| 515 |
|
| 516 |
-
|
| 517 |
-
io_binding = self.session.io_binding()
|
| 518 |
|
| 519 |
-
|
| 520 |
-
|
| 521 |
-
io_binding.bind_output(name='dets')
|
| 522 |
-
io_binding.bind_output(name='keypoints')
|
| 523 |
|
| 524 |
-
|
| 525 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 526 |
|
| 527 |
-
# Retrieve the outputs from the IO Binding object
|
| 528 |
-
outputs = [output.numpy() for output in io_binding.get_outputs()]
|
| 529 |
-
|
| 530 |
return outputs
|
| 531 |
|
| 532 |
def postprocess_batch(
|
|
|
|
| 242 |
ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
|
| 243 |
return True
|
| 244 |
except Exception as e:
|
|
|
|
| 245 |
return False
|
| 246 |
|
| 247 |
def is_trt_engine(model_path):
|
|
|
|
| 512 |
|
| 513 |
input = batch_img
|
| 514 |
|
| 515 |
+
if self.model_format == 'onnx':
|
|
|
|
| 516 |
|
| 517 |
+
# Create an IO Binding object
|
| 518 |
+
io_binding = self.session.io_binding()
|
|
|
|
|
|
|
| 519 |
|
| 520 |
+
if not self.is_yolo_nas_pose:
|
| 521 |
+
# RTMO
|
| 522 |
+
io_binding.bind_input(name='input', device_type='cpu', device_id=0, element_type=np.float32, shape=input.shape, buffer_ptr=input.ctypes.data)
|
| 523 |
+
io_binding.bind_output(name='dets')
|
| 524 |
+
io_binding.bind_output(name='keypoints')
|
| 525 |
+
else:
|
| 526 |
+
# NAS Pose, flat format
|
| 527 |
+
io_binding.bind_input(name='input', device_type='cpu', device_id=0, element_type=np.uint8, shape=input.shape, buffer_ptr=input.ctypes.data)
|
| 528 |
+
io_binding.bind_output(name='graph2_flat_predictions')
|
| 529 |
+
|
| 530 |
+
# Run inference with IO Binding
|
| 531 |
+
self.session.run_with_iobinding(io_binding)
|
| 532 |
+
|
| 533 |
+
# Retrieve the outputs from the IO Binding object
|
| 534 |
+
outputs = [output.numpy() for output in io_binding.get_outputs()]
|
| 535 |
+
|
| 536 |
+
else: # 'engine'
|
| 537 |
+
|
| 538 |
+
if not self.session.is_active:
|
| 539 |
+
self.session.activate()
|
| 540 |
+
|
| 541 |
+
outputs = self.session.infer(feed_dict={'input': input}, check_inputs=False)
|
| 542 |
+
outputs = [output for output in outputs.values()]
|
| 543 |
|
|
|
|
|
|
|
|
|
|
| 544 |
return outputs
|
| 545 |
|
| 546 |
def postprocess_batch(
|