kuldeep0204 commited on
Commit
e5d689b
·
verified ·
1 Parent(s): 8d2f34b

Create utils/detector.py

Browse files
Files changed (1) hide show
  1. utils/detector.py +369 -0
utils/detector.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import cv2
4
+ import numpy as np
5
+ from PIL import Image
6
+ import mediapipe as mp
7
+ from facenet_pytorch import MTCNN
8
+ import time
9
+ import warnings
10
+ warnings.filterwarnings('ignore')
11
+
12
+ class DeepfakeDetector:
13
+ def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'):
14
+ self.device = device
15
+ self.face_detector = MTCNN(keep_all=True, device=device)
16
+ self.mp_face_mesh = mp.solutions.face_mesh
17
+ self.face_mesh = self.mp_face_mesh.FaceMesh(
18
+ static_image_mode=True,
19
+ max_num_faces=1,
20
+ refine_landmarks=True,
21
+ min_detection_confidence=0.5
22
+ )
23
+
24
+ # Initialize models
25
+ self.models = self.load_models()
26
+ self.threshold = 0.7
27
+
28
+ def load_models(self):
29
+ """Load pretrained models"""
30
+ models = {}
31
+
32
+ # Load EfficientNet-B4
33
+ from efficientnet_pytorch import EfficientNet
34
+ models['efficientnet'] = EfficientNet.from_pretrained('efficientnet-b4')
35
+ models['efficientnet']._fc = nn.Linear(1792, 2)
36
+
37
+ # Load Xception
38
+ from torchvision.models import xception
39
+ models['xception'] = xception(pretrained=False)
40
+ models['xception'].fc = nn.Linear(2048, 2)
41
+
42
+ # Move to device and set to eval mode
43
+ for name, model in models.items():
44
+ model_path = f"models/{name}.pth"
45
+ try:
46
+ model.load_state_dict(torch.load(model_path, map_location=self.device))
47
+ print(f"Loaded {name}")
48
+ except:
49
+ print(f"Using pretrained {name} without fine-tuning")
50
+ model.to(self.device)
51
+ model.eval()
52
+
53
+ return models
54
+
55
+ def detect_image(self, image):
56
+ """Detect deepfake in image"""
57
+ start_time = time.time()
58
+
59
+ # Convert to numpy if PIL
60
+ if isinstance(image, Image.Image):
61
+ image = np.array(image)
62
+
63
+ # Run all detection methods
64
+ results = {}
65
+
66
+ # Frequency analysis
67
+ results['frequency_score'] = self.analyze_frequency(image)
68
+
69
+ # Face artifact detection
70
+ face_results = self.analyze_faces(image)
71
+ results['face_score'] = face_results['confidence']
72
+ results['num_faces'] = face_results['num_faces']
73
+
74
+ # Model predictions
75
+ model_predictions = []
76
+ for name, model in self.models.items():
77
+ pred = self.predict_with_model(image, model)
78
+ model_predictions.append(pred)
79
+
80
+ # Ensemble voting
81
+ final_score = np.mean([
82
+ results['frequency_score'],
83
+ results['face_score'],
84
+ *model_predictions
85
+ ])
86
+
87
+ results['is_fake'] = final_score > self.threshold
88
+ results['confidence'] = final_score
89
+ results['quality_score'] = self.assess_quality(image)
90
+ results['processing_time'] = time.time() - start_time
91
+
92
+ return results
93
+
94
+ def detect_video(self, video_path, sample_frames=30):
95
+ """Detect deepfake in video"""
96
+ start_time = time.time()
97
+
98
+ cap = cv2.VideoCapture(video_path)
99
+ if not cap.isOpened():
100
+ raise ValueError(f"Cannot open video: {video_path}")
101
+
102
+ # Get video info
103
+ fps = cap.get(cv2.CAP_PROP_FPS)
104
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
105
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
106
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
107
+
108
+ # Sample frames
109
+ frame_indices = np.linspace(0, total_frames-1, min(sample_frames, total_frames), dtype=int)
110
+ frame_results = []
111
+
112
+ for frame_idx in frame_indices:
113
+ cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
114
+ ret, frame = cap.read()
115
+ if ret:
116
+ # Convert BGR to RGB
117
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
118
+ result = self.detect_image(frame_rgb)
119
+ frame_results.append(result)
120
+
121
+ cap.release()
122
+
123
+ # Aggregate results
124
+ if not frame_results:
125
+ raise ValueError("No frames could be read from video")
126
+
127
+ # Calculate video-level metrics
128
+ confidences = [r['confidence'] for r in frame_results]
129
+ fake_flags = [r['is_fake'] for r in frame_results]
130
+
131
+ final_result = {
132
+ 'is_fake': np.mean(fake_flags) > 0.5,
133
+ 'confidence': np.mean(confidences),
134
+ 'duration': total_frames / fps,
135
+ 'frames_analyzed': len(frame_results),
136
+ 'resolution': f"{width}x{height}",
137
+ 'fps': fps,
138
+ 'frame_results': frame_results,
139
+ 'processing_time': time.time() - start_time,
140
+ 'fake_segments': self.identify_fake_segments(frame_results, frame_indices, fps)
141
+ }
142
+
143
+ return final_result
144
+
145
+ def analyze_frequency(self, image):
146
+ """Analyze frequency domain"""
147
+ if len(image.shape) == 3:
148
+ gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
149
+ else:
150
+ gray = image
151
+
152
+ # Fourier Transform
153
+ f = np.fft.fft2(gray)
154
+ fshift = np.fft.fftshift(f)
155
+ magnitude = np.log(np.abs(fshift) + 1)
156
+
157
+ # Analyze frequency patterns
158
+ height, width = magnitude.shape
159
+ center_h, center_w = height // 2, width // 2
160
+
161
+ # Check for grid-like patterns common in GANs
162
+ low_freq = magnitude[center_h-20:center_h+20, center_w-20:center_w+20]
163
+ high_freq = np.copy(magnitude)
164
+ high_freq[center_h-20:center_h+20, center_w-20:center_w+20] = 0
165
+
166
+ low_energy = np.mean(low_freq)
167
+ high_energy = np.mean(high_freq)
168
+
169
+ # Deepfakes often have different frequency distributions
170
+ score = min(high_energy / (low_energy + 1e-10) * 0.5, 1.0)
171
+
172
+ return score
173
+
174
+ def analyze_faces(self, image):
175
+ """Analyze faces in image"""
176
+ # Detect faces
177
+ boxes, probs = self.face_detector.detect(image)
178
+
179
+ if boxes is None:
180
+ return {'confidence': 0.0, 'num_faces': 0}
181
+
182
+ num_faces = len(boxes)
183
+ face_scores = []
184
+
185
+ for i, box in enumerate(boxes):
186
+ if probs[i] < 0.9:
187
+ continue
188
+
189
+ # Extract face
190
+ x1, y1, x2, y2 = map(int, box)
191
+ face = image[y1:y2, x1:x2]
192
+
193
+ if face.size == 0:
194
+ continue
195
+
196
+ # Analyze face artifacts
197
+ score = self.analyze_face_artifacts(face)
198
+ face_scores.append(score)
199
+
200
+ if not face_scores:
201
+ return {'confidence': 0.0, 'num_faces': num_faces}
202
+
203
+ return {
204
+ 'confidence': np.mean(face_scores),
205
+ 'num_faces': num_faces
206
+ }
207
+
208
+ def analyze_face_artifacts(self, face_img):
209
+ """Analyze artifacts in face image"""
210
+ # Check for unnatural symmetry
211
+ if face_img.shape[1] > 10: # Ensure face is wide enough
212
+ left_half = face_img[:, :face_img.shape[1]//2]
213
+ right_half = face_img[:, face_img.shape[1]//2:]
214
+ right_half_flipped = np.fliplr(right_half)
215
+
216
+ # Resize to match
217
+ min_height = min(left_half.shape[0], right_half_flipped.shape[0])
218
+ min_width = min(left_half.shape[1], right_half_flipped.shape[1])
219
+
220
+ left_cropped = left_half[:min_height, :min_width]
221
+ right_cropped = right_half_flipped[:min_height, :min_width]
222
+
223
+ # Calculate symmetry
224
+ if left_cropped.size > 0 and right_cropped.size > 0:
225
+ symmetry_error = np.mean(np.abs(left_cropped - right_cropped))
226
+ symmetry_score = min(symmetry_error / 10.0, 1.0)
227
+ else:
228
+ symmetry_score = 0.5
229
+ else:
230
+ symmetry_score = 0.5
231
+
232
+ # Check for unnatural edges
233
+ gray = cv2.cvtColor(face_img, cv2.COLOR_RGB2GRAY)
234
+ edges = cv2.Canny(gray, 100, 200)
235
+ edge_density = np.sum(edges) / edges.size
236
+
237
+ # Combine scores
238
+ final_score = (symmetry_score * 0.6 + edge_density * 0.4)
239
+
240
+ return final_score
241
+
242
+ def predict_with_model(self, image, model):
243
+ """Predict using a specific model"""
244
+ # Preprocess image
245
+ transform = self.get_transform()
246
+
247
+ if isinstance(image, np.ndarray):
248
+ image = Image.fromarray(image)
249
+
250
+ input_tensor = transform(image).unsqueeze(0).to(self.device)
251
+
252
+ with torch.no_grad():
253
+ output = model(input_tensor)
254
+ probabilities = torch.softmax(output, dim=1)
255
+ fake_prob = probabilities[0][1].item()
256
+
257
+ return fake_prob
258
+
259
+ def get_transform(self):
260
+ """Get image transformation pipeline"""
261
+ from torchvision import transforms
262
+
263
+ return transforms.Compose([
264
+ transforms.Resize((256, 256)),
265
+ transforms.ToTensor(),
266
+ transforms.Normalize(
267
+ mean=[0.485, 0.456, 0.406],
268
+ std=[0.229, 0.224, 0.225]
269
+ )
270
+ ])
271
+
272
+ def assess_quality(self, image):
273
+ """Assess image quality"""
274
+ # Simple quality metrics
275
+ if len(image.shape) == 3:
276
+ gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
277
+ else:
278
+ gray = image
279
+
280
+ # Calculate sharpness (variance of Laplacian)
281
+ laplacian_var = cv2.Laplacian(gray, cv2.CV_64F).var()
282
+ sharpness_score = min(laplacian_var / 1000.0, 1.0)
283
+
284
+ # Calculate contrast
285
+ contrast_score = np.std(gray) / 255.0
286
+
287
+ return (sharpness_score + contrast_score) / 2
288
+
289
+ def identify_fake_segments(self, frame_results, frame_indices, fps):
290
+ """Identify segments in video that are likely deepfakes"""
291
+ if not frame_results:
292
+ return []
293
+
294
+ segments = []
295
+ current_segment = None
296
+
297
+ for i, result in enumerate(frame_results):
298
+ if result['is_fake']:
299
+ if current_segment is None:
300
+ current_segment = {
301
+ 'start': frame_indices[i] / fps,
302
+ 'end': frame_indices[i] / fps,
303
+ 'confidence': [result['confidence']]
304
+ }
305
+ else:
306
+ current_segment['end'] = frame_indices[i] / fps
307
+ current_segment['confidence'].append(result['confidence'])
308
+ else:
309
+ if current_segment is not None:
310
+ current_segment['confidence'] = np.mean(current_segment['confidence'])
311
+ segments.append(current_segment)
312
+ current_segment = None
313
+
314
+ # Add last segment if exists
315
+ if current_segment is not None:
316
+ current_segment['confidence'] = np.mean(current_segment['confidence'])
317
+ segments.append(current_segment)
318
+
319
+ return segments
320
+
321
+ def visualize_result(self, image, result):
322
+ """Create visualization of detection result"""
323
+ # Convert to BGR for OpenCV
324
+ if isinstance(image, Image.Image):
325
+ image = np.array(image)
326
+
327
+ if len(image.shape) == 3 and image.shape[2] == 3:
328
+ vis = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
329
+ else:
330
+ vis = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
331
+
332
+ # Add result text
333
+ text = "REAL" if not result['is_fake'] else "DEEPFAKE"
334
+ color = (0, 255, 0) if not result['is_fake'] else (0, 0, 255)
335
+
336
+ # Add text background
337
+ text_size = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 2, 3)[0]
338
+ cv2.rectangle(vis, (10, 10), (10 + text_size[0] + 20, 10 + text_size[1] + 20), (0, 0, 0), -1)
339
+
340
+ # Add text
341
+ cv2.putText(vis, text, (20, 20 + text_size[1]),
342
+ cv2.FONT_HERSHEY_SIMPLEX, 2, color, 3)
343
+
344
+ # Add confidence
345
+ conf_text = f"Confidence: {result['confidence']:.2%}"
346
+ cv2.putText(vis, conf_text, (20, 80),
347
+ cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
348
+
349
+ # Convert back to RGB
350
+ vis = cv2.cvtColor(vis, cv2.COLOR_BGR2RGB)
351
+
352
+ return vis
353
+
354
+ def detect_file(self, file_path):
355
+ """Detect deepfake in file (auto-detect type)"""
356
+ if file_path.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp')):
357
+ # Image file
358
+ image = Image.open(file_path)
359
+ result = self.detect_image(image)
360
+ result['type'] = 'image'
361
+ elif file_path.lower().endswith(('.mp4', '.avi', '.mov', '.mkv')):
362
+ # Video file
363
+ result = self.detect_video(file_path)
364
+ result['type'] = 'video'
365
+ else:
366
+ raise ValueError(f"Unsupported file type: {file_path}")
367
+
368
+ result['filename'] = file_path
369
+ return result