Add new class 'RTMO_GPU_Batch' that can perform inference on batch of images
Browse files- rtmo_gpu.py +84 -0
rtmo_gpu.py
CHANGED
|
@@ -378,3 +378,87 @@ class RTMO_GPU(object):
|
|
| 378 |
self.std = std
|
| 379 |
self.device = device
|
| 380 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 378 |
self.std = std
|
| 379 |
self.device = device
|
| 380 |
|
| 381 |
+
class RTMO_GPU_Batch(RTMO_GPU):
|
| 382 |
+
def preprocess_batch(self, imgs: List[np.ndarray]) -> Tuple[np.ndarray, List[float]]:
|
| 383 |
+
"""Process a batch of images for RTMPose model inference.
|
| 384 |
+
|
| 385 |
+
Args:
|
| 386 |
+
imgs (List[np.ndarray]): List of input images.
|
| 387 |
+
|
| 388 |
+
Returns:
|
| 389 |
+
tuple:
|
| 390 |
+
- batch_img (np.ndarray): Batch of preprocessed images.
|
| 391 |
+
- ratios (List[float]): Ratios used for preprocessing each image.
|
| 392 |
+
"""
|
| 393 |
+
batch_img = []
|
| 394 |
+
ratios = []
|
| 395 |
+
|
| 396 |
+
for img in imgs:
|
| 397 |
+
preprocessed_img, ratio = super().preprocess(img)
|
| 398 |
+
batch_img.append(preprocessed_img)
|
| 399 |
+
ratios.append(ratio)
|
| 400 |
+
|
| 401 |
+
# Stack along the first dimension to create a batch
|
| 402 |
+
batch_img = np.stack(batch_img, axis=0)
|
| 403 |
+
|
| 404 |
+
return batch_img, ratios
|
| 405 |
+
|
| 406 |
+
def inference(self, batch_img: np.ndarray):
|
| 407 |
+
"""Override to handle batch inference.
|
| 408 |
+
|
| 409 |
+
Args:
|
| 410 |
+
batch_img (np.ndarray): Batch of preprocessed images.
|
| 411 |
+
|
| 412 |
+
Returns:
|
| 413 |
+
outputs (List[np.ndarray]): Outputs of RTMPose model for each image.
|
| 414 |
+
"""
|
| 415 |
+
batch_img = batch_img.transpose(0, 3, 1, 2) # NCHW format
|
| 416 |
+
batch_img = np.ascontiguousarray(batch_img, dtype=np.float32)
|
| 417 |
+
|
| 418 |
+
input = batch_img
|
| 419 |
+
|
| 420 |
+
# Create an IO Binding object
|
| 421 |
+
io_binding = self.session.io_binding()
|
| 422 |
+
|
| 423 |
+
# Bind the model inputs and outputs to the IO Binding object
|
| 424 |
+
io_binding.bind_input(name='input', device_type='cpu', device_id=0, element_type=np.float32, shape=input.shape, buffer_ptr=input.ctypes.data)
|
| 425 |
+
io_binding.bind_output(name='dets')
|
| 426 |
+
io_binding.bind_output(name='keypoints')
|
| 427 |
+
|
| 428 |
+
# Run inference with IO Binding
|
| 429 |
+
self.session.run_with_iobinding(io_binding)
|
| 430 |
+
|
| 431 |
+
# Retrieve the outputs from the IO Binding object
|
| 432 |
+
outputs = [output.numpy() for output in io_binding.get_outputs()]
|
| 433 |
+
|
| 434 |
+
return outputs
|
| 435 |
+
|
| 436 |
+
def postprocess_batch(
|
| 437 |
+
self,
|
| 438 |
+
outputs: List[np.ndarray],
|
| 439 |
+
ratios: List[float]
|
| 440 |
+
) -> List[Tuple[np.ndarray, np.ndarray]]:
|
| 441 |
+
"""Process outputs for a batch of images.
|
| 442 |
+
|
| 443 |
+
Args:
|
| 444 |
+
outputs (List[np.ndarray]): Outputs from the model for each image.
|
| 445 |
+
ratios (List[float]): Ratios used for preprocessing each image.
|
| 446 |
+
|
| 447 |
+
Returns:
|
| 448 |
+
List[Tuple[np.ndarray, np.ndarray]]: keypoints and scores for each image.
|
| 449 |
+
"""
|
| 450 |
+
batch_keypoints = []
|
| 451 |
+
batch_scores = []
|
| 452 |
+
|
| 453 |
+
for i, ratio in enumerate(ratios):
|
| 454 |
+
keypoints, scores = super().postprocess(outputs, ratio)
|
| 455 |
+
batch_keypoints.append(keypoints)
|
| 456 |
+
batch_scores.append(scores)
|
| 457 |
+
|
| 458 |
+
return batch_keypoints, batch_scores
|
| 459 |
+
|
| 460 |
+
def __call__(self, images: List[np.ndarray]):
|
| 461 |
+
batch_img, ratios = self.preprocess_batch(images)
|
| 462 |
+
outputs = self.inference(batch_img)
|
| 463 |
+
keypoints, scores = self.postprocess_batch(outputs, ratios)
|
| 464 |
+
return keypoints, scores
|