Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import cv2 | |
| import time | |
| import numpy as np | |
| from ultralytics import YOLO | |
| from collections import defaultdict | |
| import tempfile | |
| import os | |
| from pathlib import Path | |
| import sys | |
| import torch | |
| # Fonction pour charger le modèle | |
| def load_model(model_path): | |
| try: | |
| return YOLO(model_path, task="detect") | |
| except Exception as e: | |
| st.error(f"Erreur lors du chargement du modèle: {str(e)}") | |
| return None | |
| st.set_page_config(page_title="Détection et Suivi de densité de Trafic sur l'Autoroute de l'Avenir", page_icon="🚗", layout="wide") | |
| # Couleurs pour les IDs de suivi | |
| def get_color_for_id(track_id): | |
| np.random.seed(track_id) | |
| return tuple(np.random.randint(0, 255, size=3).tolist()) | |
| # Fonction pour afficher du texte avec fond | |
| def draw_text_with_background(image, text, position, font=cv2.FONT_HERSHEY_SIMPLEX, | |
| font_scale=1, font_thickness=2, text_color=(255, 255, 255), bg_color=(0, 0, 0), padding=5): | |
| text_size = cv2.getTextSize(text, font, font_scale, font_thickness)[0] | |
| text_width, text_height = text_size | |
| x, y = position | |
| top_left = (x, y - text_height - padding) | |
| bottom_right = (x + text_width + padding * 2, y + padding) | |
| cv2.rectangle(image, top_left, bottom_right, bg_color, -1) | |
| cv2.putText(image, text, (x + padding, y), font, font_scale, text_color, font_thickness, cv2.LINE_AA) | |
| class VideoProcessor: | |
| def __init__(self, model_path, video_path, output_path, poly1, poly2, | |
| device=None, conf_threshold=0.25, traffic_threshold=2): | |
| # Détection automatique du périphérique optimal | |
| self.device = device if device else ('cuda' if torch.cuda.is_available() else 'cpu') | |
| st.sidebar.info(f"Utilisation du périphérique: {self.device}") | |
| # Charger le modèle | |
| self.model = load_model(model_path) | |
| if self.model is None: | |
| st.error("Échec du chargement du modèle.") | |
| st.stop() | |
| self.conf_threshold = conf_threshold | |
| self.track_history = defaultdict(list) | |
| self.video_path = video_path | |
| self.output_path = output_path | |
| self.max_track_history = 15 # Limite de l'historique pour réduire la mémoire | |
| self.traffic_threshold = traffic_threshold | |
| # Utiliser les polygones fournis | |
| self.poly1_np = np.array(poly1, dtype=np.int32) | |
| self.poly2_np = np.array(poly2, dtype=np.int32) | |
| def is_in_region(self, center, poly_np): | |
| return cv2.pointPolygonTest(poly_np, center, False) >= 0 | |
| def process_video(self, stframe, progress_bar): | |
| cap = cv2.VideoCapture(self.video_path) | |
| if not cap.isOpened(): | |
| st.error(f"Impossible d'ouvrir la vidéo à {self.video_path}") | |
| return None | |
| # Métadonnées de la vidéo | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| input_fps = cap.get(cv2.CAP_PROP_FPS) | |
| # Configuration de l'écriture vidéo | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| out = cv2.VideoWriter(self.output_path, fourcc, input_fps, (frame_width, frame_height)) | |
| # Intervalle d'affichage pour Streamlit | |
| display_interval = max(1, int(input_fps / 10)) | |
| # Variables pour le traitement | |
| prev_frame_time = time.time() | |
| frame_count = 0 | |
| process_count = 0 | |
| # Configuration du tracker | |
| tracker = "botsort.yaml" | |
| # Préparer les variables de texte | |
| static_texts = { | |
| 'fps': 'FPS: {}', | |
| 'total': 'Total Vehicules: {}', | |
| 'region1': 'Trafic Sens 1: {}', | |
| 'region1_status': 'Situation Trafic: {}', | |
| 'region2': 'Trafic Sens 2: {}', | |
| 'region2_status': 'Situation Trafic: {}' | |
| } | |
| # Préparation des polygones | |
| region1_poly = self.poly1_np.reshape((-1, 1, 2)) | |
| region2_poly = self.poly2_np.reshape((-1, 1, 2)) | |
| # Boucle principale | |
| while cap.isOpened(): | |
| success, frame = cap.read() | |
| if not success: | |
| break | |
| frame_count += 1 | |
| progress_bar.progress(min(frame_count / total_frames, 1.0)) | |
| process_count += 1 | |
| # Effectuer la détection | |
| try: | |
| results = self.model.track( | |
| source=frame, | |
| persist=True, | |
| tracker=tracker, | |
| classes=[0], # Classe pour les véhicules | |
| conf=self.conf_threshold, | |
| verbose=False | |
| ) | |
| except Exception as e: | |
| st.error(f"Erreur lors du traitement de l'image: {str(e)}") | |
| continue | |
| # Initialiser les tableaux | |
| boxes = [] | |
| track_ids = [] | |
| confidences = [] | |
| if results and len(results) > 0 and hasattr(results[0], 'boxes') and len(results[0].boxes) > 0: | |
| try: | |
| boxes = results[0].boxes.xywh.cpu().numpy() | |
| confidences = results[0].boxes.conf.cpu().numpy() | |
| if hasattr(results[0].boxes, 'id'): | |
| track_ids = results[0].boxes.id.int().cpu().tolist() | |
| else: | |
| track_ids = list(range(len(boxes))) | |
| except Exception as e: | |
| st.error(f"Erreur lors de l'extraction des résultats: {str(e)}") | |
| # Créer une copie pour l'annotation | |
| annotated_frame = frame.copy() | |
| # Dessiner les régions | |
| cv2.polylines(annotated_frame, [region1_poly], isClosed=True, color=(255, 0, 0), thickness=2) | |
| cv2.polylines(annotated_frame, [region2_poly], isClosed=True, color=(0, 255, 0), thickness=2) | |
| # Initialiser les ensembles pour les régions | |
| current_region1_ids = set() | |
| current_region2_ids = set() | |
| # Traiter les détections | |
| if boxes.size > 0 and len(track_ids) > 0: | |
| for box, track_id, confidence in zip(boxes, track_ids, confidences): | |
| # Extraire les coordonnées | |
| x, y, w, h = box | |
| # Obtenir la couleur pour cet ID | |
| color = get_color_for_id(track_id) | |
| # Dessiner le rectangle | |
| cv2.rectangle( | |
| annotated_frame, | |
| (int(x - w / 2), int(y - h / 2)), | |
| (int(x + w / 2), int(y + h / 2)), | |
| color, | |
| 2 | |
| ) | |
| # Ajouter du texte | |
| draw_text_with_background( | |
| annotated_frame, | |
| f'ID: {track_id} ({confidence:.2f})', | |
| (int(x - w / 2), int(y - h / 2) - 10), | |
| bg_color=color | |
| ) | |
| # Point central et historique de suivi | |
| center_point = (int(x), int(y)) | |
| cv2.circle(annotated_frame, center_point, radius=4, color=color, thickness=-1) | |
| # Mettre à jour l'historique | |
| self.track_history[track_id].append((float(x), float(y))) | |
| if len(self.track_history[track_id]) > self.max_track_history: | |
| self.track_history[track_id].pop(0) | |
| # Dessiner le chemin | |
| if len(self.track_history[track_id]) > 1: | |
| points = np.array(self.track_history[track_id]).astype(np.int32).reshape((-1, 1, 2)) | |
| cv2.polylines(annotated_frame, [points], isClosed=False, color=color, thickness=2) | |
| # Vérifier si dans une région | |
| if self.is_in_region(center_point, self.poly1_np): | |
| current_region1_ids.add(track_id) | |
| if self.is_in_region(center_point, self.poly2_np): | |
| current_region2_ids.add(track_id) | |
| # Calculer les statistiques | |
| all_cars = len(track_ids) | |
| region1_cars = len(current_region1_ids) | |
| region2_cars = len(current_region2_ids) | |
| # Statut du trafic | |
| region1_status = "Dense" if region1_cars > self.traffic_threshold else "Fluide" | |
| region2_status = "Dense" if region2_cars > self.traffic_threshold else "Fluide" | |
| # Calculer le FPS | |
| new_frame_time = time.time() | |
| fps = 1 / (new_frame_time - prev_frame_time) if prev_frame_time > 0 else 30 | |
| prev_frame_time = new_frame_time | |
| # Afficher les informations | |
| draw_text_with_background(annotated_frame, static_texts['fps'].format(int(fps)), (7, 30), bg_color=(0, 0, 0), text_color=(0, 255, 0)) | |
| draw_text_with_background(annotated_frame, static_texts['total'].format(all_cars), (7, 65), bg_color=(0, 0, 0), text_color=(0, 255, 0)) | |
| draw_text_with_background(annotated_frame, static_texts['region1'].format(region1_cars), (7, 105), bg_color=(255, 0, 0), text_color=(255, 255, 255)) | |
| region1_text_color = (0, 0, 255) if region1_status == "Dense" else (255, 255, 255) | |
| draw_text_with_background(annotated_frame, static_texts['region1_status'].format(region1_status), (7, 139), bg_color=(255, 0, 0), text_color=region1_text_color) | |
| draw_text_with_background(annotated_frame, static_texts['region2'].format(region2_cars), (880, 105), bg_color=(0, 255, 0), text_color=(255, 255, 255)) | |
| region2_text_color = (0, 0, 255) if region2_status == "Dense" else (255, 255, 255) | |
| draw_text_with_background(annotated_frame, static_texts['region2_status'].format(region2_status), (880, 139), bg_color=(0, 255, 0), text_color=region2_text_color) | |
| # Écrire la frame | |
| out.write(annotated_frame) | |
| # Mettre à jour l'affichage Streamlit | |
| if process_count % display_interval == 0: | |
| stframe.image(annotated_frame, channels="BGR", use_container_width=True) | |
| # Libérer les ressources | |
| cap.release() | |
| out.release() | |
| cv2.destroyAllWindows() | |
| # Nettoyer la mémoire | |
| self.track_history.clear() | |
| return self.output_path | |
| def main(): | |
| st.title("🚗 Détection et Suivi de densité de Trafic sur l'Autoroute de l'Avenir") | |
| st.write("Cette application détecte les véhicules, et analyse la densité du trafic dans les deux sens d'un tronçon de l'Autoroute de l'Avenir.") | |
| # Sidebar pour les options | |
| st.sidebar.header("Configuration") | |
| # Détection du matériel | |
| if torch.cuda.is_available(): | |
| device_options = ["cuda", "cpu"] | |
| device_default = "cuda" | |
| st.sidebar.success("🚀 GPU détecté! Traitement accéléré disponible.") | |
| else: | |
| device_options = ["cpu"] | |
| device_default = "cpu" | |
| st.sidebar.warning("⚠️ Pas de GPU détecté. Performances limitées au CPU.") | |
| device = st.sidebar.selectbox("Périphérique de calcul", device_options, index=device_options.index(device_default)) | |
| # Paramètres de détection | |
| confidence = st.sidebar.slider("Seuil de confiance", 0.1, 1.0, 0.25, 0.05, | |
| help="Un seuil plus élevé détecte moins d'objets mais avec plus de précision") | |
| # Configuration des régions d'intérêt (ROI) | |
| st.sidebar.subheader("Régions d'intérêt") | |
| # Valeurs par défaut pour les polygones | |
| # default_poly1 = [(465, 350), (609, 350), (520, 630), (3, 630)] | |
| # default_poly2 = [(678, 350), (815, 350), (1203, 630), (743, 630)] | |
| default_poly1 = [(900, 350), (1150, 350), (650, 630), (200, 630)] | |
| default_poly2 = [(1200, 350), (1400, 350), (1150, 630), (743, 630)] | |
| # Configuration des polygones | |
| st.sidebar.markdown("#### Région 1 (Sens 1)") | |
| poly1 = [] | |
| for i in range(4): | |
| col1, col2 = st.sidebar.columns(2) | |
| with col1: | |
| x = st.number_input(f"Point {i+1} X", value=default_poly1[i][0], min_value=0, step=1, key=f"p1_x_{i}") | |
| with col2: | |
| y = st.number_input(f"Point {i+1} Y", value=default_poly1[i][1], min_value=0, step=1, key=f"p1_y_{i}") | |
| poly1.append((x, y)) | |
| st.sidebar.markdown("#### Région 2 (Sens 2)") | |
| poly2 = [] | |
| for i in range(4): | |
| col1, col2 = st.sidebar.columns(2) | |
| with col1: | |
| x = st.number_input(f"Point {i+1} X", value=default_poly2[i][0], min_value=0, step=1, key=f"p2_x_{i}") | |
| with col2: | |
| y = st.number_input(f"Point {i+1} Y", value=default_poly2[i][1], min_value=0, step=1, key=f"p2_y_{i}") | |
| poly2.append((x, y)) | |
| # Seuil pour trafic dense | |
| traffic_threshold = st.sidebar.slider("Seuil de trafic dense", 1, 10, 2, | |
| help="Nombre de véhicules à partir duquel le trafic est considéré comme dense") | |
| # Upload du modèle | |
| st.sidebar.subheader("Modèle de détection") | |
| model_file = st.sidebar.file_uploader("Uploader votre modèle YOLO (best.pt)", type=["pt"]) | |
| # Upload de la vidéo | |
| st.subheader("Vidéo à analyser") | |
| video_file = st.file_uploader("Uploader une vidéo", type=["mp4", "avi", "mov", "mkv"]) | |
| # Afficher un aperçu des ROI sur la vidéo | |
| if video_file: | |
| try: | |
| # Lire la première frame pour l'aperçu | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=Path(video_file.name).suffix) as tmp_video: | |
| tmp_video.write(video_file.getvalue()) | |
| video_path = tmp_video.name | |
| cap = cv2.VideoCapture(video_path) | |
| ret, frame = cap.read() | |
| if ret: | |
| # Dessiner les polygones | |
| preview = frame.copy() | |
| cv2.polylines(preview, [np.array(poly1)], isClosed=True, color=(255, 0, 0), thickness=2) | |
| cv2.polylines(preview, [np.array(poly2)], isClosed=True, color=(0, 255, 0), thickness=2) | |
| # Ajouter des labels | |
| cv2.putText(preview, "Zone 1", (poly1[0][0], poly1[0][1] - 10), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 0, 0), 2) | |
| cv2.putText(preview, "Zone 2", (poly2[0][0], poly2[0][1] - 10), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2) | |
| # Afficher l'aperçu | |
| st.subheader("Aperçu des zones de détection") | |
| st.image(preview, channels="BGR", caption="Aperçu des régions d'intérêt configurées") | |
| cap.release() | |
| os.unlink(video_path) | |
| else: | |
| st.warning("Impossible de générer un aperçu des zones de détection.") | |
| except Exception as e: | |
| st.warning(f"Erreur lors de la génération de l'aperçu: {str(e)}") | |
| if video_file and model_file: | |
| # Sauvegarde temporaire des fichiers | |
| with tempfile.NamedTemporaryFile(delete=False, suffix='.pt') as tmp_model: | |
| tmp_model.write(model_file.getvalue()) | |
| model_path = tmp_model.name | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=Path(video_file.name).suffix) as tmp_video: | |
| tmp_video.write(video_file.getvalue()) | |
| video_path = tmp_video.name | |
| output_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name | |
| st.success("Fichiers chargés avec succès. Prêt à traiter la vidéo.") | |
| # Création des colonnes pour l'affichage | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.write("Vidéo originale:") | |
| st_video_placeholder = st.empty() | |
| st_video_placeholder.video(video_file) | |
| with col2: | |
| st.write("Vidéo traitée:") | |
| stframe = st.empty() | |
| progress_bar = st.progress(0) | |
| # Bouton pour lancer le traitement | |
| if st.button("Traiter la vidéo"): | |
| start_time = time.time() | |
| try: | |
| with st.spinner("Traitement en cours..."): | |
| # Traitement de la vidéo | |
| video_processor = VideoProcessor( | |
| model_path=model_path, | |
| video_path=video_path, | |
| output_path=output_path, | |
| device=device, | |
| conf_threshold=confidence, | |
| poly1=poly1, | |
| poly2=poly2, | |
| traffic_threshold=traffic_threshold | |
| ) | |
| output_video = video_processor.process_video(stframe, progress_bar) | |
| if output_video: | |
| # Temps de traitement | |
| total_time = time.time() - start_time | |
| st.success(f"✅ Traitement terminé en {total_time:.2f} secondes!") | |
| # Téléchargement de la vidéo | |
| try: | |
| with open(output_video, "rb") as file: | |
| video_bytes = file.read() | |
| st.download_button( | |
| label="Télécharger la vidéo traitée", | |
| data=video_bytes, | |
| file_name="video_traitee.mp4", | |
| mime="video/mp4" | |
| ) | |
| # Afficher la vidéo résultante | |
| st.video(video_bytes) | |
| except Exception as e: | |
| st.error(f"Erreur lors de la préparation du téléchargement: {str(e)}") | |
| else: | |
| st.error("Échec du traitement de la vidéo.") | |
| except Exception as e: | |
| st.error(f"Une erreur s'est produite: {str(e)}") | |
| st.error(f"Détails: {sys.exc_info()}") | |
| # Nettoyer les fichiers temporaires | |
| try: | |
| if model_path and os.path.exists(model_path): | |
| os.unlink(model_path) | |
| if video_path and os.path.exists(video_path): | |
| os.unlink(video_path) | |
| if output_path and os.path.exists(output_path): | |
| os.unlink(output_path) | |
| except Exception as e: | |
| st.warning(f"Problème lors du nettoyage des fichiers temporaires: {str(e)}") | |
| else: | |
| st.info("Veuillez uploader un modèle YOLO (best.pt) et une vidéo pour commencer.") | |
| st.sidebar.markdown("---") | |
| st.sidebar.markdown("### Comment utiliser cette application") | |
| st.sidebar.markdown(""" | |
| 1. Choisissez un périphérique de calcul (GPU recommandé) | |
| 2. Configurez les régions d'intérêt si nécessaire | |
| 3. Uploadez votre modèle YOLO (best.pt) | |
| 4. Uploadez une vidéo à analyser | |
| 5. Cliquez sur "Traiter la vidéo" | |
| 6. Visualisez les résultats et téléchargez la vidéo traitée | |
| """) | |
| if __name__ == "__main__": | |
| main() |