Spaces:
Sleeping
Sleeping
| import supervision as sv | |
| import numpy as np | |
| import cv2 | |
| import warnings | |
| import rfdetr | |
| from base_inference import BaseInference | |
| # Suppress PyTorch meshgrid warnings | |
| warnings.filterwarnings("ignore", category=UserWarning, message="torch.meshgrid") | |
| class RFDETRInference(BaseInference): | |
| """ | |
| A class to perform inference using RF-DETR models of different sizes. | |
| """ | |
| def __init__(self, version='small', pretrain_weights="./models/rfdetr_small/checkpoint_best_total.pth"): | |
| """ | |
| Initializes the RFDETR model. | |
| Args: | |
| version (str): Model version ('nano', 'small', 'medium', 'base', 'base2', 'large'). | |
| pretrain_weights (str): Path to the pretrained .pth weights file. | |
| Raises: | |
| ValueError: If an unsupported version is passed. | |
| """ | |
| # Map version names to RFDETR model classes | |
| model_cls = { | |
| 'nano': rfdetr.RFDETRNano, | |
| 'small': rfdetr.RFDETRSmall, | |
| 'medium': rfdetr.RFDETRMedium, | |
| 'base': rfdetr.RFDETRBase, | |
| 'base2': rfdetr.RFDETRBase, | |
| 'large': rfdetr.RFDETRLarge | |
| }.get(version) | |
| if not model_cls: | |
| raise ValueError(f"Unsupported version: {version}") | |
| self.model = model_cls(pretrain_weights=pretrain_weights) | |
| def infer(self, image, confidence=0.5, use_nms=False, nms_thresh=0.7): | |
| """ | |
| Perform inference on a single image. | |
| Args: | |
| image (np.ndarray): Input image (BGR format). | |
| confidence (float): Confidence threshold. | |
| use_nms (bool): Whether to apply Non-Maximum Suppression. | |
| nms_thresh (float): NMS IoU threshold. | |
| Returns: | |
| sv.Detections: Detection results including bounding boxes, class IDs, and confidences. | |
| """ | |
| # Convert grayscale to BGR | |
| if image is not None and (len(image.shape) == 2 or image.shape[2] == 1): | |
| image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) | |
| # Perform prediction | |
| if use_nms: | |
| detections = self.model.predict(image, threshold=confidence).with_nms( | |
| threshold=nms_thresh, class_agnostic=True) | |
| else: | |
| detections = self.model.predict(image, threshold=confidence) | |
| return sv.Detections( | |
| xyxy=np.array(detections.xyxy), | |
| class_id=np.array(detections.class_id), | |
| confidence=np.array(detections.confidence) | |
| ) | |