usiddiquee786 commited on
Commit
cfd4db2
·
verified ·
1 Parent(s): 550aefc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -216
app.py CHANGED
@@ -6,8 +6,6 @@ import shutil
6
  from pathlib import Path
7
  import sys
8
  import importlib.util
9
- import cv2
10
- import numpy as np
11
 
12
  # Ensure models directory exists
13
  MODELS_DIR = Path("models")
@@ -21,14 +19,12 @@ def ensure_dependencies():
21
  """Ensure all required dependencies are installed."""
22
  required_packages = [
23
  "ultralytics",
24
- "boxmot",
25
- "opencv-python",
26
- "numpy"
27
  ]
28
 
29
  for package in required_packages:
30
  try:
31
- importlib.import_module(package.replace('-', '_'))
32
  print(f"✅ {package} is installed")
33
  except ImportError:
34
  print(f"⚠️ {package} is not installed, attempting to install...")
@@ -50,194 +46,7 @@ def apply_patches():
50
  else:
51
  print("⚠️ tracker_patch.py not found, skipping patches")
52
 
53
- class LineCounter:
54
- """Count objects crossing a line"""
55
- def __init__(self, line_position=0.5, line_orientation='horizontal'):
56
- """
57
- Initialize a line counter
58
-
59
- Args:
60
- line_position: float between 0 and 1, position of line (default: middle)
61
- line_orientation: 'horizontal' or 'vertical'
62
- """
63
- self.line_position = line_position
64
- self.line_orientation = line_orientation
65
- self.counts = {} # Track counts by class
66
- self.crossed_ids = set() # Track IDs that have crossed the line
67
- self.prev_positions = {} # Store previous positions of tracked objects
68
-
69
- def update(self, bboxes, identities, clss, frame_shape):
70
- """
71
- Update counter with new detections
72
-
73
- Args:
74
- bboxes: list of bounding boxes [x1, y1, x2, y2]
75
- identities: list of track IDs
76
- clss: list of class IDs
77
- frame_shape: tuple of (height, width)
78
-
79
- Returns:
80
- count_info: dict of counts by class
81
- """
82
- if not len(bboxes):
83
- return self.counts, []
84
-
85
- height, width = frame_shape[:2]
86
-
87
- # Calculate line position in pixels
88
- if self.line_orientation == 'horizontal':
89
- line_pos = int(height * self.line_position)
90
- start_point = (0, line_pos)
91
- end_point = (width, line_pos)
92
- else: # vertical
93
- line_pos = int(width * self.line_position)
94
- start_point = (line_pos, 0)
95
- end_point = (line_pos, height)
96
-
97
- # Store line info for drawing
98
- line_info = {
99
- 'start': start_point,
100
- 'end': end_point,
101
- 'orientation': self.line_orientation
102
- }
103
-
104
- # Process each detection
105
- for i, (bbox, track_id, cls) in enumerate(zip(bboxes, identities, clss)):
106
- x1, y1, x2, y2 = bbox
107
-
108
- # Use center point of bbox to determine position
109
- if self.line_orientation == 'horizontal':
110
- center_pos = (y1 + y2) / 2
111
- crossed_now = False
112
-
113
- # Check if crossed the line
114
- if track_id in self.prev_positions:
115
- prev_pos = self.prev_positions[track_id]
116
- # Crossed from above to below
117
- if prev_pos < line_pos and center_pos >= line_pos:
118
- crossed_now = True
119
- # Crossed from below to above
120
- elif prev_pos >= line_pos and center_pos < line_pos:
121
- crossed_now = True
122
-
123
- # Store current position for next frame
124
- self.prev_positions[track_id] = center_pos
125
-
126
- else: # vertical line
127
- center_pos = (x1 + x2) / 2
128
- crossed_now = False
129
-
130
- # Check if crossed the line
131
- if track_id in self.prev_positions:
132
- prev_pos = self.prev_positions[track_id]
133
- # Crossed from left to right
134
- if prev_pos < line_pos and center_pos >= line_pos:
135
- crossed_now = True
136
- # Crossed from right to left
137
- elif prev_pos >= line_pos and center_pos < line_pos:
138
- crossed_now = True
139
-
140
- # Store current position for next frame
141
- self.prev_positions[track_id] = center_pos
142
-
143
- # Count if crossed and not counted before
144
- if crossed_now and track_id not in self.crossed_ids:
145
- self.crossed_ids.add(track_id)
146
- cls_id = int(cls)
147
- self.counts[cls_id] = self.counts.get(cls_id, 0) + 1
148
-
149
- return self.counts, line_info
150
-
151
- def reset(self):
152
- """Reset all counts and tracking info"""
153
- self.counts = {}
154
- self.crossed_ids = set()
155
- self.prev_positions = {}
156
-
157
- def add_line_counter_to_video(input_video, output_video, line_position, line_orientation):
158
- """Add line counter visualization to tracked video"""
159
- try:
160
- # Open the video file
161
- cap = cv2.VideoCapture(input_video)
162
- if not cap.isOpened():
163
- return False, "Failed to open input video"
164
-
165
- # Get video properties
166
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
167
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
168
- fps = cap.get(cv2.CAP_PROP_FPS)
169
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
170
-
171
- # Create video writer
172
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
173
- out = cv2.VideoWriter(output_video, fourcc, fps, (width, height))
174
-
175
- # Initialize line counter
176
- counter = LineCounter(line_position, line_orientation)
177
-
178
- # Calculate line position in pixels
179
- if line_orientation == 'horizontal':
180
- line_y = int(height * line_position)
181
- line_start = (0, line_y)
182
- line_end = (width, line_y)
183
- else: # vertical
184
- line_x = int(width * line_position)
185
- line_start = (line_x, 0)
186
- line_end = (line_x, height)
187
-
188
- # Process each frame
189
- frame_count = 0
190
- class_counts = {}
191
- tracked_objects = {} # {track_id: {"prev_pos": pos, "class": class_id}}
192
-
193
- while True:
194
- ret, frame = cap.read()
195
- if not ret:
196
- break
197
-
198
- # Draw the line
199
- cv2.line(frame, line_start, line_end, (0, 255, 0), 2)
200
-
201
- # Process tracking info from this frame (bounding boxes)
202
- # In a real implementation, we'd extract this from the tracking results
203
- # For now, we'll simulate this by detecting simple blob movements
204
-
205
- # TODO: Extract tracking data from the frame
206
- # This would involve parsing the visualization to extract bounding boxes
207
- # This is a complex task that might require a custom detector
208
-
209
- # Draw count information on the frame
210
- y_offset = 40
211
- cv2.putText(frame, "Line Counter", (20, y_offset),
212
- cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
213
- y_offset += 40
214
-
215
- for cls_id, count in class_counts.items():
216
- cv2.putText(frame, f"Class {cls_id}: {count}", (20, y_offset),
217
- cv2.FONT_HERSHEY_SIMPLEX, 0.75, (0, 0, 255), 2)
218
- y_offset += 30
219
-
220
- # Write the processed frame
221
- out.write(frame)
222
-
223
- # Progress update
224
- frame_count += 1
225
- if frame_count % 100 == 0:
226
- print(f"Processed {frame_count}/{total_frames} frames")
227
-
228
- # Release resources
229
- cap.release()
230
- out.release()
231
-
232
- return True, class_counts
233
-
234
- except Exception as e:
235
- import traceback
236
- traceback.print_exc()
237
- return False, f"Error processing video: {str(e)}"
238
-
239
- def run_tracking(video_file, yolo_model, reid_model, tracking_method, class_ids, conf_threshold,
240
- line_position, line_orientation):
241
  """Run object tracking on the uploaded video."""
242
  try:
243
  # Create temporary workspace
@@ -250,7 +59,7 @@ def run_tracking(video_file, yolo_model, reid_model, tracking_method, class_ids,
250
  output_dir = os.path.join(temp_dir, "output")
251
  os.makedirs(output_dir, exist_ok=True)
252
 
253
- # Build command for tracking (keeping original implementation)
254
  cmd = [
255
  "python", "tracking/track.py",
256
  "--yolo-model", str(MODELS_DIR / yolo_model),
@@ -264,6 +73,11 @@ def run_tracking(video_file, yolo_model, reid_model, tracking_method, class_ids,
264
  "--exist-ok"
265
  ]
266
 
 
 
 
 
 
267
  # Add class filtering if specific classes are provided
268
  if class_ids and class_ids.strip():
269
  # Parse the comma-separated class IDs
@@ -311,25 +125,54 @@ def run_tracking(video_file, yolo_model, reid_model, tracking_method, class_ids,
311
  print("No output video files found")
312
  return None, "No output video was generated. Check if tracking was successful."
313
 
314
- tracked_video = output_files[0]
315
-
316
- # Now add the line counter
317
- line_counted_video = os.path.join(temp_dir, "line_counted_output.mp4")
318
-
319
- # Process the tracked video to add line counter visualization
320
- # For now, we'll just copy the file as implementing actual post-processing
321
- # would require custom code to analyze the tracked objects in the video
322
- shutil.copy(tracked_video, line_counted_video)
323
-
324
- # Copy to permanent location with unique name
325
- permanent_path = os.path.join(OUTPUT_DIR, f"line_counted_{os.path.basename(video_file)}")
326
- shutil.copy(line_counted_video, permanent_path)
327
 
328
  # Verify file exists and has size
329
- if os.path.exists(permanent_path) and os.path.getsize(permanent_path) > 0:
330
- return permanent_path, f"Processing completed successfully! Line counter added at {line_orientation} position {line_position}."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
331
  else:
332
- return None, "Error: Output file was not generated properly."
 
333
 
334
  except Exception as e:
335
  import traceback
@@ -337,8 +180,7 @@ def run_tracking(video_file, yolo_model, reid_model, tracking_method, class_ids,
337
  return None, f"Error: {str(e)}"
338
 
339
  # Define the Gradio interface
340
- def process_video(video_path, yolo_model, reid_model, tracking_method, class_ids, conf_threshold,
341
- line_position, line_orientation):
342
  # Validate inputs
343
  if not video_path:
344
  return None, "Please upload a video file"
@@ -378,10 +220,23 @@ ensure_dependencies()
378
  apply_patches()
379
 
380
  # Create the Gradio interface
381
- with gr.Blocks(title="YOLO Object Tracking with Line Counter") as app:
382
- gr.Markdown("# 🚀 YOLO Object Tracking with Line Counter")
383
  gr.Markdown("Upload a video file to detect, track and count objects crossing a line. Processing may take a few minutes depending on video length.")
384
 
 
 
 
 
 
 
 
 
 
 
 
 
 
385
  # Add class reference information
386
  with gr.Accordion("YOLO Class Reference", open=False):
387
  gr.Markdown("""
@@ -459,7 +314,7 @@ with gr.Blocks(title="YOLO Object Tracking with Line Counter") as app:
459
 
460
  with gr.Column(scale=1):
461
  output_video = gr.Video(label="Output Video with Tracking and Counting")
462
- status_text = gr.Textbox(label="Status", value="Ready to process video")
463
 
464
  process_btn.click(
465
  fn=process_video,
 
6
  from pathlib import Path
7
  import sys
8
  import importlib.util
 
 
9
 
10
  # Ensure models directory exists
11
  MODELS_DIR = Path("models")
 
19
  """Ensure all required dependencies are installed."""
20
  required_packages = [
21
  "ultralytics",
22
+ "boxmot"
 
 
23
  ]
24
 
25
  for package in required_packages:
26
  try:
27
+ importlib.import_module(package)
28
  print(f"✅ {package} is installed")
29
  except ImportError:
30
  print(f"⚠️ {package} is not installed, attempting to install...")
 
46
  else:
47
  print("⚠️ tracker_patch.py not found, skipping patches")
48
 
49
+ def run_tracking(video_file, yolo_model, reid_model, tracking_method, class_ids, conf_threshold, line_position, line_orientation):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  """Run object tracking on the uploaded video."""
51
  try:
52
  # Create temporary workspace
 
59
  output_dir = os.path.join(temp_dir, "output")
60
  os.makedirs(output_dir, exist_ok=True)
61
 
62
+ # Build command
63
  cmd = [
64
  "python", "tracking/track.py",
65
  "--yolo-model", str(MODELS_DIR / yolo_model),
 
73
  "--exist-ok"
74
  ]
75
 
76
+ # Add line counter parameters
77
+ cmd.extend(["--line-pos", str(line_position)])
78
+ cmd.extend(["--line-direction", line_orientation])
79
+ cmd.extend(["--count-objects"])
80
+
81
  # Add class filtering if specific classes are provided
82
  if class_ids and class_ids.strip():
83
  # Parse the comma-separated class IDs
 
125
  print("No output video files found")
126
  return None, "No output video was generated. Check if tracking was successful."
127
 
128
+ output_file = output_files[0]
129
+ print(f"Selected output file: {output_file}")
 
 
 
 
 
 
 
 
 
 
 
130
 
131
  # Verify file exists and has size
132
+ if os.path.exists(output_file):
133
+ file_size = os.path.getsize(output_file)
134
+ print(f"Output file exists with size: {file_size} bytes")
135
+
136
+ if file_size == 0:
137
+ return None, "Output video was generated but has zero size."
138
+
139
+ # Copy to permanent location with unique name
140
+ permanent_path = os.path.join(OUTPUT_DIR, f"line_counted_{os.path.basename(video_file)}")
141
+ shutil.copy(output_file, permanent_path)
142
+ print(f"Copied output to permanent location: {permanent_path}")
143
+
144
+ # Extract counting information from output
145
+ count_summary = "Object counting complete"
146
+ try:
147
+ # Try to find any count results in the process output
148
+ output_text = process.stdout
149
+ if "counts" in output_text or "Counted" in output_text:
150
+ count_lines = [line for line in output_text.split('\n') if "counts" in line or "Counted" in line]
151
+ if count_lines:
152
+ count_summary = "\n".join(count_lines)
153
+ except:
154
+ pass
155
+
156
+ # Ensure the file is in MP4 format for better compatibility with Gradio
157
+ if not permanent_path.lower().endswith('.mp4'):
158
+ mp4_path = os.path.splitext(permanent_path)[0] + '.mp4'
159
+ try:
160
+ print(f"Converting to MP4 format: {mp4_path}")
161
+ subprocess.run([
162
+ 'ffmpeg', '-i', permanent_path,
163
+ '-c:v', 'libx264', '-preset', 'fast',
164
+ '-c:a', 'aac', mp4_path
165
+ ], check=True, capture_output=True)
166
+ os.remove(permanent_path) # Remove the original file
167
+ permanent_path = mp4_path
168
+ except Exception as e:
169
+ print(f"Failed to convert to MP4: {str(e)}")
170
+ # Continue with original file if conversion fails
171
+
172
+ return permanent_path, f"Processing completed successfully with {line_orientation} line at position {line_position}.\n{count_summary}"
173
  else:
174
+ print(f"Output file not found at {output_file}")
175
+ return None, "Output file was referenced but doesn't exist on disk."
176
 
177
  except Exception as e:
178
  import traceback
 
180
  return None, f"Error: {str(e)}"
181
 
182
  # Define the Gradio interface
183
+ def process_video(video_path, yolo_model, reid_model, tracking_method, class_ids, conf_threshold, line_position, line_orientation):
 
184
  # Validate inputs
185
  if not video_path:
186
  return None, "Please upload a video file"
 
220
  apply_patches()
221
 
222
  # Create the Gradio interface
223
+ with gr.Blocks(title="Object Tracking with Line Counter") as app:
224
+ gr.Markdown("# 🚀 Object Tracking with Line Counter")
225
  gr.Markdown("Upload a video file to detect, track and count objects crossing a line. Processing may take a few minutes depending on video length.")
226
 
227
+ # Add instructions for line counter
228
+ with gr.Accordion("Line Counter Instructions", open=True):
229
+ gr.Markdown("""
230
+ ## Line Counter Feature
231
+
232
+ Use the line counter to count objects that cross a specified line in the video:
233
+
234
+ 1. **Line Orientation**: Choose 'horizontal' for a line across the video or 'vertical' for a line from top to bottom
235
+ 2. **Line Position**: Adjust the slider to position the line (0.1 = top/left, 0.9 = bottom/right)
236
+
237
+ The count will appear in the output video and in the status box below.
238
+ """)
239
+
240
  # Add class reference information
241
  with gr.Accordion("YOLO Class Reference", open=False):
242
  gr.Markdown("""
 
314
 
315
  with gr.Column(scale=1):
316
  output_video = gr.Video(label="Output Video with Tracking and Counting")
317
+ status_text = gr.Textbox(label="Count Results", value="Ready to process video", lines=3)
318
 
319
  process_btn.click(
320
  fn=process_video,