yolo3_from_scratch / utils.py
nathbns's picture
Upload 4 files
b343099 verified
"""
Utility functions for YOLOv3 (simplifié pour Gradio)
"""
import torch
def intersection_over_union(boxes_preds, boxes_labels, box_format="midpoint"):
"""
Calcule l'intersection over union (IoU) entre deux bounding boxes
Args:
boxes_preds: Prédictions [x, y, w, h] ou [x1, y1, x2, y2]
boxes_labels: Labels [x, y, w, h] ou [x1, y1, x2, y2]
box_format: "midpoint" ou "corners"
Returns:
IoU score
"""
if box_format == "midpoint":
box1_x1 = boxes_preds[..., 0:1] - boxes_preds[..., 2:3] / 2
box1_y1 = boxes_preds[..., 1:2] - boxes_preds[..., 3:4] / 2
box1_x2 = boxes_preds[..., 0:1] + boxes_preds[..., 2:3] / 2
box1_y2 = boxes_preds[..., 1:2] + boxes_preds[..., 3:4] / 2
box2_x1 = boxes_labels[..., 0:1] - boxes_labels[..., 2:3] / 2
box2_y1 = boxes_labels[..., 1:2] - boxes_labels[..., 3:4] / 2
box2_x2 = boxes_labels[..., 0:1] + boxes_labels[..., 2:3] / 2
box2_y2 = boxes_labels[..., 1:2] + boxes_labels[..., 3:4] / 2
else: # corners
box1_x1 = boxes_preds[..., 0:1]
box1_y1 = boxes_preds[..., 1:2]
box1_x2 = boxes_preds[..., 2:3]
box1_y2 = boxes_preds[..., 3:4]
box2_x1 = boxes_labels[..., 0:1]
box2_y1 = boxes_labels[..., 1:2]
box2_x2 = boxes_labels[..., 2:3]
box2_y2 = boxes_labels[..., 3:4]
x1 = torch.max(box1_x1, box2_x1)
y1 = torch.max(box1_y1, box2_y1)
x2 = torch.min(box1_x2, box2_x2)
y2 = torch.min(box1_y2, box2_y2)
intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)
box1_area = abs((box1_x2 - box1_x1) * (box1_y2 - box1_y1))
box2_area = abs((box2_x2 - box2_x1) * (box2_y2 - box2_y1))
return intersection / (box1_area + box2_area - intersection + 1e-6)
def non_max_suppression(bboxes, iou_threshold, threshold, box_format="corners"):
"""
Applique le Non-Maximum Suppression (NMS)
Args:
bboxes: Liste de bboxes [class_pred, prob_score, x, y, w, h]
iou_threshold: Seuil IoU pour supprimer les boxes
threshold: Seuil de confiance minimum
box_format: "midpoint" ou "corners"
Returns:
Liste de bboxes après NMS
"""
assert type(bboxes) == list
# Filtrer par confiance
bboxes = [box for box in bboxes if box[1] > threshold]
bboxes = sorted(bboxes, key=lambda x: x[1], reverse=True)
bboxes_after_nms = []
while bboxes:
chosen_box = bboxes.pop(0)
bboxes = [
box
for box in bboxes
if box[0] != chosen_box[0] # Différente classe
or intersection_over_union(
torch.tensor(chosen_box[2:]),
torch.tensor(box[2:]),
box_format=box_format,
)
< iou_threshold # IoU faible
]
bboxes_after_nms.append(chosen_box)
return bboxes_after_nms
def cells_to_bboxes(predictions, anchors, S, is_preds=True):
"""
Convertit les prédictions YOLOv3 en bounding boxes
Args:
predictions: Tensor [N, 3, S, S, num_classes+5]
anchors: Anchors pour cette échelle
S: Taille de la grille (13, 26, ou 52)
is_preds: Si True, applique sigmoid/exp
Returns:
Liste de bboxes converties
"""
BATCH_SIZE = predictions.shape[0]
num_anchors = len(anchors)
box_predictions = predictions[..., 1:5]
if is_preds:
anchors = anchors.reshape(1, len(anchors), 1, 1, 2)
box_predictions[..., 0:2] = torch.sigmoid(box_predictions[..., 0:2])
box_predictions[..., 2:] = torch.exp(box_predictions[..., 2:]) * anchors
scores = torch.sigmoid(predictions[..., 0:1])
best_class = torch.argmax(predictions[..., 5:], dim=-1).unsqueeze(-1)
else:
scores = predictions[..., 0:1]
best_class = predictions[..., 5:6]
# Indices de cellules
cell_indices = (
torch.arange(S)
.repeat(predictions.shape[0], 3, S, 1)
.unsqueeze(-1)
.to(predictions.device)
)
# Convertir en coordonnées absolues [0, 1]
x = 1 / S * (box_predictions[..., 0:1] + cell_indices)
y = 1 / S * (box_predictions[..., 1:2] + cell_indices.permute(0, 1, 3, 2, 4))
w_h = 1 / S * box_predictions[..., 2:4]
converted_bboxes = torch.cat((best_class, scores, x, y, w_h), dim=-1).reshape(
BATCH_SIZE, num_anchors * S * S, 6
)
return converted_bboxes.tolist()