MrTsp commited on
Commit
78f257d
·
1 Parent(s): 07f2243

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -109
app.py CHANGED
@@ -1,11 +1,7 @@
1
  """
2
- DeepShield AI — Full-Stack FastAPI Backend
3
  Serves the frontend UI + deepfake detection API from one HF Space.
4
-
5
- Routes:
6
- GET / → Serves index.html (the web UI)
7
- GET /health → JSON health check
8
- POST /predict → Video upload → REAL/FAKE prediction
9
  """
10
 
11
  import os
@@ -20,6 +16,7 @@ from functools import lru_cache
20
  import cv2
21
  import torch
22
  import torch.nn as nn
 
23
  import numpy as np
24
  from PIL import Image, ImageFile
25
  from facenet_pytorch import MTCNN
@@ -34,7 +31,7 @@ logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(mess
34
  logger = logging.getLogger(__name__)
35
 
36
  # ─────────────────────────────────────────────
37
- # Model Definition (self-contained)
38
  # ─────────────────────────────────────────────
39
 
40
  class DINOv2Extractor(nn.Module):
@@ -47,14 +44,12 @@ class DINOv2Extractor(nn.Module):
47
  self.feature_dim = 768
48
  for p in self.backbone.parameters():
49
  p.requires_grad = False
50
- logger.info("DINOv2 backbone loaded (frozen).")
51
 
52
  def forward(self, x: torch.Tensor) -> torch.Tensor:
53
  return self.backbone(x)
54
 
55
-
56
  class MLPClassifier(nn.Module):
57
- def __init__(self, input_dim: int = 1536, num_classes: int = 2, dropout: float = 0.4):
58
  super().__init__()
59
  self.net = nn.Sequential(
60
  nn.Linear(input_dim, 512),
@@ -71,24 +66,40 @@ class MLPClassifier(nn.Module):
71
  def forward(self, x: torch.Tensor) -> torch.Tensor:
72
  return self.net(x)
73
 
74
-
75
- class DeepfakeDetector(nn.Module):
76
- def __init__(self, dual_input: bool = True):
 
 
 
77
  super().__init__()
78
  self.dual_input = dual_input
79
  self.extractor = DINOv2Extractor()
80
- feat_dim = 1536 if dual_input else 768
81
- self.classifier = MLPClassifier(input_dim=feat_dim)
82
-
83
- def forward(self, full_img: torch.Tensor, face_img: torch.Tensor = None) -> torch.Tensor:
84
- full_feat = self.extractor(full_img)
85
- if self.dual_input and face_img is not None:
86
- face_feat = self.extractor(face_img)
87
- feats = torch.cat([full_feat, face_feat], dim=1)
88
- else:
89
- feats = full_feat
90
- return self.classifier(feats)
 
 
91
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
  # ─────────────────────────────────────────────
94
  # App Setup
@@ -96,8 +107,8 @@ class DeepfakeDetector(nn.Module):
96
 
97
  app = FastAPI(
98
  title="DeepShield AI",
99
- description="DINO-G50 deepfake detector — full-stack web app",
100
- version="2.0.0",
101
  )
102
 
103
  app.add_middleware(
@@ -114,7 +125,7 @@ MAX_FRAMES = 20
114
  MAX_FILE_MB = 30
115
  MAX_DURATION_SEC = 60
116
 
117
- # MTCNN face detector (initialized once, CPU is fine for detection)
118
  try:
119
  MTCNN_DETECTOR = MTCNN(
120
  image_size=224,
@@ -126,7 +137,7 @@ try:
126
  logger.info("MTCNN face detector initialized.")
127
  except Exception as e:
128
  MTCNN_DETECTOR = None
129
- logger.warning(f"MTCNN init failed (will use full frame fallback): {e}")
130
 
131
  TRANSFORM = T.Compose([
132
  T.Resize((224, 224)),
@@ -135,9 +146,7 @@ TRANSFORM = T.Compose([
135
  T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
136
  ])
137
 
138
-
139
  def detect_face_crop(img: Image.Image) -> Image.Image:
140
- """Detect face with MTCNN and return cropped face, or None if not found."""
141
  if MTCNN_DETECTOR is None:
142
  return None
143
  try:
@@ -146,20 +155,15 @@ def detect_face_crop(img: Image.Image) -> Image.Image:
146
  return None
147
 
148
  best_idx = np.argmax(probs)
149
- best_prob = probs[best_idx]
150
-
151
- if best_prob < 0.9:
152
  return None
153
 
154
  box = boxes[best_idx]
155
  w, h = img.size
156
  x1, y1, x2, y2 = [int(b) for b in box]
157
  margin = 40
158
-
159
- x1 = max(0, x1 - margin)
160
- y1 = max(0, y1 - margin)
161
- x2 = min(w, x2 + margin)
162
- y2 = min(h, y2 + margin)
163
 
164
  face = img.crop((x1, y1, x2, y2))
165
  return face.resize((224, 224), Image.LANCZOS)
@@ -167,90 +171,74 @@ def detect_face_crop(img: Image.Image) -> Image.Image:
167
  pass
168
  return None
169
 
170
-
171
  @lru_cache(maxsize=1)
172
- def load_model() -> DeepfakeDetector:
173
  if not CHECKPOINT_PATH.exists():
174
- raise RuntimeError("best_model.pth not found. Upload it to this HF Space.")
 
 
 
 
175
 
176
- logger.info(f"Loading checkpoint on {DEVICE}...")
177
  ckpt = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
178
  state = ckpt.get("model_state_dict", ckpt)
179
 
 
180
  mlp_w = state.get("classifier.net.0.weight", None)
181
  dual = (mlp_w.shape[1] == 1536) if mlp_w is not None else True
182
 
183
- model = DeepfakeDetector(dual_input=dual).to(DEVICE)
184
  model.load_state_dict(state, strict=False)
185
  model.eval()
186
- logger.info(f"Model ready. dual_input={dual}, device={DEVICE}")
187
  return model
188
 
189
-
190
  def extract_frames(video_path: str, output_dir: str, num_frames: int = MAX_FRAMES) -> list:
191
  cap = cv2.VideoCapture(video_path)
192
  if not cap.isOpened():
193
  raise ValueError("Cannot open video file.")
194
-
195
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
196
- fps = cap.get(cv2.CAP_PROP_FPS) or 25
197
- duration = total_frames / fps if fps > 0 else 0
198
-
199
- if duration > MAX_DURATION_SEC:
200
- cap.release()
201
- raise ValueError(f"Video too long ({duration:.0f}s). Max: {MAX_DURATION_SEC}s.")
202
-
203
- if total_frames <= 0:
204
- total_frames = int(fps * MAX_DURATION_SEC)
205
-
206
  step = max(1, total_frames // num_frames)
207
  target_indices = set(range(0, total_frames, step))
208
  saved_paths = []
209
  frame_idx = 0
210
-
211
  while len(saved_paths) < num_frames:
212
  ret, frame = cap.read()
213
- if not ret:
214
- break
215
  if frame_idx in target_indices:
216
  rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
217
  path = os.path.join(output_dir, f"frame_{len(saved_paths):04d}.jpg")
218
  Image.fromarray(rgb).save(path, quality=90)
219
  saved_paths.append(path)
220
  frame_idx += 1
221
-
222
  cap.release()
223
  return saved_paths
224
 
225
-
226
- def run_inference(model: DeepfakeDetector, frame_paths: list) -> dict:
227
  fake_probs = []
228
  with torch.no_grad():
229
  for fpath in frame_paths:
230
  try:
231
  img = Image.open(fpath).convert("RGB")
232
  t_img = TRANSFORM(img).unsqueeze(0).to(DEVICE)
233
-
234
- # Try MTCNN face detection first (same as test_real.py)
235
- t_face = t_img # default fallback = full frame
236
  if model.dual_input:
237
  face_crop = detect_face_crop(img)
238
  if face_crop is not None:
239
  t_face = TRANSFORM(face_crop).unsqueeze(0).to(DEVICE)
240
- # else: fallback to full image (face not detected)
241
 
242
  logits = model(t_img, t_face if model.dual_input else None)
243
  prob = torch.softmax(logits, dim=1)[0, 1].item()
244
  fake_probs.append(prob)
245
  except Exception as e:
246
- logger.warning(f"Skipping frame {fpath}: {e}")
247
 
248
- if not fake_probs:
249
- raise ValueError("No frames could be processed.")
250
-
251
- # 1. Simple Aggregation (Mean) to match test_real.py
252
  video_fake_prob = float(np.mean(fake_probs))
253
-
254
  is_fake = video_fake_prob > 0.5
255
  avg_real = 1.0 - video_fake_prob
256
 
@@ -263,11 +251,6 @@ def run_inference(model: DeepfakeDetector, frame_paths: list) -> dict:
263
  "per_frame_scores": [round(p * 100, 1) for p in fake_probs],
264
  }
265
 
266
-
267
- # ─────────────────────────────────────────────
268
- # API Routes (must be defined BEFORE static mount)
269
- # ─────────────────────────────────────────────
270
-
271
  @app.on_event("startup")
272
  async def startup_event():
273
  try:
@@ -275,24 +258,20 @@ async def startup_event():
275
  except Exception as e:
276
  logger.error(f"Startup model load failed: {e}")
277
 
278
-
279
  @app.get("/health")
280
  def health_check():
281
  return {
282
  "status": "ok",
283
- "model": "DINO-G50 Deepfake Detector",
284
- "device": str(DEVICE),
285
  "model_loaded": CHECKPOINT_PATH.exists(),
286
  }
287
 
288
-
289
  @app.post("/predict")
290
  async def predict(file: UploadFile = File(...)):
291
  allowed_exts = {".mp4", ".mov", ".avi", ".mkv", ".jpg", ".jpeg", ".png", ".webp"}
292
  ext = Path(file.filename).suffix.lower() if file.filename else ""
293
-
294
  if ext not in allowed_exts:
295
- raise HTTPException(400, f"Unsupported type '{ext}'. Use: {allowed_exts}")
296
 
297
  content = await file.read()
298
  size_mb = len(content) / (1024 * 1024)
@@ -303,46 +282,30 @@ async def predict(file: UploadFile = File(...)):
303
  temp_dir = Path(tempfile.gettempdir()) / f"deepshield_{job_id}"
304
  frames_dir = temp_dir / "frames"
305
  frames_dir.mkdir(parents=True, exist_ok=True)
306
- video_path = temp_dir / f"input{ext}"
307
 
308
  try:
309
- with open(video_path, "wb") as f:
310
  f.write(content)
311
  del content
312
-
313
  model = load_model()
314
- logger.info(f"[{job_id}] Processing: {file.filename} ({size_mb:.1f} MB)")
315
-
316
  if ext in {".mp4", ".mov", ".avi", ".mkv"}:
317
- frame_paths = extract_frames(str(video_path), str(frames_dir))
318
- if not frame_paths:
319
- raise HTTPException(422, "No frames could be extracted from video.")
320
  else:
321
  img_path = frames_dir / f"frame_0000{ext}"
322
- shutil.copy(video_path, img_path)
323
  frame_paths = [str(img_path)]
324
 
 
 
325
  result = run_inference(model, frame_paths)
326
- result["filename"] = file.filename
327
- result["file_size_mb"] = round(size_mb, 2)
328
- result["job_id"] = job_id
329
-
330
- logger.info(f"[{job_id}] Result: {result['verdict']} ({result['fake_probability']}% fake)")
331
  return JSONResponse(content=result)
332
-
333
- except HTTPException:
334
- raise
335
- except ValueError as e:
336
- raise HTTPException(422, str(e))
337
  except Exception as e:
338
- logger.error(f"[{job_id}] Error: {e}", exc_info=True)
339
- raise HTTPException(500, f"Internal error: {str(e)}")
340
  finally:
341
  shutil.rmtree(temp_dir, ignore_errors=True)
342
- logger.info(f"[{job_id}] Cleanup done.")
343
-
344
 
345
- # ─────────────────────────────────────────────
346
- # Static Frontend (mounted LAST — serves index.html at /)
347
- # ─────────────────────────────────────────────
348
  app.mount("/", StaticFiles(directory="static", html=True), name="static")
 
1
  """
2
+ DeepShield AI — Full-Stack FastAPI Backend (SupCon Version)
3
  Serves the frontend UI + deepfake detection API from one HF Space.
4
+ 98.3% Accuracy — Supervised Contrastive Learning Model
 
 
 
 
5
  """
6
 
7
  import os
 
16
  import cv2
17
  import torch
18
  import torch.nn as nn
19
+ import torch.nn.functional as F
20
  import numpy as np
21
  from PIL import Image, ImageFile
22
  from facenet_pytorch import MTCNN
 
31
  logger = logging.getLogger(__name__)
32
 
33
  # ─────────────────────────────────────────────
34
+ # Model Definition (Self-Contained SupCon Architecture)
35
  # ─────────────────────────────────────────────
36
 
37
  class DINOv2Extractor(nn.Module):
 
44
  self.feature_dim = 768
45
  for p in self.backbone.parameters():
46
  p.requires_grad = False
 
47
 
48
  def forward(self, x: torch.Tensor) -> torch.Tensor:
49
  return self.backbone(x)
50
 
 
51
  class MLPClassifier(nn.Module):
52
+ def __init__(self, input_dim: int, num_classes: int = 2, dropout: float = 0.4):
53
  super().__init__()
54
  self.net = nn.Sequential(
55
  nn.Linear(input_dim, 512),
 
66
  def forward(self, x: torch.Tensor) -> torch.Tensor:
67
  return self.net(x)
68
 
69
+ class SupConDeepfakeClassifier(nn.Module):
70
+ """
71
+ Supervised Contrastive Version of the DINOv2 Deepfake Detector.
72
+ Matches the architecture used in scripts3.
73
+ """
74
+ def __init__(self, dual_input: bool = True, proj_dim: int = 128):
75
  super().__init__()
76
  self.dual_input = dual_input
77
  self.extractor = DINOv2Extractor()
78
+
79
+ feat_dim = 768
80
+ classifier_input = feat_dim * 2 if dual_input else feat_dim
81
+
82
+ # Projection Head for SupCon (needed for weight loading, even if not used in inference)
83
+ self.head = nn.Sequential(
84
+ nn.Linear(classifier_input, classifier_input),
85
+ nn.BatchNorm1d(classifier_input),
86
+ nn.ReLU(inplace=True),
87
+ nn.Linear(classifier_input, proj_dim)
88
+ )
89
+
90
+ self.classifier = MLPClassifier(classifier_input)
91
 
92
+ def forward(self, full_image: torch.Tensor, face_crop: torch.Tensor = None):
93
+ full_feat = self.extractor(full_image)
94
+ if self.dual_input:
95
+ face_feat = self.extractor(face_crop if face_crop is not None else full_image)
96
+ features = torch.cat([full_feat, face_feat], dim=1)
97
+ else:
98
+ features = full_feat
99
+
100
+ logits = self.classifier(features)
101
+ # We don't need 'proj' for inference
102
+ return logits
103
 
104
  # ─────────────────────────────────────────────
105
  # App Setup
 
107
 
108
  app = FastAPI(
109
  title="DeepShield AI",
110
+ description="DINO-G50 deepfake detector — SupCon SOTA version",
111
+ version="3.0.0",
112
  )
113
 
114
  app.add_middleware(
 
125
  MAX_FILE_MB = 30
126
  MAX_DURATION_SEC = 60
127
 
128
+ # MTCNN face detector
129
  try:
130
  MTCNN_DETECTOR = MTCNN(
131
  image_size=224,
 
137
  logger.info("MTCNN face detector initialized.")
138
  except Exception as e:
139
  MTCNN_DETECTOR = None
140
+ logger.warning(f"MTCNN init failed: {e}")
141
 
142
  TRANSFORM = T.Compose([
143
  T.Resize((224, 224)),
 
146
  T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
147
  ])
148
 
 
149
  def detect_face_crop(img: Image.Image) -> Image.Image:
 
150
  if MTCNN_DETECTOR is None:
151
  return None
152
  try:
 
155
  return None
156
 
157
  best_idx = np.argmax(probs)
158
+ if probs[best_idx] < 0.9:
 
 
159
  return None
160
 
161
  box = boxes[best_idx]
162
  w, h = img.size
163
  x1, y1, x2, y2 = [int(b) for b in box]
164
  margin = 40
165
+ x1, y1 = max(0, x1-margin), max(0, y1-margin)
166
+ x2, y2 = min(w, x2+margin), min(h, y2+margin)
 
 
 
167
 
168
  face = img.crop((x1, y1, x2, y2))
169
  return face.resize((224, 224), Image.LANCZOS)
 
171
  pass
172
  return None
173
 
 
174
  @lru_cache(maxsize=1)
175
+ def load_model() -> SupConDeepfakeClassifier:
176
  if not CHECKPOINT_PATH.exists():
177
+ fallback = Path("models3/checkpoints/best_model.pth")
178
+ if fallback.exists():
179
+ shutil.copy(fallback, CHECKPOINT_PATH)
180
+ else:
181
+ raise RuntimeError("best_model.pth not found. Please upload the model from models3/.")
182
 
183
+ logger.info(f"Loading SupCon checkpoint on {DEVICE}...")
184
  ckpt = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
185
  state = ckpt.get("model_state_dict", ckpt)
186
 
187
+ # Auto-detect dual input from weights
188
  mlp_w = state.get("classifier.net.0.weight", None)
189
  dual = (mlp_w.shape[1] == 1536) if mlp_w is not None else True
190
 
191
+ model = SupConDeepfakeClassifier(dual_input=dual).to(DEVICE)
192
  model.load_state_dict(state, strict=False)
193
  model.eval()
194
+ logger.info(f"SupCon Model ready. dual_input={dual}, device={DEVICE}")
195
  return model
196
 
 
197
  def extract_frames(video_path: str, output_dir: str, num_frames: int = MAX_FRAMES) -> list:
198
  cap = cv2.VideoCapture(video_path)
199
  if not cap.isOpened():
200
  raise ValueError("Cannot open video file.")
 
201
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
202
+ if total_frames <= 0: total_frames = 300
 
 
 
 
 
 
 
 
 
203
  step = max(1, total_frames // num_frames)
204
  target_indices = set(range(0, total_frames, step))
205
  saved_paths = []
206
  frame_idx = 0
 
207
  while len(saved_paths) < num_frames:
208
  ret, frame = cap.read()
209
+ if not ret: break
 
210
  if frame_idx in target_indices:
211
  rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
212
  path = os.path.join(output_dir, f"frame_{len(saved_paths):04d}.jpg")
213
  Image.fromarray(rgb).save(path, quality=90)
214
  saved_paths.append(path)
215
  frame_idx += 1
 
216
  cap.release()
217
  return saved_paths
218
 
219
+ def run_inference(model: SupConDeepfakeClassifier, frame_paths: list) -> dict:
 
220
  fake_probs = []
221
  with torch.no_grad():
222
  for fpath in frame_paths:
223
  try:
224
  img = Image.open(fpath).convert("RGB")
225
  t_img = TRANSFORM(img).unsqueeze(0).to(DEVICE)
226
+ t_face = t_img
 
 
227
  if model.dual_input:
228
  face_crop = detect_face_crop(img)
229
  if face_crop is not None:
230
  t_face = TRANSFORM(face_crop).unsqueeze(0).to(DEVICE)
 
231
 
232
  logits = model(t_img, t_face if model.dual_input else None)
233
  prob = torch.softmax(logits, dim=1)[0, 1].item()
234
  fake_probs.append(prob)
235
  except Exception as e:
236
+ logger.warning(f"Error on {fpath}: {e}")
237
 
238
+ if not fake_probs: raise ValueError("No frames processed.")
239
+
240
+ # Matching test_real.py simple mean logic for consistency
 
241
  video_fake_prob = float(np.mean(fake_probs))
 
242
  is_fake = video_fake_prob > 0.5
243
  avg_real = 1.0 - video_fake_prob
244
 
 
251
  "per_frame_scores": [round(p * 100, 1) for p in fake_probs],
252
  }
253
 
 
 
 
 
 
254
  @app.on_event("startup")
255
  async def startup_event():
256
  try:
 
258
  except Exception as e:
259
  logger.error(f"Startup model load failed: {e}")
260
 
 
261
  @app.get("/health")
262
  def health_check():
263
  return {
264
  "status": "ok",
265
+ "model": "DINO-G50 SupCon Detector",
 
266
  "model_loaded": CHECKPOINT_PATH.exists(),
267
  }
268
 
 
269
  @app.post("/predict")
270
  async def predict(file: UploadFile = File(...)):
271
  allowed_exts = {".mp4", ".mov", ".avi", ".mkv", ".jpg", ".jpeg", ".png", ".webp"}
272
  ext = Path(file.filename).suffix.lower() if file.filename else ""
 
273
  if ext not in allowed_exts:
274
+ raise HTTPException(400, f"Unsupported file type '{ext}'.")
275
 
276
  content = await file.read()
277
  size_mb = len(content) / (1024 * 1024)
 
282
  temp_dir = Path(tempfile.gettempdir()) / f"deepshield_{job_id}"
283
  frames_dir = temp_dir / "frames"
284
  frames_dir.mkdir(parents=True, exist_ok=True)
285
+ file_path = temp_dir / f"input{ext}"
286
 
287
  try:
288
+ with open(file_path, "wb") as f:
289
  f.write(content)
290
  del content
 
291
  model = load_model()
292
+
 
293
  if ext in {".mp4", ".mov", ".avi", ".mkv"}:
294
+ frame_paths = extract_frames(str(file_path), str(frames_dir))
 
 
295
  else:
296
  img_path = frames_dir / f"frame_0000{ext}"
297
+ shutil.copy(file_path, img_path)
298
  frame_paths = [str(img_path)]
299
 
300
+ if not frame_paths: raise HTTPException(422, "Failed to extract frames.")
301
+
302
  result = run_inference(model, frame_paths)
303
+ result.update({"filename": file.filename, "file_size_mb": round(size_mb, 2)})
 
 
 
 
304
  return JSONResponse(content=result)
 
 
 
 
 
305
  except Exception as e:
306
+ logger.error(f"Error: {e}", exc_info=True)
307
+ raise HTTPException(500, str(e))
308
  finally:
309
  shutil.rmtree(temp_dir, ignore_errors=True)
 
 
310
 
 
 
 
311
  app.mount("/", StaticFiles(directory="static", html=True), name="static")