Spaces:
Running
Running
| import torch | |
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| import torchvision.transforms as transforms | |
| from model import Yolo_V1 | |
| from utils import cellboxes_to_boxes, non_max_suppression | |
| import cv2 | |
| import os | |
| import glob | |
| import time | |
| from huggingface_hub import hf_hub_download | |
| # Classes PASCAL VOC | |
| CLASSES = [ | |
| "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", | |
| "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", | |
| "pottedplant", "sheep", "sofa", "train", "tvmonitor" | |
| ] | |
| np.random.seed(42) | |
| COLORS = np.random.randint(50, 255, size=(len(CLASSES), 3), dtype=np.uint8) | |
| DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" | |
| MODEL_REPO_ID = "nathbns/yolov1_from_scratch" | |
| MODEL_FILENAME = "checkpoint_epoch_50.pth.tar" | |
| # Charger le modèle depuis Hugging Face Hub | |
| print(f"Chargement du modèle depuis {MODEL_REPO_ID}...") | |
| try: | |
| model_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=MODEL_FILENAME) | |
| print(f"Modèle téléchargé depuis Hugging Face Hub: {model_path}") | |
| except Exception as e: | |
| print(f"Erreur lors du téléchargement: {e}") | |
| print("Tentative de chargement local...") | |
| model_path = MODEL_FILENAME | |
| model = Yolo_V1(split_size=7, num_boxes=2, num_classes=20).to(DEVICE) | |
| checkpoint = torch.load(model_path, map_location=DEVICE) | |
| model.load_state_dict(checkpoint["state_dict"]) | |
| model.eval() | |
| print(f"Modèle chargé avec succès!") | |
| # Info sur le modèle | |
| MODEL_INFO = { | |
| "mAP": checkpoint.get("mAP", "N/A"), | |
| "epoch": checkpoint.get("epoch", "N/A"), | |
| "device": DEVICE, | |
| "classes": len(CLASSES) | |
| } | |
| print(f"entraînement: {MODEL_INFO['mAP']}") | |
| print(f"Device: {DEVICE}") | |
| # Charger des images d'exemple depuis le dossier data | |
| EXAMPLE_IMAGES = [] | |
| if os.path.exists("data/images"): | |
| image_files = glob.glob("data/images/*.jpg")[:20] # Prendre 20 images | |
| EXAMPLE_IMAGES = sorted(image_files) | |
| print(f"{len(EXAMPLE_IMAGES)} images d'exemple chargées") | |
| def draw_boxes(image, boxes): | |
| """Dessine les bounding boxes sur l'image""" | |
| img_array = np.array(image) | |
| height, width = img_array.shape[:2] | |
| for box in boxes: | |
| # box format: [class_pred, prob_score, x, y, width, height] | |
| class_pred = int(box[0]) | |
| confidence = float(box[1]) | |
| x_center, y_center, box_width, box_height = box[2:6] | |
| # Convertir de coordonnées normalisées à pixels | |
| x1 = int((x_center - box_width / 2) * width) | |
| y1 = int((y_center - box_height / 2) * height) | |
| x2 = int((x_center + box_width / 2) * width) | |
| y2 = int((y_center + box_height / 2) * height) | |
| # Couleur de la classe | |
| color = tuple(int(c) for c in COLORS[class_pred]) | |
| # Dessiner le rectangle | |
| cv2.rectangle(img_array, (x1, y1), (x2, y2), color, 2) | |
| # Texte | |
| label = f"{CLASSES[class_pred]}: {confidence:.2f}" | |
| # Fond du texte | |
| (text_width, text_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1) | |
| cv2.rectangle(img_array, (x1, y1 - text_height - 5), (x1 + text_width, y1), color, -1) | |
| # Texte blanc | |
| cv2.putText(img_array, label, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1) | |
| return Image.fromarray(img_array) | |
| def detect_objects(image, confidence_threshold, iou_threshold, show_confidence=True): | |
| """Détecte les objets dans une image avec statistiques détaillées""" | |
| if image is None: | |
| return None, None, "**Veuillez uploader ou sélectionner une image**" | |
| start_time = time.time() | |
| # Prétraiter l'image | |
| transform = transforms.Compose([ | |
| transforms.Resize((448, 448)), | |
| transforms.ToTensor(), | |
| ]) | |
| # Garder l'image originale pour l'affichage | |
| original_image = image.copy() | |
| original_size = image.size # (width, height) | |
| # Transformer l'image | |
| img_tensor = transform(image).unsqueeze(0).to(DEVICE) | |
| # Prédiction | |
| with torch.no_grad(): | |
| predictions = model(img_tensor) | |
| # Convertir les prédictions en bounding boxes | |
| bboxes = cellboxes_to_boxes(predictions) | |
| # Non-maximum suppression | |
| nms_boxes = non_max_suppression( | |
| bboxes[0], | |
| iou_threshold=iou_threshold, | |
| threshold=confidence_threshold, | |
| box_format="midpoint" | |
| ) | |
| inference_time = time.time() - start_time | |
| # Dessiner les boxes | |
| result_image = draw_boxes(original_image, nms_boxes) | |
| # Statistiques détaillées | |
| num_detections = len(nms_boxes) | |
| detected_classes = [CLASSES[int(box[0])] for box in nms_boxes] | |
| class_counts = {} | |
| confidence_scores = [] | |
| for box in nms_boxes: | |
| cls = CLASSES[int(box[0])] | |
| conf = float(box[1]) | |
| class_counts[cls] = class_counts.get(cls, 0) + 1 | |
| confidence_scores.append(conf) | |
| # Créer un rapport détaillé | |
| stats = f"##Résultats de détection\n\n" | |
| stats += f"**{num_detections} objet(s) détecté(s)**\n\n" | |
| if num_detections > 0: | |
| stats += f"Temps d'inférence: **{inference_time:.3f}s**\n" | |
| stats += f"Taille image: **{original_size[0]}x{original_size[1]}**\n" | |
| stats += f"Confiance moyenne: **{np.mean(confidence_scores):.2%}**\n\n" | |
| stats += "### Objets détectés:\n\n" | |
| for cls, count in sorted(class_counts.items(), key=lambda x: x[1], reverse=True): | |
| stats += f"- **{cls}**: {count}\n" | |
| if show_confidence: | |
| stats += "\n### Confiances individuelles:\n\n" | |
| for i, box in enumerate(nms_boxes[:10], 1): # Top 10 | |
| cls = CLASSES[int(box[0])] | |
| conf = float(box[1]) | |
| stats += f"{i}. {cls}: {conf:.1%}\n" | |
| if len(nms_boxes) > 10: | |
| stats += f"\n*...et {len(nms_boxes)-10} détection(s) de plus*\n" | |
| else: | |
| stats += "Aucun objet détecté.\n\n" | |
| return original_image, result_image, stats | |
| # Interface Gradio améliorée | |
| with gr.Blocks(title="YOLO v1 - Détection d'objets", theme=gr.themes.Soft(), css=""" | |
| .gradio-container {max-width: 1400px !important} | |
| .example-gallery {height: 400px; overflow-y: auto} | |
| """) as demo: | |
| # En-tête | |
| mAP_display = f"{MODEL_INFO['mAP']:.4f}" if isinstance(MODEL_INFO['mAP'], (int, float)) else MODEL_INFO['mAP'] | |
| gr.Markdown(f""" | |
| # YOLO v1 - Détection d'objets en temps réel | |
| --- | |
| """) | |
| with gr.Tabs(): | |
| # Onglet principal - Détection | |
| with gr.Tab("Détection"): | |
| gr.Markdown(""" | |
| ### Uploadez votre image ou sélectionnez un exemple | |
| **Classes PASCAL VOC :** aeroplane, bicycle, bird, boat, bottle, bus, car, cat, chair, cow, | |
| diningtable, dog, horse, motorbike, person, pottedplant, sheep, sofa, train, tvmonitor | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| input_image = gr.Image(type="pil", label="Image d'entrée") | |
| with gr.Accordion("Paramètres avancés", open=True): | |
| confidence_slider = gr.Slider( | |
| minimum=0.05, | |
| maximum=0.95, | |
| value=0.4, | |
| step=0.05, | |
| label="Seuil de confiance", | |
| info="Plus bas = plus de détections" | |
| ) | |
| iou_slider = gr.Slider( | |
| minimum=0.1, | |
| maximum=0.9, | |
| value=0.5, | |
| step=0.05, | |
| label="Seuil", | |
| info="Plus haut = garde plus de boxes qui se chevauchent" | |
| ) | |
| show_conf_check = gr.Checkbox( | |
| value=True, | |
| label="Afficher les confiances détaillées" | |
| ) | |
| detect_btn = gr.Button("Détecter les objets", variant="primary", size="lg") | |
| with gr.Column(scale=2): | |
| with gr.Row(): | |
| original_display = gr.Image(type="pil", label="Image originale") | |
| output_image = gr.Image(type="pil", label="Résultat avec détections") | |
| output_stats = gr.Markdown("**Uploadez une image et cliquez sur 'Détecter' pour commencer !**") | |
| # Galerie d'exemples | |
| if EXAMPLE_IMAGES: | |
| gr.Markdown("### Exemples (cliquez pour tester)") | |
| examples_list = [[img, 0.4, 0.5, True] for img in EXAMPLE_IMAGES[:12]] | |
| gr.Examples( | |
| examples=examples_list, | |
| inputs=[input_image, confidence_slider, iou_slider, show_conf_check], | |
| outputs=[original_display, output_image, output_stats], | |
| fn=detect_objects, | |
| cache_examples=False, | |
| examples_per_page=6, | |
| ) | |
| # Actions | |
| detect_btn.click( | |
| fn=detect_objects, | |
| inputs=[input_image, confidence_slider, iou_slider, show_conf_check], | |
| outputs=[original_display, output_image, output_stats] | |
| ) | |
| input_image.change( | |
| fn=detect_objects, | |
| inputs=[input_image, confidence_slider, iou_slider, show_conf_check], | |
| outputs=[original_display, output_image, output_stats] | |
| ) | |
| # Onglet Info | |
| with gr.Tab("À propos"): | |
| mAP_info = f"{MODEL_INFO['mAP']:.4f}" if isinstance(MODEL_INFO['mAP'], (int, float)) else 'N/A' | |
| epoch_info = MODEL_INFO['epoch'] if MODEL_INFO['epoch'] != 'N/A' else 'N/A' | |
| # Lancer l'app | |
| if __name__ == "__main__": | |
| print("\n" + "="*60) | |
| print("Lancement de l'application Gradio YOLO v1") | |
| print("="*60) | |
| print(f"Modèle: {MODEL_REPO_ID}/{MODEL_FILENAME}") | |
| print(f"Device: {DEVICE}") | |
| print(f"Exemples chargés: {len(EXAMPLE_IMAGES)}") | |
| print("="*60 + "\n") | |
| demo.launch( | |
| share=True, | |
| server_name="0.0.0.0", # Accessible depuis le réseau local | |
| server_port=7860, | |
| show_error=True | |
| ) | |