Spaces:
Build error
Build error
| import streamlit as st | |
| import cv2 | |
| import tempfile | |
| import os | |
| import time | |
| import numpy as np | |
| import pandas as pd | |
| from collections import defaultdict | |
| from ultralytics import YOLO | |
| # --- FONCTIONS UTILES --- | |
| def draw_text_with_background(image, text, position, font=cv2.FONT_HERSHEY_SIMPLEX, | |
| font_scale=1, font_thickness=2, text_color=(255, 255, 255), bg_color=(0, 0, 0), padding=5): | |
| """Ajoute du texte avec un fond sur une image OpenCV.""" | |
| text_size = cv2.getTextSize(text, font, font_scale, font_thickness)[0] | |
| text_width, text_height = text_size | |
| x, y = position | |
| top_left = (x, y - text_height - padding) | |
| bottom_right = (x + text_width + padding * 2, y + padding) | |
| cv2.rectangle(image, top_left, bottom_right, bg_color, -1) | |
| cv2.putText(image, text, (x + padding, y), font, font_scale, text_color, font_thickness, cv2.LINE_AA) | |
| # --- CLASSE YOLO --- | |
| class YOLOVideoProcessor: | |
| def __init__(self, model_path, video_path, output_path, poly1, poly2, tracker_method="bot"): | |
| self.model = YOLO(model_path, task="detect") | |
| self.tracker_method = tracker_method | |
| self.video_path = video_path | |
| self.output_path = output_path | |
| self.unique_region1_ids = set() | |
| self.unique_region2_ids = set() | |
| self.poly1 = poly1 | |
| self.poly2 = poly2 | |
| def is_in_region(self, center, poly): | |
| poly_np = np.array(poly, dtype=np.int32) | |
| return cv2.pointPolygonTest(poly_np, center, False) >= 0 | |
| def process_video(self, progress_bar=None): | |
| cap = cv2.VideoCapture(self.video_path) | |
| if not cap.isOpened(): | |
| st.error("⚠️ Erreur : Impossible d'ouvrir la vidéo.") | |
| return | |
| frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| fps = int(cap.get(cv2.CAP_PROP_FPS)) | |
| if fps == 0: | |
| fps = 30 # Valeur par défaut si FPS est invalide | |
| # Utiliser XVID qui est généralement mieux supporté | |
| fourcc = cv2.VideoWriter_fourcc(*'XVID') | |
| out = cv2.VideoWriter(self.output_path, fourcc, fps, (frame_width, frame_height)) | |
| processed_frames = 0 | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| while cap.isOpened(): | |
| success, frame = cap.read() | |
| if not success: | |
| break | |
| # Mise à jour de la barre de progression | |
| if progress_bar is not None: | |
| progress_bar.progress(processed_frames / total_frames) | |
| tracker = "botsort.yaml" if self.tracker_method.lower() == "bot" else "bytetrack.yaml" | |
| results = self.model.track(frame, persist=True, tracker=tracker, conf=0.25) | |
| track_ids = [] | |
| if results and len(results) > 0 and len(results[0].boxes) > 0: | |
| try: | |
| track_ids = results[0].boxes.id.int().cpu().tolist() | |
| except AttributeError: | |
| track_ids = [i for i in range(len(results[0].boxes.xywh.cpu().numpy()))] | |
| # Dessiner les polygones | |
| cv2.polylines(frame, [np.array(self.poly1, np.int32)], isClosed=True, color=(0, 255, 0), thickness=2) | |
| cv2.polylines(frame, [np.array(self.poly2, np.int32)], isClosed=True, color=(255, 0, 0), thickness=2) | |
| for box, track_id in zip(results[0].boxes.xywh.cpu().numpy(), track_ids): | |
| x, y, w, h = box | |
| center_point = (int(x), int(y)) | |
| if self.is_in_region(center_point, self.poly1): | |
| self.unique_region1_ids.add(track_id) | |
| if self.is_in_region(center_point, self.poly2): | |
| self.unique_region2_ids.add(track_id) | |
| # Affichage du comptage des véhicules | |
| draw_text_with_background(frame, f'Total Sens 1: {len(self.unique_region1_ids)}', (10, frame_height - 50)) | |
| draw_text_with_background(frame, f'Total Sens 2: {len(self.unique_region2_ids)}', (frame_width - 300, frame_height - 50)) | |
| out.write(frame) | |
| processed_frames += 1 | |
| cap.release() | |
| out.release() | |
| cv2.destroyAllWindows() | |
| if processed_frames == 0: | |
| st.error("⚠️ Aucune image n'a été écrite dans la vidéo de sortie !") | |
| return len(self.unique_region1_ids), len(self.unique_region2_ids) | |
| # --- INTERFACE STREAMLIT --- | |
| def main(): | |
| st.set_page_config( | |
| page_title="Détecteur de Véhicules", | |
| page_icon="🚗", | |
| layout="wide" | |
| ) | |
| st.title("🚗 Détection et comptage de Véhicules sur l'Autoroute de l'Avenir") | |
| # Vérifier si le modèle existe déjà ou doit être téléchargé | |
| model_path = "best.pt" | |
| if not os.path.exists(model_path): | |
| with st.spinner("📥 Chargement du modèle YOLO... Cela peut prendre un moment."): | |
| # Utilisez hub.load pour télécharger le modèle depuis Hugging Face Hub | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| model_path = hf_hub_download(repo_id="ModuMLTECH/projet_trafic_2", filename="best.pt") | |
| st.success("✅ Modèle chargé avec succès!") | |
| except Exception as e: | |
| st.error(f"❌ Erreur lors du chargement du modèle: {e}") | |
| # Fallback: utiliser un modèle YOLO standard | |
| st.warning("⚠️ Utilisation du modèle YOLO standard à la place") | |
| model_path = "yolov8n.pt" | |
| # Colonnes pour l'organisation de l'interface | |
| col1, col2 = st.columns([3, 1]) | |
| with col2: | |
| st.header("🔹 Paramètres") | |
| # Entrée utilisateur pour les polygones | |
| st.subheader("📍 Polygone 1 (vert)") | |
| poly1_input = st.text_area("Entrez 4 points (x,y) séparés par des espaces", "465,350 609,350 520,630 3,630") | |
| st.subheader("📍 Polygone 2 (rouge)") | |
| poly2_input = st.text_area("Entrez 4 points (x,y) séparés par des espaces", "678,350 815,350 1203,630 743,630") | |
| tracker_method = st.selectbox("Méthode de tracking", ["bot", "byte"], index=0) | |
| with col1: | |
| uploaded_file = st.file_uploader("📂 Upload une vidéo", type=["mp4", "avi", "mov"]) | |
| def parse_polygon(input_text): | |
| try: | |
| return [tuple(map(int, point.split(','))) for point in input_text.split()] | |
| except: | |
| return [] | |
| poly1 = parse_polygon(poly1_input) | |
| poly2 = parse_polygon(poly2_input) | |
| if uploaded_file is not None: | |
| # Créer un dossier temporaire si nécessaire | |
| temp_dir = tempfile.mkdtemp() | |
| input_video_path = os.path.join(temp_dir, "input_video.mp4") | |
| output_video_path = os.path.join(temp_dir, "output_video.mp4") | |
| # Écrire le fichier téléchargé dans un fichier temporaire | |
| with open(input_video_path, "wb") as f: | |
| f.write(uploaded_file.getbuffer()) | |
| st.video(input_video_path) # Afficher la vidéo d'entrée | |
| if st.button("▶️ Lancer la détection"): | |
| if len(poly1) == 4 and len(poly2) == 4: | |
| # Afficher la barre de progression | |
| progress_text = "🔄 Traitement de la vidéo en cours..." | |
| progress_bar = st.progress(0) | |
| # Traitement de la vidéo | |
| processor = YOLOVideoProcessor(model_path, input_video_path, output_video_path, poly1, poly2, tracker_method) | |
| # Démarrer le traitement | |
| start_time = time.time() | |
| count1, count2 = processor.process_video(progress_bar=progress_bar) | |
| end_time = time.time() | |
| # Calcul du temps de traitement | |
| processing_time = end_time - start_time | |
| progress_bar.progress(1.0) # Compléter la barre de progression | |
| st.success(f"✅ Traitement terminé en {processing_time:.2f} secondes!") | |
| # Afficher les résultats | |
| col_result1, col_result2 = st.columns(2) | |
| with col_result1: | |
| st.metric("Véhicules Sens 1 (Vert)", count1) | |
| with col_result2: | |
| st.metric("Véhicules Sens 2 (Rouge)", count2) | |
| # Afficher la vidéo traitée | |
| st.subheader("Vidéo traitée") | |
| st.video(output_video_path) | |
| # Option de téléchargement | |
| with open(output_video_path, "rb") as file: | |
| st.download_button( | |
| label="⬇️ Télécharger la vidéo", | |
| data=file, | |
| file_name="video_traitee.mp4", | |
| mime="video/mp4" | |
| ) | |
| else: | |
| st.error("❌ Les coordonnées des polygones doivent contenir **exactement 4 points**.") | |
| if __name__ == "__main__": | |
| main() |