import torch import gradio as gr import numpy as np from PIL import Image import torchvision.transforms as transforms from model import Yolo_V1 from utils import cellboxes_to_boxes, non_max_suppression import cv2 import os import glob import time from huggingface_hub import hf_hub_download # Classes PASCAL VOC CLASSES = [ "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor" ] np.random.seed(42) COLORS = np.random.randint(50, 255, size=(len(CLASSES), 3), dtype=np.uint8) DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" MODEL_REPO_ID = "nathbns/yolov1_from_scratch" MODEL_FILENAME = "checkpoint_epoch_50.pth.tar" # Charger le modèle depuis Hugging Face Hub print(f"Chargement du modèle depuis {MODEL_REPO_ID}...") try: model_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=MODEL_FILENAME) print(f"Modèle téléchargé depuis Hugging Face Hub: {model_path}") except Exception as e: print(f"Erreur lors du téléchargement: {e}") print("Tentative de chargement local...") model_path = MODEL_FILENAME model = Yolo_V1(split_size=7, num_boxes=2, num_classes=20).to(DEVICE) checkpoint = torch.load(model_path, map_location=DEVICE) model.load_state_dict(checkpoint["state_dict"]) model.eval() print(f"Modèle chargé avec succès!") # Info sur le modèle MODEL_INFO = { "mAP": checkpoint.get("mAP", "N/A"), "epoch": checkpoint.get("epoch", "N/A"), "device": DEVICE, "classes": len(CLASSES) } print(f"entraînement: {MODEL_INFO['mAP']}") print(f"Device: {DEVICE}") # Charger des images d'exemple depuis le dossier data EXAMPLE_IMAGES = [] if os.path.exists("data/images"): image_files = glob.glob("data/images/*.jpg")[:20] # Prendre 20 images EXAMPLE_IMAGES = sorted(image_files) print(f"{len(EXAMPLE_IMAGES)} images d'exemple chargées") def draw_boxes(image, boxes): """Dessine les bounding boxes sur l'image""" img_array = np.array(image) height, width = img_array.shape[:2] for box in boxes: # box format: [class_pred, prob_score, x, y, width, height] class_pred = int(box[0]) confidence = float(box[1]) x_center, y_center, box_width, box_height = box[2:6] # Convertir de coordonnées normalisées à pixels x1 = int((x_center - box_width / 2) * width) y1 = int((y_center - box_height / 2) * height) x2 = int((x_center + box_width / 2) * width) y2 = int((y_center + box_height / 2) * height) # Couleur de la classe color = tuple(int(c) for c in COLORS[class_pred]) # Dessiner le rectangle cv2.rectangle(img_array, (x1, y1), (x2, y2), color, 2) # Texte label = f"{CLASSES[class_pred]}: {confidence:.2f}" # Fond du texte (text_width, text_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1) cv2.rectangle(img_array, (x1, y1 - text_height - 5), (x1 + text_width, y1), color, -1) # Texte blanc cv2.putText(img_array, label, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1) return Image.fromarray(img_array) def detect_objects(image, confidence_threshold, iou_threshold, show_confidence=True): """Détecte les objets dans une image avec statistiques détaillées""" if image is None: return None, None, "**Veuillez uploader ou sélectionner une image**" start_time = time.time() # Prétraiter l'image transform = transforms.Compose([ transforms.Resize((448, 448)), transforms.ToTensor(), ]) # Garder l'image originale pour l'affichage original_image = image.copy() original_size = image.size # (width, height) # Transformer l'image img_tensor = transform(image).unsqueeze(0).to(DEVICE) # Prédiction with torch.no_grad(): predictions = model(img_tensor) # Convertir les prédictions en bounding boxes bboxes = cellboxes_to_boxes(predictions) # Non-maximum suppression nms_boxes = non_max_suppression( bboxes[0], iou_threshold=iou_threshold, threshold=confidence_threshold, box_format="midpoint" ) inference_time = time.time() - start_time # Dessiner les boxes result_image = draw_boxes(original_image, nms_boxes) # Statistiques détaillées num_detections = len(nms_boxes) detected_classes = [CLASSES[int(box[0])] for box in nms_boxes] class_counts = {} confidence_scores = [] for box in nms_boxes: cls = CLASSES[int(box[0])] conf = float(box[1]) class_counts[cls] = class_counts.get(cls, 0) + 1 confidence_scores.append(conf) # Créer un rapport détaillé stats = f"##Résultats de détection\n\n" stats += f"**{num_detections} objet(s) détecté(s)**\n\n" if num_detections > 0: stats += f"Temps d'inférence: **{inference_time:.3f}s**\n" stats += f"Taille image: **{original_size[0]}x{original_size[1]}**\n" stats += f"Confiance moyenne: **{np.mean(confidence_scores):.2%}**\n\n" stats += "### Objets détectés:\n\n" for cls, count in sorted(class_counts.items(), key=lambda x: x[1], reverse=True): stats += f"- **{cls}**: {count}\n" if show_confidence: stats += "\n### Confiances individuelles:\n\n" for i, box in enumerate(nms_boxes[:10], 1): # Top 10 cls = CLASSES[int(box[0])] conf = float(box[1]) stats += f"{i}. {cls}: {conf:.1%}\n" if len(nms_boxes) > 10: stats += f"\n*...et {len(nms_boxes)-10} détection(s) de plus*\n" else: stats += "Aucun objet détecté.\n\n" return original_image, result_image, stats # Interface Gradio améliorée with gr.Blocks(title="YOLO v1 - Détection d'objets", theme=gr.themes.Soft(), css=""" .gradio-container {max-width: 1400px !important} .example-gallery {height: 400px; overflow-y: auto} """) as demo: # En-tête mAP_display = f"{MODEL_INFO['mAP']:.4f}" if isinstance(MODEL_INFO['mAP'], (int, float)) else MODEL_INFO['mAP'] gr.Markdown(f""" # YOLO v1 - Détection d'objets en temps réel --- """) with gr.Tabs(): # Onglet principal - Détection with gr.Tab("Détection"): gr.Markdown(""" ### Uploadez votre image ou sélectionnez un exemple **Classes PASCAL VOC :** aeroplane, bicycle, bird, boat, bottle, bus, car, cat, chair, cow, diningtable, dog, horse, motorbike, person, pottedplant, sheep, sofa, train, tvmonitor """) with gr.Row(): with gr.Column(scale=1): input_image = gr.Image(type="pil", label="Image d'entrée") with gr.Accordion("Paramètres avancés", open=True): confidence_slider = gr.Slider( minimum=0.05, maximum=0.95, value=0.4, step=0.05, label="Seuil de confiance", info="Plus bas = plus de détections" ) iou_slider = gr.Slider( minimum=0.1, maximum=0.9, value=0.5, step=0.05, label="Seuil", info="Plus haut = garde plus de boxes qui se chevauchent" ) show_conf_check = gr.Checkbox( value=True, label="Afficher les confiances détaillées" ) detect_btn = gr.Button("Détecter les objets", variant="primary", size="lg") with gr.Column(scale=2): with gr.Row(): original_display = gr.Image(type="pil", label="Image originale") output_image = gr.Image(type="pil", label="Résultat avec détections") output_stats = gr.Markdown("**Uploadez une image et cliquez sur 'Détecter' pour commencer !**") # Galerie d'exemples if EXAMPLE_IMAGES: gr.Markdown("### Exemples (cliquez pour tester)") examples_list = [[img, 0.4, 0.5, True] for img in EXAMPLE_IMAGES[:12]] gr.Examples( examples=examples_list, inputs=[input_image, confidence_slider, iou_slider, show_conf_check], outputs=[original_display, output_image, output_stats], fn=detect_objects, cache_examples=False, examples_per_page=6, ) # Actions detect_btn.click( fn=detect_objects, inputs=[input_image, confidence_slider, iou_slider, show_conf_check], outputs=[original_display, output_image, output_stats] ) input_image.change( fn=detect_objects, inputs=[input_image, confidence_slider, iou_slider, show_conf_check], outputs=[original_display, output_image, output_stats] ) # Onglet Info with gr.Tab("À propos"): mAP_info = f"{MODEL_INFO['mAP']:.4f}" if isinstance(MODEL_INFO['mAP'], (int, float)) else 'N/A' epoch_info = MODEL_INFO['epoch'] if MODEL_INFO['epoch'] != 'N/A' else 'N/A' # Lancer l'app if __name__ == "__main__": print("\n" + "="*60) print("Lancement de l'application Gradio YOLO v1") print("="*60) print(f"Modèle: {MODEL_REPO_ID}/{MODEL_FILENAME}") print(f"Device: {DEVICE}") print(f"Exemples chargés: {len(EXAMPLE_IMAGES)}") print("="*60 + "\n") demo.launch( share=True, server_name="0.0.0.0", # Accessible depuis le réseau local server_port=7860, show_error=True )