Spaces:
Running
Running
| """ | |
| Utility functions for YOLOv3 (simplifié pour Gradio) | |
| """ | |
| import torch | |
| def intersection_over_union(boxes_preds, boxes_labels, box_format="midpoint"): | |
| """ | |
| Calcule l'intersection over union (IoU) entre deux bounding boxes | |
| Args: | |
| boxes_preds: Prédictions [x, y, w, h] ou [x1, y1, x2, y2] | |
| boxes_labels: Labels [x, y, w, h] ou [x1, y1, x2, y2] | |
| box_format: "midpoint" ou "corners" | |
| Returns: | |
| IoU score | |
| """ | |
| if box_format == "midpoint": | |
| box1_x1 = boxes_preds[..., 0:1] - boxes_preds[..., 2:3] / 2 | |
| box1_y1 = boxes_preds[..., 1:2] - boxes_preds[..., 3:4] / 2 | |
| box1_x2 = boxes_preds[..., 0:1] + boxes_preds[..., 2:3] / 2 | |
| box1_y2 = boxes_preds[..., 1:2] + boxes_preds[..., 3:4] / 2 | |
| box2_x1 = boxes_labels[..., 0:1] - boxes_labels[..., 2:3] / 2 | |
| box2_y1 = boxes_labels[..., 1:2] - boxes_labels[..., 3:4] / 2 | |
| box2_x2 = boxes_labels[..., 0:1] + boxes_labels[..., 2:3] / 2 | |
| box2_y2 = boxes_labels[..., 1:2] + boxes_labels[..., 3:4] / 2 | |
| else: # corners | |
| box1_x1 = boxes_preds[..., 0:1] | |
| box1_y1 = boxes_preds[..., 1:2] | |
| box1_x2 = boxes_preds[..., 2:3] | |
| box1_y2 = boxes_preds[..., 3:4] | |
| box2_x1 = boxes_labels[..., 0:1] | |
| box2_y1 = boxes_labels[..., 1:2] | |
| box2_x2 = boxes_labels[..., 2:3] | |
| box2_y2 = boxes_labels[..., 3:4] | |
| x1 = torch.max(box1_x1, box2_x1) | |
| y1 = torch.max(box1_y1, box2_y1) | |
| x2 = torch.min(box1_x2, box2_x2) | |
| y2 = torch.min(box1_y2, box2_y2) | |
| intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0) | |
| box1_area = abs((box1_x2 - box1_x1) * (box1_y2 - box1_y1)) | |
| box2_area = abs((box2_x2 - box2_x1) * (box2_y2 - box2_y1)) | |
| return intersection / (box1_area + box2_area - intersection + 1e-6) | |
| def non_max_suppression(bboxes, iou_threshold, threshold, box_format="corners"): | |
| """ | |
| Applique le Non-Maximum Suppression (NMS) | |
| Args: | |
| bboxes: Liste de bboxes [class_pred, prob_score, x, y, w, h] | |
| iou_threshold: Seuil IoU pour supprimer les boxes | |
| threshold: Seuil de confiance minimum | |
| box_format: "midpoint" ou "corners" | |
| Returns: | |
| Liste de bboxes après NMS | |
| """ | |
| assert type(bboxes) == list | |
| # Filtrer par confiance | |
| bboxes = [box for box in bboxes if box[1] > threshold] | |
| bboxes = sorted(bboxes, key=lambda x: x[1], reverse=True) | |
| bboxes_after_nms = [] | |
| while bboxes: | |
| chosen_box = bboxes.pop(0) | |
| bboxes = [ | |
| box | |
| for box in bboxes | |
| if box[0] != chosen_box[0] # Différente classe | |
| or intersection_over_union( | |
| torch.tensor(chosen_box[2:]), | |
| torch.tensor(box[2:]), | |
| box_format=box_format, | |
| ) | |
| < iou_threshold # IoU faible | |
| ] | |
| bboxes_after_nms.append(chosen_box) | |
| return bboxes_after_nms | |
| def cells_to_bboxes(predictions, anchors, S, is_preds=True): | |
| """ | |
| Convertit les prédictions YOLOv3 en bounding boxes | |
| Args: | |
| predictions: Tensor [N, 3, S, S, num_classes+5] | |
| anchors: Anchors pour cette échelle | |
| S: Taille de la grille (13, 26, ou 52) | |
| is_preds: Si True, applique sigmoid/exp | |
| Returns: | |
| Liste de bboxes converties | |
| """ | |
| BATCH_SIZE = predictions.shape[0] | |
| num_anchors = len(anchors) | |
| box_predictions = predictions[..., 1:5] | |
| if is_preds: | |
| anchors = anchors.reshape(1, len(anchors), 1, 1, 2) | |
| box_predictions[..., 0:2] = torch.sigmoid(box_predictions[..., 0:2]) | |
| box_predictions[..., 2:] = torch.exp(box_predictions[..., 2:]) * anchors | |
| scores = torch.sigmoid(predictions[..., 0:1]) | |
| best_class = torch.argmax(predictions[..., 5:], dim=-1).unsqueeze(-1) | |
| else: | |
| scores = predictions[..., 0:1] | |
| best_class = predictions[..., 5:6] | |
| # Indices de cellules | |
| cell_indices = ( | |
| torch.arange(S) | |
| .repeat(predictions.shape[0], 3, S, 1) | |
| .unsqueeze(-1) | |
| .to(predictions.device) | |
| ) | |
| # Convertir en coordonnées absolues [0, 1] | |
| x = 1 / S * (box_predictions[..., 0:1] + cell_indices) | |
| y = 1 / S * (box_predictions[..., 1:2] + cell_indices.permute(0, 1, 3, 2, 4)) | |
| w_h = 1 / S * box_predictions[..., 2:4] | |
| converted_bboxes = torch.cat((best_class, scores, x, y, w_h), dim=-1).reshape( | |
| BATCH_SIZE, num_anchors * S * S, 6 | |
| ) | |
| return converted_bboxes.tolist() | |