MrTsp commited on
Commit
07f2243
·
1 Parent(s): 77037c2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -34
app.py CHANGED
@@ -54,17 +54,17 @@ class DINOv2Extractor(nn.Module):
54
 
55
 
56
  class MLPClassifier(nn.Module):
57
- def __init__(self, input_dim: int = 1536, num_classes: int = 2, dropout: float = 0.3):
58
  super().__init__()
59
  self.net = nn.Sequential(
60
  nn.Linear(input_dim, 512),
61
- nn.LayerNorm(512),
62
  nn.GELU(),
63
  nn.Dropout(dropout),
64
  nn.Linear(512, 256),
65
- nn.LayerNorm(256),
66
  nn.GELU(),
67
- nn.Dropout(dropout / 2),
68
  nn.Linear(256, num_classes),
69
  )
70
 
@@ -119,9 +119,8 @@ try:
119
  MTCNN_DETECTOR = MTCNN(
120
  image_size=224,
121
  margin=40,
122
- min_face_size=20,
123
- thresholds=[0.6, 0.7, 0.9],
124
  keep_all=False,
 
125
  device='cpu'
126
  )
127
  logger.info("MTCNN face detector initialized.")
@@ -131,6 +130,7 @@ except Exception as e:
131
 
132
  TRANSFORM = T.Compose([
133
  T.Resize((224, 224)),
 
134
  T.ToTensor(),
135
  T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
136
  ])
@@ -141,13 +141,28 @@ def detect_face_crop(img: Image.Image) -> Image.Image:
141
  if MTCNN_DETECTOR is None:
142
  return None
143
  try:
144
- # MTCNN returns the cropped tensor directly
145
- face_tensor = MTCNN_DETECTOR(img)
146
- if face_tensor is not None:
147
- # Convert tensor back to PIL Image
148
- face_np = face_tensor.permute(1, 2, 0).numpy()
149
- face_np = ((face_np * 128) + 127.5).clip(0, 255).astype(np.uint8)
150
- return Image.fromarray(face_np)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  except Exception:
152
  pass
153
  return None
@@ -233,24 +248,10 @@ def run_inference(model: DeepfakeDetector, frame_paths: list) -> dict:
233
  if not fake_probs:
234
  raise ValueError("No frames could be processed.")
235
 
236
- # 1. Advanced Aggregation (Top 50% Mean)
237
- # Deepfake artifacts might only appear in parts of the video.
238
- # Averaging all frames dilutes the score. We take the top 50% most suspicious frames.
239
- sorted_probs = sorted(fake_probs, reverse=True)
240
- top_k = max(1, len(sorted_probs) // 2)
241
- video_fake_prob = float(np.mean(sorted_probs[:top_k]))
242
-
243
- # 2. Ratio Check
244
- # If at least 30% of frames are distinctly flagged as Fake, mark the whole video as Fake.
245
- fake_frame_count = sum(1 for p in fake_probs if p > 0.5)
246
- fake_ratio = fake_frame_count / len(fake_probs)
247
-
248
- is_fake = (video_fake_prob > 0.5) or (fake_ratio >= 0.3)
249
-
250
- # Ensure UI consistency: If flagged as FAKE by ratio, but probability is low, boost it to 51%
251
- if is_fake and video_fake_prob <= 0.5:
252
- video_fake_prob = 0.51
253
 
 
254
  avg_real = 1.0 - video_fake_prob
255
 
256
  return {
@@ -287,7 +288,7 @@ def health_check():
287
 
288
  @app.post("/predict")
289
  async def predict(file: UploadFile = File(...)):
290
- allowed_exts = {".mp4", ".mov", ".avi", ".mkv"}
291
  ext = Path(file.filename).suffix.lower() if file.filename else ""
292
 
293
  if ext not in allowed_exts:
@@ -312,9 +313,14 @@ async def predict(file: UploadFile = File(...)):
312
  model = load_model()
313
  logger.info(f"[{job_id}] Processing: {file.filename} ({size_mb:.1f} MB)")
314
 
315
- frame_paths = extract_frames(str(video_path), str(frames_dir))
316
- if not frame_paths:
317
- raise HTTPException(422, "No frames could be extracted from video.")
 
 
 
 
 
318
 
319
  result = run_inference(model, frame_paths)
320
  result["filename"] = file.filename
 
54
 
55
 
56
  class MLPClassifier(nn.Module):
57
+ def __init__(self, input_dim: int = 1536, num_classes: int = 2, dropout: float = 0.4):
58
  super().__init__()
59
  self.net = nn.Sequential(
60
  nn.Linear(input_dim, 512),
61
+ nn.BatchNorm1d(512),
62
  nn.GELU(),
63
  nn.Dropout(dropout),
64
  nn.Linear(512, 256),
65
+ nn.BatchNorm1d(256),
66
  nn.GELU(),
67
+ nn.Dropout(dropout * 0.75),
68
  nn.Linear(256, num_classes),
69
  )
70
 
 
119
  MTCNN_DETECTOR = MTCNN(
120
  image_size=224,
121
  margin=40,
 
 
122
  keep_all=False,
123
+ post_process=False,
124
  device='cpu'
125
  )
126
  logger.info("MTCNN face detector initialized.")
 
130
 
131
  TRANSFORM = T.Compose([
132
  T.Resize((224, 224)),
133
+ T.CenterCrop(224),
134
  T.ToTensor(),
135
  T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
136
  ])
 
141
  if MTCNN_DETECTOR is None:
142
  return None
143
  try:
144
+ boxes, probs = MTCNN_DETECTOR.detect(img)
145
+ if boxes is None or len(boxes) == 0:
146
+ return None
147
+
148
+ best_idx = np.argmax(probs)
149
+ best_prob = probs[best_idx]
150
+
151
+ if best_prob < 0.9:
152
+ return None
153
+
154
+ box = boxes[best_idx]
155
+ w, h = img.size
156
+ x1, y1, x2, y2 = [int(b) for b in box]
157
+ margin = 40
158
+
159
+ x1 = max(0, x1 - margin)
160
+ y1 = max(0, y1 - margin)
161
+ x2 = min(w, x2 + margin)
162
+ y2 = min(h, y2 + margin)
163
+
164
+ face = img.crop((x1, y1, x2, y2))
165
+ return face.resize((224, 224), Image.LANCZOS)
166
  except Exception:
167
  pass
168
  return None
 
248
  if not fake_probs:
249
  raise ValueError("No frames could be processed.")
250
 
251
+ # 1. Simple Aggregation (Mean) to match test_real.py
252
+ video_fake_prob = float(np.mean(fake_probs))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
 
254
+ is_fake = video_fake_prob > 0.5
255
  avg_real = 1.0 - video_fake_prob
256
 
257
  return {
 
288
 
289
  @app.post("/predict")
290
  async def predict(file: UploadFile = File(...)):
291
+ allowed_exts = {".mp4", ".mov", ".avi", ".mkv", ".jpg", ".jpeg", ".png", ".webp"}
292
  ext = Path(file.filename).suffix.lower() if file.filename else ""
293
 
294
  if ext not in allowed_exts:
 
313
  model = load_model()
314
  logger.info(f"[{job_id}] Processing: {file.filename} ({size_mb:.1f} MB)")
315
 
316
+ if ext in {".mp4", ".mov", ".avi", ".mkv"}:
317
+ frame_paths = extract_frames(str(video_path), str(frames_dir))
318
+ if not frame_paths:
319
+ raise HTTPException(422, "No frames could be extracted from video.")
320
+ else:
321
+ img_path = frames_dir / f"frame_0000{ext}"
322
+ shutil.copy(video_path, img_path)
323
+ frame_paths = [str(img_path)]
324
 
325
  result = run_inference(model, frame_paths)
326
  result["filename"] = file.filename