sunbal7 commited on
Commit
0a996d2
Β·
verified Β·
1 Parent(s): 4a4f8af

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +425 -356
app.py CHANGED
@@ -1,397 +1,466 @@
1
  import streamlit as st
2
- import cv2
 
 
 
 
 
3
  import numpy as np
4
- import mediapipe as mp
5
- from scipy.spatial import distance
6
- import av
7
- from streamlit_webrtc import webrtc_streamer, VideoProcessorBase, RTCConfiguration
8
- import queue
9
- import threading
10
- import time
11
 
12
- # MediaPipe setup
13
- mp_face_mesh = mp.solutions.face_mesh
14
- face_mesh = mp_face_mesh.FaceMesh(
15
- min_detection_confidence=0.5,
16
- min_tracking_confidence=0.5,
17
- max_num_faces=1
18
  )
19
 
20
- # Eye and mouth landmark indices for MediaPipe
21
- LEFT_EYE_INDICES = [33, 160, 158, 133, 153, 144]
22
- RIGHT_EYE_INDICES = [362, 385, 387, 263, 373, 380]
23
- MOUTH_INDICES = [61, 39, 0, 269, 291, 405, 314, 17, 84, 181, 91, 185]
24
-
25
- # Thresholds and parameters
26
- EAR_THRESHOLD = 0.25 # Eye Aspect Ratio threshold
27
- MAR_THRESHOLD = 0.5 # Mouth Aspect Ratio threshold
28
- CONSECUTIVE_FRAMES_EYE = 15 # Frames for eye closure detection
29
- CONSECUTIVE_FRAMES_MOUTH = 20 # Frames for yawn detection
30
- ALERT_DURATION = 3 # Alert display duration in seconds
31
-
32
- # For Streamlit audio alert (using browser sound)
33
- AUDIO_ALERT_HTML = """
34
- <audio id="alertAudio" preload="auto">
35
- <source src="https://assets.mixkit.co/sfx/preview/mixkit-alarm-digital-clock-beep-989.mp3" type="audio/mpeg">
36
- </audio>
37
- <script>
38
- function playAlert() {
39
- var audio = document.getElementById('alertAudio');
40
- audio.play();
41
- }
42
- </script>
43
- """
44
-
45
- def eye_aspect_ratio(eye_points):
46
- """Calculate Eye Aspect Ratio"""
47
- # Vertical distances
48
- A = distance.euclidean(eye_points[1], eye_points[5])
49
- B = distance.euclidean(eye_points[2], eye_points[4])
50
-
51
- # Horizontal distance
52
- C = distance.euclidean(eye_points[0], eye_points[3])
53
-
54
- # EAR formula
55
- ear = (A + B) / (2.0 * C)
56
- return ear
57
-
58
- def mouth_aspect_ratio(mouth_points):
59
- """Calculate Mouth Aspect Ratio"""
60
- # Vertical distances
61
- A = distance.euclidean(mouth_points[2], mouth_points[10])
62
- B = distance.euclidean(mouth_points[4], mouth_points[8])
63
-
64
- # Horizontal distance
65
- C = distance.euclidean(mouth_points[0], mouth_points[6])
66
-
67
- # MAR formula
68
- mar = (A + B) / (2.0 * C)
69
- return mar
70
-
71
- class DrowsinessProcessor(VideoProcessorBase):
72
- def __init__(self):
73
- self.eye_closed_frames = 0
74
- self.mouth_open_frames = 0
75
- self.alert_active = False
76
- self.last_alert_time = 0
77
- self.frame_queue = queue.Queue(maxsize=30)
78
-
79
- def recv(self, frame):
80
- img = frame.to_ndarray(format="bgr24")
81
- img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
82
-
83
- # Process with MediaPipe
84
- results = face_mesh.process(img_rgb)
85
-
86
- drowsiness_detected = False
87
- eye_status = "OPEN"
88
- mouth_status = "CLOSED"
89
-
90
- if results.multi_face_landmarks:
91
- for face_landmarks in results.multi_face_landmarks:
92
- # Extract eye landmarks
93
- left_eye_points = []
94
- right_eye_points = []
95
-
96
- h, w = img.shape[:2]
97
-
98
- # Get left eye points
99
- for idx in LEFT_EYE_INDICES:
100
- landmark = face_landmarks.landmark[idx]
101
- x = int(landmark.x * w)
102
- y = int(landmark.y * h)
103
- left_eye_points.append((x, y))
104
-
105
- # Get right eye points
106
- for idx in RIGHT_EYE_INDICES:
107
- landmark = face_landmarks.landmark[idx]
108
- x = int(landmark.x * w)
109
- y = int(landmark.y * h)
110
- right_eye_points.append((x, y))
111
-
112
- # Calculate EAR for both eyes
113
- left_ear = eye_aspect_ratio(left_eye_points)
114
- right_ear = eye_aspect_ratio(right_eye_points)
115
- ear = (left_ear + right_ear) / 2.0
116
-
117
- # Draw eye landmarks
118
- for point in left_eye_points + right_eye_points:
119
- cv2.circle(img, point, 1, (0, 255, 0), -1)
120
-
121
- # Extract mouth landmarks
122
- mouth_points = []
123
- for idx in MOUTH_INDICES:
124
- landmark = face_landmarks.landmark[idx]
125
- x = int(landmark.x * w)
126
- y = int(landmark.y * h)
127
- mouth_points.append((x, y))
128
-
129
- # Calculate MAR
130
- mar = mouth_aspect_ratio(mouth_points)
131
-
132
- # Draw mouth landmarks
133
- for point in mouth_points:
134
- cv2.circle(img, point, 1, (255, 0, 0), -1)
135
-
136
- # Eye detection logic
137
- if ear < EAR_THRESHOLD:
138
- self.eye_closed_frames += 1
139
- eye_status = "CLOSED"
140
- else:
141
- self.eye_closed_frames = 0
142
-
143
- # Mouth detection logic
144
- if mar > MAR_THRESHOLD:
145
- self.mouth_open_frames += 1
146
- mouth_status = "OPEN"
147
- else:
148
- self.mouth_open_frames = 0
149
-
150
- # Check for drowsiness
151
- if (self.eye_closed_frames >= CONSECUTIVE_FRAMES_EYE or
152
- self.mouth_open_frames >= CONSECUTIVE_FRAMES_MOUTH):
153
- drowsiness_detected = True
154
- current_time = time.time()
155
-
156
- # Trigger alert if enough time has passed since last alert
157
- if current_time - self.last_alert_time > ALERT_DURATION:
158
- self.alert_active = True
159
- self.last_alert_time = current_time
160
- # We'll handle the audio alert through the frontend
161
-
162
- # Display metrics
163
- cv2.putText(img, f"EAR: {ear:.2f}", (10, 30),
164
- cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
165
- cv2.putText(img, f"MAR: {mar:.2f}", (10, 60),
166
- cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 0, 0), 2)
167
- cv2.putText(img, f"Eyes: {eye_status}", (10, 90),
168
- cv2.FONT_HERSHEY_SIMPLEX, 0.7,
169
- (0, 0, 255) if eye_status == "CLOSED" else (0, 255, 0), 2)
170
- cv2.putText(img, f"Mouth: {mouth_status}", (10, 120),
171
- cv2.FONT_HERSHEY_SIMPLEX, 0.7,
172
- (0, 0, 255) if mouth_status == "OPEN" else (0, 255, 0), 2)
173
-
174
- # Draw drowsiness warning
175
- if drowsiness_detected:
176
- cv2.putText(img, "DROWSINESS DETECTED!", (w//2 - 150, 50),
177
- cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 3)
178
- cv2.rectangle(img, (0, 0), (w, h), (0, 0, 255), 10)
179
-
180
- # Add to frame queue for alert trigger
181
- if not self.frame_queue.full():
182
- self.frame_queue.put({"alert": True, "frame": img})
183
-
184
- # Reset alert after duration
185
- if self.alert_active and time.time() - self.last_alert_time > ALERT_DURATION:
186
- self.alert_active = False
187
-
188
- return av.VideoFrame.from_ndarray(img, format="bgr24")
189
-
190
- def main():
191
- st.set_page_config(
192
- page_title="Real-time Drowsiness Detection",
193
- page_icon="πŸš—",
194
- layout="wide"
195
- )
196
-
197
- # Custom CSS
198
- st.markdown("""
199
- <style>
200
  .main-header {
201
  font-size: 2.5rem;
202
- color: #2E86AB;
203
  text-align: center;
204
- margin-bottom: 2rem;
205
  }
206
  .sub-header {
207
- font-size: 1.5rem;
208
- color: #A23B72;
209
- margin-top: 2rem;
 
210
  }
211
- .metric-box {
212
  background-color: #f0f2f6;
213
  padding: 1rem;
214
  border-radius: 10px;
215
- margin: 1rem 0;
 
216
  }
217
- .alert-box {
218
- background-color: #ffcccc;
219
  padding: 1rem;
220
  border-radius: 10px;
221
- border-left: 5px solid #ff0000;
222
- animation: pulse 2s infinite;
 
 
 
 
223
  }
224
- @keyframes pulse {
225
- 0% { opacity: 1; }
226
- 50% { opacity: 0.7; }
227
- 100% { opacity: 1; }
228
  }
229
- </style>
230
- """, unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
 
232
- # Header
233
- st.markdown('<h1 class="main-header">πŸš— Real-time Drowsiness Detection System</h1>',
234
- unsafe_allow_html=True)
235
 
236
- # Introduction
237
- col1, col2 = st.columns([2, 1])
238
- with col1:
239
- st.markdown("""
240
- ### πŸ“Š About This System
241
- This AI-powered system detects driver drowsiness in real-time using:
242
- - **Facial Landmark Detection**: Identifies key facial points using MediaPipe
243
- - **Eye Aspect Ratio (EAR)**: Monitors eye closure duration
244
- - **Mouth Aspect Ratio (MAR)**: Detects yawning behavior
245
- - **Real-time Alerting**: Triggers audible alerts when drowsiness is detected
246
 
247
- **How it works:**
248
- 1. The webcam captures video feed
249
- 2. AI model detects facial landmarks
250
- 3. EAR and MAR are calculated for each frame
251
- 4. System triggers alert if metrics indicate drowsiness
252
- """)
253
-
254
- with col2:
255
- st.markdown("""
256
- ### βš™οΈ Parameters
257
- """)
258
- st.code(f"""
259
- EAR Threshold: {EAR_THRESHOLD}
260
- MAR Threshold: {MAR_THRESHOLD}
261
- Eye Closure Frames: {CONSECUTIVE_FRAMES_EYE}
262
- Yawn Detection Frames: {CONSECUTIVE_FRAMES_MOUTH}
263
- """)
264
-
265
- st.markdown("---")
266
-
267
- # Add audio alert HTML
268
- st.markdown(AUDIO_ALERT_HTML, unsafe_allow_html=True)
269
-
270
- # Video stream section
271
- st.markdown('<h2 class="sub-header">πŸŽ₯ Live Drowsiness Detection</h2>',
272
- unsafe_allow_html=True)
273
-
274
- # Warning message
275
- with st.expander("⚠️ Important Note", expanded=True):
276
- st.warning("""
277
- **For proper functionality:**
278
- 1. Ensure good lighting on your face
279
- 2. Position yourself facing the camera
280
- 3. Grant camera permissions when prompted
281
- 4. Keep your face visible to the camera
282
- 5. The system works best in a well-lit environment
283
- """)
284
-
285
- # WebRTC configuration
286
- rtc_configuration = RTCConfiguration({
287
- "iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]
288
- })
289
 
290
- # Create WebRTC streamer
291
- webrtc_ctx = webrtc_streamer(
292
- key="drowsiness-detection",
293
- mode=WebRtcMode.SENDRECV,
294
- rtc_configuration=rtc_configuration,
295
- video_processor_factory=DrowsinessProcessor,
296
- media_stream_constraints={"video": True, "audio": False},
297
- async_processing=True,
298
- )
299
 
300
- # Status indicators
301
- col1, col2, col3 = st.columns(3)
302
- with col1:
303
- if webrtc_ctx.state.playing:
304
- st.success("βœ… Camera Active")
305
- else:
306
- st.error("❌ Camera Inactive")
307
 
308
- with col2:
309
- if webrtc_ctx.state.playing:
310
- st.info("πŸ” Monitoring Active")
311
- else:
312
- st.warning("⚠️ Monitoring Paused")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
 
314
- with col3:
315
- st.info("🎯 Waiting for face detection...")
 
 
 
 
 
 
 
 
316
 
317
- # Alert system
318
- if webrtc_ctx.state.playing:
319
- # JavaScript for audio alert
320
- alert_js = """
321
- <script>
322
- function checkForAlert() {
323
- // This would typically check a websocket or server-sent event
324
- // For simplicity, we'll use a placeholder
325
- setTimeout(checkForAlert, 1000);
326
- }
327
- checkForAlert();
328
- </script>
329
- """
330
- st.markdown(alert_js, unsafe_allow_html=True)
331
 
332
- # Metrics explanation
333
- st.markdown("---")
334
- st.markdown('<h2 class="sub-header">πŸ“ˆ Detection Metrics</h2>',
335
- unsafe_allow_html=True)
 
 
 
 
 
336
 
337
- col1, col2 = st.columns(2)
 
 
 
 
 
 
338
 
339
- with col1:
340
- st.markdown('<div class="metric-box">', unsafe_allow_html=True)
341
- st.markdown("### πŸ‘οΈ **Eye Aspect Ratio (EAR)**")
 
 
 
342
  st.markdown("""
343
- - **Normal**: EAR > 0.25
344
- - **Drowsy**: EAR < 0.25 for consecutive frames
345
- - **Calculation**: (Vertical distances) / (2 Γ— Horizontal distance)
346
- """)
347
- st.markdown('</div>', unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
348
 
349
- with col2:
350
- st.markdown('<div class="metric-box">', unsafe_allow_html=True)
351
- st.markdown("### πŸ‘„ **Mouth Aspect Ratio (MAR)**")
352
- st.markdown("""
353
- - **Normal**: MAR < 0.5
354
- - **Yawning**: MAR > 0.5 for consecutive frames
355
- - **Calculation**: (Vertical distances) / (2 Γ— Horizontal distance)
356
- """)
357
- st.markdown('</div>', unsafe_allow_html=True)
358
 
359
- # Technical details
360
- with st.expander("πŸ”§ Technical Implementation Details"):
361
- st.markdown("""
362
- ### πŸ—οΈ **Tech Stack**
363
- - **MediaPipe**: Facial landmark detection (468 points)
364
- - **OpenCV**: Real-time video processing
365
- - **Streamlit**: Web interface and deployment
366
- - **SciPy**: Distance calculations for EAR/MAR
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
367
 
368
- ### βš™οΈ **Algorithm**
369
- 1. **Face Detection**: MediaPipe Face Mesh identifies facial landmarks
370
- 2. **Feature Extraction**:
371
- - Eye landmarks (6 points per eye)
372
- - Mouth landmarks (12 points)
373
- 3. **Metric Calculation**:
374
- - EAR = (|p2-p6| + |p3-p5|) / (2 * |p1-p4|)
375
- - MAR = (|p2-p10| + |p4-p8|) / (2 * |p1-p7|)
376
- 4. **Decision Logic**:
377
- - Alert if EAR < threshold for N consecutive frames
378
- - Alert if MAR > threshold for M consecutive frames
379
 
380
- ### πŸš€ **Performance Features**
381
- - Real-time processing (>30 FPS)
382
- - Low latency alert system
383
- - Robust to lighting variations
384
- - Multi-person detection capable
385
- """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
386
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
387
  # Footer
388
  st.markdown("---")
389
- st.markdown("""
390
- <div style='text-align: center'>
391
- <p><strong>🚨 Safety-critical Application | Real-time Alerting System | Biomedical Signal Processing</strong></p>
392
- <p>Designed for deployment on Hugging Face Spaces β€’ For demonstration purposes only</p>
393
- </div>
394
- """, unsafe_allow_html=True)
 
395
 
396
  if __name__ == "__main__":
397
  main()
 
1
  import streamlit as st
2
+ import torch
3
+ from transformers import DetrImageProcessor, DetrForObjectDetection
4
+ from PIL import Image, ImageDraw, ImageFont
5
+ import io
6
+ import matplotlib.pyplot as plt
7
+ import matplotlib.patches as patches
8
  import numpy as np
9
+ from collections import Counter
10
+ import warnings
11
+ warnings.filterwarnings('ignore')
 
 
 
 
12
 
13
+ # Page configuration
14
+ st.set_page_config(
15
+ page_title="Object Detection Playground",
16
+ page_icon="πŸ”",
17
+ layout="wide"
 
18
  )
19
 
20
+ # Custom CSS for better styling
21
+ st.markdown("""
22
+ <style>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  .main-header {
24
  font-size: 2.5rem;
25
+ color: #1E88E5;
26
  text-align: center;
27
+ margin-bottom: 1rem;
28
  }
29
  .sub-header {
30
+ font-size: 1.2rem;
31
+ color: #666;
32
+ text-align: center;
33
+ margin-bottom: 2rem;
34
  }
35
+ .stat-box {
36
  background-color: #f0f2f6;
37
  padding: 1rem;
38
  border-radius: 10px;
39
+ border-left: 5px solid #1E88E5;
40
+ margin: 0.5rem 0;
41
  }
42
+ .model-info {
43
+ background-color: #e8f4fd;
44
  padding: 1rem;
45
  border-radius: 10px;
46
+ margin: 1rem 0;
47
+ }
48
+ .stButton button {
49
+ background-color: #1E88E5;
50
+ color: white;
51
+ font-weight: bold;
52
  }
53
+ .confidence-slider {
54
+ margin: 2rem 0;
 
 
55
  }
56
+ </style>
57
+ """, unsafe_allow_html=True)
58
+
59
+ @st.cache_resource
60
+ def load_model():
61
+ """Load DETR model and processor with caching"""
62
+ try:
63
+ processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
64
+ model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
65
+ return processor, model
66
+ except Exception as e:
67
+ st.error(f"Error loading model: {e}")
68
+ return None, None
69
+
70
+ def draw_bounding_boxes(image, results, threshold=0.5):
71
+ """Draw bounding boxes on the image with labels and confidence scores"""
72
+ draw = ImageDraw.Draw(image)
73
 
74
+ # Keep track of colors for each class
75
+ class_colors = {}
 
76
 
77
+ # Get predictions
78
+ for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
79
+ if score < threshold:
80
+ continue
81
+
82
+ # Convert to int
83
+ box = [round(i, 2) for i in box.tolist()]
84
+ label_name = model.config.id2label[label.item()]
 
 
85
 
86
+ # Generate or get color for this class
87
+ if label_name not in class_colors:
88
+ # Generate a unique color based on label hash
89
+ color_hash = hash(label_name) % 256
90
+ color = (color_hash, (color_hash * 37) % 256, (color_hash * 73) % 256)
91
+ class_colors[label_name] = color
92
+ else:
93
+ color = class_colors[label_name]
94
+
95
+ # Draw rectangle
96
+ draw.rectangle(box, outline=color, width=3)
97
+
98
+ # Prepare label text
99
+ label_text = f"{label_name}: {score:.2f}"
100
+
101
+ # Draw label background
102
+ text_bbox = draw.textbbox((box[0], box[1]), label_text)
103
+ draw.rectangle(text_bbox, fill=color)
104
+
105
+ # Draw label text
106
+ draw.text((box[0], box[1]), label_text, fill="white")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
+ return image, class_colors
109
+
110
+ def plot_detections(image, results, threshold=0.5):
111
+ """Alternative visualization using matplotlib"""
112
+ fig, ax = plt.subplots(1, figsize=(12, 8))
113
+ ax.imshow(image)
 
 
 
114
 
115
+ # Count objects per class
116
+ class_counts = Counter()
 
 
 
 
 
117
 
118
+ for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
119
+ if score < threshold:
120
+ continue
121
+
122
+ label_name = model.config.id2label[label.item()]
123
+ class_counts[label_name] += 1
124
+
125
+ # Convert box coordinates
126
+ xmin, ymin, xmax, ymax = box.tolist()
127
+ width = xmax - xmin
128
+ height = ymax - ymin
129
+
130
+ # Create rectangle patch
131
+ rect = patches.Rectangle(
132
+ (xmin, ymin), width, height,
133
+ linewidth=2, edgecolor='red', facecolor='none'
134
+ )
135
+ ax.add_patch(rect)
136
+
137
+ # Add label
138
+ ax.text(
139
+ xmin, ymin - 10,
140
+ f"{label_name}: {score:.2f}",
141
+ bbox=dict(facecolor='red', alpha=0.5),
142
+ fontsize=10, color='white'
143
+ )
144
 
145
+ plt.axis('off')
146
+ plt.tight_layout()
147
+ return fig, class_counts
148
+
149
+ def get_statistics(results, threshold=0.5):
150
+ """Calculate detection statistics"""
151
+ total_detections = 0
152
+ confident_detections = 0
153
+ confidence_scores = []
154
+ classes_detected = set()
155
 
156
+ for score, label in zip(results["scores"], results["labels"]):
157
+ total_detections += 1
158
+ confidence_scores.append(score.item())
159
+ classes_detected.add(model.config.id2label[label.item()])
160
+ if score >= threshold:
161
+ confident_detections += 1
 
 
 
 
 
 
 
 
162
 
163
+ stats = {
164
+ "total_predictions": total_detections,
165
+ "confident_detections": confident_detections,
166
+ "avg_confidence": np.mean(confidence_scores) if confidence_scores else 0,
167
+ "max_confidence": max(confidence_scores) if confidence_scores else 0,
168
+ "min_confidence": min(confidence_scores) if confidence_scores else 0,
169
+ "unique_classes": len(classes_detected),
170
+ "classes_list": list(classes_detected)
171
+ }
172
 
173
+ return stats
174
+
175
+ # Main app
176
+ def main():
177
+ # Header
178
+ st.markdown('<h1 class="main-header">πŸ” Object Detection Playground</h1>', unsafe_allow_html=True)
179
+ st.markdown('<p class="sub-header">Upload images and visualize detections with DETR (DEtection TRansformer)</p>', unsafe_allow_html=True)
180
 
181
+ # Sidebar
182
+ with st.sidebar:
183
+ st.header("βš™οΈ Settings")
184
+
185
+ # Model info
186
+ st.markdown("### Model Information")
187
  st.markdown("""
188
+ <div class="model-info">
189
+ <strong>Model:</strong> facebook/detr-resnet-50<br>
190
+ <strong>Architecture:</strong> DETR (DEtection TRansformer)<br>
191
+ <strong>Backbone:</strong> ResNet-50<br>
192
+ <strong>Classes:</strong> 91 COCO classes
193
+ </div>
194
+ """, unsafe_allow_html=True)
195
+
196
+ # Confidence threshold slider
197
+ st.markdown("### Detection Settings")
198
+ confidence_threshold = st.slider(
199
+ "Confidence Threshold",
200
+ min_value=0.0,
201
+ max_value=1.0,
202
+ value=0.5,
203
+ step=0.05,
204
+ help="Adjust the minimum confidence score for detections"
205
+ )
206
+
207
+ # Visualization options
208
+ st.markdown("### Visualization")
209
+ visualization_mode = st.selectbox(
210
+ "Choose visualization style",
211
+ ["PIL Drawing", "Matplotlib", "Both"]
212
+ )
213
+
214
+ # Show class labels
215
+ show_class_labels = st.checkbox("Show class labels on image", value=True)
216
+
217
+ # Advanced options
218
+ with st.expander("Advanced Options"):
219
+ max_detections = st.slider(
220
+ "Maximum detections to show",
221
+ min_value=1,
222
+ max_value=50,
223
+ value=25,
224
+ step=1
225
+ )
226
+
227
+ detection_color = st.color_picker(
228
+ "Detection color",
229
+ value="#FF0000"
230
+ )
231
 
232
+ # Main content area
233
+ col1, col2 = st.columns([2, 1])
 
 
 
 
 
 
 
234
 
235
+ with col1:
236
+ st.markdown("### πŸ“€ Upload Image")
237
+
238
+ # Image upload options
239
+ upload_method = st.radio(
240
+ "Choose upload method:",
241
+ ["Upload file", "Use sample image"]
242
+ )
243
+
244
+ image = None
245
+
246
+ if upload_method == "Upload file":
247
+ uploaded_file = st.file_uploader(
248
+ "Choose an image...",
249
+ type=['jpg', 'jpeg', 'png', 'bmp', 'tiff'],
250
+ help="Upload an image for object detection"
251
+ )
252
+
253
+ if uploaded_file is not None:
254
+ image = Image.open(uploaded_file).convert("RGB")
255
+ st.image(image, caption="Uploaded Image", use_column_width=True)
256
+
257
+ else:
258
+ # Sample images
259
+ sample_option = st.selectbox(
260
+ "Choose a sample image:",
261
+ ["Street Scene", "Office", "Kitchen", "Animals", "Sports"]
262
+ )
263
+
264
+ sample_images = {
265
+ "Street Scene": "https://images.unsplash.com/photo-1449824913935-59a10b8d2000?w=800&auto=format&fit=crop",
266
+ "Office": "https://images.unsplash.com/photo-1497366754035-f200968a6e72?w-800&auto=format&fit=crop",
267
+ "Kitchen": "https://images.unsplash.com/photo-1556909114-f6e7ad7d3136?w=800&auto=format&fit=crop",
268
+ "Animals": "https://images.unsplash.com/photo-1564349683136-77e08dba1ef7?w=800&auto=format&fit=crop",
269
+ "Sports": "https://images.unsplash.com/photo-1461896836934-ffe607ba8211?w=800&auto=format&fit=crop"
270
+ }
271
+
272
+ if st.button("Load Sample Image"):
273
+ # Note: In production, you'd need to download the image
274
+ # For now, we'll use a placeholder
275
+ st.info("Sample images require internet connection. In HuggingFace Spaces, you'll need to implement download.")
276
 
277
+ # Load model
278
+ with st.spinner("Loading DETR model..."):
279
+ processor, model = load_model()
 
 
 
 
 
 
 
 
280
 
281
+ if image is not None and model is not None:
282
+ # Process button
283
+ if st.button("πŸ” Detect Objects", type="primary", use_container_width=True):
284
+ with st.spinner("Processing image..."):
285
+ # Prepare inputs
286
+ inputs = processor(images=image, return_tensors="pt")
287
+
288
+ # Get predictions
289
+ with torch.no_grad():
290
+ outputs = model(**inputs)
291
+
292
+ # Process outputs
293
+ target_sizes = torch.tensor([image.size[::-1]])
294
+ results = processor.post_process_object_detection(
295
+ outputs,
296
+ target_sizes=target_sizes,
297
+ threshold=0.0 # We'll filter by our own threshold
298
+ )[0]
299
+
300
+ # Get statistics
301
+ stats = get_statistics(results, confidence_threshold)
302
+
303
+ # Display results
304
+ st.markdown("---")
305
+ st.markdown("### πŸ“Š Detection Results")
306
+
307
+ # Create two columns for visualizations
308
+ if visualization_mode in ["PIL Drawing", "Both"]:
309
+ # PIL visualization
310
+ pil_image = image.copy()
311
+ annotated_image, class_colors = draw_bounding_boxes(
312
+ pil_image, results, confidence_threshold
313
+ )
314
+ st.image(annotated_image, caption="Detected Objects", use_column_width=True)
315
+
316
+ if visualization_mode in ["Matplotlib", "Both"]:
317
+ # Matplotlib visualization
318
+ fig, class_counts = plot_detections(image, results, confidence_threshold)
319
+ st.pyplot(fig)
320
+ plt.close()
321
+
322
+ # Display class distribution
323
+ if class_counts:
324
+ st.markdown("#### πŸ“ˆ Class Distribution")
325
+ for class_name, count in class_counts.most_common():
326
+ st.progress(count/10 if count < 10 else 1.0,
327
+ text=f"{class_name}: {count} objects")
328
+
329
+ # Statistics in the right column
330
+ with col2:
331
+ st.markdown("### πŸ“ˆ Statistics")
332
+
333
+ # Create metrics
334
+ metrics_col1, metrics_col2 = st.columns(2)
335
+
336
+ with metrics_col1:
337
+ st.metric(
338
+ "Total Objects",
339
+ stats["confident_detections"],
340
+ f"{stats['total_predictions']} total predictions"
341
+ )
342
+
343
+ st.metric(
344
+ "Unique Classes",
345
+ stats["unique_classes"]
346
+ )
347
+
348
+ with metrics_col2:
349
+ st.metric(
350
+ "Avg Confidence",
351
+ f"{stats['avg_confidence']:.2%}"
352
+ )
353
+
354
+ st.metric(
355
+ "Max Confidence",
356
+ f"{stats['max_confidence']:.2%}"
357
+ )
358
+
359
+ # Class list
360
+ st.markdown("#### 🏷️ Detected Classes")
361
+ if stats["classes_list"]:
362
+ for class_name in sorted(stats["classes_list"]):
363
+ st.markdown(f"- {class_name}")
364
+ else:
365
+ st.info("No objects detected above threshold")
366
+
367
+ # Confidence distribution
368
+ st.markdown("#### πŸ“Š Confidence Distribution")
369
+
370
+ # Get confidence scores for histogram
371
+ confidence_scores = [score.item() for score in results["scores"]]
372
+ if confidence_scores:
373
+ fig_hist, ax_hist = plt.subplots(figsize=(8, 4))
374
+ ax_hist.hist(confidence_scores, bins=20, alpha=0.7, color='skyblue', edgecolor='black')
375
+ ax_hist.axvline(x=confidence_threshold, color='red', linestyle='--',
376
+ label=f'Threshold: {confidence_threshold}')
377
+ ax_hist.set_xlabel('Confidence Score')
378
+ ax_hist.set_ylabel('Count')
379
+ ax_hist.set_title('Distribution of Confidence Scores')
380
+ ax_hist.legend()
381
+ ax_hist.grid(True, alpha=0.3)
382
+ st.pyplot(fig_hist)
383
+ plt.close()
384
+
385
+ # Download button for processed image
386
+ if visualization_mode in ["PIL Drawing", "Both"]:
387
+ buffered = io.BytesIO()
388
+ annotated_image.save(buffered, format="PNG")
389
+ st.download_button(
390
+ label="πŸ“₯ Download Processed Image",
391
+ data=buffered.getvalue(),
392
+ file_name="detected_objects.png",
393
+ mime="image/png",
394
+ use_container_width=True
395
+ )
396
 
397
+ # Instructions in main area if no image
398
+ if 'image' not in locals() or image is None:
399
+ with col1:
400
+ st.info("πŸ‘ˆ Please upload an image or select a sample image to begin object detection.")
401
+
402
+ # Quick guide
403
+ with st.expander("πŸ“š Quick Guide"):
404
+ st.markdown("""
405
+ ### How to use:
406
+ 1. **Upload an image** using the file uploader or select a sample image
407
+ 2. **Adjust the confidence threshold** in the sidebar (default: 0.5)
408
+ 3. **Choose visualization style** (PIL or Matplotlib)
409
+ 4. **Click 'Detect Objects'** to run the model
410
+
411
+ ### Features:
412
+ - **Real-time statistics** showing object counts
413
+ - **Adjustable confidence threshold** to filter detections
414
+ - **Multiple visualization options**
415
+ - **Download processed images**
416
+ - **Class distribution analysis**
417
+
418
+ ### About DETR:
419
+ DETR (DEtection TRansformer) is an end-to-end object detection model that uses
420
+ transformers instead of traditional convolutional approaches.
421
+ """)
422
+
423
+ # Model capabilities
424
+ st.markdown("### 🎯 Model Capabilities")
425
+ col_cap1, col_cap2, col_cap3 = st.columns(3)
426
+
427
+ with col_cap1:
428
+ st.markdown("""
429
+ **Common Objects:**
430
+ - Person
431
+ - Vehicle
432
+ - Furniture
433
+ - Animal
434
+ - Food items
435
+ """)
436
+
437
+ with col_cap2:
438
+ st.markdown("""
439
+ **Detection Types:**
440
+ - 91 COCO classes
441
+ - Real-time processing
442
+ - Bounding boxes
443
+ - Confidence scores
444
+ """)
445
+
446
+ with col_cap3:
447
+ st.markdown("""
448
+ **Best For:**
449
+ - General scenes
450
+ - Multiple objects
451
+ - Indoor/outdoor
452
+ - Real-world images
453
+ """)
454
+
455
  # Footer
456
  st.markdown("---")
457
+ st.markdown(
458
+ "<div style='text-align: center; color: #666;'>"
459
+ "Object Detection Playground β€’ Powered by DETR Transformers β€’ "
460
+ "<a href='https://huggingface.co/facebook/detr-resnet-50' target='_blank'>Model Card</a>"
461
+ "</div>",
462
+ unsafe_allow_html=True
463
+ )
464
 
465
  if __name__ == "__main__":
466
  main()