nathbns's picture
Update app.py
cf3565d verified
import torch
import gradio as gr
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
from model import Yolo_V1
from utils import cellboxes_to_boxes, non_max_suppression
import cv2
import os
import glob
import time
from huggingface_hub import hf_hub_download
# Classes PASCAL VOC
CLASSES = [
"aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat",
"chair", "cow", "diningtable", "dog", "horse", "motorbike", "person",
"pottedplant", "sheep", "sofa", "train", "tvmonitor"
]
np.random.seed(42)
COLORS = np.random.randint(50, 255, size=(len(CLASSES), 3), dtype=np.uint8)
DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
MODEL_REPO_ID = "nathbns/yolov1_from_scratch"
MODEL_FILENAME = "checkpoint_epoch_50.pth.tar"
# Charger le modèle depuis Hugging Face Hub
print(f"Chargement du modèle depuis {MODEL_REPO_ID}...")
try:
model_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=MODEL_FILENAME)
print(f"Modèle téléchargé depuis Hugging Face Hub: {model_path}")
except Exception as e:
print(f"Erreur lors du téléchargement: {e}")
print("Tentative de chargement local...")
model_path = MODEL_FILENAME
model = Yolo_V1(split_size=7, num_boxes=2, num_classes=20).to(DEVICE)
checkpoint = torch.load(model_path, map_location=DEVICE)
model.load_state_dict(checkpoint["state_dict"])
model.eval()
print(f"Modèle chargé avec succès!")
# Info sur le modèle
MODEL_INFO = {
"mAP": checkpoint.get("mAP", "N/A"),
"epoch": checkpoint.get("epoch", "N/A"),
"device": DEVICE,
"classes": len(CLASSES)
}
print(f"entraînement: {MODEL_INFO['mAP']}")
print(f"Device: {DEVICE}")
# Charger des images d'exemple depuis le dossier data
EXAMPLE_IMAGES = []
if os.path.exists("data/images"):
image_files = glob.glob("data/images/*.jpg")[:20] # Prendre 20 images
EXAMPLE_IMAGES = sorted(image_files)
print(f"{len(EXAMPLE_IMAGES)} images d'exemple chargées")
def draw_boxes(image, boxes):
"""Dessine les bounding boxes sur l'image"""
img_array = np.array(image)
height, width = img_array.shape[:2]
for box in boxes:
# box format: [class_pred, prob_score, x, y, width, height]
class_pred = int(box[0])
confidence = float(box[1])
x_center, y_center, box_width, box_height = box[2:6]
# Convertir de coordonnées normalisées à pixels
x1 = int((x_center - box_width / 2) * width)
y1 = int((y_center - box_height / 2) * height)
x2 = int((x_center + box_width / 2) * width)
y2 = int((y_center + box_height / 2) * height)
# Couleur de la classe
color = tuple(int(c) for c in COLORS[class_pred])
# Dessiner le rectangle
cv2.rectangle(img_array, (x1, y1), (x2, y2), color, 2)
# Texte
label = f"{CLASSES[class_pred]}: {confidence:.2f}"
# Fond du texte
(text_width, text_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
cv2.rectangle(img_array, (x1, y1 - text_height - 5), (x1 + text_width, y1), color, -1)
# Texte blanc
cv2.putText(img_array, label, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
return Image.fromarray(img_array)
def detect_objects(image, confidence_threshold, iou_threshold, show_confidence=True):
"""Détecte les objets dans une image avec statistiques détaillées"""
if image is None:
return None, None, "**Veuillez uploader ou sélectionner une image**"
start_time = time.time()
# Prétraiter l'image
transform = transforms.Compose([
transforms.Resize((448, 448)),
transforms.ToTensor(),
])
# Garder l'image originale pour l'affichage
original_image = image.copy()
original_size = image.size # (width, height)
# Transformer l'image
img_tensor = transform(image).unsqueeze(0).to(DEVICE)
# Prédiction
with torch.no_grad():
predictions = model(img_tensor)
# Convertir les prédictions en bounding boxes
bboxes = cellboxes_to_boxes(predictions)
# Non-maximum suppression
nms_boxes = non_max_suppression(
bboxes[0],
iou_threshold=iou_threshold,
threshold=confidence_threshold,
box_format="midpoint"
)
inference_time = time.time() - start_time
# Dessiner les boxes
result_image = draw_boxes(original_image, nms_boxes)
# Statistiques détaillées
num_detections = len(nms_boxes)
detected_classes = [CLASSES[int(box[0])] for box in nms_boxes]
class_counts = {}
confidence_scores = []
for box in nms_boxes:
cls = CLASSES[int(box[0])]
conf = float(box[1])
class_counts[cls] = class_counts.get(cls, 0) + 1
confidence_scores.append(conf)
# Créer un rapport détaillé
stats = f"##Résultats de détection\n\n"
stats += f"**{num_detections} objet(s) détecté(s)**\n\n"
if num_detections > 0:
stats += f"Temps d'inférence: **{inference_time:.3f}s**\n"
stats += f"Taille image: **{original_size[0]}x{original_size[1]}**\n"
stats += f"Confiance moyenne: **{np.mean(confidence_scores):.2%}**\n\n"
stats += "### Objets détectés:\n\n"
for cls, count in sorted(class_counts.items(), key=lambda x: x[1], reverse=True):
stats += f"- **{cls}**: {count}\n"
if show_confidence:
stats += "\n### Confiances individuelles:\n\n"
for i, box in enumerate(nms_boxes[:10], 1): # Top 10
cls = CLASSES[int(box[0])]
conf = float(box[1])
stats += f"{i}. {cls}: {conf:.1%}\n"
if len(nms_boxes) > 10:
stats += f"\n*...et {len(nms_boxes)-10} détection(s) de plus*\n"
else:
stats += "Aucun objet détecté.\n\n"
return original_image, result_image, stats
# Interface Gradio améliorée
with gr.Blocks(title="YOLO v1 - Détection d'objets", theme=gr.themes.Soft(), css="""
.gradio-container {max-width: 1400px !important}
.example-gallery {height: 400px; overflow-y: auto}
""") as demo:
# En-tête
mAP_display = f"{MODEL_INFO['mAP']:.4f}" if isinstance(MODEL_INFO['mAP'], (int, float)) else MODEL_INFO['mAP']
gr.Markdown(f"""
# YOLO v1 - Détection d'objets en temps réel
---
""")
with gr.Tabs():
# Onglet principal - Détection
with gr.Tab("Détection"):
gr.Markdown("""
### Uploadez votre image ou sélectionnez un exemple
**Classes PASCAL VOC :** aeroplane, bicycle, bird, boat, bottle, bus, car, cat, chair, cow,
diningtable, dog, horse, motorbike, person, pottedplant, sheep, sofa, train, tvmonitor
""")
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(type="pil", label="Image d'entrée")
with gr.Accordion("Paramètres avancés", open=True):
confidence_slider = gr.Slider(
minimum=0.05,
maximum=0.95,
value=0.4,
step=0.05,
label="Seuil de confiance",
info="Plus bas = plus de détections"
)
iou_slider = gr.Slider(
minimum=0.1,
maximum=0.9,
value=0.5,
step=0.05,
label="Seuil",
info="Plus haut = garde plus de boxes qui se chevauchent"
)
show_conf_check = gr.Checkbox(
value=True,
label="Afficher les confiances détaillées"
)
detect_btn = gr.Button("Détecter les objets", variant="primary", size="lg")
with gr.Column(scale=2):
with gr.Row():
original_display = gr.Image(type="pil", label="Image originale")
output_image = gr.Image(type="pil", label="Résultat avec détections")
output_stats = gr.Markdown("**Uploadez une image et cliquez sur 'Détecter' pour commencer !**")
# Galerie d'exemples
if EXAMPLE_IMAGES:
gr.Markdown("### Exemples (cliquez pour tester)")
examples_list = [[img, 0.4, 0.5, True] for img in EXAMPLE_IMAGES[:12]]
gr.Examples(
examples=examples_list,
inputs=[input_image, confidence_slider, iou_slider, show_conf_check],
outputs=[original_display, output_image, output_stats],
fn=detect_objects,
cache_examples=False,
examples_per_page=6,
)
# Actions
detect_btn.click(
fn=detect_objects,
inputs=[input_image, confidence_slider, iou_slider, show_conf_check],
outputs=[original_display, output_image, output_stats]
)
input_image.change(
fn=detect_objects,
inputs=[input_image, confidence_slider, iou_slider, show_conf_check],
outputs=[original_display, output_image, output_stats]
)
# Onglet Info
with gr.Tab("À propos"):
mAP_info = f"{MODEL_INFO['mAP']:.4f}" if isinstance(MODEL_INFO['mAP'], (int, float)) else 'N/A'
epoch_info = MODEL_INFO['epoch'] if MODEL_INFO['epoch'] != 'N/A' else 'N/A'
# Lancer l'app
if __name__ == "__main__":
print("\n" + "="*60)
print("Lancement de l'application Gradio YOLO v1")
print("="*60)
print(f"Modèle: {MODEL_REPO_ID}/{MODEL_FILENAME}")
print(f"Device: {DEVICE}")
print(f"Exemples chargés: {len(EXAMPLE_IMAGES)}")
print("="*60 + "\n")
demo.launch(
share=True,
server_name="0.0.0.0", # Accessible depuis le réseau local
server_port=7860,
show_error=True
)