| import supervision as sv |
| import numpy as np |
| import cv2 |
| import warnings |
| import rfdetr |
| from base_inference import BaseInference |
|
|
| |
| warnings.filterwarnings("ignore", category=UserWarning, message="torch.meshgrid") |
|
|
|
|
| class RFDETRInference(BaseInference): |
| """ |
| A class to perform inference using RF-DETR models of different sizes. |
| """ |
|
|
| def __init__(self, version='small', pretrain_weights="./models/rfdetr_small/checkpoint_best_total.pth"): |
| """ |
| Initializes the RFDETR model. |
| |
| Args: |
| version (str): Model version ('nano', 'small', 'medium', 'base', 'base2', 'large'). |
| pretrain_weights (str): Path to the pretrained .pth weights file. |
| |
| Raises: |
| ValueError: If an unsupported version is passed. |
| """ |
| |
| model_cls = { |
| 'nano': rfdetr.RFDETRNano, |
| 'small': rfdetr.RFDETRSmall, |
| 'medium': rfdetr.RFDETRMedium, |
| 'base': rfdetr.RFDETRBase, |
| 'base2': rfdetr.RFDETRBase, |
| 'large': rfdetr.RFDETRLarge |
| }.get(version) |
|
|
| if not model_cls: |
| raise ValueError(f"Unsupported version: {version}") |
|
|
| self.model = model_cls(pretrain_weights=pretrain_weights) |
|
|
|
|
| def infer(self, image, confidence=0.5, use_nms=False, nms_thresh=0.7): |
| """ |
| Perform inference on a single image. |
| |
| Args: |
| image (np.ndarray): Input image (BGR format). |
| confidence (float): Confidence threshold. |
| use_nms (bool): Whether to apply Non-Maximum Suppression. |
| nms_thresh (float): NMS IoU threshold. |
| |
| Returns: |
| sv.Detections: Detection results including bounding boxes, class IDs, and confidences. |
| """ |
| |
| if image is not None and (len(image.shape) == 2 or image.shape[2] == 1): |
| image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) |
|
|
| |
| if use_nms: |
| detections = self.model.predict(image, threshold=confidence).with_nms( |
| threshold=nms_thresh, class_agnostic=True) |
| else: |
| detections = self.model.predict(image, threshold=confidence) |
|
|
| return sv.Detections( |
| xyxy=np.array(detections.xyxy), |
| class_id=np.array(detections.class_id), |
| confidence=np.array(detections.confidence) |
| ) |
| |