nathbns commited on
Commit
6698bc8
·
verified ·
1 Parent(s): 4490e35

Upload 6 files

Browse files
Files changed (6) hide show
  1. app.py +269 -0
  2. dataset.py +69 -0
  3. loss.py +83 -0
  4. model.py +110 -0
  5. requirements.txt +8 -0
  6. utils.py +337 -0
app.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ import numpy as np
4
+ from PIL import Image
5
+ import torchvision.transforms as transforms
6
+ from model import Yolov1
7
+ from utils import cellboxes_to_boxes, non_max_suppression
8
+ import cv2
9
+ import os
10
+ import glob
11
+ import time
12
+
13
+ # Classes PASCAL VOC
14
+ CLASSES = [
15
+ "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat",
16
+ "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person",
17
+ "pottedplant", "sheep", "sofa", "train", "tvmonitor"
18
+ ]
19
+
20
+ np.random.seed(42)
21
+ COLORS = np.random.randint(50, 255, size=(len(CLASSES), 3), dtype=np.uint8)
22
+
23
+ DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
24
+ MODEL_PATH = "checkpoint_epoch_50.pth.tar"
25
+
26
+ # Charger le modèle
27
+ print(f"Chargement du modèle depuis {MODEL_PATH}...")
28
+ model = Yolov1(split_size=7, num_boxes=2, num_classes=20).to(DEVICE)
29
+ checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
30
+ model.load_state_dict(checkpoint["state_dict"])
31
+ model.eval()
32
+ print(f"Modèle chargé avec succès!")
33
+
34
+ # Info sur le modèle
35
+ MODEL_INFO = {
36
+ "mAP": checkpoint.get("mAP", "N/A"),
37
+ "epoch": checkpoint.get("epoch", "N/A"),
38
+ "device": DEVICE,
39
+ "classes": len(CLASSES)
40
+ }
41
+ print(f"entraînement: {MODEL_INFO['mAP']}")
42
+ print(f"Device: {DEVICE}")
43
+
44
+ # Charger des images d'exemple depuis le dossier data
45
+ EXAMPLE_IMAGES = []
46
+ if os.path.exists("data/images"):
47
+ image_files = glob.glob("data/images/*.jpg")[:20] # Prendre 20 images
48
+ EXAMPLE_IMAGES = sorted(image_files)
49
+ print(f"{len(EXAMPLE_IMAGES)} images d'exemple chargées")
50
+
51
+ def draw_boxes(image, boxes):
52
+ """Dessine les bounding boxes sur l'image"""
53
+ img_array = np.array(image)
54
+ height, width = img_array.shape[:2]
55
+
56
+ for box in boxes:
57
+ # box format: [class_pred, prob_score, x, y, width, height]
58
+ class_pred = int(box[0])
59
+ confidence = float(box[1])
60
+ x_center, y_center, box_width, box_height = box[2:6]
61
+
62
+ # Convertir de coordonnées normalisées à pixels
63
+ x1 = int((x_center - box_width / 2) * width)
64
+ y1 = int((y_center - box_height / 2) * height)
65
+ x2 = int((x_center + box_width / 2) * width)
66
+ y2 = int((y_center + box_height / 2) * height)
67
+
68
+ # Couleur de la classe
69
+ color = tuple(int(c) for c in COLORS[class_pred])
70
+
71
+ # Dessiner le rectangle
72
+ cv2.rectangle(img_array, (x1, y1), (x2, y2), color, 2)
73
+
74
+ # Texte
75
+ label = f"{CLASSES[class_pred]}: {confidence:.2f}"
76
+
77
+ # Fond du texte
78
+ (text_width, text_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
79
+ cv2.rectangle(img_array, (x1, y1 - text_height - 5), (x1 + text_width, y1), color, -1)
80
+
81
+ # Texte blanc
82
+ cv2.putText(img_array, label, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
83
+
84
+ return Image.fromarray(img_array)
85
+
86
+ def detect_objects(image, confidence_threshold, iou_threshold, show_confidence=True):
87
+ """Détecte les objets dans une image avec statistiques détaillées"""
88
+ if image is None:
89
+ return None, None, "**Veuillez uploader ou sélectionner une image**"
90
+
91
+ start_time = time.time()
92
+
93
+ # Prétraiter l'image
94
+ transform = transforms.Compose([
95
+ transforms.Resize((448, 448)),
96
+ transforms.ToTensor(),
97
+ ])
98
+
99
+ # Garder l'image originale pour l'affichage
100
+ original_image = image.copy()
101
+ original_size = image.size # (width, height)
102
+
103
+ # Transformer l'image
104
+ img_tensor = transform(image).unsqueeze(0).to(DEVICE)
105
+
106
+ # Prédiction
107
+ with torch.no_grad():
108
+ predictions = model(img_tensor)
109
+
110
+ # Convertir les prédictions en bounding boxes
111
+ bboxes = cellboxes_to_boxes(predictions)
112
+
113
+ # Non-maximum suppression
114
+ nms_boxes = non_max_suppression(
115
+ bboxes[0],
116
+ iou_threshold=iou_threshold,
117
+ threshold=confidence_threshold,
118
+ box_format="midpoint"
119
+ )
120
+
121
+ inference_time = time.time() - start_time
122
+
123
+ # Dessiner les boxes
124
+ result_image = draw_boxes(original_image, nms_boxes)
125
+
126
+ # Statistiques détaillées
127
+ num_detections = len(nms_boxes)
128
+ detected_classes = [CLASSES[int(box[0])] for box in nms_boxes]
129
+ class_counts = {}
130
+ confidence_scores = []
131
+
132
+ for box in nms_boxes:
133
+ cls = CLASSES[int(box[0])]
134
+ conf = float(box[1])
135
+ class_counts[cls] = class_counts.get(cls, 0) + 1
136
+ confidence_scores.append(conf)
137
+
138
+ # Créer un rapport détaillé
139
+ stats = f"##Résultats de détection\n\n"
140
+ stats += f"**{num_detections} objet(s) détecté(s)**\n\n"
141
+
142
+ if num_detections > 0:
143
+ stats += f"Temps d'inférence: **{inference_time:.3f}s**\n"
144
+ stats += f"Taille image: **{original_size[0]}x{original_size[1]}**\n"
145
+ stats += f"Confiance moyenne: **{np.mean(confidence_scores):.2%}**\n\n"
146
+
147
+ stats += "### Objets détectés:\n\n"
148
+ for cls, count in sorted(class_counts.items(), key=lambda x: x[1], reverse=True):
149
+ stats += f"- **{cls}**: {count}\n"
150
+
151
+ if show_confidence:
152
+ stats += "\n### Confiances individuelles:\n\n"
153
+ for i, box in enumerate(nms_boxes[:10], 1): # Top 10
154
+ cls = CLASSES[int(box[0])]
155
+ conf = float(box[1])
156
+ stats += f"{i}. {cls}: {conf:.1%}\n"
157
+ if len(nms_boxes) > 10:
158
+ stats += f"\n*...et {len(nms_boxes)-10} détection(s) de plus*\n"
159
+ else:
160
+ stats += "Aucun objet détecté.\n\n"
161
+
162
+ return original_image, result_image, stats
163
+
164
+ # Interface Gradio améliorée
165
+ with gr.Blocks(title="YOLO v1 - Détection d'objets", theme=gr.themes.Soft(), css="""
166
+ .gradio-container {max-width: 1400px !important}
167
+ .example-gallery {height: 400px; overflow-y: auto}
168
+ """) as demo:
169
+
170
+ # En-tête
171
+ mAP_display = f"{MODEL_INFO['mAP']:.4f}" if isinstance(MODEL_INFO['mAP'], (int, float)) else MODEL_INFO['mAP']
172
+
173
+ gr.Markdown(f"""
174
+ # YOLO v1 - Détection d'objets en temps réel
175
+ ---
176
+ """)
177
+
178
+ with gr.Tabs():
179
+ # Onglet principal - Détection
180
+ with gr.Tab("Détection"):
181
+ gr.Markdown("""
182
+ ### Uploadez votre image ou sélectionnez un exemple
183
+ **Classes PASCAL VOC :** aeroplane, bicycle, bird, boat, bottle, bus, car, cat, chair, cow,
184
+ diningtable, dog, horse, motorbike, person, pottedplant, sheep, sofa, train, tvmonitor
185
+ """)
186
+
187
+ with gr.Row():
188
+ with gr.Column(scale=1):
189
+ input_image = gr.Image(type="pil", label="Image d'entrée")
190
+
191
+ with gr.Accordion("Paramètres avancés", open=True):
192
+ confidence_slider = gr.Slider(
193
+ minimum=0.05,
194
+ maximum=0.95,
195
+ value=0.4,
196
+ step=0.05,
197
+ label="Seuil de confiance",
198
+ info="Plus bas = plus de détections"
199
+ )
200
+ iou_slider = gr.Slider(
201
+ minimum=0.1,
202
+ maximum=0.9,
203
+ value=0.5,
204
+ step=0.05,
205
+ label="Seuil",
206
+ info="Plus haut = garde plus de boxes qui se chevauchent"
207
+ )
208
+ show_conf_check = gr.Checkbox(
209
+ value=True,
210
+ label="Afficher les confiances détaillées"
211
+ )
212
+
213
+ detect_btn = gr.Button("Détecter les objets", variant="primary", size="lg")
214
+
215
+
216
+ with gr.Column(scale=2):
217
+ with gr.Row():
218
+ original_display = gr.Image(type="pil", label="Image originale")
219
+ output_image = gr.Image(type="pil", label="Résultat avec détections")
220
+
221
+ output_stats = gr.Markdown("**Uploadez une image et cliquez sur 'Détecter' pour commencer !**")
222
+
223
+ # Galerie d'exemples
224
+ if EXAMPLE_IMAGES:
225
+ gr.Markdown("### Exemples (cliquez pour tester)")
226
+ examples_list = [[img, 0.4, 0.5, True] for img in EXAMPLE_IMAGES[:12]]
227
+ gr.Examples(
228
+ examples=examples_list,
229
+ inputs=[input_image, confidence_slider, iou_slider, show_conf_check],
230
+ outputs=[original_display, output_image, output_stats],
231
+ fn=detect_objects,
232
+ cache_examples=False,
233
+ examples_per_page=6,
234
+ )
235
+
236
+ # Actions
237
+ detect_btn.click(
238
+ fn=detect_objects,
239
+ inputs=[input_image, confidence_slider, iou_slider, show_conf_check],
240
+ outputs=[original_display, output_image, output_stats]
241
+ )
242
+
243
+ input_image.change(
244
+ fn=detect_objects,
245
+ inputs=[input_image, confidence_slider, iou_slider, show_conf_check],
246
+ outputs=[original_display, output_image, output_stats]
247
+ )
248
+
249
+ # Onglet Info
250
+ with gr.Tab("À propos"):
251
+ mAP_info = f"{MODEL_INFO['mAP']:.4f}" if isinstance(MODEL_INFO['mAP'], (int, float)) else 'N/A'
252
+ epoch_info = MODEL_INFO['epoch'] if MODEL_INFO['epoch'] != 'N/A' else 'N/A'
253
+
254
+ # Lancer l'app
255
+ if __name__ == "__main__":
256
+ print("\n" + "="*60)
257
+ print("Lancement de l'application Gradio YOLO v1")
258
+ print("="*60)
259
+ print(f"Modèle: {MODEL_PATH}")
260
+ print(f"Device: {DEVICE}")
261
+ print(f"Exemples chargés: {len(EXAMPLE_IMAGES)}")
262
+ print("="*60 + "\n")
263
+
264
+ demo.launch(
265
+ share=True,
266
+ server_name="0.0.0.0", # Accessible depuis le réseau local
267
+ server_port=7860,
268
+ show_error=True
269
+ )
dataset.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import pandas as pd
4
+ from PIL import Image
5
+
6
+
7
+ class VOCDataset(torch.utils.data.Dataset):
8
+ '''
9
+ on reprend les params originel de la paper YOLOV1:
10
+ 7x7 cellules, 2 boites par cellule, 20 classes VOC.
11
+ '''
12
+ def __init__(self, csv_file, img_dir, label_dir, S=7, B=2, C=20, transform=None):
13
+ self.annotations = pd.read_csv(csv_file)
14
+ self.img_dir = img_dir
15
+ self.label_dir = label_dir
16
+ self.transform = transform # fct appliquee a l'img
17
+ self.S = S
18
+ self.B = B
19
+ self.C = C
20
+
21
+ def __len__(self):
22
+ return len(self.annotations) # nb de lignes csv
23
+
24
+ def __getitem__(self, index):
25
+ label_path = os.path.join(self.label_dir, self.annotations.iloc[index, 1])
26
+ boxes = []
27
+ with open(label_path) as f:
28
+ for label in f.readlines():
29
+ class_label, x, y, width, height = [
30
+ float(x) if float(x) != int(float(x)) else int(x)
31
+ for x in label.replace("\n", "").split()
32
+ ]
33
+
34
+ boxes.append([class_label, x, y, width, height])
35
+
36
+ img_path = os.path.join(self.img_dir, self.annotations.iloc[index, 0])
37
+ image = Image.open(img_path)
38
+ boxes = torch.tensor(boxes)
39
+
40
+ if self.transform:
41
+ image, boxes = self.transform(image, boxes)
42
+
43
+ label_matrix = torch.zeros((self.S, self.S, self.C + 5 * self.B))
44
+ for box in boxes:
45
+ class_label, x, y, width, height = box.tolist()
46
+ class_label = int(class_label)
47
+
48
+ i, j = int(self.S * y), int(self.S * x)
49
+ x_cell, y_cell = self.S * x - j, self.S * y - i
50
+
51
+ width_cell, height_cell = (
52
+ width * self.S,
53
+ height * self.S,
54
+ )
55
+
56
+
57
+ if label_matrix[i, j, 20] == 0:
58
+ label_matrix[i, j, 20] = 1
59
+
60
+ box_coordinates = torch.tensor(
61
+ [x_cell, y_cell, width_cell, height_cell]
62
+ )
63
+
64
+ label_matrix[i, j, 21:25] = box_coordinates
65
+
66
+ # one hot encoding
67
+ label_matrix[i, j, class_label] = 1
68
+
69
+ return image, label_matrix
loss.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from utils import intersection_over_union
4
+
5
+
6
+ class Loss_Yolo(nn.Module):
7
+ def __init__(self, S=7, B=2, C=20):
8
+ super(Loss_Yolo, self).__init__()
9
+ self.mse = nn.MSELoss(reduction="sum")
10
+
11
+ self.S = S
12
+ self.B = B
13
+ self.C = C
14
+
15
+ self.lambda_noobj = 0.5
16
+ self.lambda_coord = 5
17
+
18
+ def forward(self, predictions, target):
19
+
20
+ predictions = predictions.reshape(-1, self.S, self.S, self.C + self.B * 5)
21
+
22
+ iou_b1 = intersection_over_union(predictions[..., 21:25], target[..., 21:25])
23
+ iou_b2 = intersection_over_union(predictions[..., 26:30], target[..., 21:25])
24
+ ious = torch.cat([iou_b1.unsqueeze(0), iou_b2.unsqueeze(0)], dim=0)
25
+
26
+
27
+ iou_maxes, bestbox = torch.max(ious, dim=0)
28
+ exists_box = target[..., 20].unsqueeze(3)
29
+
30
+
31
+ box_predictions = exists_box * (
32
+ (
33
+ bestbox * predictions[..., 26:30]
34
+ + (1 - bestbox) * predictions[..., 21:25]
35
+ )
36
+ )
37
+
38
+ box_targets = exists_box * target[..., 21:25]
39
+
40
+ box_predictions[..., 2:4] = torch.sign(box_predictions[..., 2:4]) * torch.sqrt(
41
+ torch.abs(box_predictions[..., 2:4] + 1e-6)
42
+ )
43
+ box_targets[..., 2:4] = torch.sqrt(box_targets[..., 2:4])
44
+
45
+ box_loss = self.mse(
46
+ torch.flatten(box_predictions, end_dim=-2),
47
+ torch.flatten(box_targets, end_dim=-2),
48
+ )
49
+
50
+ pred_box = (
51
+ bestbox * predictions[..., 25:26] + (1 - bestbox) * predictions[..., 20:21]
52
+ )
53
+
54
+ object_loss = self.mse(
55
+ torch.flatten(exists_box * pred_box),
56
+ torch.flatten(exists_box * target[..., 20:21]),
57
+ )
58
+
59
+
60
+ no_object_loss = self.mse(
61
+ torch.flatten((1 - exists_box) * predictions[..., 20:21], start_dim=1),
62
+ torch.flatten((1 - exists_box) * target[..., 20:21], start_dim=1),
63
+ )
64
+
65
+ no_object_loss += self.mse(
66
+ torch.flatten((1 - exists_box) * predictions[..., 25:26], start_dim=1),
67
+ torch.flatten((1 - exists_box) * target[..., 20:21], start_dim=1)
68
+ )
69
+
70
+
71
+ class_loss = self.mse(
72
+ torch.flatten(exists_box * predictions[..., :20], end_dim=-2,),
73
+ torch.flatten(exists_box * target[..., :20], end_dim=-2,),
74
+ )
75
+
76
+ loss = (
77
+ self.lambda_coord * box_loss # les deux premieres lignes dans le papier
78
+ + object_loss
79
+ + self.lambda_noobj * no_object_loss
80
+ + class_loss
81
+ )
82
+
83
+ return loss
model.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class CNN(nn.Module):
5
+ """
6
+ **kwargs tous les autre args, sous forme de dict,
7
+ couche de convolution, bias=False parce que l'on batchNorm (il a son propre biais),
8
+ leaky relue: si x > 0 -> x, sinon -> 0.1 * x
9
+ """
10
+ def __init__(self, in_channels, out_channels, **kwargs):
11
+ super().__init__()
12
+ self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
13
+ self.batchnorm = nn.BatchNorm2d(out_channels)
14
+ self.leakyrelue = nn.LeakyReLU(0.1)
15
+
16
+ def forward(self, x):
17
+ return self.leakyrelue(self.batchnorm(self.conv(x)))
18
+
19
+
20
+ class Yolo_V1(nn.Module):
21
+ def __init__(self, in_channels=3, split_size=7, num_boxes=2, num_classes=20):
22
+ super(Yolo_V1, self).__init__()
23
+
24
+ # Darknet model, mais from scratch
25
+ self.conv1 = CNN(in_channels, 64, kernel_size=7, stride=2, padding=3)
26
+ self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
27
+
28
+ self.conv2 = CNN(64, 192, kernel_size=3, stride=1, padding=1)
29
+ self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
30
+
31
+ self.conv3 = CNN(192, 128, kernel_size=1, stride=1, padding=0)
32
+ self.conv4 = CNN(128, 256, kernel_size=3, stride=1, padding=1)
33
+ self.conv5 = CNN(256, 256, kernel_size=1, stride=1, padding=0)
34
+ self.conv6 = CNN(256, 512, kernel_size=3, stride=1, padding=1)
35
+ self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
36
+
37
+ # Bloc répété 4 fois: (1x1 256) -> (3x3 512)
38
+ self.conv7 = CNN(512, 256, kernel_size=1, stride=1, padding=0)
39
+ self.conv8 = CNN(256, 512, kernel_size=3, stride=1, padding=1)
40
+ self.conv9 = CNN(512, 256, kernel_size=1, stride=1, padding=0)
41
+ self.conv10 = CNN(256, 512, kernel_size=3, stride=1, padding=1)
42
+ self.conv11 = CNN(512, 256, kernel_size=1, stride=1, padding=0)
43
+ self.conv12 = CNN(256, 512, kernel_size=3, stride=1, padding=1)
44
+ self.conv13 = CNN(512, 256, kernel_size=1, stride=1, padding=0)
45
+ self.conv14 = CNN(256, 512, kernel_size=3, stride=1, padding=1)
46
+
47
+ self.conv15 = CNN(512, 512, kernel_size=1, stride=1, padding=0)
48
+ self.conv16 = CNN(512, 1024, kernel_size=3, stride=1, padding=1)
49
+ self.maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2)
50
+
51
+ # Bloc répété 2 fois: (1x1 512) -> (3x3 1024)
52
+ self.conv17 = CNN(1024, 512, kernel_size=1, stride=1, padding=0)
53
+ self.conv18 = CNN(512, 1024, kernel_size=3, stride=1, padding=1)
54
+ self.conv19 = CNN(1024, 512, kernel_size=1, stride=1, padding=0)
55
+ self.conv20 = CNN(512, 1024, kernel_size=3, stride=1, padding=1)
56
+
57
+ self.conv21 = CNN(1024, 1024, kernel_size=3, stride=1, padding=1)
58
+ self.conv22 = CNN(1024, 1024, kernel_size=3, stride=2, padding=1)
59
+ self.conv23 = CNN(1024, 1024, kernel_size=3, stride=1, padding=1)
60
+ self.conv24 = CNN(1024, 1024, kernel_size=3, stride=1, padding=1)
61
+
62
+ # Head du modele
63
+ S, B, C = split_size, num_boxes, num_classes
64
+ self.fc1 = nn.Linear(1024 * S * S, 496)
65
+ self.dropout = nn.Dropout(0.0)
66
+ self.leaky = nn.LeakyReLU(0.1)
67
+ self.fc2 = nn.Linear(496, S * S * (C + B * 5))
68
+
69
+ def forward(self, x):
70
+ x = self.conv1(x)
71
+ x = self.maxpool1(x)
72
+
73
+ x = self.conv2(x)
74
+ x = self.maxpool2(x)
75
+
76
+ x = self.conv3(x)
77
+ x = self.conv4(x)
78
+ x = self.conv5(x)
79
+ x = self.conv6(x)
80
+ x = self.maxpool3(x)
81
+
82
+ x = self.conv7(x)
83
+ x = self.conv8(x)
84
+ x = self.conv9(x)
85
+ x = self.conv10(x)
86
+ x = self.conv11(x)
87
+ x = self.conv12(x)
88
+ x = self.conv13(x)
89
+ x = self.conv14(x)
90
+
91
+ x = self.conv15(x)
92
+ x = self.conv16(x)
93
+ x = self.maxpool4(x)
94
+
95
+ x = self.conv17(x)
96
+ x = self.conv18(x)
97
+ x = self.conv19(x)
98
+ x = self.conv20(x)
99
+
100
+ x = self.conv21(x)
101
+ x = self.conv22(x)
102
+ x = self.conv23(x)
103
+ x = self.conv24(x)
104
+
105
+ x = torch.flatten(x, start_dim=1)
106
+ x = self.fc1(x)
107
+ x = self.dropout(x)
108
+ x = self.leaky(x)
109
+ x = self.fc2(x)
110
+ return x
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ gradio
4
+ numpy
5
+ pandas
6
+ opencv-python
7
+ pillow
8
+ matplotlib
utils.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ import matplotlib.patches as patches
5
+ from collections import Counter
6
+
7
+ def intersection_over_union(boxes_preds, boxes_labels, box_format="midpoint"):
8
+ if box_format == "midpoint":
9
+ box1_x1 = boxes_preds[..., 0:1] - boxes_preds[..., 2:3] / 2
10
+ box1_y1 = boxes_preds[..., 1:2] - boxes_preds[..., 3:4] / 2
11
+ box1_x2 = boxes_preds[..., 0:1] + boxes_preds[..., 2:3] / 2
12
+ box1_y2 = boxes_preds[..., 1:2] + boxes_preds[..., 3:4] / 2
13
+ box2_x1 = boxes_labels[..., 0:1] - boxes_labels[..., 2:3] / 2
14
+ box2_y1 = boxes_labels[..., 1:2] - boxes_labels[..., 3:4] / 2
15
+ box2_x2 = boxes_labels[..., 0:1] + boxes_labels[..., 2:3] / 2
16
+ box2_y2 = boxes_labels[..., 1:2] + boxes_labels[..., 3:4] / 2
17
+
18
+ if box_format == "corners":
19
+ box1_x1 = boxes_preds[..., 0:1]
20
+ box1_y1 = boxes_preds[..., 1:2]
21
+ box1_x2 = boxes_preds[..., 2:3]
22
+ box1_y2 = boxes_preds[..., 3:4] # (N, 1)
23
+ box2_x1 = boxes_labels[..., 0:1]
24
+ box2_y1 = boxes_labels[..., 1:2]
25
+ box2_x2 = boxes_labels[..., 2:3]
26
+ box2_y2 = boxes_labels[..., 3:4]
27
+
28
+ x1 = torch.max(box1_x1, box2_x1)
29
+ y1 = torch.max(box1_y1, box2_y1)
30
+ x2 = torch.min(box1_x2, box2_x2)
31
+ y2 = torch.min(box1_y2, box2_y2)
32
+
33
+ # .clamp(0) is for the case when they do not intersect
34
+ intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)
35
+
36
+ box1_area = abs((box1_x2 - box1_x1) * (box1_y2 - box1_y1))
37
+ box2_area = abs((box2_x2 - box2_x1) * (box2_y2 - box2_y1))
38
+
39
+ return intersection / (box1_area + box2_area - intersection + 1e-6)
40
+
41
+
42
+ def non_max_suppression(bboxes, iou_threshold, threshold, box_format="corners"):
43
+ """
44
+ Does Non Max Suppression given bboxes
45
+
46
+ Parameters:
47
+ bboxes (list): list of lists containing all bboxes with each bboxes
48
+ specified as [class_pred, prob_score, x1, y1, x2, y2]
49
+ iou_threshold (float): threshold where predicted bboxes is correct
50
+ threshold (float): threshold to remove predicted bboxes (independent of IoU)
51
+ box_format (str): "midpoint" or "corners" used to specify bboxes
52
+
53
+ Returns:
54
+ list: bboxes after performing NMS given a specific IoU threshold
55
+ """
56
+
57
+ assert type(bboxes) == list
58
+
59
+ bboxes = [box for box in bboxes if box[1] > threshold]
60
+ bboxes = sorted(bboxes, key=lambda x: x[1], reverse=True)
61
+ bboxes_after_nms = []
62
+
63
+ while bboxes:
64
+ chosen_box = bboxes.pop(0)
65
+
66
+ bboxes = [
67
+ box
68
+ for box in bboxes
69
+ if box[0] != chosen_box[0]
70
+ or intersection_over_union(
71
+ torch.tensor(chosen_box[2:]),
72
+ torch.tensor(box[2:]),
73
+ box_format=box_format,
74
+ )
75
+ < iou_threshold
76
+ ]
77
+
78
+ bboxes_after_nms.append(chosen_box)
79
+
80
+ return bboxes_after_nms
81
+
82
+
83
+ def mean_average_precision(
84
+ pred_boxes, true_boxes, iou_threshold=0.5, box_format="midpoint", num_classes=20
85
+ ):
86
+ """
87
+ Calculates mean average precision
88
+
89
+ Parameters:
90
+ pred_boxes (list): list of lists containing all bboxes with each bboxes
91
+ specified as [train_idx, class_prediction, prob_score, x1, y1, x2, y2]
92
+ true_boxes (list): Similar as pred_boxes except all the correct ones
93
+ iou_threshold (float): threshold where predicted bboxes is correct
94
+ box_format (str): "midpoint" or "corners" used to specify bboxes
95
+ num_classes (int): number of classes
96
+
97
+ Returns:
98
+ float: mAP value across all classes given a specific IoU threshold
99
+ """
100
+
101
+ # list storing all AP for respective classes
102
+ average_precisions = []
103
+
104
+ # used for numerical stability later on
105
+ epsilon = 1e-6
106
+
107
+ for c in range(num_classes):
108
+ detections = []
109
+ ground_truths = []
110
+
111
+ # Go through all predictions and targets,
112
+ # and only add the ones that belong to the
113
+ # current class c
114
+ for detection in pred_boxes:
115
+ if detection[1] == c:
116
+ detections.append(detection)
117
+
118
+ for true_box in true_boxes:
119
+ if true_box[1] == c:
120
+ ground_truths.append(true_box)
121
+
122
+ # find the amount of bboxes for each training example
123
+ # Counter here finds how many ground truth bboxes we get
124
+ # for each training example, so let's say img 0 has 3,
125
+ # img 1 has 5 then we will obtain a dictionary with:
126
+ # amount_bboxes = {0:3, 1:5}
127
+ amount_bboxes = Counter([gt[0] for gt in ground_truths])
128
+
129
+ # We then go through each key, val in this dictionary
130
+ # and convert to the following (w.r.t same example):
131
+ # ammount_bboxes = {0:torch.tensor[0,0,0], 1:torch.tensor[0,0,0,0,0]}
132
+ for key, val in amount_bboxes.items():
133
+ amount_bboxes[key] = torch.zeros(val)
134
+
135
+ # sort by box probabilities which is index 2
136
+ detections.sort(key=lambda x: x[2], reverse=True)
137
+ TP = torch.zeros((len(detections)))
138
+ FP = torch.zeros((len(detections)))
139
+ total_true_bboxes = len(ground_truths)
140
+
141
+ # If none exists for this class then we can safely skip
142
+ if total_true_bboxes == 0:
143
+ continue
144
+
145
+ for detection_idx, detection in enumerate(detections):
146
+ # Only take out the ground_truths that have the same
147
+ # training idx as detection
148
+ ground_truth_img = [
149
+ bbox for bbox in ground_truths if bbox[0] == detection[0]
150
+ ]
151
+
152
+ num_gts = len(ground_truth_img)
153
+ best_iou = 0
154
+
155
+ for idx, gt in enumerate(ground_truth_img):
156
+ iou = intersection_over_union(
157
+ torch.tensor(detection[3:]),
158
+ torch.tensor(gt[3:]),
159
+ box_format=box_format,
160
+ )
161
+
162
+ if iou > best_iou:
163
+ best_iou = iou
164
+ best_gt_idx = idx
165
+
166
+ if best_iou > iou_threshold:
167
+ # only detect ground truth detection once
168
+ if amount_bboxes[detection[0]][best_gt_idx] == 0:
169
+ # true positive and add this bounding box to seen
170
+ TP[detection_idx] = 1
171
+ amount_bboxes[detection[0]][best_gt_idx] = 1
172
+ else:
173
+ FP[detection_idx] = 1
174
+
175
+ # if IOU is lower then the detection is a false positive
176
+ else:
177
+ FP[detection_idx] = 1
178
+
179
+ TP_cumsum = torch.cumsum(TP, dim=0)
180
+ FP_cumsum = torch.cumsum(FP, dim=0)
181
+ recalls = TP_cumsum / (total_true_bboxes + epsilon)
182
+ precisions = torch.divide(TP_cumsum, (TP_cumsum + FP_cumsum + epsilon))
183
+ precisions = torch.cat((torch.tensor([1]), precisions))
184
+ recalls = torch.cat((torch.tensor([0]), recalls))
185
+ # torch.trapz for numerical integration
186
+ average_precisions.append(torch.trapz(precisions, recalls))
187
+
188
+ return sum(average_precisions) / len(average_precisions)
189
+
190
+
191
+ def plot_image(image, boxes):
192
+ """Plots predicted bounding boxes on the image"""
193
+ im = np.array(image)
194
+ height, width, _ = im.shape
195
+
196
+ # Create figure and axes
197
+ fig, ax = plt.subplots(1)
198
+ # Display the image
199
+ ax.imshow(im)
200
+
201
+ # box[0] is x midpoint, box[2] is width
202
+ # box[1] is y midpoint, box[3] is height
203
+
204
+ # Create a Rectangle potch
205
+ for box in boxes:
206
+ box = box[2:]
207
+ assert len(box) == 4, "Got more values than in x, y, w, h, in a box!"
208
+ upper_left_x = box[0] - box[2] / 2
209
+ upper_left_y = box[1] - box[3] / 2
210
+ rect = patches.Rectangle(
211
+ (upper_left_x * width, upper_left_y * height),
212
+ box[2] * width,
213
+ box[3] * height,
214
+ linewidth=1,
215
+ edgecolor="r",
216
+ facecolor="none",
217
+ )
218
+ # Add the patch to the Axes
219
+ ax.add_patch(rect)
220
+
221
+ plt.show()
222
+
223
+ def get_bboxes(
224
+ loader,
225
+ model,
226
+ iou_threshold,
227
+ threshold,
228
+ pred_format="cells",
229
+ box_format="midpoint",
230
+ device="cuda",
231
+ ):
232
+ all_pred_boxes = []
233
+ all_true_boxes = []
234
+
235
+ # make sure model is in eval before get bboxes
236
+ model.eval()
237
+ train_idx = 0
238
+
239
+ for batch_idx, (x, labels) in enumerate(loader):
240
+ x = x.to(device)
241
+ labels = labels.to(device)
242
+
243
+ with torch.no_grad():
244
+ predictions = model(x)
245
+
246
+ batch_size = x.shape[0]
247
+ true_bboxes = cellboxes_to_boxes(labels)
248
+ bboxes = cellboxes_to_boxes(predictions)
249
+
250
+ for idx in range(batch_size):
251
+ nms_boxes = non_max_suppression(
252
+ bboxes[idx],
253
+ iou_threshold=iou_threshold,
254
+ threshold=threshold,
255
+ box_format=box_format,
256
+ )
257
+
258
+
259
+ #if batch_idx == 0 and idx == 0:
260
+ # plot_image(x[idx].permute(1,2,0).to("cpu"), nms_boxes)
261
+ # print(nms_boxes)
262
+
263
+ for nms_box in nms_boxes:
264
+ all_pred_boxes.append([train_idx] + nms_box)
265
+
266
+ for box in true_bboxes[idx]:
267
+ # many will get converted to 0 pred
268
+ if box[1] > threshold:
269
+ all_true_boxes.append([train_idx] + box)
270
+
271
+ train_idx += 1
272
+
273
+ model.train()
274
+ return all_pred_boxes, all_true_boxes
275
+
276
+
277
+
278
+ def convert_cellboxes(predictions, S=7):
279
+ """
280
+ Converts bounding boxes output from Yolo with
281
+ an image split size of S into entire image ratios
282
+ rather than relative to cell ratios. Tried to do this
283
+ vectorized, but this resulted in quite difficult to read
284
+ code... Use as a black box? Or implement a more intuitive,
285
+ using 2 for loops iterating range(S) and convert them one
286
+ by one, resulting in a slower but more readable implementation.
287
+ """
288
+
289
+ predictions = predictions.to("cpu")
290
+ batch_size = predictions.shape[0]
291
+ predictions = predictions.reshape(batch_size, 7, 7, 30)
292
+ bboxes1 = predictions[..., 21:25]
293
+ bboxes2 = predictions[..., 26:30]
294
+ scores = torch.cat(
295
+ (predictions[..., 20].unsqueeze(0), predictions[..., 25].unsqueeze(0)), dim=0
296
+ )
297
+ best_box = scores.argmax(0).unsqueeze(-1)
298
+ best_boxes = bboxes1 * (1 - best_box) + best_box * bboxes2
299
+ cell_indices = torch.arange(7).repeat(batch_size, 7, 1).unsqueeze(-1)
300
+ x = 1 / S * (best_boxes[..., :1] + cell_indices)
301
+ y = 1 / S * (best_boxes[..., 1:2] + cell_indices.permute(0, 2, 1, 3))
302
+ w_y = 1 / S * best_boxes[..., 2:4]
303
+ converted_bboxes = torch.cat((x, y, w_y), dim=-1)
304
+ predicted_class = predictions[..., :20].argmax(-1).unsqueeze(-1)
305
+ best_confidence = torch.max(predictions[..., 20], predictions[..., 25]).unsqueeze(
306
+ -1
307
+ )
308
+ converted_preds = torch.cat(
309
+ (predicted_class, best_confidence, converted_bboxes), dim=-1
310
+ )
311
+
312
+ return converted_preds
313
+
314
+
315
+ def cellboxes_to_boxes(out, S=7):
316
+ converted_pred = convert_cellboxes(out).reshape(out.shape[0], S * S, -1)
317
+ converted_pred[..., 0] = converted_pred[..., 0].long()
318
+ all_bboxes = []
319
+
320
+ for ex_idx in range(out.shape[0]):
321
+ bboxes = []
322
+
323
+ for bbox_idx in range(S * S):
324
+ bboxes.append([x.item() for x in converted_pred[ex_idx, bbox_idx, :]])
325
+ all_bboxes.append(bboxes)
326
+
327
+ return all_bboxes
328
+
329
+ def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
330
+ print("=> Saving checkpoint")
331
+ torch.save(state, filename)
332
+
333
+
334
+ def load_checkpoint(checkpoint, model, optimizer):
335
+ print("=> Loading checkpoint")
336
+ model.load_state_dict(checkpoint["state_dict"])
337
+ optimizer.load_state_dict(checkpoint["optimizer"])