sunbal7 commited on
Commit
8bfa29e
Β·
verified Β·
1 Parent(s): 925181d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +374 -123
app.py CHANGED
@@ -1,150 +1,401 @@
1
  import streamlit as st
2
- import torch
3
- from transformers import DetrImageProcessor, DetrForObjectDetection
4
- from PIL import Image, ImageDraw
5
  import numpy as np
 
 
 
6
  import time
 
 
 
 
7
 
8
- # Page config
9
  st.set_page_config(
10
- page_title="Simple Object Detection",
11
- page_icon="πŸ”",
12
  layout="wide"
13
  )
14
 
15
- # Title
16
- st.title("πŸ” Simple Object Detection with DETR")
17
- st.markdown("Upload an image to detect objects using Facebook's DETR model")
 
 
 
18
 
19
- # Initialize model in session state
20
- if 'model_loaded' not in st.session_state:
21
- st.session_state.model_loaded = False
22
- st.session_state.processor = None
23
- st.session_state.model = None
24
-
25
- # Sidebar
26
  with st.sidebar:
27
- st.header("Settings")
28
 
29
  # Confidence threshold
30
  confidence = st.slider(
31
- "Confidence Threshold",
32
  min_value=0.1,
33
- max_value=0.99,
34
- value=0.7,
35
- step=0.05
 
36
  )
37
 
38
- # Load model button
39
- if not st.session_state.model_loaded:
40
- if st.button("Load Model", type="primary"):
41
- with st.spinner("Loading DETR model..."):
42
- try:
43
- st.session_state.processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
44
- st.session_state.model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
45
- st.session_state.model_loaded = True
46
- st.success("Model loaded successfully!")
47
- except Exception as e:
48
- st.error(f"Error loading model: {e}")
49
- else:
50
- st.success("βœ… Model is loaded and ready!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
- # Main content
53
- uploaded_file = st.file_uploader(
54
- "Choose an image...",
55
- type=['jpg', 'jpeg', 'png']
56
- )
 
 
 
 
 
 
 
 
 
 
57
 
58
- if uploaded_file is not None:
59
- # Display original image
60
- image = Image.open(uploaded_file).convert("RGB")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  col1, col2 = st.columns(2)
62
 
63
  with col1:
64
- st.image(image, caption="Original Image", use_column_width=True)
65
-
66
- if st.session_state.model_loaded and st.button("Detect Objects"):
67
- with st.spinner("Detecting objects..."):
68
- try:
69
- # Process image
70
- processor = st.session_state.processor
71
- model = st.session_state.model
72
-
73
- inputs = processor(images=image, return_tensors="pt")
74
-
75
- with torch.no_grad():
76
- outputs = model(**inputs)
77
-
78
- # Convert outputs
79
- target_sizes = torch.tensor([image.size[::-1]])
80
- results = processor.post_process_object_detection(
81
- outputs,
82
- target_sizes=target_sizes,
83
- threshold=confidence
84
- )[0]
85
-
86
- # Draw boxes
87
- draw = ImageDraw.Draw(image)
88
- colors = ["red", "green", "blue", "yellow", "purple", "orange"]
89
-
90
- detected_objects = []
91
-
92
- for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
93
- box = [round(i, 2) for i in box.tolist()]
94
- label_name = model.config.id2label[label.item()]
95
-
96
- # Draw rectangle
97
- color = colors[label.item() % len(colors)]
98
- draw.rectangle(box, outline=color, width=3)
99
-
100
- # Add label
101
- label_text = f"{label_name}: {score:.2f}"
102
- draw.text((box[0], box[1]), label_text, fill=color)
103
-
104
- detected_objects.append((label_name, score.item()))
105
-
106
- # Display results
107
- with col2:
108
- st.image(image, caption="Detected Objects", use_column_width=True)
109
-
110
- # Show statistics
111
- st.subheader("πŸ“Š Detection Results")
112
-
113
- if detected_objects:
114
- col_stats1, col_stats2, col_stats3 = st.columns(3)
115
-
116
- with col_stats1:
117
- st.metric("Objects Found", len(detected_objects))
118
-
119
- with col_stats2:
120
- avg_conf = np.mean([score for _, score in detected_objects])
121
- st.metric("Average Confidence", f"{avg_conf:.1%}")
122
-
123
- with col_stats3:
124
- st.metric("Unique Classes", len(set([label for label, _ in detected_objects])))
125
-
126
- # Show details
127
- st.subheader("Detected Objects:")
128
- for label, score in detected_objects:
129
- st.write(f"- **{label}** (confidence: {score:.1%})")
130
- else:
131
- st.warning("No objects detected above the confidence threshold.")
132
-
133
- except Exception as e:
134
- st.error(f"Error during detection: {e}")
135
- else:
136
- st.info("πŸ‘ˆ Please upload an image and load the model from the sidebar")
137
 
138
- # Instructions
139
- with st.expander("How to use this app"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  st.markdown("""
141
- 1. **Load the model** using the button in the sidebar
142
- 2. **Upload an image** (JPG, PNG formats)
143
- 3. **Adjust confidence threshold** if needed
144
- 4. **Click 'Detect Objects'** to run detection
145
- 5. **View results** and detected objects
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  """)
147
 
148
  # Footer
149
  st.markdown("---")
150
- st.markdown("Built with [DETR](https://huggingface.co/facebook/detr-resnet-50) β€’ [Streamlit](https://streamlit.io)")
 
 
 
 
 
 
1
  import streamlit as st
2
+ import cv2
3
+ import tempfile
 
4
  import numpy as np
5
+ from ultralytics import YOLO
6
+ import plotly.graph_objects as go
7
+ from collections import defaultdict
8
  import time
9
+ import requests
10
+ from PIL import Image
11
+ import io
12
+ import pandas as pd
13
 
14
+ # Page configuration
15
  st.set_page_config(
16
+ page_title="People & Vehicle Counter",
17
+ page_icon="πŸš—",
18
  layout="wide"
19
  )
20
 
21
+ # Title and description
22
+ st.title("πŸš— People & Vehicle Counter")
23
+ st.markdown("""
24
+ Upload a video or provide a video URL to count people and vehicles in real-time using YOLOv8.
25
+ This app is useful for traffic monitoring, retail analytics, and crowd management.
26
+ """)
27
 
28
+ # Sidebar for settings
 
 
 
 
 
 
29
  with st.sidebar:
30
+ st.header("βš™οΈ Settings")
31
 
32
  # Confidence threshold
33
  confidence = st.slider(
34
+ "Detection Confidence",
35
  min_value=0.1,
36
+ max_value=1.0,
37
+ value=0.25,
38
+ step=0.05,
39
+ help="Higher values reduce false positives but might miss some objects"
40
  )
41
 
42
+ # Object classes to detect
43
+ st.subheader("Objects to Count")
44
+ count_person = st.checkbox("Person", value=True)
45
+ count_car = st.checkbox("Car", value=True)
46
+ count_bus = st.checkbox("Bus", value=False)
47
+ count_truck = st.checkbox("Truck", value=False)
48
+ count_motorcycle = st.checkbox("Motorcycle", value=False)
49
+ count_bicycle = st.checkbox("Bicycle", value=False)
50
+
51
+ # Line intersection for counting
52
+ st.subheader("Counting Line")
53
+ show_line = st.checkbox("Show counting line", value=True)
54
+ line_position = st.slider(
55
+ "Line position (%)",
56
+ min_value=0,
57
+ max_value=100,
58
+ value=50,
59
+ help="Vertical line position for counting object crossings"
60
+ )
61
+
62
+ # Processing options
63
+ st.subheader("Processing Options")
64
+ process_every_nth = st.slider(
65
+ "Process every Nth frame",
66
+ min_value=1,
67
+ max_value=10,
68
+ value=3,
69
+ help="Higher values speed up processing but reduce accuracy"
70
+ )
71
+
72
+ max_frames = st.number_input(
73
+ "Maximum frames to process",
74
+ min_value=10,
75
+ max_value=1000,
76
+ value=200,
77
+ help="Limit processing for long videos"
78
+ )
79
 
80
+ # Initialize session state
81
+ if 'total_counts' not in st.session_state:
82
+ st.session_state.total_counts = defaultdict(int)
83
+ if 'frame_counts' not in st.session_state:
84
+ st.session_state.frame_counts = []
85
+ if 'processing_complete' not in st.session_state:
86
+ st.session_state.processing_complete = False
87
+ if 'processed_video' not in st.session_state:
88
+ st.session_state.processed_video = None
89
+
90
+ # COCO class names for YOLO (common objects)
91
+ CLASS_NAMES = {
92
+ 0: "person", 1: "bicycle", 2: "car", 3: "motorcycle",
93
+ 5: "bus", 7: "truck", 64: "chair" # Note: YOLOv8 uses different indices
94
+ }
95
 
96
+ # Map our checkboxes to class IDs
97
+ def get_selected_classes():
98
+ selected_classes = []
99
+ class_mapping = {
100
+ "person": 0,
101
+ "bicycle": 2,
102
+ "car": 2,
103
+ "motorcycle": 3,
104
+ "bus": 5,
105
+ "truck": 7
106
+ }
107
+
108
+ if count_person:
109
+ selected_classes.append(0)
110
+ if count_bicycle:
111
+ selected_classes.append(1)
112
+ if count_car:
113
+ selected_classes.append(2)
114
+ if count_motorcycle:
115
+ selected_classes.append(3)
116
+ if count_bus:
117
+ selected_classes.append(5)
118
+ if count_truck:
119
+ selected_classes.append(7)
120
+
121
+ return selected_classes
122
+
123
+ # Load YOLO model with caching
124
+ @st.cache_resource
125
+ def load_model():
126
+ return YOLO('yolov8n.pt') # Using nano model for speed
127
+
128
+ # Function to process video
129
+ def process_video(video_path, selected_classes):
130
+ model = load_model()
131
+
132
+ cap = cv2.VideoCapture(video_path)
133
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
134
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
135
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
136
+
137
+ # Initialize video writer
138
+ temp_output = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
139
+ output_path = temp_output.name
140
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
141
+ out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
142
+
143
+ # Reset counts
144
+ st.session_state.total_counts = defaultdict(int)
145
+ st.session_state.frame_counts = []
146
+
147
+ # Progress bar
148
+ progress_bar = st.progress(0)
149
+ status_text = st.empty()
150
+
151
+ frame_count = 0
152
+ processed_frames = 0
153
+
154
+ while cap.isOpened():
155
+ ret, frame = cap.read()
156
+ if not ret or processed_frames >= max_frames:
157
+ break
158
+
159
+ frame_count += 1
160
+
161
+ # Process only every nth frame
162
+ if frame_count % process_every_nth != 0:
163
+ out.write(frame)
164
+ continue
165
+
166
+ # Run YOLO inference
167
+ results = model(frame, conf=confidence, classes=selected_classes, verbose=False)
168
+
169
+ # Draw results
170
+ annotated_frame = results[0].plot()
171
+
172
+ # Count objects
173
+ frame_counts = defaultdict(int)
174
+ for box in results[0].boxes:
175
+ cls_id = int(box.cls[0])
176
+ class_name = CLASS_NAMES.get(cls_id, f"class_{cls_id}")
177
+ frame_counts[class_name] += 1
178
+ st.session_state.total_counts[class_name] += 1
179
+
180
+ # Store frame counts for chart
181
+ st.session_state.frame_counts.append({
182
+ 'frame': processed_frames,
183
+ **frame_counts
184
+ })
185
+
186
+ # Draw counting line if enabled
187
+ if show_line:
188
+ line_x = int(width * line_position / 100)
189
+ cv2.line(annotated_frame, (line_x, 0), (line_x, height), (0, 255, 255), 2)
190
+
191
+ # Add counter text
192
+ y_offset = 30
193
+ for obj_type, count in frame_counts.items():
194
+ cv2.putText(annotated_frame,
195
+ f"{obj_type}: {count}",
196
+ (10, y_offset),
197
+ cv2.FONT_HERSHEY_SIMPLEX,
198
+ 0.7, (0, 255, 0), 2)
199
+ y_offset += 25
200
+
201
+ # Write frame to output video
202
+ out.write(annotated_frame)
203
+ processed_frames += 1
204
+
205
+ # Update progress
206
+ progress = min(processed_frames / max_frames, 1.0)
207
+ progress_bar.progress(progress)
208
+ status_text.text(f"Processing frame {processed_frames}/{max_frames}")
209
+
210
+ cap.release()
211
+ out.release()
212
+
213
+ st.session_state.processing_complete = True
214
+ st.session_state.processed_video = output_path
215
+
216
+ return output_path
217
+
218
+ # Function to download video from URL
219
+ def download_video_from_url(url):
220
+ try:
221
+ response = requests.get(url, stream=True)
222
+ if response.status_code == 200:
223
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
224
+ for chunk in response.iter_content(chunk_size=8192):
225
+ temp_file.write(chunk)
226
+ temp_file.close()
227
+ return temp_file.name
228
+ else:
229
+ st.error(f"Failed to download video. Status code: {response.status_code}")
230
+ return None
231
+ except Exception as e:
232
+ st.error(f"Error downloading video: {str(e)}")
233
+ return None
234
+
235
+ # Main app layout
236
+ tab1, tab2, tab3 = st.tabs(["πŸ“Ή Video Input", "πŸ“Š Results", "ℹ️ About"])
237
+
238
+ with tab1:
239
  col1, col2 = st.columns(2)
240
 
241
  with col1:
242
+ st.subheader("Upload Video")
243
+ uploaded_file = st.file_uploader(
244
+ "Choose a video file",
245
+ type=['mp4', 'avi', 'mov', 'mkv']
246
+ )
247
+
248
+ if uploaded_file is not None:
249
+ # Save uploaded file to temp location
250
+ tfile = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
251
+ tfile.write(uploaded_file.read())
252
+ video_path = tfile.name
253
+
254
+ # Display video info
255
+ st.video(uploaded_file)
256
+ st.info(f"Uploaded: {uploaded_file.name}")
257
+
258
+ with col2:
259
+ st.subheader("Video URL")
260
+ video_url = st.text_input(
261
+ "Enter video URL",
262
+ placeholder="https://example.com/video.mp4"
263
+ )
264
+
265
+ if st.button("Load from URL") and video_url:
266
+ with st.spinner("Downloading video..."):
267
+ video_path = download_video_from_url(video_url)
268
+ if video_path:
269
+ st.success("Video downloaded successfully!")
270
+ # Display first frame
271
+ cap = cv2.VideoCapture(video_path)
272
+ ret, frame = cap.read()
273
+ if ret:
274
+ st.image(frame, caption="First frame of video", use_column_width=True)
275
+ cap.release()
276
+
277
+ # Process button
278
+ if ('video_path' in locals() and video_path) or ('uploaded_file' in locals() and uploaded_file):
279
+ if st.button("πŸš€ Start Counting", type="primary"):
280
+ selected_classes = get_selected_classes()
281
+
282
+ if not selected_classes:
283
+ st.warning("Please select at least one object type to count.")
284
+ else:
285
+ with st.spinner("Processing video with YOLOv8..."):
286
+ output_path = process_video(video_path, selected_classes)
287
+ st.success("Processing complete!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
 
289
+ with tab2:
290
+ if st.session_state.processing_complete:
291
+ col1, col2 = st.columns([2, 1])
292
+
293
+ with col1:
294
+ st.subheader("Processed Video")
295
+ # Display processed video
296
+ video_file = open(st.session_state.processed_video, 'rb')
297
+ video_bytes = video_file.read()
298
+ st.video(video_bytes)
299
+
300
+ # Download button
301
+ st.download_button(
302
+ label="πŸ“₯ Download Processed Video",
303
+ data=video_bytes,
304
+ file_name="processed_video.mp4",
305
+ mime="video/mp4"
306
+ )
307
+
308
+ with col2:
309
+ st.subheader("πŸ“ˆ Total Counts")
310
+
311
+ # Display total counts
312
+ if st.session_state.total_counts:
313
+ for obj_type, count in st.session_state.total_counts.items():
314
+ st.metric(label=obj_type.capitalize(), value=count)
315
+ else:
316
+ st.info("No objects detected")
317
+
318
+ # Summary statistics
319
+ st.subheader("πŸ“Š Summary")
320
+ if st.session_state.frame_counts:
321
+ df = pd.DataFrame(st.session_state.frame_counts)
322
+ st.dataframe(df.tail(10), use_container_width=True)
323
+
324
+ # Time series chart
325
+ st.subheader("πŸ“ˆ Objects Over Time")
326
+ if st.session_state.frame_counts:
327
+ df = pd.DataFrame(st.session_state.frame_counts)
328
+
329
+ fig = go.Figure()
330
+
331
+ # Add a trace for each object type
332
+ for column in df.columns:
333
+ if column != 'frame':
334
+ fig.add_trace(go.Scatter(
335
+ x=df['frame'],
336
+ y=df[column],
337
+ name=column.capitalize(),
338
+ mode='lines+markers'
339
+ ))
340
+
341
+ fig.update_layout(
342
+ title="Object Counts per Frame",
343
+ xaxis_title="Frame Number",
344
+ yaxis_title="Count",
345
+ hovermode='x unified',
346
+ height=400
347
+ )
348
+
349
+ st.plotly_chart(fig, use_container_width=True)
350
+ else:
351
+ st.info("Process a video first to see results here.")
352
+
353
+ with tab3:
354
  st.markdown("""
355
+ ## About This App
356
+
357
+ ### πŸ”§ Technology Stack
358
+ - **YOLOv8**: State-of-the-art object detection model
359
+ - **Streamlit**: Interactive web app framework
360
+ - **OpenCV**: Computer vision library for video processing
361
+ - **Plotly**: Interactive visualizations
362
+
363
+ ### πŸ“‹ Features
364
+ 1. **Multiple Input Sources**: Upload videos or use URLs
365
+ 2. **Customizable Detection**: Select specific object classes
366
+ 3. **Real-time Counting**: Track objects frame by frame
367
+ 4. **Visual Analytics**: Interactive charts and statistics
368
+ 5. **Export Results**: Download processed videos and data
369
+
370
+ ### 🎯 Use Cases
371
+ - **Traffic Monitoring**: Count vehicles on roads
372
+ - **Retail Analytics**: Track customer movements
373
+ - **Crowd Management**: Monitor people density
374
+ - **Security**: Detect and count objects of interest
375
+
376
+ ### ⚠️ Limitations
377
+ - Processing speed depends on video length and resolution
378
+ - Maximum 200 frames processed in this demo
379
+ - Accuracy depends on model confidence settings
380
+
381
+ ### πŸ“š How to Use
382
+ 1. Upload a video or provide a URL
383
+ 2. Configure detection settings in the sidebar
384
+ 3. Click "Start Counting"
385
+ 4. View results in the Results tab
386
+ 5. Download processed video and data
387
+
388
+ ---
389
+
390
+ **Note**: This app runs on Hugging Face Spaces with limited resources.
391
+ For heavy processing, consider running locally with GPU support.
392
  """)
393
 
394
  # Footer
395
  st.markdown("---")
396
+ st.markdown(
397
+ "<div style='text-align: center'>"
398
+ "Built with ❀️ using YOLOv8, Streamlit, and Hugging Face Spaces"
399
+ "</div>",
400
+ unsafe_allow_html=True
401
+ )