BagTrack / app.py
LaurianeMD's picture
Update app.py
fae76ad verified
# app.py
import gradio as gr
import cv2
from ultralytics import YOLO
import tempfile
# Charger le modèle YOLO (chemin relatif)
model = YOLO("./yolo12m.pt")
# Classes à détecter
target_classes = ["backpack", "suitcase", "handbag"]
class_name_to_id = {name: idx for idx, name in model.names.items()}
target_ids = [class_name_to_id[c] for c in target_classes]
def process_video(video_path):
cap = cv2.VideoCapture(video_path)
ret, frame = cap.read()
if not ret:
return "Erreur lecture vidéo"
H, W = frame.shape[:2]
line_y = int(H * 0.6)
tolerance = 25
counted_ids = set()
class_counts = {c: 0 for c in target_classes}
total_count = 0
# Vidéo temporaire de sortie
temp_out = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(temp_out.name, fourcc, 20, (W, H))
cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
while True:
ret, frame = cap.read()
if not ret:
break
results = model.track(frame, persist=True, classes=target_ids, verbose=False)
if results[0].boxes.id is not None:
ids = results[0].boxes.id.int().cpu().tolist()
clss = results[0].boxes.cls.int().cpu().tolist()
boxes = results[0].boxes.xyxy.cpu().tolist()
for obj_id, cls, box in zip(ids, clss, boxes):
x1, y1, x2, y2 = map(int, box)
center_x = (x1 + x2) // 2
center_y = (y1 + y2) // 2
class_name = model.names[cls]
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
cv2.putText(frame, f"{class_name} ID:{obj_id}", (x1, y1 - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
if obj_id not in counted_ids and (line_y - tolerance < center_y < line_y + tolerance):
counted_ids.add(obj_id)
total_count += 1
class_counts[class_name] += 1
# Affichage des résultats
y_offset = 30
for name in target_classes:
cv2.putText(frame, f"{name}: {class_counts[name]}", (20, y_offset),
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 255), 2)
y_offset += 30
cv2.putText(frame, f"Total bagages : {total_count}", (20, y_offset + 10),
cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)
cv2.line(frame, (0, line_y), (W, line_y), (0, 0, 255), 2)
out.write(frame)
cap.release()
out.release()
return temp_out.name
# Interface Gradio
demo = gr.Interface(
fn=process_video,
inputs=gr.Video(label="Importer une vidéo de bagages"),
outputs=gr.Video(label="Vidéo annotée avec comptage"),
title="🎒 Compteur intelligent de bagages",
description="Détecte et compte les sacs, valises, et sacs à dos avec YOLO."
)
if __name__ == "__main__":
demo.launch()