Object_Detection_HUB / scripts /inference_rfdetr.py
Panagiota Moraiti
Add scripts
d987cda
import supervision as sv
import numpy as np
import cv2
import warnings
import rfdetr
from base_inference import BaseInference
# Suppress PyTorch meshgrid warnings
warnings.filterwarnings("ignore", category=UserWarning, message="torch.meshgrid")
class RFDETRInference(BaseInference):
"""
A class to perform inference using RF-DETR models of different sizes.
"""
def __init__(self, version='small', pretrain_weights="./models/rfdetr_small/checkpoint_best_total.pth"):
"""
Initializes the RFDETR model.
Args:
version (str): Model version ('nano', 'small', 'medium', 'base', 'base2', 'large').
pretrain_weights (str): Path to the pretrained .pth weights file.
Raises:
ValueError: If an unsupported version is passed.
"""
# Map version names to RFDETR model classes
model_cls = {
'nano': rfdetr.RFDETRNano,
'small': rfdetr.RFDETRSmall,
'medium': rfdetr.RFDETRMedium,
'base': rfdetr.RFDETRBase,
'base2': rfdetr.RFDETRBase,
'large': rfdetr.RFDETRLarge
}.get(version)
if not model_cls:
raise ValueError(f"Unsupported version: {version}")
self.model = model_cls(pretrain_weights=pretrain_weights)
def infer(self, image, confidence=0.5, use_nms=False, nms_thresh=0.7):
"""
Perform inference on a single image.
Args:
image (np.ndarray): Input image (BGR format).
confidence (float): Confidence threshold.
use_nms (bool): Whether to apply Non-Maximum Suppression.
nms_thresh (float): NMS IoU threshold.
Returns:
sv.Detections: Detection results including bounding boxes, class IDs, and confidences.
"""
# Convert grayscale to BGR
if image is not None and (len(image.shape) == 2 or image.shape[2] == 1):
image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
# Perform prediction
if use_nms:
detections = self.model.predict(image, threshold=confidence).with_nms(
threshold=nms_thresh, class_agnostic=True)
else:
detections = self.model.predict(image, threshold=confidence)
return sv.Detections(
xyxy=np.array(detections.xyxy),
class_id=np.array(detections.class_id),
confidence=np.array(detections.confidence)
)