recklessforlife commited on
Commit
5682a66
·
verified ·
1 Parent(s): 65cccbb

Create Test.py

Browse files
Files changed (1) hide show
  1. Test.py +831 -0
Test.py ADDED
@@ -0,0 +1,831 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # =============================================================================
3
+ # FACE CLASSIFIER TESTING PROGRAM
4
+ # Tests trained model on images with face detection and cropping
5
+ # =============================================================================
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torchvision.transforms as transforms
10
+ import cv2
11
+ import numpy as np
12
+ from PIL import Image, ImageDraw, ImageFont
13
+ import os
14
+ import matplotlib.pyplot as plt
15
+ from pathlib import Path
16
+ import time
17
+ from tqdm import tqdm
18
+
19
+ # =============================================================================
20
+ # CONFIGURATION
21
+ # =============================================================================
22
+
23
+ # Paths
24
+ MODEL_PATH = r"../Training/best_face_classifier_real_data.pth"
25
+ TEST_IMAGES_PATH = r"\Pictures\Saved Pictures"
26
+ OUTPUT_PATH = "test_results"
27
+
28
+ # Model parameters (must match training configuration)
29
+ IMAGE_SIZE = 224
30
+ INPUT_CHANNELS = 3
31
+ NUM_CLASSES = 1
32
+ CONV_FILTERS = [128, 256, 512] # Updated to match TrainV3.py
33
+ FC_SIZES = [1024, 512]
34
+ DROPOUT_RATES = [0.3, 0.5]
35
+
36
+ # Image processing
37
+ FACE_DETECTION_SCALE_FACTOR = 1.1
38
+ FACE_DETECTION_MIN_NEIGHBORS = 5
39
+ MIN_FACE_SIZE = (30, 30)
40
+ IMAGE_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp']
41
+
42
+ # Visualization
43
+ CONFIDENCE_THRESHOLD = 0.8
44
+ SAVE_RESULTS = True
45
+ SHOW_PLOTS = True
46
+ SAVE_INDIVIDUAL_IMAGES = True # Save each image with annotations
47
+ CREATE_COMPREHENSIVE_SUMMARY = True # Create complete grid summary
48
+
49
+ # =============================================================================
50
+ # MODEL ARCHITECTURE (Must match training)
51
+ # =============================================================================
52
+
53
+ class ImprovedFaceClassifierCNN(nn.Module):
54
+ """Same architecture as used in training"""
55
+
56
+ def __init__(self):
57
+ super().__init__()
58
+
59
+ # Feature extraction layers
60
+ self.features = nn.Sequential(
61
+ # Block 1: 224x224 -> 112x112
62
+ nn.Conv2d(INPUT_CHANNELS, CONV_FILTERS[0], 3, padding=1),
63
+ nn.BatchNorm2d(CONV_FILTERS[0]),
64
+ nn.ReLU(inplace=True),
65
+ nn.Conv2d(CONV_FILTERS[0], CONV_FILTERS[0], 3, padding=1),
66
+ nn.BatchNorm2d(CONV_FILTERS[0]),
67
+ nn.ReLU(inplace=True),
68
+ nn.MaxPool2d(2, 2),
69
+ nn.Dropout(DROPOUT_RATES[0]),
70
+
71
+ # Block 2: 112x112 -> 56x56
72
+ nn.Conv2d(CONV_FILTERS[0], CONV_FILTERS[1], 3, padding=1),
73
+ nn.BatchNorm2d(CONV_FILTERS[1]),
74
+ nn.ReLU(inplace=True),
75
+ nn.Conv2d(CONV_FILTERS[1], CONV_FILTERS[1], 3, padding=1),
76
+ nn.BatchNorm2d(CONV_FILTERS[1]),
77
+ nn.ReLU(inplace=True),
78
+ nn.MaxPool2d(2, 2),
79
+ nn.Dropout(DROPOUT_RATES[0]),
80
+
81
+ # Block 3: 56x56 -> 28x28
82
+ nn.Conv2d(CONV_FILTERS[1], CONV_FILTERS[2], 3, padding=1),
83
+ nn.BatchNorm2d(CONV_FILTERS[2]),
84
+ nn.ReLU(inplace=True),
85
+ nn.Conv2d(CONV_FILTERS[2], CONV_FILTERS[2], 3, padding=1),
86
+ nn.BatchNorm2d(CONV_FILTERS[2]),
87
+ nn.ReLU(inplace=True),
88
+ nn.MaxPool2d(2, 2),
89
+ nn.Dropout(DROPOUT_RATES[0]),
90
+
91
+ # Global Average Pooling
92
+ nn.AdaptiveAvgPool2d((7, 7))
93
+ )
94
+
95
+ # Classifier
96
+ self.classifier = nn.Sequential(
97
+ nn.Linear(CONV_FILTERS[2] * 7 * 7, FC_SIZES[0]),
98
+ nn.BatchNorm1d(FC_SIZES[0]),
99
+ nn.ReLU(inplace=True),
100
+ nn.Dropout(DROPOUT_RATES[1]),
101
+
102
+ nn.Linear(FC_SIZES[0], FC_SIZES[1]),
103
+ nn.ReLU(inplace=True),
104
+ nn.Dropout(DROPOUT_RATES[1]),
105
+
106
+ nn.Linear(FC_SIZES[1], NUM_CLASSES)
107
+ )
108
+
109
+ def forward(self, x):
110
+ x = self.features(x)
111
+ x = x.view(x.size(0), -1)
112
+ return self.classifier(x)
113
+
114
+ # =============================================================================
115
+ # FACE DETECTION AND PROCESSING
116
+ # =============================================================================
117
+
118
+ class FaceProcessor:
119
+ """Face detection and processing for classification"""
120
+
121
+ def __init__(self):
122
+ # Initialize face detector (OpenCV Haar Cascade)
123
+ self.face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
124
+
125
+ # Alternative: Try to use more accurate DNN face detector if available
126
+ try:
127
+ # Download OpenCV DNN face detector if not present
128
+ self.net = None
129
+ self.use_dnn = False
130
+ # Note: For production, you might want to use a more sophisticated face detector
131
+ except:
132
+ self.net = None
133
+ self.use_dnn = False
134
+
135
+ # Image preprocessing transform
136
+ self.transform = transforms.Compose([
137
+ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
138
+ transforms.ToTensor(),
139
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
140
+ ])
141
+
142
+ def detect_faces(self, image):
143
+ """Detect faces in image and return bounding boxes with duplicate removal"""
144
+ if isinstance(image, Image.Image):
145
+ # Convert PIL to OpenCV format
146
+ image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
147
+ else:
148
+ image_cv = image.copy()
149
+
150
+ # Convert to grayscale for face detection
151
+ gray = cv2.cvtColor(image_cv, cv2.COLOR_BGR2GRAY)
152
+
153
+ # Detect faces
154
+ faces = self.face_cascade.detectMultiScale(
155
+ gray,
156
+ scaleFactor=FACE_DETECTION_SCALE_FACTOR,
157
+ minNeighbors=FACE_DETECTION_MIN_NEIGHBORS,
158
+ minSize=MIN_FACE_SIZE,
159
+ flags=cv2.CASCADE_SCALE_IMAGE
160
+ )
161
+
162
+ # Remove duplicate/overlapping faces using Non-Maximum Suppression
163
+ if len(faces) > 1:
164
+ faces = self._remove_duplicate_faces(faces)
165
+
166
+ return faces
167
+
168
+ def _remove_duplicate_faces(self, faces, overlap_threshold=0.15):
169
+ """Remove duplicate/overlapping face detections using improved NMS"""
170
+ if len(faces) <= 1:
171
+ return faces
172
+
173
+ # Convert to list for easier manipulation
174
+ face_list = list(faces)
175
+
176
+ # Calculate areas and create extended info
177
+ face_info = []
178
+ for i, (x, y, w, h) in enumerate(face_list):
179
+ area = w * h
180
+ face_info.append({
181
+ 'index': i,
182
+ 'bbox': (x, y, w, h),
183
+ 'area': area,
184
+ 'x1': x, 'y1': y, 'x2': x + w, 'y2': y + h
185
+ })
186
+
187
+ # Sort by area (larger faces first - usually more reliable)
188
+ face_info.sort(key=lambda f: f['area'], reverse=True)
189
+
190
+ keep_indices = []
191
+
192
+ for i, current_face in enumerate(face_info):
193
+ should_keep = True
194
+
195
+ # Check against all previously kept faces
196
+ for kept_idx in keep_indices:
197
+ kept_face = face_info[kept_idx]
198
+
199
+ # Calculate intersection
200
+ x1 = max(current_face['x1'], kept_face['x1'])
201
+ y1 = max(current_face['y1'], kept_face['y1'])
202
+ x2 = min(current_face['x2'], kept_face['x2'])
203
+ y2 = min(current_face['y2'], kept_face['y2'])
204
+
205
+ if x1 < x2 and y1 < y2:
206
+ intersection = (x2 - x1) * (y2 - y1)
207
+
208
+ # Calculate IoU
209
+ union = current_face['area'] + kept_face['area'] - intersection
210
+ iou = intersection / union if union > 0 else 0
211
+
212
+ # Also check overlap ratio (intersection over smaller box)
213
+ smaller_area = min(current_face['area'], kept_face['area'])
214
+ overlap_ratio = intersection / smaller_area if smaller_area > 0 else 0
215
+
216
+ # Remove if either IoU or overlap ratio is too high
217
+ if iou > overlap_threshold or overlap_ratio > 0.5:
218
+ should_keep = False
219
+ break
220
+
221
+ if should_keep:
222
+ keep_indices.append(i)
223
+
224
+ # Return filtered faces
225
+ filtered_faces = np.array([face_info[i]['bbox'] for i in keep_indices])
226
+
227
+ # Debug info
228
+ if len(faces) != len(filtered_faces):
229
+ print(f" [NMS] Removed {len(faces) - len(filtered_faces)} duplicate faces "
230
+ f"({len(faces)} → {len(filtered_faces)})")
231
+
232
+ return filtered_faces
233
+
234
+ def crop_face(self, image, face_bbox, expand_ratio=0.2):
235
+ """Crop face from image with some padding"""
236
+ x, y, w, h = face_bbox
237
+
238
+ # Add padding around face
239
+ pad_x = int(w * expand_ratio)
240
+ pad_y = int(h * expand_ratio)
241
+
242
+ # Calculate expanded bounding box
243
+ x1 = max(0, x - pad_x)
244
+ y1 = max(0, y - pad_y)
245
+ x2 = min(image.width if isinstance(image, Image.Image) else image.shape[1], x + w + pad_x)
246
+ y2 = min(image.height if isinstance(image, Image.Image) else image.shape[0], y + h + pad_y)
247
+
248
+ # Crop the face
249
+ if isinstance(image, Image.Image):
250
+ face_crop = image.crop((x1, y1, x2, y2))
251
+ else:
252
+ # OpenCV format
253
+ face_crop = image[y1:y2, x1:x2]
254
+ face_crop = Image.fromarray(cv2.cvtColor(face_crop, cv2.COLOR_BGR2RGB))
255
+
256
+ return face_crop, (x1, y1, x2, y2)
257
+
258
+ def preprocess_face(self, face_image):
259
+ """Preprocess face for model input"""
260
+ # Ensure face is PIL Image
261
+ if not isinstance(face_image, Image.Image):
262
+ face_image = Image.fromarray(face_image)
263
+
264
+ # Apply transforms
265
+ face_tensor = self.transform(face_image)
266
+
267
+ # Add batch dimension
268
+ face_batch = face_tensor.unsqueeze(0)
269
+
270
+ return face_batch
271
+
272
+ # =============================================================================
273
+ # MODEL LOADER AND CLASSIFIER
274
+ # =============================================================================
275
+
276
+ class FaceClassifierTester:
277
+ """Test trained face classifier on new images"""
278
+
279
+ def __init__(self, model_path, device='auto'):
280
+ self.device = self._setup_device(device)
281
+ self.model = self._load_model(model_path)
282
+ self.face_processor = FaceProcessor()
283
+ self.results = []
284
+
285
+ print(f"[*] Face Classifier Tester initialized")
286
+ print(f" Device: {self.device}")
287
+ print(f" Model: {model_path}")
288
+
289
+ def _setup_device(self, device):
290
+ """Setup computing device"""
291
+ if device == 'auto':
292
+ if torch.cuda.is_available():
293
+ device = torch.device('cuda:0')
294
+ print(f"[GPU] Using GPU: {torch.cuda.get_device_name(0)}")
295
+ else:
296
+ device = torch.device('cpu')
297
+ print("[CPU] Using CPU")
298
+ else:
299
+ device = torch.device(device)
300
+
301
+ return device
302
+
303
+ def _load_model(self, model_path):
304
+ """Load trained model from checkpoint"""
305
+ try:
306
+ # Load checkpoint
307
+ checkpoint = torch.load(model_path, map_location=self.device)
308
+
309
+ # Initialize model
310
+ model = ImprovedFaceClassifierCNN()
311
+
312
+ # Load state dict
313
+ if 'model_state_dict' in checkpoint:
314
+ model.load_state_dict(checkpoint['model_state_dict'])
315
+ print(f"[OK] Model loaded from checkpoint")
316
+ print(f" Epoch: {checkpoint.get('epoch', 'Unknown')}")
317
+ print(f" Validation Accuracy: {checkpoint.get('val_acc', 'Unknown'):.2f}%")
318
+ else:
319
+ # Direct state dict
320
+ model.load_state_dict(checkpoint)
321
+ print(f"[OK] Model loaded successfully")
322
+
323
+ model.to(self.device)
324
+ model.eval()
325
+
326
+ return model
327
+
328
+ except Exception as e:
329
+ print(f"[ERROR] Error loading model: {e}")
330
+ print("Make sure the model file exists and matches the architecture")
331
+ raise
332
+
333
+ def classify_face(self, face_image):
334
+ """Classify a single face image"""
335
+ try:
336
+ # Preprocess face
337
+ face_tensor = self.face_processor.preprocess_face(face_image)
338
+ face_tensor = face_tensor.to(self.device)
339
+
340
+ # Run inference
341
+ with torch.no_grad():
342
+ logits = self.model(face_tensor)
343
+ probability = torch.sigmoid(logits).cpu().numpy()[0][0]
344
+
345
+ # Convert probability to prediction
346
+ prediction = "REAL" if probability > CONFIDENCE_THRESHOLD else "FAKE"
347
+ confidence = probability if prediction == "REAL" else (1 - probability)
348
+
349
+ return {
350
+ 'prediction': prediction,
351
+ 'confidence': confidence,
352
+ 'probability': probability,
353
+ 'raw_logit': logits.cpu().numpy()[0][0]
354
+ }
355
+
356
+ except Exception as e:
357
+ print(f"[ERROR] Error in classification: {e}")
358
+ return {
359
+ 'prediction': 'ERROR',
360
+ 'confidence': 0.0,
361
+ 'probability': 0.0,
362
+ 'raw_logit': 0.0
363
+ }
364
+
365
+ def process_image(self, image_path):
366
+ """Process a single image: detect faces and classify them"""
367
+ try:
368
+ # Load image
369
+ image = Image.open(image_path).convert('RGB')
370
+ image_name = os.path.basename(image_path)
371
+
372
+ print(f"\n[PROCESSING] {image_name}")
373
+
374
+ # Detect faces
375
+ faces = self.face_processor.detect_faces(image)
376
+
377
+ if len(faces) == 0:
378
+ print(f" [WARNING] No faces detected in {image_name}")
379
+ return {
380
+ 'image_path': image_path,
381
+ 'image_name': image_name,
382
+ 'num_faces': 0,
383
+ 'faces': [],
384
+ 'status': 'no_faces'
385
+ }
386
+
387
+ print(f" [FACES] Found {len(faces)} face(s)")
388
+
389
+ # Process each detected face
390
+ face_results = []
391
+ for i, face_bbox in enumerate(faces):
392
+ # Crop face
393
+ face_crop, expanded_bbox = self.face_processor.crop_face(image, face_bbox)
394
+
395
+ # Classify face
396
+ classification = self.classify_face(face_crop)
397
+
398
+ # Store results
399
+ face_result = {
400
+ 'face_id': i,
401
+ 'bbox': face_bbox.tolist(),
402
+ 'expanded_bbox': expanded_bbox,
403
+ 'face_crop': face_crop,
404
+ 'classification': classification
405
+ }
406
+ face_results.append(face_result)
407
+
408
+ print(f" Face {i+1}: {classification['prediction']} "
409
+ f"({classification['confidence']:.1%} confidence)")
410
+
411
+ return {
412
+ 'image_path': image_path,
413
+ 'image_name': image_name,
414
+ 'image': image,
415
+ 'num_faces': len(faces),
416
+ 'faces': face_results,
417
+ 'status': 'success'
418
+ }
419
+
420
+ except Exception as e:
421
+ print(f"[ERROR] Error processing {image_path}: {e}")
422
+ return {
423
+ 'image_path': image_path,
424
+ 'image_name': os.path.basename(image_path),
425
+ 'num_faces': 0,
426
+ 'faces': [],
427
+ 'status': 'error',
428
+ 'error': str(e)
429
+ }
430
+
431
+ def test_folder(self, folder_path, max_images=None):
432
+ """Test all images in a folder"""
433
+ print(f"\n[TESTING] FACE CLASSIFIER")
434
+ print(f"="*60)
435
+ print(f"Test folder: {folder_path}")
436
+ print(f"Model: {MODEL_PATH}")
437
+
438
+ # Check if folder exists
439
+ if not os.path.exists(folder_path):
440
+ print(f"[ERROR] Test folder not found: {folder_path}")
441
+ return []
442
+
443
+ # Get all image files (avoid duplicates from case variations)
444
+ image_files_set = set()
445
+ for ext in IMAGE_EXTENSIONS:
446
+ # Use case-insensitive glob patterns
447
+ image_files_set.update(Path(folder_path).glob(f"*{ext}"))
448
+ image_files_set.update(Path(folder_path).glob(f"*{ext.upper()}"))
449
+
450
+ # Convert set back to list and remove duplicates by resolving paths
451
+ image_files = []
452
+ seen_paths = set()
453
+ for file_path in image_files_set:
454
+ resolved_path = file_path.resolve()
455
+ if resolved_path not in seen_paths:
456
+ image_files.append(file_path)
457
+ seen_paths.add(resolved_path)
458
+
459
+ if not image_files:
460
+ print(f"[ERROR] No images found in {folder_path}")
461
+ print(f" Looking for extensions: {IMAGE_EXTENSIONS}")
462
+ return []
463
+
464
+ if max_images:
465
+ image_files = image_files[:max_images]
466
+
467
+ print(f"[FILES] Found {len(image_files)} images to process")
468
+
469
+ # Process each image
470
+ results = []
471
+ start_time = time.time()
472
+
473
+ for image_path in tqdm(image_files, desc="Processing images"):
474
+ result = self.process_image(str(image_path))
475
+ results.append(result)
476
+ self.results.append(result)
477
+
478
+ total_time = time.time() - start_time
479
+
480
+ # Print summary
481
+ self._print_summary(results, total_time)
482
+
483
+ # Save and visualize results
484
+ if SAVE_RESULTS:
485
+ self._save_results(results)
486
+ self._save_individual_images(results) # Save each image with bounding boxes
487
+
488
+ if SHOW_PLOTS:
489
+ #self._visualize_results(results)
490
+ self._create_comprehensive_summary(results) # Create complete grid summary
491
+
492
+ return results
493
+
494
+ def _print_summary(self, results, total_time):
495
+ """Print testing summary"""
496
+ print(f"\n[SUMMARY] TESTING SUMMARY")
497
+ print(f"="*40)
498
+
499
+ total_images = len(results)
500
+ successful = len([r for r in results if r['status'] == 'success'])
501
+ total_faces = sum(r['num_faces'] for r in results)
502
+ no_faces = len([r for r in results if r['status'] == 'no_faces'])
503
+ errors = len([r for r in results if r['status'] == 'error'])
504
+
505
+ print(f"Images processed: {total_images}")
506
+ print(f"Successful: {successful}")
507
+ print(f"No faces detected: {no_faces}")
508
+ print(f"Errors: {errors}")
509
+ print(f"Total faces detected: {total_faces}")
510
+ print(f"Processing time: {total_time:.1f}s")
511
+ print(f"Average time per image: {total_time/total_images:.2f}s")
512
+
513
+ # Classification summary
514
+ if total_faces > 0:
515
+ real_faces = sum(len([f for f in r['faces'] if f['classification']['prediction'] == 'REAL'])
516
+ for r in results if r['status'] == 'success')
517
+ fake_faces = total_faces - real_faces
518
+
519
+ print(f"\n[RESULTS] Classification Results:")
520
+ print(f"Real faces: {real_faces} ({real_faces/total_faces:.1%})")
521
+ print(f"Fake faces: {fake_faces} ({fake_faces/total_faces:.1%})")
522
+
523
+ def _save_results(self, results):
524
+ """Save results to files"""
525
+ os.makedirs(OUTPUT_PATH, exist_ok=True)
526
+
527
+ # Save summary CSV
528
+ import csv
529
+ csv_path = os.path.join(OUTPUT_PATH, 'classification_results.csv')
530
+
531
+ with open(csv_path, 'w', newline='', encoding='utf-8') as csvfile:
532
+ writer = csv.writer(csvfile)
533
+ writer.writerow(['Image', 'Face_ID', 'Prediction', 'Confidence', 'Probability', 'Bbox_X', 'Bbox_Y', 'Bbox_W', 'Bbox_H'])
534
+
535
+ for result in results:
536
+ if result['status'] == 'success':
537
+ for face in result['faces']:
538
+ bbox = face['bbox']
539
+ cls = face['classification']
540
+ writer.writerow([
541
+ result['image_name'],
542
+ face['face_id'],
543
+ cls['prediction'],
544
+ f"{cls['confidence']:.3f}",
545
+ f"{cls['probability']:.3f}",
546
+ bbox[0], bbox[1], bbox[2], bbox[3]
547
+ ])
548
+
549
+ print(f"[SAVED] Results saved to: {csv_path}")
550
+
551
+ def _save_individual_images(self, results):
552
+ """Save each processed image with bounding boxes and classifications"""
553
+ os.makedirs(OUTPUT_PATH, exist_ok=True)
554
+ individual_dir = os.path.join(OUTPUT_PATH, 'annotated_images')
555
+ os.makedirs(individual_dir, exist_ok=True)
556
+
557
+ saved_count = 0
558
+ for result in results:
559
+ if result['status'] in ['success', 'no_faces']:
560
+ try:
561
+ # Load original image
562
+ if 'image' in result:
563
+ image = result['image'].copy()
564
+ else:
565
+ image = Image.open(result['image_path']).convert('RGB')
566
+
567
+ # Draw bounding boxes and labels
568
+ draw = ImageDraw.Draw(image)
569
+
570
+ # Try to use a larger font
571
+ try:
572
+ font = ImageFont.truetype("arial.ttf", 20)
573
+ except:
574
+ font = ImageFont.load_default()
575
+
576
+ if result['num_faces'] > 0:
577
+ for face in result['faces']:
578
+ bbox = face['bbox']
579
+ cls = face['classification']
580
+
581
+ # Choose color based on prediction
582
+ color = 'green' if cls['prediction'] == 'REAL' else 'red'
583
+
584
+ # Draw bounding box
585
+ x, y, w, h = bbox
586
+ draw.rectangle([x, y, x+w, y+h], outline=color, width=3)
587
+
588
+ # Create label with prediction and confidence
589
+ label = f"{cls['prediction']} ({cls['confidence']:.1%})"
590
+
591
+ # Draw label background
592
+ text_bbox = draw.textbbox((x, y-25), label, font=font)
593
+ draw.rectangle(text_bbox, fill=color)
594
+
595
+ # Draw label text
596
+ draw.text((x, y-25), label, fill='white', font=font)
597
+ else:
598
+ # Add "NO FACES" label for images without faces
599
+ draw.text((10, 10), "NO FACES DETECTED", fill='orange', font=font)
600
+
601
+ # Save annotated image
602
+ base_name = os.path.splitext(result['image_name'])[0]
603
+ output_filename = f"{base_name}_annotated.jpg"
604
+ output_path = os.path.join(individual_dir, output_filename)
605
+
606
+ image.save(output_path, 'JPEG', quality=95)
607
+ saved_count += 1
608
+
609
+ except Exception as e:
610
+ print(f"[WARNING] Could not save annotated image for {result['image_name']}: {e}")
611
+
612
+ print(f"[SAVED] {saved_count} annotated images saved to: {individual_dir}")
613
+
614
+ def _visualize_results(self, results, max_display=6):
615
+ """Visualize results with matplotlib (limited sample)"""
616
+ try:
617
+ # Filter successful results with faces
618
+ display_results = [r for r in results if r['status'] == 'success' and r['num_faces'] > 0]
619
+ display_results = display_results[:max_display]
620
+
621
+ if not display_results:
622
+ print("No results to visualize")
623
+ return
624
+
625
+ # Create subplots
626
+ fig, axes = plt.subplots(2, 3, figsize=(15, 10))
627
+ axes = axes.flatten()
628
+
629
+ for i, result in enumerate(display_results):
630
+ if i >= len(axes):
631
+ break
632
+
633
+ ax = axes[i]
634
+ image = result['image']
635
+
636
+ # Draw bounding boxes on image
637
+ draw_image = image.copy()
638
+ draw = ImageDraw.Draw(draw_image)
639
+
640
+ for face in result['faces']:
641
+ bbox = face['bbox']
642
+ cls = face['classification']
643
+
644
+ # Choose color based on prediction
645
+ color = 'green' if cls['prediction'] == 'REAL' else 'red'
646
+
647
+ # Draw bounding box
648
+ draw.rectangle([bbox[0], bbox[1], bbox[0]+bbox[2], bbox[1]+bbox[3]],
649
+ outline=color, width=3)
650
+
651
+ # Add label
652
+ label = f"{cls['prediction']} ({cls['confidence']:.1%})"
653
+ draw.text((bbox[0], bbox[1]-20), label, fill=color)
654
+
655
+ # Display image
656
+ ax.imshow(draw_image)
657
+ ax.set_title(f"{result['image_name']}\n{result['num_faces']} face(s)")
658
+ ax.axis('off')
659
+
660
+ # Hide empty subplots
661
+ for i in range(len(display_results), len(axes)):
662
+ axes[i].axis('off')
663
+
664
+ plt.tight_layout()
665
+ plt.savefig(os.path.join(OUTPUT_PATH, 'sample_classification.png'), dpi=150, bbox_inches='tight')
666
+ plt.show()
667
+
668
+ except Exception as e:
669
+ print(f"[WARNING] Could not create sample visualization: {e}")
670
+
671
+ def _create_comprehensive_summary(self, results):
672
+ """Create a comprehensive grid summary of all processed images"""
673
+ try:
674
+ # Include all results (successful, no_faces, errors)
675
+ all_results = results
676
+
677
+ if not all_results:
678
+ print("No results to create comprehensive summary")
679
+ return
680
+
681
+ # Calculate grid dimensions
682
+ total_images = len(all_results)
683
+ cols = 4 # 4 images per row
684
+ rows = (total_images + cols - 1) // cols # Ceiling division
685
+
686
+ # Create large figure
687
+ fig, axes = plt.subplots(rows, cols, figsize=(20, 5*rows))
688
+
689
+ # Handle single row case
690
+ if rows == 1:
691
+ axes = axes.reshape(1, -1) if total_images > 1 else [axes]
692
+
693
+ # Flatten axes for easier indexing
694
+ axes_flat = axes.flatten() if total_images > 1 else [axes]
695
+
696
+ for i, result in enumerate(all_results):
697
+ ax = axes_flat[i]
698
+
699
+ try:
700
+ # Load image
701
+ if 'image' in result and result['image'] is not None:
702
+ image = result['image'].copy()
703
+ else:
704
+ image = Image.open(result['image_path']).convert('RGB')
705
+
706
+ # Create annotated copy
707
+ draw_image = image.copy()
708
+ draw = ImageDraw.Draw(draw_image)
709
+
710
+ # Set up title based on status
711
+ title_parts = [result['image_name'][:20]] # Truncate long names
712
+
713
+ if result['status'] == 'success':
714
+ if result['num_faces'] > 0:
715
+ # Draw faces with bounding boxes
716
+ for face in result['faces']:
717
+ bbox = face['bbox']
718
+ cls = face['classification']
719
+
720
+ # Choose color
721
+ color = 'green' if cls['prediction'] == 'REAL' else 'red'
722
+
723
+ # Draw bounding box
724
+ x, y, w, h = bbox
725
+ draw.rectangle([x, y, x+w, y+h], outline=color, width=2)
726
+
727
+ # Add small label
728
+ label = f"{cls['prediction']}\n{cls['confidence']:.0%}"
729
+ draw.text((x, y-15), label, fill=color)
730
+
731
+ title_parts.append(f"{result['num_faces']} face(s)")
732
+
733
+ # Count real vs fake
734
+ real_count = sum(1 for f in result['faces'] if f['classification']['prediction'] == 'REAL')
735
+ fake_count = result['num_faces'] - real_count
736
+ if real_count > 0:
737
+ title_parts.append(f"Real: {real_count}")
738
+ if fake_count > 0:
739
+ title_parts.append(f"Fake: {fake_count}")
740
+ else:
741
+ title_parts.append("No faces")
742
+ # Add text overlay
743
+ draw.text((10, 10), "NO FACES", fill='orange')
744
+
745
+ elif result['status'] == 'no_faces':
746
+ title_parts.append("No faces detected")
747
+ draw.text((10, 10), "NO FACES", fill='orange')
748
+
749
+ elif result['status'] == 'error':
750
+ title_parts.append("Error")
751
+ draw.text((10, 10), "ERROR", fill='red')
752
+
753
+ # Display image
754
+ ax.imshow(draw_image)
755
+ ax.set_title('\n'.join(title_parts), fontsize=8)
756
+ ax.axis('off')
757
+
758
+ except Exception as e:
759
+ # Handle individual image errors
760
+ ax.text(0.5, 0.5, f"Error loading\n{result['image_name']}",
761
+ ha='center', va='center', transform=ax.transAxes)
762
+ ax.set_title(f"Error: {result['image_name'][:20]}")
763
+ ax.axis('off')
764
+
765
+ # Hide unused subplots
766
+ for i in range(total_images, len(axes_flat)):
767
+ axes_flat[i].axis('off')
768
+
769
+ # Add overall title with summary stats
770
+ total_faces = sum(r['num_faces'] for r in results if r['status'] == 'success')
771
+ real_faces = sum(len([f for f in r['faces'] if f['classification']['prediction'] == 'REAL'])
772
+ for r in results if r['status'] == 'success')
773
+ fake_faces = total_faces - real_faces
774
+
775
+ fig.suptitle(f"Face Classification Results - {total_images} Images, {total_faces} Faces\n"
776
+ f"Real: {real_faces} ({real_faces/total_faces*100 if total_faces > 0 else 0:.1f}%), "
777
+ f"Fake: {fake_faces} ({fake_faces/total_faces*100 if total_faces > 0 else 0:.1f}%)",
778
+ fontsize=16, y=0.98)
779
+
780
+ plt.tight_layout()
781
+ plt.subplots_adjust(top=0.92)
782
+
783
+ # Save comprehensive summary
784
+ summary_path = os.path.join(OUTPUT_PATH, 'comprehensive_summary.png')
785
+ plt.savefig(summary_path, dpi=200, bbox_inches='tight')
786
+ print(f"[SAVED] Comprehensive summary saved to: {summary_path}")
787
+
788
+ plt.show()
789
+
790
+ except Exception as e:
791
+ print(f"[WARNING] Could not create comprehensive summary: {e}")
792
+
793
+ # =============================================================================
794
+ # MAIN TESTING FUNCTION
795
+ # =============================================================================
796
+
797
+ def main():
798
+ """Main testing function"""
799
+ print("[*] FACE CLASSIFIER TESTING")
800
+ print("="*50)
801
+
802
+ # Check if model exists
803
+ if not os.path.exists(MODEL_PATH):
804
+ print(f"[ERROR] Model file not found: {MODEL_PATH}")
805
+ print("Please make sure you have trained the model first.")
806
+ print("Expected file: best_face_classifier_real_data.pth")
807
+ return
808
+
809
+ # Check if test folder exists
810
+ if not os.path.exists(TEST_IMAGES_PATH):
811
+ print(f"[ERROR] Test images folder not found: {TEST_IMAGES_PATH}")
812
+ print("Please check the path and make sure it contains images.")
813
+ return
814
+
815
+ try:
816
+ # Initialize tester
817
+ tester = FaceClassifierTester(MODEL_PATH)
818
+
819
+ # Run tests
820
+ results = tester.test_folder(TEST_IMAGES_PATH, max_images=20) # Limit for demo
821
+
822
+ print(f"\n[OK] Testing completed!")
823
+ print(f"Check '{OUTPUT_PATH}' folder for detailed results.")
824
+
825
+ except Exception as e:
826
+ print(f"[ERROR] Testing failed: {e}")
827
+ import traceback
828
+ traceback.print_exc()
829
+
830
+ if __name__ == "__main__":
831
+ main()