nathbns commited on
Commit
b343099
·
verified ·
1 Parent(s): 7051cbc

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +284 -0
  2. model.py +162 -0
  3. requirements.txt +8 -0
  4. utils.py +136 -0
app.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 🎯 Application Gradio pour YOLOv3 Object Detection - Pascal VOC
3
+ Déployée sur Hugging Face Spaces
4
+ """
5
+
6
+ import gradio as gr
7
+ import torch
8
+ import cv2
9
+ import numpy as np
10
+ from PIL import Image
11
+ from huggingface_hub import hf_hub_download
12
+ import os
13
+
14
+ # Configuration
15
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
16
+ IMAGE_SIZE = 416
17
+ NUM_CLASSES = 20
18
+
19
+ # Anchors YOLOv3 (normalisés)
20
+ ANCHORS = [
21
+ [(0.28, 0.22), (0.38, 0.48), (0.9, 0.78)],
22
+ [(0.07, 0.15), (0.15, 0.11), (0.14, 0.29)],
23
+ [(0.02, 0.03), (0.04, 0.07), (0.08, 0.06)],
24
+ ]
25
+
26
+ S = [IMAGE_SIZE // 32, IMAGE_SIZE // 16, IMAGE_SIZE // 8]
27
+
28
+ # Classes Pascal VOC
29
+ PASCAL_CLASSES = [
30
+ "aeroplane", "bicycle", "bird", "boat", "bottle",
31
+ "bus", "car", "cat", "chair", "cow",
32
+ "diningtable", "dog", "horse", "motorbike", "person",
33
+ "pottedplant", "sheep", "sofa", "train", "tvmonitor"
34
+ ]
35
+
36
+ # Import du modèle
37
+ from model import YOLOv3
38
+ from utils import cells_to_bboxes, non_max_suppression
39
+
40
+
41
+ class YOLOv3Detector:
42
+ def __init__(self, checkpoint_path):
43
+ """Initialise le détecteur YOLOv3"""
44
+ print(f"🔧 Device: {DEVICE}")
45
+
46
+ # Charger le modèle
47
+ self.model = YOLOv3(num_classes=NUM_CLASSES).to(DEVICE)
48
+ checkpoint = torch.load(checkpoint_path, map_location=DEVICE)
49
+ self.model.load_state_dict(checkpoint["state_dict"])
50
+ self.model.eval()
51
+
52
+ # Anchors mis à l'échelle
53
+ self.scaled_anchors = (
54
+ torch.tensor(ANCHORS)
55
+ * torch.tensor(S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)
56
+ ).to(DEVICE)
57
+
58
+ # Couleurs pour chaque classe
59
+ np.random.seed(42)
60
+ self.colors = np.random.randint(0, 255, size=(len(PASCAL_CLASSES), 3), dtype=np.uint8)
61
+
62
+ print("✅ Modèle chargé avec succès!")
63
+
64
+ def preprocess_image(self, image):
65
+ """Prétraite l'image pour le modèle"""
66
+ if isinstance(image, Image.Image):
67
+ image = np.array(image)
68
+
69
+ original_shape = image.shape[:2]
70
+ image_resized = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE))
71
+
72
+ # Normaliser et convertir en tensor
73
+ image_tensor = torch.from_numpy(image_resized).float() / 255.0
74
+ image_tensor = image_tensor.permute(2, 0, 1).unsqueeze(0)
75
+
76
+ return image_tensor.to(DEVICE), original_shape
77
+
78
+ def detect(self, image, conf_threshold=0.5, iou_threshold=0.45):
79
+ """Détecte les objets dans l'image"""
80
+ image_tensor, original_shape = self.preprocess_image(image)
81
+
82
+ with torch.no_grad():
83
+ predictions = self.model(image_tensor)
84
+
85
+ # Convertir les prédictions en bboxes
86
+ bboxes = [[] for _ in range(1)]
87
+ for i in range(3):
88
+ S = predictions[i].shape[2]
89
+ anchor = self.scaled_anchors[i]
90
+ boxes_scale_i = cells_to_bboxes(
91
+ predictions[i], anchor, S=S, is_preds=True
92
+ )
93
+ for idx, box in enumerate(boxes_scale_i):
94
+ bboxes[idx] += box
95
+
96
+ # Appliquer NMS
97
+ nms_boxes = non_max_suppression(
98
+ bboxes[0],
99
+ iou_threshold=iou_threshold,
100
+ threshold=conf_threshold,
101
+ box_format="midpoint",
102
+ )
103
+
104
+ return nms_boxes
105
+
106
+ def draw_boxes(self, image, boxes):
107
+ """Dessine les bounding boxes sur l'image"""
108
+ if isinstance(image, Image.Image):
109
+ image = np.array(image)
110
+
111
+ image = image.copy()
112
+ height, width = image.shape[:2]
113
+
114
+ detections_info = []
115
+
116
+ for box in boxes:
117
+ class_idx = int(box[0])
118
+ confidence = box[1]
119
+ x_center, y_center, box_width, box_height = box[2:]
120
+
121
+ # Convertir en coordonnées pixel
122
+ x1 = int((x_center - box_width / 2) * width)
123
+ y1 = int((y_center - box_height / 2) * height)
124
+ x2 = int((x_center + box_width / 2) * width)
125
+ y2 = int((y_center + box_height / 2) * height)
126
+
127
+ # Couleur pour cette classe
128
+ color = self.colors[class_idx].tolist()
129
+
130
+ # Dessiner le rectangle
131
+ cv2.rectangle(image, (x1, y1), (x2, y2), color, 2)
132
+
133
+ # Texte
134
+ label = f"{PASCAL_CLASSES[class_idx]}: {confidence:.2f}"
135
+
136
+ # Fond du texte
137
+ (text_width, text_height), _ = cv2.getTextSize(
138
+ label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1
139
+ )
140
+ cv2.rectangle(
141
+ image,
142
+ (x1, y1 - text_height - 4),
143
+ (x1 + text_width, y1),
144
+ color,
145
+ -1
146
+ )
147
+
148
+ # Texte blanc
149
+ cv2.putText(
150
+ image,
151
+ label,
152
+ (x1, y1 - 4),
153
+ cv2.FONT_HERSHEY_SIMPLEX,
154
+ 0.5,
155
+ (255, 255, 255),
156
+ 1
157
+ )
158
+
159
+ detections_info.append(f"• {PASCAL_CLASSES[class_idx]}: {confidence:.1%}")
160
+
161
+ return image, detections_info
162
+
163
+
164
+ # Télécharger le modèle depuis Hugging Face
165
+ print("📥 Téléchargement du modèle depuis Hugging Face...")
166
+ checkpoint_path = hf_hub_download(
167
+ repo_id="nathbns/yolov3_from_scratch",
168
+ filename="checkpoint.pth.tar"
169
+ )
170
+ print(f"✅ Modèle téléchargé: {checkpoint_path}")
171
+
172
+ # Initialiser le détecteur
173
+ print("🚀 Chargement du modèle...")
174
+ detector = YOLOv3Detector(checkpoint_path)
175
+
176
+
177
+ def predict(image, conf_threshold, iou_threshold):
178
+ """Fonction de prédiction pour Gradio"""
179
+ if image is None:
180
+ return None, "❌ Aucune image fournie"
181
+
182
+ # Détecter
183
+ boxes = detector.detect(image, conf_threshold, iou_threshold)
184
+
185
+ # Dessiner
186
+ result_image, detections = detector.draw_boxes(image, boxes)
187
+
188
+ # Texte des détections
189
+ if detections:
190
+ detection_text = f"**✅ {len(detections)} objet(s) détecté(s) :**\n\n" + "\n".join(detections)
191
+ else:
192
+ detection_text = "❌ Aucun objet détecté"
193
+
194
+ return result_image, detection_text
195
+
196
+
197
+ # Interface Gradio
198
+ with gr.Blocks(title="YOLOv3 Object Detection", theme=gr.themes.Soft()) as demo:
199
+ gr.Markdown(
200
+ """
201
+ # 🎯 YOLOv3 Object Detection - Pascal VOC
202
+
203
+ Uploadez une image pour détecter des objets parmi **20 classes Pascal VOC**.
204
+
205
+ **Classes détectables:** personne, vélo, voiture, moto, avion, bus, train, camion, bateau,
206
+ feu de circulation, bouche d'incendie, panneau stop, parcomètre, banc, oiseau, chat, chien,
207
+ cheval, mouton, vache, éléphant, ours, zèbre, girafe, sac à dos, parapluie, etc.
208
+
209
+ ---
210
+ """
211
+ )
212
+
213
+ with gr.Row():
214
+ with gr.Column():
215
+ input_image = gr.Image(type="pil", label="📸 Image d'entrée")
216
+
217
+ with gr.Accordion("⚙️ Paramètres", open=True):
218
+ conf_slider = gr.Slider(
219
+ minimum=0.1,
220
+ maximum=1.0,
221
+ value=0.5,
222
+ step=0.05,
223
+ label="Seuil de confiance",
224
+ info="Plus élevé = moins de détections mais plus sûres"
225
+ )
226
+ iou_slider = gr.Slider(
227
+ minimum=0.1,
228
+ maximum=1.0,
229
+ value=0.45,
230
+ step=0.05,
231
+ label="Seuil NMS (IoU)",
232
+ info="Plus élevé = plus de boîtes qui se chevauchent"
233
+ )
234
+
235
+ detect_btn = gr.Button("🔍 Détecter les objets", variant="primary", size="lg")
236
+
237
+ with gr.Column():
238
+ output_image = gr.Image(label="✨ Résultat")
239
+ output_text = gr.Markdown(label="📊 Détections")
240
+
241
+ # Action
242
+ detect_btn.click(
243
+ fn=predict,
244
+ inputs=[input_image, conf_slider, iou_slider],
245
+ outputs=[output_image, output_text]
246
+ )
247
+
248
+ # Auto-run sur upload
249
+ input_image.change(
250
+ fn=predict,
251
+ inputs=[input_image, conf_slider, iou_slider],
252
+ outputs=[output_image, output_text]
253
+ )
254
+
255
+ gr.Markdown(
256
+ """
257
+ ---
258
+
259
+ ### 📊 Informations sur le modèle
260
+
261
+ - **Architecture:** YOLOv3 (Darknet-53 backbone)
262
+ - **Dataset:** Pascal VOC (16 551 images d'entraînement)
263
+ - **Epochs:** 100
264
+ - **mAP @ 0.5 IoU:** ~38.3%
265
+ - **Classes:** 20 objets courants
266
+ - **Taille d'entrée:** 416x416
267
+
268
+ ---
269
+
270
+ ### 💡 Astuces
271
+
272
+ - **Seuil de confiance bas (0.3):** Plus de détections, mais plus de faux positifs
273
+ - **Seuil de confiance élevé (0.7):** Moins de détections, mais plus précises
274
+ - **Seuil NMS:** Contrôle le chevauchement des boîtes de détection
275
+
276
+ ---
277
+
278
+ Créé avec ❤️ par [nathbns](https://huggingface.co/nathbns)
279
+ """
280
+ )
281
+
282
+ if __name__ == "__main__":
283
+ demo.launch()
284
+
model.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Implementation of YOLOv3 architecture
3
+ """
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+
9
+ class CNNBlock(nn.Module):
10
+ """Convolutional block with BatchNorm and LeakyReLU"""
11
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bn_act=True):
12
+ super().__init__()
13
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=not bn_act)
14
+ self.bn = nn.BatchNorm2d(out_channels)
15
+ self.leaky = nn.LeakyReLU(0.1)
16
+ self.use_bn_act = bn_act
17
+
18
+ def forward(self, x):
19
+ if self.use_bn_act:
20
+ return self.leaky(self.bn(self.conv(x)))
21
+ else:
22
+ return self.conv(x)
23
+
24
+
25
+ class ResidualBlock(nn.Module):
26
+ """Residual block with skip connection"""
27
+ def __init__(self, channels, num_repeats=1):
28
+ super().__init__()
29
+ self.layers = nn.ModuleList()
30
+ for _ in range(num_repeats):
31
+ self.layers.append(
32
+ nn.Sequential(
33
+ CNNBlock(channels, channels // 2, kernel_size=1),
34
+ CNNBlock(channels // 2, channels, kernel_size=3, padding=1),
35
+ )
36
+ )
37
+
38
+ def forward(self, x):
39
+ for layer in self.layers:
40
+ x = x + layer(x)
41
+ return x
42
+
43
+
44
+ class ScalePrediction(nn.Module):
45
+ """Scale prediction block for YOLO output"""
46
+ def __init__(self, in_channels, num_classes):
47
+ super().__init__()
48
+ self.pred = nn.Sequential(
49
+ CNNBlock(in_channels, 2 * in_channels, kernel_size=3, padding=1),
50
+ CNNBlock(2 * in_channels, (num_classes + 5) * 3, kernel_size=1, bn_act=False),
51
+ )
52
+ self.num_classes = num_classes
53
+
54
+ def forward(self, x):
55
+ return (
56
+ self.pred(x)
57
+ .reshape(x.shape[0], 3, self.num_classes + 5, x.shape[2], x.shape[3])
58
+ .permute(0, 1, 3, 4, 2)
59
+ )
60
+
61
+
62
+ class YOLOv3(nn.Module):
63
+ """YOLOv3 architecture with Darknet-53 backbone"""
64
+ def __init__(self, in_channels=3, num_classes=20):
65
+ super().__init__()
66
+ self.num_classes = num_classes
67
+
68
+ # Darknet-53 Backbone
69
+ self.conv1 = CNNBlock(in_channels, 32, kernel_size=3, stride=1, padding=1)
70
+
71
+ self.conv2 = CNNBlock(32, 64, kernel_size=3, stride=2, padding=1)
72
+ self.residual1 = ResidualBlock(64, num_repeats=1)
73
+
74
+ self.conv3 = CNNBlock(64, 128, kernel_size=3, stride=2, padding=1)
75
+ self.residual2 = ResidualBlock(128, num_repeats=2)
76
+
77
+ self.conv4 = CNNBlock(128, 256, kernel_size=3, stride=2, padding=1)
78
+ self.residual3 = ResidualBlock(256, num_repeats=8)
79
+
80
+ self.conv5 = CNNBlock(256, 512, kernel_size=3, stride=2, padding=1)
81
+ self.residual4 = ResidualBlock(512, num_repeats=8)
82
+
83
+ self.conv6 = CNNBlock(512, 1024, kernel_size=3, stride=2, padding=1)
84
+ self.residual5 = ResidualBlock(1024, num_repeats=4)
85
+
86
+ # First scale prediction (13x13 - large objects)
87
+ self.conv7 = CNNBlock(1024, 512, kernel_size=1, stride=1, padding=0)
88
+ self.conv8 = CNNBlock(512, 1024, kernel_size=3, stride=1, padding=1)
89
+ self.residual6 = ResidualBlock(1024, num_repeats=1)
90
+ self.conv9 = CNNBlock(1024, 512, kernel_size=1, stride=1, padding=0)
91
+ self.scale_pred1 = ScalePrediction(512, num_classes=num_classes)
92
+
93
+ # Second scale (26x26 - medium objects)
94
+ self.conv10 = CNNBlock(512, 256, kernel_size=1, stride=1, padding=0)
95
+ self.upsample1 = nn.Upsample(scale_factor=2, mode='nearest')
96
+
97
+ self.conv11 = CNNBlock(768, 256, kernel_size=1, stride=1, padding=0)
98
+ self.conv12 = CNNBlock(256, 512, kernel_size=3, stride=1, padding=1)
99
+ self.residual7 = ResidualBlock(512, num_repeats=1)
100
+ self.conv13 = CNNBlock(512, 256, kernel_size=1, stride=1, padding=0)
101
+ self.scale_pred2 = ScalePrediction(256, num_classes=num_classes)
102
+
103
+ # Third scale (52x52 - small objects)
104
+ self.conv14 = CNNBlock(256, 128, kernel_size=1, stride=1, padding=0)
105
+ self.upsample2 = nn.Upsample(scale_factor=2, mode='nearest')
106
+
107
+ self.conv15 = CNNBlock(384, 128, kernel_size=1, stride=1, padding=0)
108
+ self.conv16 = CNNBlock(128, 256, kernel_size=3, stride=1, padding=1)
109
+ self.residual8 = ResidualBlock(256, num_repeats=1)
110
+ self.conv17 = CNNBlock(256, 128, kernel_size=1, stride=1, padding=0)
111
+ self.scale_pred3 = ScalePrediction(128, num_classes=num_classes)
112
+
113
+ def forward(self, x):
114
+ # Darknet-53 feature extraction
115
+ x = self.conv1(x)
116
+
117
+ x = self.conv2(x)
118
+ x = self.residual1(x)
119
+
120
+ x = self.conv3(x)
121
+ x = self.residual2(x)
122
+
123
+ x = self.conv4(x)
124
+ route1 = self.residual3(x)
125
+
126
+ x = self.conv5(route1)
127
+ route2 = self.residual4(x)
128
+
129
+ x = self.conv6(route2)
130
+ x = self.residual5(x)
131
+
132
+ # First scale (13x13)
133
+ x = self.conv7(x)
134
+ x = self.conv8(x)
135
+ x = self.residual6(x)
136
+ x = self.conv9(x)
137
+ out1 = self.scale_pred1(x)
138
+
139
+ # Second scale (26x26)
140
+ x = self.conv10(x)
141
+ x = self.upsample1(x)
142
+ x = torch.cat([x, route2], dim=1)
143
+
144
+ x = self.conv11(x)
145
+ x = self.conv12(x)
146
+ x = self.residual7(x)
147
+ x = self.conv13(x)
148
+ out2 = self.scale_pred2(x)
149
+
150
+ # Third scale (52x52)
151
+ x = self.conv14(x)
152
+ x = self.upsample2(x)
153
+ x = torch.cat([x, route1], dim=1)
154
+
155
+ x = self.conv15(x)
156
+ x = self.conv16(x)
157
+ x = self.residual8(x)
158
+ x = self.conv17(x)
159
+ out3 = self.scale_pred3(x)
160
+
161
+ return [out1, out2, out3]
162
+
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ torchvision>=0.15.0
3
+ gradio>=4.0.0
4
+ huggingface-hub>=0.17.0
5
+ opencv-python-headless>=4.8.0
6
+ numpy>=1.24.0
7
+ Pillow>=10.0.0
8
+
utils.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility functions for YOLOv3 (simplifié pour Gradio)
3
+ """
4
+
5
+ import torch
6
+
7
+
8
+ def intersection_over_union(boxes_preds, boxes_labels, box_format="midpoint"):
9
+ """
10
+ Calcule l'intersection over union (IoU) entre deux bounding boxes
11
+
12
+ Args:
13
+ boxes_preds: Prédictions [x, y, w, h] ou [x1, y1, x2, y2]
14
+ boxes_labels: Labels [x, y, w, h] ou [x1, y1, x2, y2]
15
+ box_format: "midpoint" ou "corners"
16
+
17
+ Returns:
18
+ IoU score
19
+ """
20
+ if box_format == "midpoint":
21
+ box1_x1 = boxes_preds[..., 0:1] - boxes_preds[..., 2:3] / 2
22
+ box1_y1 = boxes_preds[..., 1:2] - boxes_preds[..., 3:4] / 2
23
+ box1_x2 = boxes_preds[..., 0:1] + boxes_preds[..., 2:3] / 2
24
+ box1_y2 = boxes_preds[..., 1:2] + boxes_preds[..., 3:4] / 2
25
+ box2_x1 = boxes_labels[..., 0:1] - boxes_labels[..., 2:3] / 2
26
+ box2_y1 = boxes_labels[..., 1:2] - boxes_labels[..., 3:4] / 2
27
+ box2_x2 = boxes_labels[..., 0:1] + boxes_labels[..., 2:3] / 2
28
+ box2_y2 = boxes_labels[..., 1:2] + boxes_labels[..., 3:4] / 2
29
+ else: # corners
30
+ box1_x1 = boxes_preds[..., 0:1]
31
+ box1_y1 = boxes_preds[..., 1:2]
32
+ box1_x2 = boxes_preds[..., 2:3]
33
+ box1_y2 = boxes_preds[..., 3:4]
34
+ box2_x1 = boxes_labels[..., 0:1]
35
+ box2_y1 = boxes_labels[..., 1:2]
36
+ box2_x2 = boxes_labels[..., 2:3]
37
+ box2_y2 = boxes_labels[..., 3:4]
38
+
39
+ x1 = torch.max(box1_x1, box2_x1)
40
+ y1 = torch.max(box1_y1, box2_y1)
41
+ x2 = torch.min(box1_x2, box2_x2)
42
+ y2 = torch.min(box1_y2, box2_y2)
43
+
44
+ intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)
45
+ box1_area = abs((box1_x2 - box1_x1) * (box1_y2 - box1_y1))
46
+ box2_area = abs((box2_x2 - box2_x1) * (box2_y2 - box2_y1))
47
+
48
+ return intersection / (box1_area + box2_area - intersection + 1e-6)
49
+
50
+
51
+ def non_max_suppression(bboxes, iou_threshold, threshold, box_format="corners"):
52
+ """
53
+ Applique le Non-Maximum Suppression (NMS)
54
+
55
+ Args:
56
+ bboxes: Liste de bboxes [class_pred, prob_score, x, y, w, h]
57
+ iou_threshold: Seuil IoU pour supprimer les boxes
58
+ threshold: Seuil de confiance minimum
59
+ box_format: "midpoint" ou "corners"
60
+
61
+ Returns:
62
+ Liste de bboxes après NMS
63
+ """
64
+ assert type(bboxes) == list
65
+
66
+ # Filtrer par confiance
67
+ bboxes = [box for box in bboxes if box[1] > threshold]
68
+ bboxes = sorted(bboxes, key=lambda x: x[1], reverse=True)
69
+ bboxes_after_nms = []
70
+
71
+ while bboxes:
72
+ chosen_box = bboxes.pop(0)
73
+
74
+ bboxes = [
75
+ box
76
+ for box in bboxes
77
+ if box[0] != chosen_box[0] # Différente classe
78
+ or intersection_over_union(
79
+ torch.tensor(chosen_box[2:]),
80
+ torch.tensor(box[2:]),
81
+ box_format=box_format,
82
+ )
83
+ < iou_threshold # IoU faible
84
+ ]
85
+
86
+ bboxes_after_nms.append(chosen_box)
87
+
88
+ return bboxes_after_nms
89
+
90
+
91
+ def cells_to_bboxes(predictions, anchors, S, is_preds=True):
92
+ """
93
+ Convertit les prédictions YOLOv3 en bounding boxes
94
+
95
+ Args:
96
+ predictions: Tensor [N, 3, S, S, num_classes+5]
97
+ anchors: Anchors pour cette échelle
98
+ S: Taille de la grille (13, 26, ou 52)
99
+ is_preds: Si True, applique sigmoid/exp
100
+
101
+ Returns:
102
+ Liste de bboxes converties
103
+ """
104
+ BATCH_SIZE = predictions.shape[0]
105
+ num_anchors = len(anchors)
106
+ box_predictions = predictions[..., 1:5]
107
+
108
+ if is_preds:
109
+ anchors = anchors.reshape(1, len(anchors), 1, 1, 2)
110
+ box_predictions[..., 0:2] = torch.sigmoid(box_predictions[..., 0:2])
111
+ box_predictions[..., 2:] = torch.exp(box_predictions[..., 2:]) * anchors
112
+ scores = torch.sigmoid(predictions[..., 0:1])
113
+ best_class = torch.argmax(predictions[..., 5:], dim=-1).unsqueeze(-1)
114
+ else:
115
+ scores = predictions[..., 0:1]
116
+ best_class = predictions[..., 5:6]
117
+
118
+ # Indices de cellules
119
+ cell_indices = (
120
+ torch.arange(S)
121
+ .repeat(predictions.shape[0], 3, S, 1)
122
+ .unsqueeze(-1)
123
+ .to(predictions.device)
124
+ )
125
+
126
+ # Convertir en coordonnées absolues [0, 1]
127
+ x = 1 / S * (box_predictions[..., 0:1] + cell_indices)
128
+ y = 1 / S * (box_predictions[..., 1:2] + cell_indices.permute(0, 1, 3, 2, 4))
129
+ w_h = 1 / S * box_predictions[..., 2:4]
130
+
131
+ converted_bboxes = torch.cat((best_class, scores, x, y, w_h), dim=-1).reshape(
132
+ BATCH_SIZE, num_anchors * S * S, 6
133
+ )
134
+
135
+ return converted_bboxes.tolist()
136
+