Panagiota Moraiti commited on
Commit
6bfc4b5
·
1 Parent(s): b8aa74f

Add python files

Browse files
scripts/model_factory.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from inference_rfdetr import RFDETRInference
2
+
3
+
4
+ def get_model(model_name, version, pretrain_weights):
5
+ """
6
+ Factory method to return the correct model inference class based on name and version.
7
+
8
+ Args:
9
+ model_name (str): Name of the model (e.g., 'rfdetr').
10
+ version (str): Version string (e.g., 'small', 'nano').
11
+ pretrain_weights (str): Path to model weights.
12
+
13
+ Returns:
14
+ BaseInference: A model inference object.
15
+
16
+ Raises:
17
+ ValueError: If model_name is unsupported.
18
+ """
19
+ if model_name == 'rfdetr':
20
+ return RFDETRInference(version, pretrain_weights)
21
+ else:
22
+ raise ValueError(f"Unsupported model: {model_name}")
23
+
scripts/plot_bboxes_save_images_and_yolo_predictions.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import supervision as sv
2
+ import cv2
3
+ import os
4
+
5
+
6
+ def convert_to_yolo(x1, y1, x2, y2, img_width, img_height):
7
+ """
8
+ Converts bounding box coordinates to YOLO format (relative center-x, center-y, width, height).
9
+
10
+ Returns:
11
+ tuple: (x_center, y_center, width, height)
12
+ """
13
+ x_center = (x1 + x2) / 2 / img_width
14
+ y_center = (y1 + y2) / 2 / img_height
15
+ width = (x2 - x1) / img_width
16
+ height = (y2 - y1) / img_height
17
+ return x_center, y_center, width, height
18
+
19
+
20
+ def save_yolo_labels(save_path, detections, image):
21
+ """
22
+ Saves detection boxes in YOLO format to a .txt file.
23
+
24
+ Args:
25
+ save_path (str): Base path to match image name.
26
+ detections (sv.Detections): Detection results.
27
+ image (np.ndarray): Image to get original dimensions.
28
+ """
29
+ boxes = detections.xyxy
30
+ scores = detections.confidence
31
+ labels_ids = detections.class_id
32
+
33
+ img_height, img_width = image.shape[:2]
34
+ yolo_boxes = [convert_to_yolo(x1, y1, x2, y2, img_width, img_height) for x1, y1, x2, y2 in boxes]
35
+
36
+ base_name = os.path.splitext(os.path.basename(save_path))[0]
37
+ preds_dir = os.path.join(os.path.dirname(save_path), "../predictions")
38
+ os.makedirs(preds_dir, exist_ok=True)
39
+ txt_path = os.path.join(preds_dir, base_name + ".txt")
40
+
41
+ with open(txt_path, "w") as f:
42
+ for box, label, conf in zip(yolo_boxes, labels_ids, scores):
43
+ x_center, y_center, width, height = box
44
+ f.write(f"{label} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f} {conf:.6f}\n")
45
+
46
+ # print(f"[LABELS SAVED] {os.path.basename(txt_path)} → {os.path.abspath(txt_path)}")
47
+
48
+
49
+ def annotate_image(image, detections, class_names):
50
+ """
51
+ Draws bounding boxes and class labels on the image.
52
+
53
+ Args:
54
+ image (np.ndarray): The image to annotate.
55
+ detections (sv.Detections): Detection results.
56
+ class_names (dict): Class ID to name mapping.
57
+
58
+ Returns:
59
+ np.ndarray: Annotated image.
60
+ """
61
+ if len(detections) > 0:
62
+ for i in range(len(detections)):
63
+ class_id = detections.class_id[i]
64
+ conf = detections.confidence[i]
65
+ label = f"{class_names[class_id]} {conf:.2f}"
66
+
67
+ # Draw box
68
+ box_annotator = sv.BoxAnnotator(thickness=6)
69
+ image = box_annotator.annotate(scene=image, detections=detections[i:i+1])
70
+
71
+ # Draw label
72
+ label_annotator = sv.LabelAnnotator(text_scale=2.0, text_thickness=4)
73
+ image = label_annotator.annotate(scene=image, detections=detections[i:i+1], labels=[label])
74
+
75
+ return image
76
+
77
+
78
+ def process_image_frame(image, detections, class_names, save_path, plot_dets=True, save_preds=True, show=False):
79
+ """
80
+ Handles image output: annotation, saving image, saving labels, and optional GUI display.
81
+
82
+ Args:
83
+ image (np.ndarray): Input image.
84
+ detections (sv.Detections): Detection results.
85
+ class_names (dict): Class ID to name mapping.
86
+ save_path (str): Path to save image and labels.
87
+ plot_dets (bool): Save annotated image.
88
+ save_preds (bool): Save YOLO labels.
89
+ show (bool): Show OpenCV window.
90
+ """
91
+ annotated_image = annotate_image(image, detections, class_names)
92
+
93
+ if plot_dets:
94
+ cv2.imwrite(save_path, annotated_image)
95
+ # print(f"[IMAGE SAVED] {os.path.basename(save_path)} → {os.path.abspath(save_path)}")
96
+
97
+ if save_preds:
98
+ save_yolo_labels(save_path, detections, image)
99
+
100
+ if show:
101
+ h, w = image.shape[:2]
102
+ scale = min(750 / w, 750 / h)
103
+ resized = cv2.resize(annotated_image, (int(w * scale), int(h * scale)))
104
+ cv2.imshow("Detection", resized)
105
+ cv2.waitKey(1500)
106
+ cv2.destroyAllWindows()
107
+
108
+
109
+ def process_video_frame(frame, detections, class_names, plot_dets=True, show=False, video_writer=None):
110
+ """
111
+ Handles each frame of video for annotation, display, and saving.
112
+
113
+ Args:
114
+ frame (np.ndarray): Video frame.
115
+ detections (sv.Detections): Detection results.
116
+ class_names (dict): Class ID to name mapping.
117
+ plot_dets (bool): Save annotated video frame.
118
+ show (bool): Display annotated frame in a window.
119
+ video_writer (cv2.VideoWriter): OpenCV video writer object.
120
+ """
121
+ annotated_frame = annotate_image(frame, detections, class_names)
122
+
123
+ if plot_dets and video_writer is not None:
124
+ video_writer.write(annotated_frame)
125
+
126
+ if show:
127
+ h, w = frame.shape[:2]
128
+ scale = min(750 / w, 750 / h)
129
+ resized = cv2.resize(annotated_frame, (int(w * scale), int(h * scale)))
130
+ cv2.imshow("Detection", resized)
131
+ cv2.waitKey(1)