Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| from huggingface_hub import hf_hub_download | |
| import os | |
| # Configuration | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| IMAGE_SIZE = 416 | |
| NUM_CLASSES = 20 | |
| # Anchors YOLOv3 (normalisés) | |
| ANCHORS = [ | |
| [(0.28, 0.22), (0.38, 0.48), (0.9, 0.78)], | |
| [(0.07, 0.15), (0.15, 0.11), (0.14, 0.29)], | |
| [(0.02, 0.03), (0.04, 0.07), (0.08, 0.06)], | |
| ] | |
| S = [IMAGE_SIZE // 32, IMAGE_SIZE // 16, IMAGE_SIZE // 8] | |
| # Classes Pascal VOC | |
| PASCAL_CLASSES = [ | |
| "aeroplane", "bicycle", "bird", "boat", "bottle", | |
| "bus", "car", "cat", "chair", "cow", | |
| "diningtable", "dog", "horse", "motorbike", "person", | |
| "pottedplant", "sheep", "sofa", "train", "tvmonitor" | |
| ] | |
| # Import du modèle | |
| from model import YOLOv3 | |
| from utils import cells_to_bboxes, non_max_suppression | |
| class YOLOv3Detector: | |
| def __init__(self, checkpoint_path): | |
| """Initialise le détecteur YOLOv3""" | |
| # Charger le modèle | |
| self.model = YOLOv3(num_classes=NUM_CLASSES).to(DEVICE) | |
| checkpoint = torch.load(checkpoint_path, map_location=DEVICE) | |
| self.model.load_state_dict(checkpoint["state_dict"]) | |
| self.model.eval() | |
| # Anchors mis à l'échelle | |
| self.scaled_anchors = ( | |
| torch.tensor(ANCHORS) | |
| * torch.tensor(S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2) | |
| ).to(DEVICE) | |
| # Couleurs pour chaque classe | |
| np.random.seed(42) | |
| self.colors = np.random.randint(0, 255, size=(len(PASCAL_CLASSES), 3), dtype=np.uint8) | |
| def preprocess_image(self, image): | |
| """Prétraite l'image pour le modèle""" | |
| if isinstance(image, Image.Image): | |
| image = np.array(image) | |
| original_shape = image.shape[:2] | |
| image_resized = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE)) | |
| # Normaliser et convertir en tensor | |
| image_tensor = torch.from_numpy(image_resized).float() / 255.0 | |
| image_tensor = image_tensor.permute(2, 0, 1).unsqueeze(0) | |
| return image_tensor.to(DEVICE), original_shape | |
| def detect(self, image, conf_threshold=0.5, iou_threshold=0.45): | |
| """Détecte les objets dans l'image""" | |
| image_tensor, original_shape = self.preprocess_image(image) | |
| with torch.no_grad(): | |
| predictions = self.model(image_tensor) | |
| # Convertir les prédictions en bboxes | |
| bboxes = [[] for _ in range(1)] | |
| for i in range(3): | |
| S = predictions[i].shape[2] | |
| anchor = self.scaled_anchors[i] | |
| boxes_scale_i = cells_to_bboxes( | |
| predictions[i], anchor, S=S, is_preds=True | |
| ) | |
| for idx, box in enumerate(boxes_scale_i): | |
| bboxes[idx] += box | |
| # Appliquer NMS | |
| nms_boxes = non_max_suppression( | |
| bboxes[0], | |
| iou_threshold=iou_threshold, | |
| threshold=conf_threshold, | |
| box_format="midpoint", | |
| ) | |
| return nms_boxes | |
| def draw_boxes(self, image, boxes): | |
| """Dessine les bounding boxes sur l'image""" | |
| if isinstance(image, Image.Image): | |
| image = np.array(image) | |
| image = image.copy() | |
| height, width = image.shape[:2] | |
| detections_info = [] | |
| for box in boxes: | |
| class_idx = int(box[0]) | |
| confidence = box[1] | |
| x_center, y_center, box_width, box_height = box[2:] | |
| # Convertir en coordonnées pixel | |
| x1 = int((x_center - box_width / 2) * width) | |
| y1 = int((y_center - box_height / 2) * height) | |
| x2 = int((x_center + box_width / 2) * width) | |
| y2 = int((y_center + box_height / 2) * height) | |
| # Couleur pour cette classe | |
| color = self.colors[class_idx].tolist() | |
| # Dessiner le rectangle | |
| cv2.rectangle(image, (x1, y1), (x2, y2), color, 2) | |
| # Texte | |
| label = f"{PASCAL_CLASSES[class_idx]}: {confidence:.2f}" | |
| # Fond du texte | |
| (text_width, text_height), _ = cv2.getTextSize( | |
| label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1 | |
| ) | |
| cv2.rectangle( | |
| image, | |
| (x1, y1 - text_height - 4), | |
| (x1 + text_width, y1), | |
| color, | |
| -1 | |
| ) | |
| # Texte blanc | |
| cv2.putText( | |
| image, | |
| label, | |
| (x1, y1 - 4), | |
| cv2.FONT_HERSHEY_SIMPLEX, | |
| 0.5, | |
| (255, 255, 255), | |
| 1 | |
| ) | |
| detections_info.append(f"• {PASCAL_CLASSES[class_idx]}: {confidence:.1%}") | |
| return image, detections_info | |
| # Télécharger le modèle depuis Hugging Face | |
| checkpoint_path = hf_hub_download( | |
| repo_id="nathbns/yolov3_from_scratch", | |
| filename="checkpoint.pth.tar" | |
| ) | |
| # Initialiser le détecteur | |
| detector = YOLOv3Detector(checkpoint_path) | |
| def predict(image, conf_threshold, iou_threshold): | |
| """Fonction de prédiction pour Gradio""" | |
| if image is None: | |
| return None, "Aucune image fournie" | |
| # Détecter | |
| boxes = detector.detect(image, conf_threshold, iou_threshold) | |
| # Dessiner | |
| result_image, detections = detector.draw_boxes(image, boxes) | |
| # Texte des détections | |
| if detections: | |
| detection_text = f"**{len(detections)} objet(s) détecté(s) :**\n\n" + "\n".join(detections) | |
| else: | |
| detection_text = "Aucun objet détecté" | |
| return result_image, detection_text | |
| # Interface Gradio | |
| with gr.Blocks(title="YOLOv3 Object Detection", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| # YOLOv3 Object Detection - Pascal VOC | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(type="pil", label="Image d'entrée") | |
| with gr.Accordion("Paramètres", open=True): | |
| conf_slider = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.5, | |
| step=0.05, | |
| label="Seuil de confiance", | |
| info="Plus élevé = moins de détections mais plus sûres" | |
| ) | |
| iou_slider = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.45, | |
| step=0.05, | |
| label="Seuil NMS (IoU)", | |
| info="Plus élevé = plus de boîtes qui se chevauchent" | |
| ) | |
| detect_btn = gr.Button("Détecter les objets", variant="primary", size="lg") | |
| with gr.Column(): | |
| output_image = gr.Image(label="Résultat") | |
| output_text = gr.Markdown(label="Détections") | |
| # Action | |
| detect_btn.click( | |
| fn=predict, | |
| inputs=[input_image, conf_slider, iou_slider], | |
| outputs=[output_image, output_text] | |
| ) | |
| # Auto-run sur upload | |
| input_image.change( | |
| fn=predict, | |
| inputs=[input_image, conf_slider, iou_slider], | |
| outputs=[output_image, output_text] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |