molecularmax commited on
Commit
dfb6ca6
·
verified ·
1 Parent(s): 2eeea7e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +449 -0
app.py ADDED
@@ -0,0 +1,449 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import cv2
4
+ import gradio as gr
5
+ import torch
6
+ from PIL import Image, ImageDraw, ImageFont
7
+ import matplotlib.pyplot as plt
8
+ from facenet_pytorch import MTCNN, RetinaFace
9
+ from retinaface.pre_trained_models import get_model as get_retinaface_model
10
+ import matplotlib.cm as cm
11
+ from collections import defaultdict
12
+
13
+ # Set up device
14
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
15
+ print(f"Using device: {device}")
16
+
17
+ # Load face detector models for ensemble
18
+ models = {}
19
+
20
+ # Initialize MTCNN
21
+ models['mtcnn'] = MTCNN(keep_all=True, device=device)
22
+
23
+ # Initialize RetinaFace
24
+ models['retinaface'] = get_retinaface_model("resnet50", max_size=1024, device=device.type)
25
+ models['retinaface'].eval()
26
+
27
+ def load_images_from_folder(folder_path):
28
+ """Load all jpg images from the specified folder"""
29
+ image_paths = []
30
+ if os.path.exists(folder_path):
31
+ for filename in os.listdir(folder_path):
32
+ if filename.lower().endswith(('.jpg', '.jpeg')):
33
+ image_paths.append(os.path.join(folder_path, filename))
34
+ return sorted(image_paths)
35
+
36
+ def detect_faces_ensemble(image):
37
+ """
38
+ Detect faces using an ensemble of face detectors
39
+ Returns: List of face bounding boxes with format [x1, y1, x2, y2, confidence]
40
+ """
41
+ # Convert image to RGB if needed
42
+ if isinstance(image, str):
43
+ image = Image.open(image).convert('RGB')
44
+ elif isinstance(image, np.ndarray):
45
+ if image.shape[2] == 3:
46
+ image = Image.fromarray(image)
47
+ else:
48
+ image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
49
+
50
+ # Get MTCNN detections
51
+ boxes_mtcnn, probs_mtcnn = models['mtcnn'].detect(image)
52
+
53
+ # Get RetinaFace detections
54
+ tensor_image = models['retinaface'].preprocess_image(np.array(image))
55
+ with torch.no_grad():
56
+ boxes_retinaface, scores_retinaface = models['retinaface'].predict(tensor_image)
57
+
58
+ # Ensemble the results (in this simple case, we'll just combine them)
59
+ all_boxes = []
60
+
61
+ # Add MTCNN boxes
62
+ if boxes_mtcnn is not None:
63
+ for box, prob in zip(boxes_mtcnn, probs_mtcnn):
64
+ x1, y1, x2, y2 = box
65
+ all_boxes.append([int(x1), int(y1), int(x2), int(y2), float(prob)])
66
+
67
+ # Add RetinaFace boxes
68
+ if len(boxes_retinaface) > 0:
69
+ for box, score in zip(boxes_retinaface, scores_retinaface):
70
+ x1, y1, x2, y2 = box
71
+ all_boxes.append([int(x1), int(y1), int(x2), int(y2), float(score)])
72
+
73
+ # Apply non-maximum suppression to remove duplicate detections
74
+ if len(all_boxes) > 0:
75
+ all_boxes = non_maximum_suppression(all_boxes, 0.5)
76
+
77
+ return all_boxes, image
78
+
79
+ def calculate_iou(box1, box2):
80
+ """Calculate intersection over union between two boxes"""
81
+ x1_1, y1_1, x2_1, y2_1 = box1[:4]
82
+ x1_2, y1_2, x2_2, y2_2 = box2[:4]
83
+
84
+ # Calculate intersection area
85
+ x_left = max(x1_1, x1_2)
86
+ y_top = max(y1_1, y1_2)
87
+ x_right = min(x2_1, x2_2)
88
+ y_bottom = min(y2_1, y2_2)
89
+
90
+ if x_right < x_left or y_bottom < y_top:
91
+ return 0.0
92
+
93
+ intersection_area = (x_right - x_left) * (y_bottom - y_top)
94
+
95
+ # Calculate union area
96
+ box1_area = (x2_1 - x1_1) * (y2_1 - y1_1)
97
+ box2_area = (x2_2 - x1_2) * (y2_2 - y1_2)
98
+ union_area = box1_area + box2_area - intersection_area
99
+
100
+ return intersection_area / union_area
101
+
102
+ def non_maximum_suppression(boxes, iou_threshold):
103
+ """Apply non-maximum suppression to remove overlapping boxes"""
104
+ if len(boxes) == 0:
105
+ return []
106
+
107
+ # Sort boxes by confidence (descending)
108
+ boxes = sorted(boxes, key=lambda x: x[4], reverse=True)
109
+ kept_boxes = []
110
+
111
+ while len(boxes) > 0:
112
+ # Add the box with highest confidence
113
+ current_box = boxes.pop(0)
114
+ kept_boxes.append(current_box)
115
+
116
+ # Remove overlapping boxes
117
+ remaining_boxes = []
118
+ for box in boxes:
119
+ if calculate_iou(current_box, box) < iou_threshold:
120
+ remaining_boxes.append(box)
121
+
122
+ boxes = remaining_boxes
123
+
124
+ return kept_boxes
125
+
126
+ def bin_faces_by_size(faces):
127
+ """Group faces into bins based on their size (max of width and height)"""
128
+ face_sizes = []
129
+ bin_size = 20 # Size of each bin in pixels
130
+
131
+ # Calculate face sizes
132
+ for face in faces:
133
+ x1, y1, x2, y2, _ = face
134
+ width = x2 - x1
135
+ height = y2 - y1
136
+ size = max(width, height)
137
+ face_sizes.append(size)
138
+
139
+ # Determine bin range
140
+ if not face_sizes:
141
+ return {}
142
+
143
+ min_size = min(face_sizes)
144
+ max_size = max(face_sizes)
145
+
146
+ # Create bins
147
+ bin_edges = range(
148
+ bin_size * (min_size // bin_size),
149
+ bin_size * (max_size // bin_size + 2),
150
+ bin_size
151
+ )
152
+
153
+ # Place faces in bins
154
+ bin_counts = defaultdict(int)
155
+ bin_faces = defaultdict(list)
156
+
157
+ for i, size in enumerate(face_sizes):
158
+ bin_idx = size // bin_size * bin_size
159
+ bin_counts[bin_idx] += 1
160
+ bin_faces[bin_idx].append((faces[i], size))
161
+
162
+ return {
163
+ 'bin_counts': dict(bin_counts),
164
+ 'bin_faces': dict(bin_faces),
165
+ 'bin_edges': list(bin_edges)
166
+ }
167
+
168
+ def plot_face_histogram(bin_data):
169
+ """Create a histogram of face sizes"""
170
+ if not bin_data or len(bin_data['bin_counts']) == 0:
171
+ # Create empty figure if no data
172
+ fig, ax = plt.subplots(figsize=(10, 6))
173
+ ax.set_title('Face Size Distribution')
174
+ ax.set_xlabel('Face Size (pixels)')
175
+ ax.set_ylabel('Count')
176
+ ax.text(0.5, 0.5, 'No faces detected', ha='center', va='center', transform=ax.transAxes)
177
+ return fig
178
+
179
+ # Extract data
180
+ bins = sorted(bin_data['bin_counts'].keys())
181
+ counts = [bin_data['bin_counts'][b] for b in bins]
182
+
183
+ # Create histogram figure
184
+ fig, ax = plt.subplots(figsize=(10, 6))
185
+ bars = ax.bar(
186
+ [str(b) for b in bins],
187
+ counts,
188
+ color='skyblue',
189
+ edgecolor='navy'
190
+ )
191
+
192
+ # Add value labels
193
+ for bar in bars:
194
+ height = bar.get_height()
195
+ ax.annotate(
196
+ f'{height}',
197
+ xy=(bar.get_x() + bar.get_width() / 2, height),
198
+ xytext=(0, 3),
199
+ textcoords="offset points",
200
+ ha='center', va='bottom'
201
+ )
202
+
203
+ ax.set_title('Face Size Distribution')
204
+ ax.set_xlabel('Face Size (pixels)')
205
+ ax.set_ylabel('Count')
206
+
207
+ # Rotate x-axis labels for better readability
208
+ plt.xticks(rotation=45, ha='right')
209
+ plt.tight_layout()
210
+
211
+ return fig
212
+
213
+ def create_face_examples_grid(image, bin_data, selected_bin=None):
214
+ """Create a grid of face examples from the selected bin"""
215
+ if not bin_data or 'bin_faces' not in bin_data or not bin_data['bin_faces']:
216
+ return None
217
+
218
+ if isinstance(image, str):
219
+ image = Image.open(image).convert('RGB')
220
+ elif isinstance(image, np.ndarray):
221
+ image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
222
+
223
+ # If no bin is selected, return None
224
+ if selected_bin is None:
225
+ return None
226
+
227
+ # Get faces from the selected bin
228
+ if int(selected_bin) not in bin_data['bin_faces']:
229
+ return None
230
+
231
+ bin_faces = bin_data['bin_faces'][int(selected_bin)]
232
+
233
+ # Determine grid size
234
+ num_faces = len(bin_faces)
235
+ cols = min(5, num_faces)
236
+ rows = (num_faces + cols - 1) // cols
237
+
238
+ # Create empty white canvas for the grid
239
+ margin = 10
240
+ face_size = int(selected_bin) + 2 * margin
241
+
242
+ grid_width = cols * face_size + (cols + 1) * margin
243
+ grid_height = rows * face_size + (rows + 1) * margin
244
+
245
+ grid_image = Image.new('RGB', (grid_width, grid_height), color='white')
246
+ draw = ImageDraw.Draw(grid_image)
247
+
248
+ # Extract and place faces on the grid
249
+ for i, (face, size) in enumerate(bin_faces):
250
+ x1, y1, x2, y2, conf = face
251
+
252
+ # Calculate position in the grid
253
+ row = i // cols
254
+ col = i % cols
255
+
256
+ # Extract face with margin
257
+ face_img = image.crop((
258
+ max(0, x1 - margin),
259
+ max(0, y1 - margin),
260
+ min(image.width, x2 + margin),
261
+ min(image.height, y2 + margin)
262
+ ))
263
+
264
+ # Resize to consistent size if needed
265
+ target_size = face_size - 2 * margin
266
+ if face_img.width != target_size or face_img.height != target_size:
267
+ face_img = face_img.resize((target_size, target_size))
268
+
269
+ # Place face in grid
270
+ grid_x = col * face_size + (col + 1) * margin
271
+ grid_y = row * face_size + (row + 1) * margin
272
+
273
+ grid_image.paste(face_img, (grid_x, grid_y))
274
+
275
+ # Add size label
276
+ draw.rectangle(
277
+ [grid_x, grid_y + target_size - 20, grid_x + target_size, grid_y + target_size],
278
+ fill=(0, 0, 0, 128)
279
+ )
280
+ draw.text(
281
+ (grid_x + 5, grid_y + target_size - 15),
282
+ f"{size}px",
283
+ fill=(255, 255, 255)
284
+ )
285
+
286
+ return grid_image
287
+
288
+ def draw_faces_on_image(image, faces):
289
+ """Draw bounding boxes around detected faces"""
290
+ if isinstance(image, str):
291
+ image = Image.open(image).convert('RGB')
292
+ elif isinstance(image, np.ndarray):
293
+ image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
294
+
295
+ # Create a copy of the image
296
+ result_image = image.copy()
297
+ draw = ImageDraw.Draw(result_image)
298
+
299
+ # Generate colors for different face sizes
300
+ if faces:
301
+ sizes = [max(face[2] - face[0], face[3] - face[1]) for face in faces]
302
+ min_size = min(sizes)
303
+ max_size = max(sizes)
304
+ size_range = max(max_size - min_size, 1)
305
+
306
+ # Draw faces
307
+ for face in faces:
308
+ x1, y1, x2, y2, conf = face
309
+ width = x2 - x1
310
+ height = y2 - y1
311
+ size = max(width, height)
312
+
313
+ # Determine color based on face size
314
+ if max_size == min_size:
315
+ normalized_size = 0.5
316
+ else:
317
+ normalized_size = (size - min_size) / size_range
318
+
319
+ # Use a color gradient from blue to red
320
+ color_r = int(255 * normalized_size)
321
+ color_g = 0
322
+ color_b = int(255 * (1 - normalized_size))
323
+
324
+ # Draw rectangle
325
+ draw.rectangle([x1, y1, x2, y2], outline=(color_r, color_g, color_b), width=2)
326
+
327
+ # Draw size and confidence label
328
+ label = f"{size}px ({conf:.2f})"
329
+ draw.rectangle([x1, y1, x1 + 100, y1 - 20], fill=(color_r, color_g, color_b))
330
+ draw.text((x1 + 5, y1 - 15), label, fill=(255, 255, 255))
331
+
332
+ return result_image
333
+
334
+ def process_image(image, selected_bin=None):
335
+ """Main function to process an image and return results"""
336
+ # Detect faces
337
+ faces, img = detect_faces_ensemble(image)
338
+
339
+ # Bin faces by size
340
+ bin_data = bin_faces_by_size(faces)
341
+
342
+ # Create visualizations
343
+ annotated_image = draw_faces_on_image(img, faces)
344
+ histogram = plot_face_histogram(bin_data)
345
+
346
+ # Create face examples grid for selected bin
347
+ examples_grid = create_face_examples_grid(img, bin_data, selected_bin)
348
+
349
+ # Handle the case when no bin is selected
350
+ if selected_bin is None or examples_grid is None:
351
+ available_bins = sorted(bin_data['bin_counts'].keys()) if bin_data else []
352
+ return annotated_image, histogram, None, gr.Dropdown.update(choices=[str(b) for b in available_bins])
353
+
354
+ # Update dropdown choices
355
+ available_bins = sorted(bin_data['bin_counts'].keys()) if bin_data else []
356
+
357
+ return annotated_image, histogram, examples_grid, gr.Dropdown.update(choices=[str(b) for b in available_bins])
358
+
359
+ def update_examples(image, selected_bin):
360
+ """Update face examples when a bin is selected"""
361
+ # Detect faces
362
+ faces, img = detect_faces_ensemble(image)
363
+
364
+ # Bin faces by size
365
+ bin_data = bin_faces_by_size(faces)
366
+
367
+ # Create face examples grid for selected bin
368
+ examples_grid = create_face_examples_grid(img, bin_data, selected_bin)
369
+
370
+ return examples_grid
371
+
372
+ # Create Gradio interface
373
+ with gr.Blocks(title="Face Size Distribution Analysis") as demo:
374
+ gr.Markdown("# Face Size Distribution Analysis")
375
+ gr.Markdown("Upload an image or select from the examples to see the distribution of face sizes")
376
+
377
+ with gr.Row():
378
+ with gr.Column(scale=1):
379
+ # Input components
380
+ input_image = gr.Image(type="pil", label="Input Image")
381
+ example_dropdown = gr.Dropdown(
382
+ choices=[],
383
+ label="Select from available images",
384
+ interactive=True
385
+ )
386
+ run_button = gr.Button("Analyze Image")
387
+
388
+ # Bin selection for examples
389
+ bin_dropdown = gr.Dropdown(
390
+ choices=[],
391
+ label="Select size bin to see examples",
392
+ interactive=True
393
+ )
394
+
395
+ with gr.Column(scale=2):
396
+ # Output components
397
+ output_image = gr.Image(type="pil", label="Detected Faces")
398
+ with gr.Tab("Histogram"):
399
+ histogram_plot = gr.Plot(label="Face Size Distribution")
400
+ with gr.Tab("Face Examples"):
401
+ examples_grid = gr.Image(type="pil", label="Face Examples")
402
+
403
+ # Load example images on startup
404
+ def load_examples():
405
+ examples = load_images_from_folder("data")
406
+ return gr.Dropdown.update(choices=[os.path.basename(path) for path in examples], value=examples[0] if examples else None)
407
+
408
+ # Handle example selection
409
+ def select_example(example_name):
410
+ if not example_name:
411
+ return None
412
+
413
+ # Look for the example in the data folder
414
+ example_path = os.path.join("data", example_name)
415
+ if os.path.exists(example_path):
416
+ return example_path
417
+ return None
418
+
419
+ # Set up event handlers
420
+ run_button.click(
421
+ process_image,
422
+ inputs=[input_image, bin_dropdown],
423
+ outputs=[output_image, histogram_plot, examples_grid, bin_dropdown]
424
+ )
425
+
426
+ example_dropdown.change(
427
+ select_example,
428
+ inputs=[example_dropdown],
429
+ outputs=[input_image]
430
+ )
431
+
432
+ input_image.change(
433
+ process_image,
434
+ inputs=[input_image, None],
435
+ outputs=[output_image, histogram_plot, examples_grid, bin_dropdown]
436
+ )
437
+
438
+ bin_dropdown.change(
439
+ update_examples,
440
+ inputs=[input_image, bin_dropdown],
441
+ outputs=[examples_grid]
442
+ )
443
+
444
+ # Load examples on startup
445
+ demo.load(load_examples, outputs=[example_dropdown])
446
+
447
+ # Launch the demo
448
+ if __name__ == "__main__":
449
+ demo.launch()