Luigi commited on
Commit
0cdc9a7
·
1 Parent(s): 09ccc6e

Support Inference over batch with TensorRT Engine Model

Browse files
Files changed (1) hide show
  1. rtmo_gpu.py +26 -12
rtmo_gpu.py CHANGED
@@ -242,7 +242,6 @@ def is_onnx_model(model_path):
242
  ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
243
  return True
244
  except Exception as e:
245
- print('Error:', type(e))
246
  return False
247
 
248
  def is_trt_engine(model_path):
@@ -513,20 +512,35 @@ class RTMO_GPU_Batch(RTMO_GPU):
513
 
514
  input = batch_img
515
 
516
- # Create an IO Binding object
517
- io_binding = self.session.io_binding()
518
 
519
- # Bind the model inputs and outputs to the IO Binding object
520
- io_binding.bind_input(name='input', device_type='cpu', device_id=0, element_type=np.float32, shape=input.shape, buffer_ptr=input.ctypes.data)
521
- io_binding.bind_output(name='dets')
522
- io_binding.bind_output(name='keypoints')
523
 
524
- # Run inference with IO Binding
525
- self.session.run_with_iobinding(io_binding)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
526
 
527
- # Retrieve the outputs from the IO Binding object
528
- outputs = [output.numpy() for output in io_binding.get_outputs()]
529
-
530
  return outputs
531
 
532
  def postprocess_batch(
 
242
  ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
243
  return True
244
  except Exception as e:
 
245
  return False
246
 
247
  def is_trt_engine(model_path):
 
512
 
513
  input = batch_img
514
 
515
+ if self.model_format == 'onnx':
 
516
 
517
+ # Create an IO Binding object
518
+ io_binding = self.session.io_binding()
 
 
519
 
520
+ if not self.is_yolo_nas_pose:
521
+ # RTMO
522
+ io_binding.bind_input(name='input', device_type='cpu', device_id=0, element_type=np.float32, shape=input.shape, buffer_ptr=input.ctypes.data)
523
+ io_binding.bind_output(name='dets')
524
+ io_binding.bind_output(name='keypoints')
525
+ else:
526
+ # NAS Pose, flat format
527
+ io_binding.bind_input(name='input', device_type='cpu', device_id=0, element_type=np.uint8, shape=input.shape, buffer_ptr=input.ctypes.data)
528
+ io_binding.bind_output(name='graph2_flat_predictions')
529
+
530
+ # Run inference with IO Binding
531
+ self.session.run_with_iobinding(io_binding)
532
+
533
+ # Retrieve the outputs from the IO Binding object
534
+ outputs = [output.numpy() for output in io_binding.get_outputs()]
535
+
536
+ else: # 'engine'
537
+
538
+ if not self.session.is_active:
539
+ self.session.activate()
540
+
541
+ outputs = self.session.infer(feed_dict={'input': input}, check_inputs=False)
542
+ outputs = [output for output in outputs.values()]
543
 
 
 
 
544
  return outputs
545
 
546
  def postprocess_batch(