File size: 2,474 Bytes
d987cda
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import supervision as sv
import numpy as np
import cv2
import warnings
import rfdetr
from base_inference import BaseInference

# Suppress PyTorch meshgrid warnings
warnings.filterwarnings("ignore", category=UserWarning, message="torch.meshgrid")


class RFDETRInference(BaseInference):
    """
    A class to perform inference using RF-DETR models of different sizes.
    """

    def __init__(self, version='small', pretrain_weights="./models/rfdetr_small/checkpoint_best_total.pth"):
        """
        Initializes the RFDETR model.

        Args:
            version (str): Model version ('nano', 'small', 'medium', 'base', 'base2', 'large').
            pretrain_weights (str): Path to the pretrained .pth weights file.

        Raises:
            ValueError: If an unsupported version is passed.
        """
        # Map version names to RFDETR model classes
        model_cls = {
            'nano': rfdetr.RFDETRNano,
            'small': rfdetr.RFDETRSmall,
            'medium': rfdetr.RFDETRMedium,
            'base': rfdetr.RFDETRBase,
            'base2': rfdetr.RFDETRBase,
            'large': rfdetr.RFDETRLarge
        }.get(version)

        if not model_cls:
            raise ValueError(f"Unsupported version: {version}")

        self.model = model_cls(pretrain_weights=pretrain_weights)


    def infer(self, image, confidence=0.5, use_nms=False, nms_thresh=0.7):
        """
        Perform inference on a single image.

        Args:
            image (np.ndarray): Input image (BGR format).
            confidence (float): Confidence threshold.
            use_nms (bool): Whether to apply Non-Maximum Suppression.
            nms_thresh (float): NMS IoU threshold.

        Returns:
            sv.Detections: Detection results including bounding boxes, class IDs, and confidences.
        """
        # Convert grayscale to BGR
        if image is not None and (len(image.shape) == 2 or image.shape[2] == 1):
            image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)

        # Perform prediction
        if use_nms:
            detections = self.model.predict(image, threshold=confidence).with_nms(
                threshold=nms_thresh, class_agnostic=True)
        else:
            detections = self.model.predict(image, threshold=confidence)

        return sv.Detections(
            xyxy=np.array(detections.xyxy),
            class_id=np.array(detections.class_id),
            confidence=np.array(detections.confidence)
        )