nathbns's picture
Update app.py
9426cc6 verified
import gradio as gr
import torch
import cv2
import numpy as np
from PIL import Image
from huggingface_hub import hf_hub_download
import os
# Configuration
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
IMAGE_SIZE = 416
NUM_CLASSES = 20
# Anchors YOLOv3 (normalisés)
ANCHORS = [
[(0.28, 0.22), (0.38, 0.48), (0.9, 0.78)],
[(0.07, 0.15), (0.15, 0.11), (0.14, 0.29)],
[(0.02, 0.03), (0.04, 0.07), (0.08, 0.06)],
]
S = [IMAGE_SIZE // 32, IMAGE_SIZE // 16, IMAGE_SIZE // 8]
# Classes Pascal VOC
PASCAL_CLASSES = [
"aeroplane", "bicycle", "bird", "boat", "bottle",
"bus", "car", "cat", "chair", "cow",
"diningtable", "dog", "horse", "motorbike", "person",
"pottedplant", "sheep", "sofa", "train", "tvmonitor"
]
# Import du modèle
from model import YOLOv3
from utils import cells_to_bboxes, non_max_suppression
class YOLOv3Detector:
def __init__(self, checkpoint_path):
"""Initialise le détecteur YOLOv3"""
# Charger le modèle
self.model = YOLOv3(num_classes=NUM_CLASSES).to(DEVICE)
checkpoint = torch.load(checkpoint_path, map_location=DEVICE)
self.model.load_state_dict(checkpoint["state_dict"])
self.model.eval()
# Anchors mis à l'échelle
self.scaled_anchors = (
torch.tensor(ANCHORS)
* torch.tensor(S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)
).to(DEVICE)
# Couleurs pour chaque classe
np.random.seed(42)
self.colors = np.random.randint(0, 255, size=(len(PASCAL_CLASSES), 3), dtype=np.uint8)
def preprocess_image(self, image):
"""Prétraite l'image pour le modèle"""
if isinstance(image, Image.Image):
image = np.array(image)
original_shape = image.shape[:2]
image_resized = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE))
# Normaliser et convertir en tensor
image_tensor = torch.from_numpy(image_resized).float() / 255.0
image_tensor = image_tensor.permute(2, 0, 1).unsqueeze(0)
return image_tensor.to(DEVICE), original_shape
def detect(self, image, conf_threshold=0.5, iou_threshold=0.45):
"""Détecte les objets dans l'image"""
image_tensor, original_shape = self.preprocess_image(image)
with torch.no_grad():
predictions = self.model(image_tensor)
# Convertir les prédictions en bboxes
bboxes = [[] for _ in range(1)]
for i in range(3):
S = predictions[i].shape[2]
anchor = self.scaled_anchors[i]
boxes_scale_i = cells_to_bboxes(
predictions[i], anchor, S=S, is_preds=True
)
for idx, box in enumerate(boxes_scale_i):
bboxes[idx] += box
# Appliquer NMS
nms_boxes = non_max_suppression(
bboxes[0],
iou_threshold=iou_threshold,
threshold=conf_threshold,
box_format="midpoint",
)
return nms_boxes
def draw_boxes(self, image, boxes):
"""Dessine les bounding boxes sur l'image"""
if isinstance(image, Image.Image):
image = np.array(image)
image = image.copy()
height, width = image.shape[:2]
detections_info = []
for box in boxes:
class_idx = int(box[0])
confidence = box[1]
x_center, y_center, box_width, box_height = box[2:]
# Convertir en coordonnées pixel
x1 = int((x_center - box_width / 2) * width)
y1 = int((y_center - box_height / 2) * height)
x2 = int((x_center + box_width / 2) * width)
y2 = int((y_center + box_height / 2) * height)
# Couleur pour cette classe
color = self.colors[class_idx].tolist()
# Dessiner le rectangle
cv2.rectangle(image, (x1, y1), (x2, y2), color, 2)
# Texte
label = f"{PASCAL_CLASSES[class_idx]}: {confidence:.2f}"
# Fond du texte
(text_width, text_height), _ = cv2.getTextSize(
label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1
)
cv2.rectangle(
image,
(x1, y1 - text_height - 4),
(x1 + text_width, y1),
color,
-1
)
# Texte blanc
cv2.putText(
image,
label,
(x1, y1 - 4),
cv2.FONT_HERSHEY_SIMPLEX,
0.5,
(255, 255, 255),
1
)
detections_info.append(f"• {PASCAL_CLASSES[class_idx]}: {confidence:.1%}")
return image, detections_info
# Télécharger le modèle depuis Hugging Face
checkpoint_path = hf_hub_download(
repo_id="nathbns/yolov3_from_scratch",
filename="checkpoint.pth.tar"
)
# Initialiser le détecteur
detector = YOLOv3Detector(checkpoint_path)
def predict(image, conf_threshold, iou_threshold):
"""Fonction de prédiction pour Gradio"""
if image is None:
return None, "Aucune image fournie"
# Détecter
boxes = detector.detect(image, conf_threshold, iou_threshold)
# Dessiner
result_image, detections = detector.draw_boxes(image, boxes)
# Texte des détections
if detections:
detection_text = f"**{len(detections)} objet(s) détecté(s) :**\n\n" + "\n".join(detections)
else:
detection_text = "Aucun objet détecté"
return result_image, detection_text
# Interface Gradio
with gr.Blocks(title="YOLOv3 Object Detection", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# YOLOv3 Object Detection - Pascal VOC
"""
)
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Image d'entrée")
with gr.Accordion("Paramètres", open=True):
conf_slider = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.5,
step=0.05,
label="Seuil de confiance",
info="Plus élevé = moins de détections mais plus sûres"
)
iou_slider = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.45,
step=0.05,
label="Seuil NMS (IoU)",
info="Plus élevé = plus de boîtes qui se chevauchent"
)
detect_btn = gr.Button("Détecter les objets", variant="primary", size="lg")
with gr.Column():
output_image = gr.Image(label="Résultat")
output_text = gr.Markdown(label="Détections")
# Action
detect_btn.click(
fn=predict,
inputs=[input_image, conf_slider, iou_slider],
outputs=[output_image, output_text]
)
# Auto-run sur upload
input_image.change(
fn=predict,
inputs=[input_image, conf_slider, iou_slider],
outputs=[output_image, output_text]
)
if __name__ == "__main__":
demo.launch()