ModuMLTECH's picture
Update app.py
c38e692 verified
import streamlit as st
import cv2
import tempfile
import os
import time
import numpy as np
import pandas as pd
from collections import defaultdict
from ultralytics import YOLO
# --- FONCTIONS UTILES ---
def draw_text_with_background(image, text, position, font=cv2.FONT_HERSHEY_SIMPLEX,
font_scale=1, font_thickness=2, text_color=(255, 255, 255), bg_color=(0, 0, 0), padding=5):
"""Ajoute du texte avec un fond sur une image OpenCV."""
text_size = cv2.getTextSize(text, font, font_scale, font_thickness)[0]
text_width, text_height = text_size
x, y = position
top_left = (x, y - text_height - padding)
bottom_right = (x + text_width + padding * 2, y + padding)
cv2.rectangle(image, top_left, bottom_right, bg_color, -1)
cv2.putText(image, text, (x + padding, y), font, font_scale, text_color, font_thickness, cv2.LINE_AA)
# --- CLASSE YOLO ---
class YOLOVideoProcessor:
def __init__(self, model_path, video_path, output_path, poly1, poly2, tracker_method="bot"):
self.model = YOLO(model_path, task="detect")
self.tracker_method = tracker_method
self.video_path = video_path
self.output_path = output_path
self.unique_region1_ids = set()
self.unique_region2_ids = set()
self.poly1 = poly1
self.poly2 = poly2
def is_in_region(self, center, poly):
poly_np = np.array(poly, dtype=np.int32)
return cv2.pointPolygonTest(poly_np, center, False) >= 0
def process_video(self, progress_bar=None):
cap = cv2.VideoCapture(self.video_path)
if not cap.isOpened():
st.error("⚠️ Erreur : Impossible d'ouvrir la vidéo.")
return
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = int(cap.get(cv2.CAP_PROP_FPS))
if fps == 0:
fps = 30 # Valeur par défaut si FPS est invalide
# Utiliser XVID qui est généralement mieux supporté
fourcc = cv2.VideoWriter_fourcc(*'XVID')
out = cv2.VideoWriter(self.output_path, fourcc, fps, (frame_width, frame_height))
processed_frames = 0
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
while cap.isOpened():
success, frame = cap.read()
if not success:
break
# Mise à jour de la barre de progression
if progress_bar is not None:
progress_bar.progress(processed_frames / total_frames)
tracker = "botsort.yaml" if self.tracker_method.lower() == "bot" else "bytetrack.yaml"
results = self.model.track(frame, persist=True, tracker=tracker, conf=0.25)
track_ids = []
if results and len(results) > 0 and len(results[0].boxes) > 0:
try:
track_ids = results[0].boxes.id.int().cpu().tolist()
except AttributeError:
track_ids = [i for i in range(len(results[0].boxes.xywh.cpu().numpy()))]
# Dessiner les polygones
cv2.polylines(frame, [np.array(self.poly1, np.int32)], isClosed=True, color=(0, 255, 0), thickness=2)
cv2.polylines(frame, [np.array(self.poly2, np.int32)], isClosed=True, color=(255, 0, 0), thickness=2)
for box, track_id in zip(results[0].boxes.xywh.cpu().numpy(), track_ids):
x, y, w, h = box
center_point = (int(x), int(y))
if self.is_in_region(center_point, self.poly1):
self.unique_region1_ids.add(track_id)
if self.is_in_region(center_point, self.poly2):
self.unique_region2_ids.add(track_id)
# Affichage du comptage des véhicules
draw_text_with_background(frame, f'Total Sens 1: {len(self.unique_region1_ids)}', (10, frame_height - 50))
draw_text_with_background(frame, f'Total Sens 2: {len(self.unique_region2_ids)}', (frame_width - 300, frame_height - 50))
out.write(frame)
processed_frames += 1
cap.release()
out.release()
cv2.destroyAllWindows()
if processed_frames == 0:
st.error("⚠️ Aucune image n'a été écrite dans la vidéo de sortie !")
return len(self.unique_region1_ids), len(self.unique_region2_ids)
# --- INTERFACE STREAMLIT ---
def main():
st.set_page_config(
page_title="Détecteur de Véhicules",
page_icon="🚗",
layout="wide"
)
st.title("🚗 Détection et comptage de Véhicules sur l'Autoroute de l'Avenir")
# Vérifier si le modèle existe déjà ou doit être téléchargé
model_path = "best.pt"
if not os.path.exists(model_path):
with st.spinner("📥 Chargement du modèle YOLO... Cela peut prendre un moment."):
# Utilisez hub.load pour télécharger le modèle depuis Hugging Face Hub
try:
from huggingface_hub import hf_hub_download
model_path = hf_hub_download(repo_id="ModuMLTECH/projet_trafic_2", filename="best.pt")
st.success("✅ Modèle chargé avec succès!")
except Exception as e:
st.error(f"❌ Erreur lors du chargement du modèle: {e}")
# Fallback: utiliser un modèle YOLO standard
st.warning("⚠️ Utilisation du modèle YOLO standard à la place")
model_path = "yolov8n.pt"
# Colonnes pour l'organisation de l'interface
col1, col2 = st.columns([3, 1])
with col2:
st.header("🔹 Paramètres")
# Entrée utilisateur pour les polygones
st.subheader("📍 Polygone 1 (vert)")
poly1_input = st.text_area("Entrez 4 points (x,y) séparés par des espaces", "465,350 609,350 520,630 3,630")
st.subheader("📍 Polygone 2 (rouge)")
poly2_input = st.text_area("Entrez 4 points (x,y) séparés par des espaces", "678,350 815,350 1203,630 743,630")
tracker_method = st.selectbox("Méthode de tracking", ["bot", "byte"], index=0)
with col1:
uploaded_file = st.file_uploader("📂 Upload une vidéo", type=["mp4", "avi", "mov"])
def parse_polygon(input_text):
try:
return [tuple(map(int, point.split(','))) for point in input_text.split()]
except:
return []
poly1 = parse_polygon(poly1_input)
poly2 = parse_polygon(poly2_input)
if uploaded_file is not None:
# Créer un dossier temporaire si nécessaire
temp_dir = tempfile.mkdtemp()
input_video_path = os.path.join(temp_dir, "input_video.mp4")
output_video_path = os.path.join(temp_dir, "output_video.mp4")
# Écrire le fichier téléchargé dans un fichier temporaire
with open(input_video_path, "wb") as f:
f.write(uploaded_file.getbuffer())
st.video(input_video_path) # Afficher la vidéo d'entrée
if st.button("▶️ Lancer la détection"):
if len(poly1) == 4 and len(poly2) == 4:
# Afficher la barre de progression
progress_text = "🔄 Traitement de la vidéo en cours..."
progress_bar = st.progress(0)
# Traitement de la vidéo
processor = YOLOVideoProcessor(model_path, input_video_path, output_video_path, poly1, poly2, tracker_method)
# Démarrer le traitement
start_time = time.time()
count1, count2 = processor.process_video(progress_bar=progress_bar)
end_time = time.time()
# Calcul du temps de traitement
processing_time = end_time - start_time
progress_bar.progress(1.0) # Compléter la barre de progression
st.success(f"✅ Traitement terminé en {processing_time:.2f} secondes!")
# Afficher les résultats
col_result1, col_result2 = st.columns(2)
with col_result1:
st.metric("Véhicules Sens 1 (Vert)", count1)
with col_result2:
st.metric("Véhicules Sens 2 (Rouge)", count2)
# Afficher la vidéo traitée
st.subheader("Vidéo traitée")
st.video(output_video_path)
# Option de téléchargement
with open(output_video_path, "rb") as file:
st.download_button(
label="⬇️ Télécharger la vidéo",
data=file,
file_name="video_traitee.mp4",
mime="video/mp4"
)
else:
st.error("❌ Les coordonnées des polygones doivent contenir **exactement 4 points**.")
if __name__ == "__main__":
main()