sunbal7 commited on
Commit
045dbac
·
verified ·
1 Parent(s): 545df94

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +130 -432
app.py CHANGED
@@ -3,450 +3,148 @@ import torch
3
  from transformers import DetrImageProcessor, DetrForObjectDetection
4
  from PIL import Image, ImageDraw
5
  import numpy as np
6
- from collections import Counter
7
- import cv2
8
  import time
9
- import tempfile
10
- import os
11
 
12
- # Set page config
13
  st.set_page_config(
14
- page_title="Object Detection Playground",
15
  page_icon="🔍",
16
  layout="wide"
17
  )
18
 
19
- # Custom CSS
20
- st.markdown("""
21
- <style>
22
- .main-header {
23
- font-size: 2.5rem;
24
- color: #1E88E5;
25
- text-align: center;
26
- margin-bottom: 1rem;
27
- font-weight: 700;
28
- }
29
- .sub-header {
30
- font-size: 1.2rem;
31
- color: #666;
32
- text-align: center;
33
- margin-bottom: 2rem;
34
- }
35
- .stat-box {
36
- background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
37
- color: white;
38
- padding: 1.5rem;
39
- border-radius: 10px;
40
- margin: 0.5rem 0;
41
- box-shadow: 0 4px 6px rgba(0,0,0,0.1);
42
- }
43
- .metric-card {
44
- background: white;
45
- padding: 1rem;
46
- border-radius: 10px;
47
- border-left: 5px solid #1E88E5;
48
- box-shadow: 0 2px 4px rgba(0,0,0,0.1);
49
- margin: 0.5rem 0;
50
- }
51
- .stButton > button {
52
- background: linear-gradient(135deg, #1E88E5 0%, #0D47A1 100%);
53
- color: white;
54
- border: none;
55
- padding: 0.5rem 2rem;
56
- border-radius: 5px;
57
- font-weight: 600;
58
- }
59
- .stButton > button:hover {
60
- background: linear-gradient(135deg, #0D47A1 0%, #1565C0 100%);
61
- transform: translateY(-2px);
62
- transition: all 0.3s ease;
63
- }
64
- .confidence-slider {
65
- margin: 1rem 0;
66
- }
67
- .model-info-box {
68
- background: #f8f9fa;
69
- padding: 1rem;
70
- border-radius: 10px;
71
- border: 1px solid #dee2e6;
72
- }
73
- </style>
74
- """, unsafe_allow_html=True)
75
 
76
- @st.cache_resource(show_spinner=True)
77
- def load_model():
78
- """Load and cache the DETR model"""
79
- try:
80
- with st.spinner("Loading DETR model (first time may take a minute)..."):
81
- # Load processor and model
82
- processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
83
- model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
84
- model.eval() # Set to evaluation mode
85
- return processor, model
86
- except Exception as e:
87
- st.error(f"Failed to load model: {str(e)}")
88
- return None, None
89
 
90
- def process_image(image, processor, model, confidence_threshold):
91
- """Process image and return detections"""
92
- try:
93
- # Convert to RGB if needed
94
- if image.mode != 'RGB':
95
- image = image.convert('RGB')
96
-
97
- # Process image
98
- inputs = processor(images=image, return_tensors="pt")
99
-
100
- # Run inference
101
- with torch.no_grad():
102
- outputs = model(**inputs)
103
-
104
- # Process outputs
105
- target_sizes = torch.tensor([image.size[::-1]]) # [height, width]
106
- results = processor.post_process_object_detection(
107
- outputs,
108
- target_sizes=target_sizes,
109
- threshold=0.01 # Low threshold, we'll filter later
110
- )[0]
111
-
112
- # Filter by confidence threshold
113
- mask = results["scores"] >= confidence_threshold
114
- filtered_results = {
115
- "scores": results["scores"][mask],
116
- "labels": results["labels"][mask],
117
- "boxes": results["boxes"][mask]
118
- }
119
-
120
- return filtered_results
121
- except Exception as e:
122
- st.error(f"Error processing image: {str(e)}")
123
- return None
124
-
125
- def draw_detections(image, results, processor, model):
126
- """Draw bounding boxes on image"""
127
- try:
128
- # Create a copy of the image
129
- img_copy = image.copy()
130
- draw = ImageDraw.Draw(img_copy)
131
-
132
- # Color palette for different classes
133
- colors = [
134
- (255, 0, 0), (0, 255, 0), (0, 0, 255),
135
- (255, 255, 0), (255, 0, 255), (0, 255, 255),
136
- (255, 128, 0), (128, 0, 255), (0, 128, 255)
137
- ]
138
-
139
- # Draw each detection
140
- for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
141
- # Get box coordinates
142
- xmin, ymin, xmax, ymax = box.tolist()
143
-
144
- # Get label name
145
- label_id = label.item()
146
- label_name = model.config.id2label[label_id]
147
-
148
- # Choose color based on label
149
- color = colors[label_id % len(colors)]
150
-
151
- # Draw rectangle
152
- draw.rectangle([xmin, ymin, xmax, ymax], outline=color, width=3)
153
-
154
- # Create label text
155
- label_text = f"{label_name}: {score:.2f}"
156
-
157
- # Draw label background
158
- text_bbox = draw.textbbox((xmin, ymin), label_text)
159
- draw.rectangle(text_bbox, fill=color)
160
-
161
- # Draw text
162
- draw.text((xmin, ymin), label_text, fill="white")
163
-
164
- return img_copy
165
- except Exception as e:
166
- st.error(f"Error drawing detections: {str(e)}")
167
- return image
168
-
169
- def get_statistics(results, model):
170
- """Calculate and return detection statistics"""
171
- if results is None or len(results["scores"]) == 0:
172
- return {
173
- "total_objects": 0,
174
- "avg_confidence": 0,
175
- "class_distribution": {},
176
- "detected_classes": []
177
- }
178
-
179
- # Count objects per class
180
- class_counts = Counter()
181
- confidences = []
182
-
183
- for score, label in zip(results["scores"], results["labels"]):
184
- label_name = model.config.id2label[label.item()]
185
- class_counts[label_name] += 1
186
- confidences.append(score.item())
187
-
188
- # Prepare statistics
189
- stats = {
190
- "total_objects": len(results["scores"]),
191
- "avg_confidence": np.mean(confidences) if confidences else 0,
192
- "max_confidence": max(confidences) if confidences else 0,
193
- "min_confidence": min(confidences) if confidences else 0,
194
- "class_distribution": dict(class_counts),
195
- "detected_classes": list(class_counts.keys())
196
- }
197
-
198
- return stats
199
-
200
- def main():
201
- # Header
202
- st.markdown('<h1 class="main-header">🔍 Object Detection Playground</h1>', unsafe_allow_html=True)
203
- st.markdown('<p class="sub-header">Upload images and detect objects with DETR (Detection Transformer)</p>', unsafe_allow_html=True)
204
-
205
- # Initialize session state
206
- if 'processed_image' not in st.session_state:
207
- st.session_state.processed_image = None
208
- if 'detection_results' not in st.session_state:
209
- st.session_state.detection_results = None
210
 
211
- # Sidebar
212
- with st.sidebar:
213
- st.markdown("### ⚙️ Configuration")
214
-
215
- # Model info
216
- with st.expander("ℹ️ Model Information", expanded=True):
217
- st.markdown("""
218
- **Model:** facebook/detr-resnet-50
219
- **Architecture:** DETR (End-to-End Object Detection)
220
- **Backbone:** ResNet-50
221
- **Training Data:** COCO 2017
222
- **Classes:** 91 categories
223
- """)
224
-
225
- # Confidence threshold
226
- st.markdown("### 🎯 Confidence Settings")
227
- confidence_threshold = st.slider(
228
- "Detection Threshold",
229
- min_value=0.0,
230
- max_value=1.0,
231
- value=0.7,
232
- step=0.05,
233
- help="Objects with confidence below this threshold will be filtered out"
234
- )
235
-
236
- # Display options
237
- st.markdown("### 🎨 Display Options")
238
- show_labels = st.checkbox("Show labels on image", value=True)
239
- show_confidence = st.checkbox("Show confidence scores", value=True)
240
-
241
- # Performance options
242
- st.markdown("### ⚡ Performance")
243
- use_gpu = st.checkbox("Use GPU if available", value=True)
244
-
245
- # Load model button
246
- st.markdown("---")
247
- if st.button("🔄 Load/Reload Model", use_container_width=True):
248
- with st.spinner("Loading model..."):
249
- st.cache_resource.clear()
250
- processor, model = load_model()
251
- if processor and model:
252
  st.success("Model loaded successfully!")
253
-
254
- # Main content area
255
- col1, col2 = st.columns([2, 1])
 
 
 
 
 
 
 
 
 
 
 
 
256
 
257
  with col1:
258
- st.markdown("### 📤 Upload Image")
259
-
260
- # File uploader
261
- uploaded_file = st.file_uploader(
262
- "Choose an image file",
263
- type=['jpg', 'jpeg', 'png', 'bmp'],
264
- help="Supported formats: JPG, PNG, BMP"
265
- )
266
-
267
- # Or use sample images
268
- st.markdown("### 📸 Try Sample Images")
269
- sample_col1, sample_col2, sample_col3 = st.columns(3)
270
-
271
- sample_images = {
272
- "Street": "https://raw.githubusercontent.com/ultralytics/yolov5/master/data/images/bus.jpg",
273
- "Office": "https://raw.githubusercontent.com/ultralytics/yolov5/master/data/images/zidane.jpg",
274
- "Animals": "https://images.unsplash.com/photo-1564349683136-77e08dba1ef7?w=800&h=600&fit=crop"
275
- }
276
-
277
- with sample_col1:
278
- if st.button("Street Scene", use_container_width=True):
279
- st.session_state.sample_image = "street"
280
- with sample_col2:
281
- if st.button("Office Scene", use_container_width=True):
282
- st.session_state.sample_image = "office"
283
- with sample_col3:
284
- if st.button("Animals", use_container_width=True):
285
- st.session_state.sample_image = "animals"
286
-
287
- # Display uploaded or sample image
288
- image = None
289
- if uploaded_file is not None:
290
- image = Image.open(uploaded_file)
291
- st.image(image, caption="Uploaded Image", use_column_width=True)
292
- elif 'sample_image' in st.session_state:
293
- # Note: In HuggingFace Spaces, you might need to handle sample images differently
294
- # For now, we'll use placeholder
295
- st.info("Sample images require internet connection. Please upload your own image for testing.")
296
- image = None
297
-
298
- # Process button
299
- if image is not None:
300
- if st.button("🚀 Detect Objects", type="primary", use_container_width=True):
301
- with st.spinner("Processing image..."):
302
- # Load model
303
- processor, model = load_model()
304
 
305
- if processor and model:
306
- # Process image
307
- results = process_image(image, processor, model, confidence_threshold)
308
-
309
- if results:
310
- # Draw detections
311
- if show_labels:
312
- annotated_image = draw_detections(image, results, processor, model)
313
- else:
314
- annotated_image = image
315
-
316
- # Get statistics
317
- stats = get_statistics(results, model)
318
-
319
- # Store in session state
320
- st.session_state.processed_image = annotated_image
321
- st.session_state.detection_results = results
322
- st.session_state.stats = stats
323
- st.session_state.model = model
324
-
325
- st.success(f"Detected {stats['total_objects']} objects!")
326
-
327
- # Display results in right column
328
- with col2:
329
- st.markdown("### 📊 Detection Statistics")
330
-
331
- if 'stats' in st.session_state and st.session_state.stats:
332
- stats = st.session_state.stats
333
-
334
- # Metrics
335
- metric_col1, metric_col2 = st.columns(2)
336
- with metric_col1:
337
- st.metric("Total Objects", stats['total_objects'])
338
- st.metric("Avg Confidence", f"{stats['avg_confidence']:.1%}")
339
- with metric_col2:
340
- st.metric("Max Confidence", f"{stats['max_confidence']:.1%}")
341
- st.metric("Unique Classes", len(stats['detected_classes']))
342
-
343
- # Class distribution
344
- if stats['class_distribution']:
345
- st.markdown("#### 🏷️ Detected Classes")
346
- for class_name, count in sorted(stats['class_distribution'].items(), key=lambda x: x[1], reverse=True):
347
- st.markdown(f"**{class_name}**: {count} objects")
348
-
349
- # Confidence histogram
350
- if st.session_state.detection_results and len(st.session_state.detection_results["scores"]) > 0:
351
- st.markdown("#### 📈 Confidence Distribution")
352
- confidences = [s.item() for s in st.session_state.detection_results["scores"]]
353
- hist_values = np.histogram(confidences, bins=10, range=(0, 1))[0]
354
- st.bar_chart(hist_values)
355
-
356
- else:
357
- st.info("No detection results yet. Upload an image and click 'Detect Objects'.")
358
-
359
- # Display processed image below
360
- if st.session_state.processed_image is not None:
361
- st.markdown("---")
362
- st.markdown("### 🖼️ Detection Results")
363
-
364
- result_col1, result_col2 = st.columns([3, 1])
365
-
366
- with result_col1:
367
- st.image(
368
- st.session_state.processed_image,
369
- caption=f"Detected Objects (Threshold: {confidence_threshold})",
370
- use_column_width=True
371
- )
372
-
373
- with result_col2:
374
- # Download button
375
- if st.session_state.processed_image:
376
- from io import BytesIO
377
- buffered = BytesIO()
378
- st.session_state.processed_image.save(buffered, format="PNG")
379
- st.download_button(
380
- label="💾 Download Result",
381
- data=buffered.getvalue(),
382
- file_name="detection_result.png",
383
- mime="image/png",
384
- use_container_width=True
385
- )
386
-
387
- # Reset button
388
- if st.button("🔄 Clear Results", use_container_width=True):
389
- st.session_state.processed_image = None
390
- st.session_state.detection_results = None
391
- if 'stats' in st.session_state:
392
- del st.session_state.stats
393
- st.rerun()
394
-
395
- # Footer with model capabilities
396
- st.markdown("---")
397
-
398
- # Model capabilities section
399
- st.markdown("### 🎯 What Can DETR Detect?")
400
-
401
- capabilities_col1, capabilities_col2, capabilities_col3 = st.columns(3)
402
-
403
- with capabilities_col1:
404
- st.markdown("""
405
- **👥 People & Animals**
406
- - person
407
- - dog, cat, bird
408
- - horse, sheep, cow
409
- - bear, zebra, giraffe
410
- """)
411
-
412
- with capabilities_col2:
413
- st.markdown("""
414
- **🚗 Vehicles**
415
- - car, truck, bus
416
- - bicycle, motorcycle
417
- - airplane, boat
418
- - train
419
- """)
420
-
421
- with capabilities_col3:
422
- st.markdown("""
423
- **🏠 Everyday Objects**
424
- - chair, sofa, bed
425
- - dining table
426
- - tv, laptop, mouse
427
- - bottle, cup, fork
428
- """)
429
-
430
- # Tips and instructions
431
- with st.expander("💡 Tips for Best Results"):
432
- st.markdown("""
433
- 1. **Use clear images** with good lighting
434
- 2. **Start with threshold 0.7** and adjust as needed
435
- 3. **For crowded scenes**, increase threshold to reduce false positives
436
- 4. **For small objects**, decrease threshold to catch more detections
437
- 5. **Images with multiple objects** work best with DETR
438
- 6. **Allow model to load** on first run (takes about 30 seconds)
439
- """)
440
-
441
- # Footer
442
- st.markdown("---")
443
- st.markdown(
444
- "<div style='text-align: center; color: #666;'>"
445
- "Object Detection Playground • Powered by <a href='https://huggingface.co/facebook/detr-resnet-50' target='_blank'>DETR</a> • "
446
- "Built with ❤️ using Streamlit"
447
- "</div>",
448
- unsafe_allow_html=True
449
- )
450
 
451
- if __name__ == "__main__":
452
- main()
 
 
3
  from transformers import DetrImageProcessor, DetrForObjectDetection
4
  from PIL import Image, ImageDraw
5
  import numpy as np
 
 
6
  import time
 
 
7
 
8
+ # Page config
9
  st.set_page_config(
10
+ page_title="Simple Object Detection",
11
  page_icon="🔍",
12
  layout="wide"
13
  )
14
 
15
+ # Title
16
+ st.title("🔍 Simple Object Detection with DETR")
17
+ st.markdown("Upload an image to detect objects using Facebook's DETR model")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
+ # Initialize model in session state
20
+ if 'model_loaded' not in st.session_state:
21
+ st.session_state.model_loaded = False
22
+ st.session_state.processor = None
23
+ st.session_state.model = None
 
 
 
 
 
 
 
 
24
 
25
+ # Sidebar
26
+ with st.sidebar:
27
+ st.header("Settings")
28
+
29
+ # Confidence threshold
30
+ confidence = st.slider(
31
+ "Confidence Threshold",
32
+ min_value=0.1,
33
+ max_value=0.99,
34
+ value=0.7,
35
+ step=0.05
36
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
+ # Load model button
39
+ if not st.session_state.model_loaded:
40
+ if st.button("Load Model", type="primary"):
41
+ with st.spinner("Loading DETR model..."):
42
+ try:
43
+ st.session_state.processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
44
+ st.session_state.model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
45
+ st.session_state.model_loaded = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  st.success("Model loaded successfully!")
47
+ except Exception as e:
48
+ st.error(f"Error loading model: {e}")
49
+ else:
50
+ st.success("✅ Model is loaded and ready!")
51
+
52
+ # Main content
53
+ uploaded_file = st.file_uploader(
54
+ "Choose an image...",
55
+ type=['jpg', 'jpeg', 'png']
56
+ )
57
+
58
+ if uploaded_file is not None:
59
+ # Display original image
60
+ image = Image.open(uploaded_file).convert("RGB")
61
+ col1, col2 = st.columns(2)
62
 
63
  with col1:
64
+ st.image(image, caption="Original Image", use_column_width=True)
65
+
66
+ if st.session_state.model_loaded and st.button("Detect Objects"):
67
+ with st.spinner("Detecting objects..."):
68
+ try:
69
+ # Process image
70
+ processor = st.session_state.processor
71
+ model = st.session_state.model
72
+
73
+ inputs = processor(images=image, return_tensors="pt")
74
+
75
+ with torch.no_grad():
76
+ outputs = model(**inputs)
77
+
78
+ # Convert outputs
79
+ target_sizes = torch.tensor([image.size[::-1]])
80
+ results = processor.post_process_object_detection(
81
+ outputs,
82
+ target_sizes=target_sizes,
83
+ threshold=confidence
84
+ )[0]
85
+
86
+ # Draw boxes
87
+ draw = ImageDraw.Draw(image)
88
+ colors = ["red", "green", "blue", "yellow", "purple", "orange"]
89
+
90
+ detected_objects = []
91
+
92
+ for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
93
+ box = [round(i, 2) for i in box.tolist()]
94
+ label_name = model.config.id2label[label.item()]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
+ # Draw rectangle
97
+ color = colors[label.item() % len(colors)]
98
+ draw.rectangle(box, outline=color, width=3)
99
+
100
+ # Add label
101
+ label_text = f"{label_name}: {score:.2f}"
102
+ draw.text((box[0], box[1]), label_text, fill=color)
103
+
104
+ detected_objects.append((label_name, score.item()))
105
+
106
+ # Display results
107
+ with col2:
108
+ st.image(image, caption="Detected Objects", use_column_width=True)
109
+
110
+ # Show statistics
111
+ st.subheader("📊 Detection Results")
112
+
113
+ if detected_objects:
114
+ col_stats1, col_stats2, col_stats3 = st.columns(3)
115
+
116
+ with col_stats1:
117
+ st.metric("Objects Found", len(detected_objects))
118
+
119
+ with col_stats2:
120
+ avg_conf = np.mean([score for _, score in detected_objects])
121
+ st.metric("Average Confidence", f"{avg_conf:.1%}")
122
+
123
+ with col_stats3:
124
+ st.metric("Unique Classes", len(set([label for label, _ in detected_objects])))
125
+
126
+ # Show details
127
+ st.subheader("Detected Objects:")
128
+ for label, score in detected_objects:
129
+ st.write(f"- **{label}** (confidence: {score:.1%})")
130
+ else:
131
+ st.warning("No objects detected above the confidence threshold.")
132
+
133
+ except Exception as e:
134
+ st.error(f"Error during detection: {e}")
135
+ else:
136
+ st.info("👈 Please upload an image and load the model from the sidebar")
137
+
138
+ # Instructions
139
+ with st.expander("How to use this app"):
140
+ st.markdown("""
141
+ 1. **Load the model** using the button in the sidebar
142
+ 2. **Upload an image** (JPG, PNG formats)
143
+ 3. **Adjust confidence threshold** if needed
144
+ 4. **Click 'Detect Objects'** to run detection
145
+ 5. **View results** and detected objects
146
+ """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
+ # Footer
149
+ st.markdown("---")
150
+ st.markdown("Built with [DETR](https://huggingface.co/facebook/detr-resnet-50) • [Streamlit](https://streamlit.io)")