Spaces:
Running
Running
File size: 12,206 Bytes
2793310 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 | # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from __future__ import annotations
import argparse
import os
import cv2
import numpy as np
import onnxruntime as ort
import requests
import yaml
def download_file(url: str, local_path: str) -> str:
"""Download a file from a URL to a local path.
Args:
url (str): URL of the file to download.
local_path (str): Local path where the file will be saved.
"""
# Check if the local path already exists
if os.path.exists(local_path):
print(f"File already exists at {local_path}. Skipping download.")
return local_path
# Download the file from the URL
print(f"Downloading {url} to {local_path}...")
response = requests.get(url, stream=True, timeout=30)
response.raise_for_status()
with open(local_path, "wb") as f:
for chunk in response.iter_content(chunk_size=1024 * 1024):
if chunk:
f.write(chunk)
return local_path
class RTDETR:
"""RT-DETR (Real-Time Detection Transformer) object detection model for ONNX inference and visualization.
This class implements the RT-DETR model for object detection tasks, supporting ONNX model inference and
visualization of detection results with bounding boxes and class labels.
Attributes:
model_path (str): Path to the ONNX model file.
img_path (str): Path to the input image.
conf_thres (float): Confidence threshold for filtering detections.
iou_thres (float): IoU threshold for non-maximum suppression.
session (ort.InferenceSession): ONNX runtime inference session.
model_input (list): Model input metadata.
input_width (int): Width dimension required by the model.
input_height (int): Height dimension required by the model.
classes (list[str]): List of class names from COCO dataset.
color_palette (np.ndarray): Random color palette for visualization.
img (np.ndarray): Loaded input image.
img_height (int): Height of the input image.
img_width (int): Width of the input image.
Methods:
draw_detections: Draw bounding boxes and labels on the input image.
preprocess: Preprocess the input image for model inference.
bbox_cxcywh_to_xyxy: Convert bounding boxes from center format to corner format.
postprocess: Postprocess model output to extract and visualize detections.
main: Execute the complete object detection pipeline.
Examples:
Initialize RT-DETR detector and run inference
>>> detector = RTDETR("rtdetr-l.onnx", "image.jpg", conf_thres=0.5)
>>> output_image = detector.main()
>>> cv2.imshow("Detections", output_image)
"""
def __init__(
self,
model_path: str,
img_path: str,
conf_thres: float = 0.5,
iou_thres: float = 0.5,
class_names: str | None = None,
):
"""Initialize the RT-DETR object detection model.
Args:
model_path (str): Path to the ONNX model file.
img_path (str): Path to the input image.
conf_thres (float, optional): Confidence threshold for filtering detections.
iou_thres (float, optional): IoU threshold for non-maximum suppression.
class_names (Optional[str], optional): Path to a YAML file containing class names. If None, uses COCO
dataset classes.
"""
self.model_path = model_path
self.img_path = img_path
self.conf_thres = conf_thres
self.iou_thres = iou_thres
self.classes = class_names
# Set up the ONNX runtime session with available execution providers
available = ort.get_available_providers()
providers = [p for p in ("CUDAExecutionProvider", "CPUExecutionProvider") if p in available]
self.session = ort.InferenceSession(model_path, providers=providers or available)
self.model_input = self.session.get_inputs()
self.input_width = self.model_input[0].shape[2]
self.input_height = self.model_input[0].shape[3]
if self.classes is None:
# Load class names from the COCO dataset YAML file
self.classes = download_file(
"https://raw.githubusercontent.com/ultralytics/ultralytics/main/ultralytics/cfg/datasets/coco8.yaml",
"coco8.yaml",
)
# Parse the YAML file to get class names
with open(self.classes) as f:
class_data = yaml.safe_load(f)
self.classes = list(class_data["names"].values())
# Ensure the classes are a list
if not isinstance(self.classes, list):
raise ValueError("Classes should be a list of class names.")
# Generate a color palette for drawing bounding boxes
self.color_palette: np.ndarray = np.random.uniform(0, 255, size=(len(self.classes), 3))
def draw_detections(self, box: np.ndarray, score: float, class_id: int) -> None:
"""Draw bounding box and label on the input image for a detected object."""
# Extract the coordinates of the bounding box
x1, y1, x2, y2 = box
# Retrieve the color for the class ID
color = self.color_palette[class_id]
# Draw the bounding box on the image
cv2.rectangle(self.img, (int(x1), int(y1)), (int(x2), int(y2)), color, 2)
# Create the label text with class name and score
label = f"{self.classes[class_id]}: {score:.2f}"
# Calculate the dimensions of the label text
(label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
# Calculate the position of the label text
label_x = x1
label_y = y1 - 10 if y1 - 10 > label_height else y1 + 10
# Draw a filled rectangle as the background for the label text
cv2.rectangle(
self.img,
(int(label_x), int(label_y - label_height)),
(int(label_x + label_width), int(label_y + label_height)),
color,
cv2.FILLED,
)
# Draw the label text on the image
cv2.putText(
self.img, label, (int(label_x), int(label_y)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA
)
def preprocess(self) -> np.ndarray:
"""Preprocess the input image for model inference.
Loads the image, converts color space from BGR to RGB, resizes to model input dimensions, and normalizes pixel
values to [0, 1] range.
Returns:
(np.ndarray): Preprocessed image data with shape (1, 3, H, W) ready for inference.
"""
# Read the input image using OpenCV
self.img = cv2.imread(self.img_path)
if self.img is None:
raise FileNotFoundError(f"Image not found or unreadable: '{self.img_path}'")
# Get the height and width of the input image
self.img_height, self.img_width = self.img.shape[:2]
# Convert the image color space from BGR to RGB
img = cv2.cvtColor(self.img, cv2.COLOR_BGR2RGB)
# Resize the image to match the input shape
img = cv2.resize(img, (self.input_width, self.input_height))
# Normalize the image data by dividing it by 255.0
image_data = np.array(img) / 255.0
# Transpose the image to have the channel dimension as the first dimension
image_data = np.transpose(image_data, (2, 0, 1)) # Channel first
# Expand the dimensions of the image data to match the expected input shape
image_data = np.expand_dims(image_data, axis=0).astype(np.float32)
return image_data
def bbox_cxcywh_to_xyxy(self, boxes: np.ndarray) -> np.ndarray:
"""Convert bounding boxes from center format to corner format.
Args:
boxes (np.ndarray): Array of shape (N, 4) where each row represents a bounding box in (center_x, center_y,
width, height) format.
Returns:
(np.ndarray): Array of shape (N, 4) with bounding boxes in (x_min, y_min, x_max, y_max) format.
"""
# Calculate half width and half height of the bounding boxes
half_width = boxes[:, 2] / 2
half_height = boxes[:, 3] / 2
# Calculate the coordinates of the bounding boxes
x_min = boxes[:, 0] - half_width
y_min = boxes[:, 1] - half_height
x_max = boxes[:, 0] + half_width
y_max = boxes[:, 1] + half_height
# Return the bounding boxes in (x_min, y_min, x_max, y_max) format
return np.column_stack((x_min, y_min, x_max, y_max))
def postprocess(self, model_output: list[np.ndarray]) -> np.ndarray:
"""Postprocess model output to extract and visualize detections.
Applies confidence thresholding, converts bounding box format, scales coordinates to original image dimensions,
and draws detection annotations.
Args:
model_output (list[np.ndarray]): Output tensors from the model inference.
Returns:
(np.ndarray): Annotated image with detection bounding boxes and labels.
"""
# Squeeze the model output to remove unnecessary dimensions
outputs = np.squeeze(model_output[0])
# Extract bounding boxes and scores from the model output
boxes = outputs[:, :4]
scores = outputs[:, 4:]
# Get the class labels and scores for each detection
labels = np.argmax(scores, axis=1)
scores = np.max(scores, axis=1)
# Apply confidence threshold to filter out low-confidence detections
mask = scores > self.conf_thres
boxes, scores, labels = boxes[mask], scores[mask], labels[mask]
# Convert bounding boxes to (x_min, y_min, x_max, y_max) format
boxes = self.bbox_cxcywh_to_xyxy(boxes)
# Scale bounding boxes to match the original image dimensions
boxes[:, 0::2] *= self.img_width
boxes[:, 1::2] *= self.img_height
# Apply non-maximum suppression (optional for RT-DETR, but useful for filtering overlaps)
xywh_boxes = [[float(b[0]), float(b[1]), float(b[2] - b[0]), float(b[3] - b[1])] for b in boxes]
indices = cv2.dnn.NMSBoxes(xywh_boxes, scores.tolist(), self.conf_thres, self.iou_thres)
indices = indices.flatten().tolist() if len(indices) else []
# Draw detections on the image
for i in indices:
self.draw_detections(boxes[i], float(scores[i]), int(labels[i]))
return self.img
def main(self) -> np.ndarray:
"""Execute the complete object detection pipeline on the input image.
Performs preprocessing, ONNX model inference, and postprocessing to generate annotated detection results.
Returns:
(np.ndarray): Output image with detection annotations including bounding boxes and class labels.
"""
# Preprocess the image for model input
image_data = self.preprocess()
# Run the model inference
model_output = self.session.run(None, {self.model_input[0].name: image_data})
# Process and return the model output
return self.postprocess(model_output)
if __name__ == "__main__":
# Set up argument parser for command-line arguments
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="rtdetr-l.onnx", help="Path to the ONNX model file.")
parser.add_argument("--img", type=str, default="bus.jpg", help="Path to the input image.")
parser.add_argument("--conf-thres", type=float, default=0.5, help="Confidence threshold for object detection.")
parser.add_argument("--iou-thres", type=float, default=0.5, help="IoU threshold for non-maximum suppression.")
args = parser.parse_args()
# Create the detector instance with specified parameters
detection = RTDETR(args.model, args.img, args.conf_thres, args.iou_thres)
# Perform detection and get the output image
output_image = detection.main()
# Display the annotated output image
cv2.namedWindow("Output", cv2.WINDOW_NORMAL)
cv2.imshow("Output", output_image)
cv2.waitKey(0)
|