BaggTrack / app.py
abdwahdia's picture
Update app.py
b30d499 verified
import gradio as gr
import cv2
import os
import numpy as np
import pickle
from ultralytics import YOLO
# Charger le modèle entraîné
model = YOLO("My_best_model.pt") # Adapter le chemin si besoin
# Dictionnaires de classes
class_names = {0: "suitcase", 1: "backpack", 2: "handbag"}
class_colors = {
"suitcase": (255, 0, 0),
"backpack": (0, 255, 0),
"handbag": (0, 0, 255)
}
def process_video(input_video):
# Charger la vidéo
cap = cv2.VideoCapture(input_video)
fps = int(cap.get(cv2.CAP_PROP_FPS))
width, height = int(cap.get(3)), int(cap.get(4))
# Fichiers temporaires
output_path = "output_result.mp4"
state_file = "state_temp.pkl"
temp_txt = "final_counts.txt"
out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
# État de suivi
cumulative_counts = {cls: 0 for cls in class_names.values()}
seen_ids = set()
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
results = model.track(frame, persist=True, tracker="bytetrack.yaml")
if results[0].boxes.id is not None:
boxes = results[0].boxes.xyxy.cpu().numpy().astype(int)
ids = results[0].boxes.id.cpu().numpy().astype(int)
classes = results[0].boxes.cls.cpu().numpy().astype(int)
confs = results[0].boxes.conf.cpu().numpy()
for box, obj_id, cls, conf in zip(boxes, ids, classes, confs):
cls_name = class_names.get(cls)
if cls_name is None:
continue
if obj_id not in seen_ids:
cumulative_counts[cls_name] += 1
seen_ids.add(obj_id)
color = class_colors.get(cls_name, (0, 255, 0))
x1, y1, x2, y2 = box
label = f"{cls_name} ID:{obj_id} ({conf:.2f})"
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
cv2.putText(frame, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
# Affichage du comptage
y_offset = 30
cv2.rectangle(frame, (10, 10), (300, 150), (0, 0, 0), -1)
for cls_name, count in cumulative_counts.items():
color = class_colors.get(cls_name, (0, 255, 0))
cv2.putText(frame, f"{cls_name}: {count}", (20, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2)
y_offset += 30
total = sum(cumulative_counts.values())
cv2.putText(frame, f"Total: {total}", (20, y_offset + 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 255), 2)
out.write(frame)
cap.release()
out.release()
# Sauvegarder les résultats dans un fichier texte
with open(temp_txt, "w") as f:
for k, v in cumulative_counts.items():
f.write(f"{k}: {v}\n")
f.write(f"Total: {total}\n")
return output_path, temp_txt
# Interface Gradio
demo = gr.Interface(
fn=process_video,
inputs=gr.Video(label="Charger une vidéo"),
outputs=[
gr.Video(label="Vidéo annotée"),
gr.File(label="Comptage final (fichier texte)")
],
title="Détection et Comptage d'Objets avec YOLOv8",
description="Téléversez une vidéo pour détecter et compter les valises, sacs à dos et sacs à main. Utilise YOLO + ByteTrack."
)
demo.launch()