sunbal7 commited on
Commit
e191a57
·
verified ·
1 Parent(s): e8fa336

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -355
app.py CHANGED
@@ -5,35 +5,28 @@ import numpy as np
5
  from ultralytics import YOLO
6
  import plotly.graph_objects as go
7
  from collections import defaultdict
8
- import time
9
  import requests
10
  import pandas as pd
11
- from scipy.spatial import distance
12
 
13
  # --- Configuration & Initialization ---
14
 
15
- # Page configuration
16
  st.set_page_config(
17
  page_title="YOLOv8 Object Tracking & Counter",
18
  page_icon="🤖",
19
  layout="wide"
20
  )
21
 
22
- # Title and description
23
  st.title("🚦 Smart Object Traffic Analyzer (YOLOv8)")
24
  st.markdown("""
25
  A professional application for real-time **tracking and counting** of people and vehicles in video streams.
26
  It uses YOLOv8 for detection and a simple tracking algorithm to count unique objects crossing a user-defined line.
27
  """)
28
 
29
- # COCO Class Names (Corrected for standard YOLOv8)
30
  COCO_CLASS_NAMES = {
31
- 0: "person", 1: "bicycle", 2: "car", 3: "motorcycle", 4: "airplane",
32
- 5: "bus", 6: "train", 7: "truck", 8: "boat", 9: "traffic light",
33
- # ... other classes
34
  }
35
 
36
- # Mapping of checkboxed objects to their standard COCO Class IDs
37
  CLASS_MAPPING = {
38
  "Person": 0,
39
  "Bicycle": 1,
@@ -43,456 +36,190 @@ CLASS_MAPPING = {
43
  "Truck": 7,
44
  }
45
 
46
- # Initialize session state for tracking
47
  if 'processed_data' not in st.session_state:
48
  st.session_state.processed_data = {
49
  'total_counts': defaultdict(int),
50
  'frame_counts': [],
51
  'processed_video': None,
52
  'processing_complete': False,
53
- 'tracked_objects': {}, # Unique ID: {'class': str, 'last_centroid': (x, y), 'counted': bool}
54
  }
55
 
56
- # --- Sidebar for Settings ---
57
  with st.sidebar:
58
  st.header("⚙️ Configuration Settings")
59
-
60
- # Model Selection
61
- st.subheader("Model & Detection")
62
- model_name = st.selectbox("Select YOLO Model", options=['yolov8n.pt', 'yolov8s.pt'], help="Nano (n) is fast, Small (s) is more accurate.")
63
-
64
- # Confidence threshold
65
- confidence = st.slider(
66
- "Detection Confidence Threshold",
67
- min_value=0.1, max_value=1.0, value=0.40, step=0.05,
68
- help="Minimum confidence to consider a detection valid."
69
- )
70
-
71
- # Object classes to detect
72
  st.subheader("Objects for Counting")
73
  selected_classes_ui = {}
74
- for name, id in CLASS_MAPPING.items():
75
- # Default True for Person and Car, False otherwise
76
  default_val = name in ["Person", "Car"]
77
  selected_classes_ui[name] = st.checkbox(name, value=default_val)
78
-
79
- # Line intersection for counting
80
  st.subheader("Counting Line Settings")
81
  show_line = st.checkbox("Show crossing line", value=True)
82
- line_position = st.slider(
83
- "Line Position (Vertical % from left)",
84
- min_value=10, max_value=90, value=50,
85
- help="Line position for counting objects that cross it."
86
- )
87
-
88
- # Processing options
89
  st.subheader("Performance Options")
90
- process_every_nth = st.slider(
91
- "Frame Skip (Process every Nth frame)",
92
- min_value=1, max_value=10, value=2,
93
- help="Higher values significantly speed up processing but reduce tracking smoothness."
94
- )
95
-
96
- max_frames = st.number_input(
97
- "Maximum Frames to Analyze",
98
- min_value=10, max_value=5000, value=500,
99
- help="Limits the processing duration for long videos. Set to a very high number (e.g., 99999) for full video."
100
- )
101
 
102
- # --- Helper Functions ---
103
 
 
104
  @st.cache_resource
105
  def load_model(model_path):
106
- """Caches the YOLO model loading."""
107
  return YOLO(model_path)
108
 
109
  def get_selected_class_ids():
110
- """Returns a list of COCO class IDs selected by the user."""
111
  return [CLASS_MAPPING[name] for name, is_selected in selected_classes_ui.items() if is_selected]
112
 
113
- # --- Core Processing Function (with simple tracking and crossing logic) ---
114
 
 
115
  def process_video(video_path, selected_class_ids, model_path):
116
- """
117
- Processes the video, performs object detection/tracking, and counts line crossings.
118
- """
119
  model = load_model(model_path)
120
  cap = cv2.VideoCapture(video_path)
121
-
122
- # Video properties
123
  fps = int(cap.get(cv2.CAP_PROP_FPS))
124
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
125
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
126
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
127
-
128
  if total_frames > max_frames:
129
- st.warning(f"Video is being processed for the first {max_frames} frames only (configurable in sidebar).")
130
-
131
- # Setup video writer (Using a smaller size for web-friendliness if possible)
132
  temp_output = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
133
  output_path = temp_output.name
134
- # mp4v or XVID is generally compatible. mp4v preferred for browser.
135
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
136
  out = cv2.VideoWriter(output_path, fourcc, fps / process_every_nth, (width, height))
137
-
138
- # Initialize state variables for the loop
139
  current_state = st.session_state.processed_data
140
  current_state['total_counts'] = defaultdict(int)
141
  current_state['frame_counts'] = []
142
- current_state['tracked_objects'] = {} # ID: {'class': str, 'last_centroid': (x, y), 'counted': bool}
143
-
144
- # Define counting line
145
  line_x = int(width * line_position / 100)
146
-
147
- # UI Elements for progress
148
  progress_bar = st.progress(0)
149
  status_text = st.empty()
150
-
151
  frame_count = 0
152
  processed_frames = 0
153
-
154
  while cap.isOpened():
155
  ret, frame = cap.read()
156
-
157
- # Stop condition
158
  if not ret or processed_frames >= max_frames:
159
  break
160
-
161
  frame_count += 1
162
-
163
- # Skip frames for performance (still write the frame for a continuous video)
164
  if frame_count % process_every_nth != 0:
165
- # We don't write the skipped frame because we want the output video
166
- # to reflect the lower frame rate for smaller size and faster processing.
167
  continue
168
-
169
- # --- YOLO Detection ---
170
- # NOTE: Using tracker="bytetrack.yaml" for better tracking. Requires ultralytics>=8.0.198
171
- # However, for simplicity and dependency management, we will use simple centroid tracking.
172
  results = model.track(
173
- frame,
174
- conf=confidence,
175
- classes=selected_class_ids,
176
- persist=True,
177
- tracker="bytetrack.yaml", # Use YOLO's built-in tracking!
178
  verbose=False
179
  )
180
-
181
  annotated_frame = frame.copy()
182
-
183
- # Current frame counts
184
  current_frame_counts = defaultdict(int)
185
-
186
- # --- Tracking and Counting Logic ---
187
  if results and results[0].boxes.id is not None:
188
  boxes = results[0].boxes.xyxy.cpu().numpy().astype(int)
189
  track_ids = results[0].boxes.id.cpu().numpy().astype(int)
190
  class_ids = results[0].boxes.cls.cpu().numpy().astype(int)
191
-
192
  for box, track_id, class_id in zip(boxes, track_ids, class_ids):
193
  x1, y1, x2, y2 = box
194
-
195
- # Calculate centroid
196
  centroid_x = (x1 + x2) // 2
197
  centroid_y = (y1 + y2) // 2
198
  centroid = (centroid_x, centroid_y)
199
-
200
  class_name = COCO_CLASS_NAMES.get(class_id, "Unknown")
201
  current_frame_counts[class_name] += 1
202
-
203
- # Update/Initialize tracked object
204
  if track_id not in current_state['tracked_objects']:
205
- # New object detected
206
  current_state['tracked_objects'][track_id] = {
207
- 'class': class_name,
208
- 'last_centroid': centroid,
209
  'counted': False
210
  }
211
  else:
212
- # Existing object - Check for line crossing
213
  obj_data = current_state['tracked_objects'][track_id]
214
  prev_x = obj_data['last_centroid'][0]
215
-
216
  if not obj_data['counted']:
217
- # Crossing logic: object crossed the line from one side to the other
218
- if (prev_x < line_x and centroid_x >= line_x) or \
219
- (prev_x > line_x and centroid_x <= line_x):
220
-
221
- # Object crossed the line!
222
  current_state['total_counts'][class_name] += 1
223
- obj_data['counted'] = True # Count only once
224
-
225
- # Update the object's last known position
226
  obj_data['last_centroid'] = centroid
227
-
228
- # Draw bounding box, track ID, and centroid
229
  cv2.rectangle(annotated_frame, (x1, y1), (x2, y2), (255, 0, 0), 2)
230
  cv2.circle(annotated_frame, centroid, 5, (0, 0, 255), -1)
231
-
232
- label = f"ID:{track_id} {class_name}"
233
- cv2.putText(annotated_frame, label, (x1, y1 - 10),
234
  cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
235
-
236
- # --- Visualization & Logging ---
237
-
238
- # Draw counting line
239
  if show_line:
240
- line_color = (0, 255, 255) # Cyan
241
- cv2.line(annotated_frame, (line_x, 0), (line_x, height), line_color, 2)
242
-
243
- # Label for the line
244
- cv2.putText(annotated_frame, "COUNTING LINE", (line_x + 5, 20),
245
- cv2.FONT_HERSHEY_SIMPLEX, 0.6, line_color, 2)
246
-
247
- # Add total counter text
248
  y_offset = 30
249
  for obj_type, count in current_state['total_counts'].items():
250
- cv2.putText(annotated_frame,
251
- f"TOTAL {obj_type.upper()}: {count}",
252
- (width - 300, y_offset),
253
- cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2)
254
  y_offset += 35
255
 
256
- # Store frame counts for chart (count of *objects in frame*, not crossings)
257
  frame_data = {'frame': processed_frames * process_every_nth}
258
  for name in CLASS_MAPPING.keys():
259
- frame_data[name.lower()] = current_frame_counts.get(name.lower(), 0)
260
  current_state['frame_counts'].append(frame_data)
261
-
262
- # Write frame to output video
263
  out.write(annotated_frame)
264
  processed_frames += 1
265
-
266
- # Update progress
267
- progress = min(processed_frames / max_frames, 1.0)
268
- progress_bar.progress(progress)
269
  status_text.text(f"Analyzing Frame {frame_count}/{total_frames} (Processed {processed_frames})")
270
-
271
- # --- Cleanup ---
272
  cap.release()
273
  out.release()
274
-
275
- # Update global state
276
  current_state['processing_complete'] = True
277
  current_state['processed_video'] = output_path
278
  st.session_state.processed_data = current_state
279
-
280
  return output_path
281
 
282
- # Function to download video from URL
283
- @st.cache_data
284
- def download_video_from_url(url):
285
- """Downloads video from URL to a temporary file."""
286
- try:
287
- st.info("Attempting to download video. This might take a moment...")
288
- response = requests.get(url, stream=True, timeout=30)
289
- response.raise_for_status() # Raise exception for bad status codes
290
-
291
- # Determine file extension (optional, but good practice)
292
- content_type = response.headers.get('Content-Type', '')
293
- suffix = '.mp4' if 'mp4' in content_type else '.mov'
294
-
295
- temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)
296
- total_size = int(response.headers.get('Content-Length', 0))
297
-
298
- downloaded_size = 0
299
- progress_placeholder = st.empty()
300
-
301
- for chunk in response.iter_content(chunk_size=8192):
302
- temp_file.write(chunk)
303
- downloaded_size += len(chunk)
304
- if total_size > 0:
305
- progress = downloaded_size / total_size
306
- progress_placeholder.progress(progress, text=f"Downloading: {downloaded_size/(1024*1024):.2f}MB / {total_size/(1024*1024):.2f}MB")
307
-
308
- temp_file.close()
309
- progress_placeholder.empty()
310
- return temp_file.name
311
-
312
- except requests.exceptions.RequestException as e:
313
- st.error(f"Failed to download video: {str(e)}. Check URL and file access.")
314
- return None
315
- except Exception as e:
316
- st.error(f"An unexpected error occurred during download: {str(e)}")
317
- return None
318
-
319
- # --- Main App Layout ---
320
- tab1, tab2, tab3 = st.tabs(["📹 Video Input", "📊 Analysis & Results", "ℹ️ Documentation"])
321
 
322
  with tab1:
323
- col1, col2 = st.columns(2)
324
  video_path = None
325
-
326
- with col1:
327
- st.subheader("📁 Upload Video File")
328
- uploaded_file = st.file_uploader(
329
- "Choose a video file",
330
- type=['mp4', 'avi', 'mov', 'mkv'],
331
- help="Supported video formats. Maximum recommended file size: 50MB."
332
- )
333
-
334
- if uploaded_file is not None:
335
- # Save uploaded file to temp location
336
- tfile = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
337
- tfile.write(uploaded_file.read())
338
- video_path = tfile.name
339
- st.info(f"Video ready: {uploaded_file.name}")
340
- st.video(uploaded_file)
341
-
342
- with col2:
343
- st.subheader("🌐 Load from Video URL")
344
- video_url = st.text_input(
345
- "Enter public video URL (e.g., direct link to .mp4)",
346
- placeholder="https://example.com/traffic.mp4"
347
- )
348
-
349
- if st.button("🔗 Load from URL", use_container_width=True) and video_url:
350
- video_path = download_video_from_url(video_url)
351
- if video_path:
352
- st.success("Video downloaded and ready for processing.")
353
- # Try to display a frame if possible
354
- try:
355
- cap = cv2.VideoCapture(video_path)
356
- ret, frame = cap.read()
357
- if ret:
358
- st.image(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB), caption="Video Preview", use_column_width=True)
359
- cap.release()
360
- except Exception:
361
- st.warning("Could not display video preview.")
362
-
363
- st.markdown("---")
364
-
365
- # Process button logic
366
- if video_path:
367
- if st.button("🚀 START TRACKING AND COUNTING", type="primary", use_container_width=True):
368
- selected_class_ids = get_selected_class_ids()
369
-
370
- if not selected_class_ids:
371
- st.error("Please select at least one object type to count in the sidebar.")
372
- else:
373
- try:
374
- with st.spinner(f"Analyzing video with {model_name}..."):
375
- process_video(video_path, selected_class_ids, model_name)
376
- st.success("Analysis complete! See results in the 'Analysis & Results' tab.")
377
- # Automatically switch to results tab on completion? (Streamlit doesn't natively support this well)
378
- except Exception as e:
379
- st.error(f"An error occurred during video processing: {e}")
380
- # Optionally print traceback
381
- # import traceback; st.code(traceback.format_exc())
382
- else:
383
- st.info("Upload a video or provide a URL to begin.")
384
 
385
  with tab2:
386
  data = st.session_state.processed_data
387
  if data['processing_complete']:
388
- st.header("Results Summary")
389
-
390
- col1, col2 = st.columns([2, 1])
391
-
392
- with col1:
393
- st.subheader("🎥 Analyzed Video Output")
394
- # Display processed video
395
- with open(data['processed_video'], 'rb') as video_file:
396
- video_bytes = video_file.read()
397
- st.video(video_bytes)
398
-
399
- # Download button
400
- st.download_button(
401
- label="📥 Download Analyzed Video (MP4)",
402
- data=video_bytes,
403
- file_name="analyzed_tracking_video.mp4",
404
- mime="video/mp4",
405
- use_container_width=True
406
- )
407
-
408
- with col2:
409
- st.subheader("✅ Object Crossing Totals")
410
- # Display total counts
411
- if data['total_counts']:
412
- for obj_type, count in data['total_counts'].items():
413
- st.metric(label=f"Total {obj_type.capitalize()} Crossed", value=count)
414
- else:
415
- st.info("No objects crossed the counting line in the analyzed section.")
416
-
417
- st.subheader("📊 Object Presence Over Frames")
418
- if data['frame_counts']:
419
- df = pd.DataFrame(data['frame_counts']).fillna(0)
420
-
421
- # Time series chart (Plotly)
422
- fig = go.Figure()
423
-
424
- # Add a trace for each object type (columns except 'frame')
425
- for column in df.columns:
426
- if column != 'frame':
427
- fig.add_trace(go.Scatter(
428
- x=df['frame'],
429
- y=df[column],
430
- name=column.capitalize(),
431
- mode='lines+markers'
432
- ))
433
-
434
- fig.update_layout(
435
- title="Count of Objects Present in Frame",
436
- xaxis_title="Frame Number",
437
- yaxis_title="Count of Objects (Instance Count)",
438
- hovermode='x unified',
439
- height=400
440
- )
441
-
442
- st.plotly_chart(fig, use_container_width=True)
443
-
444
- st.subheader("Data Export")
445
- st.dataframe(df.tail(10), use_container_width=True, height=200)
446
-
447
- csv = df.to_csv(index=False).encode('utf-8')
448
- st.download_button(
449
- label="⬇️ Download Frame-by-Frame Data (CSV)",
450
- data=csv,
451
- file_name="object_count_data.csv",
452
- mime="text/csv",
453
- )
454
-
455
- else:
456
- st.warning("No tracking data available. Process a video first.")
457
-
458
- else:
459
- st.info("Process a video in the 'Video Input' tab to view analysis results.")
460
-
461
- with tab3:
462
- st.header("Documentation: Smart Object Traffic Analyzer")
463
- st.markdown("""
464
- This application utilizes cutting-edge computer vision techniques for object tracking and crossing counting.
465
-
466
- ### 🔑 Core Technology
467
-
468
- * **YOLOv8**: The primary model for high-accuracy, real-time object detection. We recommend the `yolov8n.pt` (Nano) for speed in browser-based demos.
469
- * **ByteTrack**: Used via the `ultralytics` package for robust object tracking, assigning a unique ID to each detected instance across frames.
470
- * **Streamlit**: Provides the interactive, professional front-end interface.
471
-
472
- ---
473
-
474
- ### ⚙️ How Crossing Counting Works
475
-
476
- Unlike simple detection counters which add to a total for every frame an object is visible, this app counts **unique object crossings** of a vertical line:
477
-
478
- 1. **Tracking**: YOLOv8's integrated tracker assigns a persistent **Track ID** to each object (`person`, `car`, etc.).
479
- 2. **Centroid Calculation**: The center-point (centroid) of the object's bounding box is calculated for every frame.
480
- 3. **Crossing Logic**: The system monitors the object's horizontal position relative to the **Counting Line**. An object is counted **once** when its centroid moves from one side of the line (e.g., left) to the other (e.g., right).
481
-
482
- This ensures an accurate count of unique events, not redundant detections.
483
-
484
- ### 🚀 Deployment on Hugging Face Spaces
485
-
486
- This script is optimized for deployment:
487
-
488
- * **Caching (`@st.cache_resource`)**: The YOLO model is loaded only once, saving significant time.
489
- * **Dependency List**: You will need a `requirements.txt` file in your Space with the following key libraries:
490
- ```text
491
- streamlit
492
- ultralytics
493
- opencv-python-headless
494
- numpy
495
- plotly
496
- pandas
497
- requests
498
- scipy # for distance calculation, though not strictly needed with bytetrack
 
5
  from ultralytics import YOLO
6
  import plotly.graph_objects as go
7
  from collections import defaultdict
 
8
  import requests
9
  import pandas as pd
 
10
 
11
  # --- Configuration & Initialization ---
12
 
 
13
  st.set_page_config(
14
  page_title="YOLOv8 Object Tracking & Counter",
15
  page_icon="🤖",
16
  layout="wide"
17
  )
18
 
 
19
  st.title("🚦 Smart Object Traffic Analyzer (YOLOv8)")
20
  st.markdown("""
21
  A professional application for real-time **tracking and counting** of people and vehicles in video streams.
22
  It uses YOLOv8 for detection and a simple tracking algorithm to count unique objects crossing a user-defined line.
23
  """)
24
 
25
+ # COCO Class Names (subset for demo)
26
  COCO_CLASS_NAMES = {
27
+ 0: "person", 1: "bicycle", 2: "car", 3: "motorcycle", 5: "bus", 7: "truck"
 
 
28
  }
29
 
 
30
  CLASS_MAPPING = {
31
  "Person": 0,
32
  "Bicycle": 1,
 
36
  "Truck": 7,
37
  }
38
 
 
39
  if 'processed_data' not in st.session_state:
40
  st.session_state.processed_data = {
41
  'total_counts': defaultdict(int),
42
  'frame_counts': [],
43
  'processed_video': None,
44
  'processing_complete': False,
45
+ 'tracked_objects': {},
46
  }
47
 
48
+ # --- Sidebar ---
49
  with st.sidebar:
50
  st.header("⚙️ Configuration Settings")
51
+
52
+ model_name = st.selectbox("Select YOLO Model", options=['yolov8n.pt', 'yolov8s.pt'])
53
+ confidence = st.slider("Detection Confidence Threshold", 0.1, 1.0, 0.40, 0.05)
54
+
 
 
 
 
 
 
 
 
 
55
  st.subheader("Objects for Counting")
56
  selected_classes_ui = {}
57
+ for name in CLASS_MAPPING.keys():
 
58
  default_val = name in ["Person", "Car"]
59
  selected_classes_ui[name] = st.checkbox(name, value=default_val)
60
+
 
61
  st.subheader("Counting Line Settings")
62
  show_line = st.checkbox("Show crossing line", value=True)
63
+ line_position = st.slider("Line Position (Vertical % from left)", 10, 90, 50)
64
+
 
 
 
 
 
65
  st.subheader("Performance Options")
66
+ process_every_nth = st.slider("Frame Skip (Process every Nth frame)", 1, 10, 2)
67
+ max_frames = st.number_input("Maximum Frames to Analyze", 10, 5000, 500)
 
 
 
 
 
 
 
 
 
68
 
 
69
 
70
+ # --- Helper Functions ---
71
  @st.cache_resource
72
  def load_model(model_path):
 
73
  return YOLO(model_path)
74
 
75
  def get_selected_class_ids():
 
76
  return [CLASS_MAPPING[name] for name, is_selected in selected_classes_ui.items() if is_selected]
77
 
 
78
 
79
+ # --- Core Processing ---
80
  def process_video(video_path, selected_class_ids, model_path):
 
 
 
81
  model = load_model(model_path)
82
  cap = cv2.VideoCapture(video_path)
83
+
 
84
  fps = int(cap.get(cv2.CAP_PROP_FPS))
85
  width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
86
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
87
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
88
+
89
  if total_frames > max_frames:
90
+ st.warning(f"Processing limited to first {max_frames} frames.")
91
+
 
92
  temp_output = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
93
  output_path = temp_output.name
94
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
 
95
  out = cv2.VideoWriter(output_path, fourcc, fps / process_every_nth, (width, height))
96
+
 
97
  current_state = st.session_state.processed_data
98
  current_state['total_counts'] = defaultdict(int)
99
  current_state['frame_counts'] = []
100
+ current_state['tracked_objects'] = {}
101
+
 
102
  line_x = int(width * line_position / 100)
103
+
 
104
  progress_bar = st.progress(0)
105
  status_text = st.empty()
106
+
107
  frame_count = 0
108
  processed_frames = 0
109
+
110
  while cap.isOpened():
111
  ret, frame = cap.read()
 
 
112
  if not ret or processed_frames >= max_frames:
113
  break
114
+
115
  frame_count += 1
 
 
116
  if frame_count % process_every_nth != 0:
 
 
117
  continue
118
+
 
 
 
119
  results = model.track(
120
+ frame,
121
+ conf=confidence,
122
+ classes=selected_class_ids,
123
+ persist=True,
124
+ tracker="bytetrack.yaml",
125
  verbose=False
126
  )
127
+
128
  annotated_frame = frame.copy()
 
 
129
  current_frame_counts = defaultdict(int)
130
+
 
131
  if results and results[0].boxes.id is not None:
132
  boxes = results[0].boxes.xyxy.cpu().numpy().astype(int)
133
  track_ids = results[0].boxes.id.cpu().numpy().astype(int)
134
  class_ids = results[0].boxes.cls.cpu().numpy().astype(int)
135
+
136
  for box, track_id, class_id in zip(boxes, track_ids, class_ids):
137
  x1, y1, x2, y2 = box
 
 
138
  centroid_x = (x1 + x2) // 2
139
  centroid_y = (y1 + y2) // 2
140
  centroid = (centroid_x, centroid_y)
141
+
142
  class_name = COCO_CLASS_NAMES.get(class_id, "Unknown")
143
  current_frame_counts[class_name] += 1
144
+
 
145
  if track_id not in current_state['tracked_objects']:
 
146
  current_state['tracked_objects'][track_id] = {
147
+ 'class': class_name,
148
+ 'last_centroid': centroid,
149
  'counted': False
150
  }
151
  else:
 
152
  obj_data = current_state['tracked_objects'][track_id]
153
  prev_x = obj_data['last_centroid'][0]
154
+
155
  if not obj_data['counted']:
156
+ if (prev_x < line_x and centroid_x >= line_x) or (prev_x > line_x and centroid_x <= line_x):
 
 
 
 
157
  current_state['total_counts'][class_name] += 1
158
+ obj_data['counted'] = True
159
+
 
160
  obj_data['last_centroid'] = centroid
161
+
 
162
  cv2.rectangle(annotated_frame, (x1, y1), (x2, y2), (255, 0, 0), 2)
163
  cv2.circle(annotated_frame, centroid, 5, (0, 0, 255), -1)
164
+ cv2.putText(annotated_frame, f"ID:{track_id} {class_name}", (x1, y1 - 10),
 
 
165
  cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
166
+
 
 
 
167
  if show_line:
168
+ cv2.line(annotated_frame, (line_x, 0), (line_x, height), (0, 255, 255), 2)
169
+ cv2.putText(annotated_frame, "COUNTING LINE", (line_x + 5, 20),
170
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2)
171
+
 
 
 
 
172
  y_offset = 30
173
  for obj_type, count in current_state['total_counts'].items():
174
+ cv2.putText(annotated_frame, f"TOTAL {obj_type.upper()}: {count}",
175
+ (width - 300, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2)
 
 
176
  y_offset += 35
177
 
 
178
  frame_data = {'frame': processed_frames * process_every_nth}
179
  for name in CLASS_MAPPING.keys():
180
+ frame_data[name.lower()] = current_frame_counts.get(name.lower(), 0)
181
  current_state['frame_counts'].append(frame_data)
182
+
 
183
  out.write(annotated_frame)
184
  processed_frames += 1
185
+
186
+ progress_bar.progress(min(processed_frames / max_frames, 1.0))
 
 
187
  status_text.text(f"Analyzing Frame {frame_count}/{total_frames} (Processed {processed_frames})")
188
+
 
189
  cap.release()
190
  out.release()
191
+
 
192
  current_state['processing_complete'] = True
193
  current_state['processed_video'] = output_path
194
  st.session_state.processed_data = current_state
195
+
196
  return output_path
197
 
198
+
199
+ # --- Main Layout ---
200
+ tab1, tab2 = st.tabs(["📹 Video Input", "📊 Analysis & Results"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
 
202
  with tab1:
 
203
  video_path = None
204
+ uploaded_file = st.file_uploader("Upload Video", type=['mp4', 'avi', 'mov', 'mkv'])
205
+ if uploaded_file:
206
+ tfile = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
207
+ tfile.write(uploaded_file.read())
208
+ video_path = tfile.name
209
+ st.video(uploaded_file)
210
+
211
+ if video_path and st.button("🚀 START TRACKING AND COUNTING"):
212
+ selected_class_ids = get_selected_class_ids()
213
+ if not selected_class_ids:
214
+ st.error("Please select at least one object type.")
215
+ else:
216
+ with st.spinner(f"Analyzing video with {model_name}..."):
217
+ process_video(video_path, selected_class_ids, model_name)
218
+ st.success("Analysis complete! See results tab.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
 
220
  with tab2:
221
  data = st.session_state.processed_data
222
  if data['processing_complete']:
223
+ st.subheader("🎥 Processed Video")
224
+ with open(data['processed_video'], 'rb') as video_file:
225
+ video_bytes = video_file.read()