ModuMLTECH commited on
Commit
c38e692
·
verified ·
1 Parent(s): e5d4423

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +211 -95
app.py CHANGED
@@ -1,102 +1,218 @@
1
- import matplotlib
2
- matplotlib.use('Agg')
3
-
4
  import streamlit as st
5
  import cv2
6
- import numpy as np
7
- from yolov5 import YOLOv5
8
- from sort.sort import Sort
9
  import tempfile
10
- import shutil
11
- from moviepy.editor import VideoFileClip, concatenate_videoclips, ImageSequenceClip
12
  import os
13
-
14
- # Load the pre-trained model and initialize the SORT tracker
15
- model_path = 'yolov5s.pt' # Ensure this path points to the model file
16
- model = YOLOv5(model_path, device='cpu')
17
- tracker = Sort()
18
-
19
- def process_video(uploaded_file):
20
- # Save the uploaded file to a temporary location
21
- temp_file_path = "temp_video.mp4"
22
- with open(temp_file_path, "wb") as temp_file:
23
- temp_file.write(uploaded_file.getbuffer())
24
-
25
- # Use moviepy to read the video file
26
- video_clip = VideoFileClip(temp_file_path)
27
- total_frames = int(video_clip.fps * video_clip.duration)
28
- width, height = video_clip.size
29
-
30
- # Temporary directory to save processed video frames
31
- temp_dir = tempfile.mkdtemp()
32
-
33
- unique_cars = set()
34
- progress_bar = st.progress(0)
35
-
36
- for frame_idx, frame in enumerate(video_clip.iter_frames()):
37
- frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
38
-
39
- progress_percentage = int((frame_idx + 1) / total_frames * 100)
40
- progress_bar.progress(progress_percentage)
41
-
42
- # Detection and tracking logic
43
- results = model.predict(frame)
44
- preds = results.pandas().xyxy[0]
45
- detections = []
46
-
47
- for index, row in preds.iterrows():
48
- if row['name'] == 'car':
49
- xmin, ymin, xmax, ymax, conf = row['xmin'], row['ymin'], row['xmax'], row['ymax'], row['confidence']
50
- detections.append([xmin, ymin, xmax, ymax, conf])
51
-
52
- if detections:
53
- detections_np = np.array(detections)
54
- trackers = tracker.update(detections_np)
55
-
56
- for d in trackers:
57
- unique_cars.add(int(d[4]))
58
- xmin, ymin, xmax, ymax = map(int, d[:4])
59
- cv2.rectangle(frame, (xmin, ymin), (xmax, ymax), (255, 0, 0), 2)
60
- cv2.putText(frame, f'ID: {int(d[4])}', (xmin, ymin - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (255, 0, 0), 2)
61
-
62
- cv2.putText(frame, f'Unique Cars: {len(unique_cars)}', (10, 35), cv2.FONT_HERSHEY_SIMPLEX, 1.25, (0, 255, 0), 2)
63
- frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # Convert back to RGB for moviepy
64
- cv2.imwrite(f"{temp_dir}/{frame_idx:04d}.jpg", frame)
65
-
66
- frames_files = [os.path.join(temp_dir, f"{i:04d}.jpg") for i in range(total_frames)]
67
- clip = ImageSequenceClip(frames_files, fps=video_clip.fps)
68
- output_video_path = 'processed_video.mp4'
69
- clip.write_videofile(output_video_path, codec='libx264') # Use libx264 codec for compatibility
70
-
71
- # Remove temporary directory and temporary files
72
- shutil.rmtree(temp_dir)
73
-
74
- return output_video_path
75
-
76
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  def main():
78
- # Initialize session state variables if they don't exist
79
- if 'output_video_path' not in st.session_state:
80
- st.session_state.output_video_path = None
81
-
82
- st.sidebar.image("logo.jpg", use_column_width=True)
83
- uploaded_file = st.sidebar.file_uploader("Upload a video", type=['mp4'])
84
-
85
- st.title("Car Detection and Tracking")
86
-
87
- if uploaded_file is not None:
88
- # Process the video only if it hasn't been processed yet or a new file is uploaded
89
- if st.session_state.output_video_path is None or st.session_state.uploaded_file_name != uploaded_file.name:
90
- st.session_state.uploaded_file_name = uploaded_file.name
91
- st.session_state.output_video_path = process_video(uploaded_file)
92
-
93
- # Display the processed video
94
- st.video(st.session_state.output_video_path)
95
-
96
- # Provide a download link for the processed video
97
- with open(st.session_state.output_video_path, "rb") as file:
98
- st.download_button("Download Processed Video", file, file_name="processed_video.mp4")
99
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
  if __name__ == "__main__":
102
- main()
 
 
 
 
1
  import streamlit as st
2
  import cv2
 
 
 
3
  import tempfile
 
 
4
  import os
5
+ import time
6
+ import numpy as np
7
+ import pandas as pd
8
+ from collections import defaultdict
9
+ from ultralytics import YOLO
10
+
11
+ # --- FONCTIONS UTILES ---
12
+ def draw_text_with_background(image, text, position, font=cv2.FONT_HERSHEY_SIMPLEX,
13
+ font_scale=1, font_thickness=2, text_color=(255, 255, 255), bg_color=(0, 0, 0), padding=5):
14
+ """Ajoute du texte avec un fond sur une image OpenCV."""
15
+ text_size = cv2.getTextSize(text, font, font_scale, font_thickness)[0]
16
+ text_width, text_height = text_size
17
+
18
+ x, y = position
19
+ top_left = (x, y - text_height - padding)
20
+ bottom_right = (x + text_width + padding * 2, y + padding)
21
+
22
+ cv2.rectangle(image, top_left, bottom_right, bg_color, -1)
23
+ cv2.putText(image, text, (x + padding, y), font, font_scale, text_color, font_thickness, cv2.LINE_AA)
24
+
25
+ # --- CLASSE YOLO ---
26
+ class YOLOVideoProcessor:
27
+ def __init__(self, model_path, video_path, output_path, poly1, poly2, tracker_method="bot"):
28
+ self.model = YOLO(model_path, task="detect")
29
+ self.tracker_method = tracker_method
30
+ self.video_path = video_path
31
+ self.output_path = output_path
32
+
33
+ self.unique_region1_ids = set()
34
+ self.unique_region2_ids = set()
35
+ self.poly1 = poly1
36
+ self.poly2 = poly2
37
+
38
+ def is_in_region(self, center, poly):
39
+ poly_np = np.array(poly, dtype=np.int32)
40
+ return cv2.pointPolygonTest(poly_np, center, False) >= 0
41
+
42
+ def process_video(self, progress_bar=None):
43
+ cap = cv2.VideoCapture(self.video_path)
44
+ if not cap.isOpened():
45
+ st.error("⚠️ Erreur : Impossible d'ouvrir la vidéo.")
46
+ return
47
+
48
+ frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
49
+ frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
50
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
51
+
52
+ if fps == 0:
53
+ fps = 30 # Valeur par défaut si FPS est invalide
54
+
55
+ # Utiliser XVID qui est généralement mieux supporté
56
+ fourcc = cv2.VideoWriter_fourcc(*'XVID')
57
+ out = cv2.VideoWriter(self.output_path, fourcc, fps, (frame_width, frame_height))
58
+
59
+ processed_frames = 0
60
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
61
+
62
+ while cap.isOpened():
63
+ success, frame = cap.read()
64
+ if not success:
65
+ break
66
+
67
+ # Mise à jour de la barre de progression
68
+ if progress_bar is not None:
69
+ progress_bar.progress(processed_frames / total_frames)
70
+
71
+ tracker = "botsort.yaml" if self.tracker_method.lower() == "bot" else "bytetrack.yaml"
72
+ results = self.model.track(frame, persist=True, tracker=tracker, conf=0.25)
73
+
74
+ track_ids = []
75
+ if results and len(results) > 0 and len(results[0].boxes) > 0:
76
+ try:
77
+ track_ids = results[0].boxes.id.int().cpu().tolist()
78
+ except AttributeError:
79
+ track_ids = [i for i in range(len(results[0].boxes.xywh.cpu().numpy()))]
80
+
81
+ # Dessiner les polygones
82
+ cv2.polylines(frame, [np.array(self.poly1, np.int32)], isClosed=True, color=(0, 255, 0), thickness=2)
83
+ cv2.polylines(frame, [np.array(self.poly2, np.int32)], isClosed=True, color=(255, 0, 0), thickness=2)
84
+
85
+ for box, track_id in zip(results[0].boxes.xywh.cpu().numpy(), track_ids):
86
+ x, y, w, h = box
87
+ center_point = (int(x), int(y))
88
+
89
+ if self.is_in_region(center_point, self.poly1):
90
+ self.unique_region1_ids.add(track_id)
91
+ if self.is_in_region(center_point, self.poly2):
92
+ self.unique_region2_ids.add(track_id)
93
+
94
+ # Affichage du comptage des véhicules
95
+ draw_text_with_background(frame, f'Total Sens 1: {len(self.unique_region1_ids)}', (10, frame_height - 50))
96
+ draw_text_with_background(frame, f'Total Sens 2: {len(self.unique_region2_ids)}', (frame_width - 300, frame_height - 50))
97
+
98
+ out.write(frame)
99
+ processed_frames += 1
100
+
101
+ cap.release()
102
+ out.release()
103
+ cv2.destroyAllWindows()
104
+
105
+ if processed_frames == 0:
106
+ st.error("⚠️ Aucune image n'a été écrite dans la vidéo de sortie !")
107
+
108
+ return len(self.unique_region1_ids), len(self.unique_region2_ids)
109
+
110
+
111
+ # --- INTERFACE STREAMLIT ---
112
  def main():
113
+ st.set_page_config(
114
+ page_title="Détecteur de Véhicules",
115
+ page_icon="🚗",
116
+ layout="wide"
117
+ )
118
+
119
+ st.title("🚗 Détection et comptage de Véhicules sur l'Autoroute de l'Avenir")
120
+
121
+ # Vérifier si le modèle existe déjà ou doit être téléchargé
122
+ model_path = "best.pt"
123
+ if not os.path.exists(model_path):
124
+ with st.spinner("📥 Chargement du modèle YOLO... Cela peut prendre un moment."):
125
+ # Utilisez hub.load pour télécharger le modèle depuis Hugging Face Hub
126
+ try:
127
+ from huggingface_hub import hf_hub_download
128
+ model_path = hf_hub_download(repo_id="ModuMLTECH/projet_trafic_2", filename="best.pt")
129
+ st.success("✅ Modèle chargé avec succès!")
130
+ except Exception as e:
131
+ st.error(f"❌ Erreur lors du chargement du modèle: {e}")
132
+ # Fallback: utiliser un modèle YOLO standard
133
+ st.warning("⚠️ Utilisation du modèle YOLO standard à la place")
134
+ model_path = "yolov8n.pt"
135
+
136
+ # Colonnes pour l'organisation de l'interface
137
+ col1, col2 = st.columns([3, 1])
138
+
139
+ with col2:
140
+ st.header("🔹 Paramètres")
141
+
142
+ # Entrée utilisateur pour les polygones
143
+ st.subheader("📍 Polygone 1 (vert)")
144
+ poly1_input = st.text_area("Entrez 4 points (x,y) séparés par des espaces", "465,350 609,350 520,630 3,630")
145
+
146
+ st.subheader("📍 Polygone 2 (rouge)")
147
+ poly2_input = st.text_area("Entrez 4 points (x,y) séparés par des espaces", "678,350 815,350 1203,630 743,630")
148
+
149
+ tracker_method = st.selectbox("Méthode de tracking", ["bot", "byte"], index=0)
150
+
151
+ with col1:
152
+ uploaded_file = st.file_uploader("📂 Upload une vidéo", type=["mp4", "avi", "mov"])
153
+
154
+ def parse_polygon(input_text):
155
+ try:
156
+ return [tuple(map(int, point.split(','))) for point in input_text.split()]
157
+ except:
158
+ return []
159
+
160
+ poly1 = parse_polygon(poly1_input)
161
+ poly2 = parse_polygon(poly2_input)
162
+
163
+ if uploaded_file is not None:
164
+ # Créer un dossier temporaire si nécessaire
165
+ temp_dir = tempfile.mkdtemp()
166
+ input_video_path = os.path.join(temp_dir, "input_video.mp4")
167
+ output_video_path = os.path.join(temp_dir, "output_video.mp4")
168
+
169
+ # Écrire le fichier téléchargé dans un fichier temporaire
170
+ with open(input_video_path, "wb") as f:
171
+ f.write(uploaded_file.getbuffer())
172
+
173
+ st.video(input_video_path) # Afficher la vidéo d'entrée
174
+
175
+ if st.button("▶️ Lancer la détection"):
176
+ if len(poly1) == 4 and len(poly2) == 4:
177
+ # Afficher la barre de progression
178
+ progress_text = "🔄 Traitement de la vidéo en cours..."
179
+ progress_bar = st.progress(0)
180
+
181
+ # Traitement de la vidéo
182
+ processor = YOLOVideoProcessor(model_path, input_video_path, output_video_path, poly1, poly2, tracker_method)
183
+
184
+ # Démarrer le traitement
185
+ start_time = time.time()
186
+ count1, count2 = processor.process_video(progress_bar=progress_bar)
187
+ end_time = time.time()
188
+
189
+ # Calcul du temps de traitement
190
+ processing_time = end_time - start_time
191
+
192
+ progress_bar.progress(1.0) # Compléter la barre de progression
193
+ st.success(f"✅ Traitement terminé en {processing_time:.2f} secondes!")
194
+
195
+ # Afficher les résultats
196
+ col_result1, col_result2 = st.columns(2)
197
+ with col_result1:
198
+ st.metric("Véhicules Sens 1 (Vert)", count1)
199
+ with col_result2:
200
+ st.metric("Véhicules Sens 2 (Rouge)", count2)
201
+
202
+ # Afficher la vidéo traitée
203
+ st.subheader("Vidéo traitée")
204
+ st.video(output_video_path)
205
+
206
+ # Option de téléchargement
207
+ with open(output_video_path, "rb") as file:
208
+ st.download_button(
209
+ label="⬇️ Télécharger la vidéo",
210
+ data=file,
211
+ file_name="video_traitee.mp4",
212
+ mime="video/mp4"
213
+ )
214
+ else:
215
+ st.error("❌ Les coordonnées des polygones doivent contenir **exactement 4 points**.")
216
 
217
  if __name__ == "__main__":
218
+ main()