MrTsp commited on
Commit
0226ac6
·
verified ·
1 Parent(s): e481e13

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +122 -178
app.py CHANGED
@@ -5,7 +5,7 @@ Serves the frontend UI + deepfake detection API from one HF Space.
5
  Routes:
6
  GET / → Serves index.html (the web UI)
7
  GET /health → JSON health check
8
- POST /predict → Video upload → REAL/FAKE prediction
9
  """
10
 
11
  import os
@@ -15,84 +15,27 @@ import shutil
15
  import logging
16
  import tempfile
17
  from pathlib import Path
18
- from functools import lru_cache
19
 
20
  import cv2
21
  import torch
22
- import torch.nn as nn
23
  import numpy as np
24
  from PIL import Image, ImageFile
25
- from facenet_pytorch import MTCNN
26
  from fastapi import FastAPI, File, UploadFile, HTTPException
27
  from fastapi.middleware.cors import CORSMiddleware
28
  from fastapi.responses import JSONResponse, FileResponse
29
  from fastapi.staticfiles import StaticFiles
30
- import torchvision.transforms as T
31
 
32
  ImageFile.LOAD_TRUNCATED_IMAGES = True
33
  logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
34
  logger = logging.getLogger(__name__)
35
 
36
- # ─────────────────────────────────────────────
37
- # Model Definition (self-contained)
38
- # ─────────────────────────────────────────────
39
-
40
- class DINOv2Extractor(nn.Module):
41
- def __init__(self, variant: str = "dinov2_vitb14"):
42
- super().__init__()
43
- logger.info(f"Loading {variant} from torch.hub...")
44
- self.backbone = torch.hub.load(
45
- "facebookresearch/dinov2", variant, pretrained=True
46
- )
47
- self.feature_dim = 768
48
- for p in self.backbone.parameters():
49
- p.requires_grad = False
50
- logger.info("DINOv2 backbone loaded (frozen).")
51
-
52
- def forward(self, x: torch.Tensor) -> torch.Tensor:
53
- return self.backbone(x)
54
-
55
-
56
- class MLPClassifier(nn.Module):
57
- def __init__(self, input_dim: int = 1536, num_classes: int = 2, dropout: float = 0.3):
58
- super().__init__()
59
- self.net = nn.Sequential(
60
- nn.Linear(input_dim, 512),
61
- nn.LayerNorm(512),
62
- nn.GELU(),
63
- nn.Dropout(dropout),
64
- nn.Linear(512, 256),
65
- nn.LayerNorm(256),
66
- nn.GELU(),
67
- nn.Dropout(dropout / 2),
68
- nn.Linear(256, num_classes),
69
- )
70
-
71
- def forward(self, x: torch.Tensor) -> torch.Tensor:
72
- return self.net(x)
73
-
74
-
75
- class DeepfakeDetector(nn.Module):
76
- def __init__(self, dual_input: bool = True):
77
- super().__init__()
78
- self.dual_input = dual_input
79
- self.extractor = DINOv2Extractor()
80
- feat_dim = 1536 if dual_input else 768
81
- self.classifier = MLPClassifier(input_dim=feat_dim)
82
-
83
- def forward(self, full_img: torch.Tensor, face_img: torch.Tensor = None) -> torch.Tensor:
84
- full_feat = self.extractor(full_img)
85
- if self.dual_input and face_img is not None:
86
- face_feat = self.extractor(face_img)
87
- feats = torch.cat([full_feat, face_feat], dim=1)
88
- else:
89
- feats = full_feat
90
- return self.classifier(feats)
91
-
92
-
93
- # ─────────────────────────────────────────────
94
- # App Setup
95
- # ─────────────────────────────────────────────
96
 
97
  app = FastAPI(
98
  title="DeepShield AI",
@@ -108,132 +51,137 @@ app.add_middleware(
108
  allow_headers=["*"],
109
  )
110
 
111
- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
112
  CHECKPOINT_PATH = Path("best_model.pth")
113
  MAX_FRAMES = 20
114
  MAX_FILE_MB = 30
115
  MAX_DURATION_SEC = 60
116
 
117
- # MTCNN face detector (initialized once, CPU is fine for detection)
118
- try:
119
- MTCNN_DETECTOR = MTCNN(
120
- image_size=224,
121
- margin=40,
122
- min_face_size=20,
123
- thresholds=[0.6, 0.7, 0.9],
124
- keep_all=False,
125
- device='cpu'
126
- )
127
- logger.info("MTCNN face detector initialized.")
128
- except Exception as e:
129
- MTCNN_DETECTOR = None
130
- logger.warning(f"MTCNN init failed (will use full frame fallback): {e}")
131
-
132
- TRANSFORM = T.Compose([
133
- T.Resize((224, 224)),
134
- T.ToTensor(),
135
- T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
136
- ])
137
-
138
-
139
- def detect_face_crop(img: Image.Image) -> Image.Image:
140
- """Detect face with MTCNN and return cropped face, or None if not found."""
141
- if MTCNN_DETECTOR is None:
142
- return None
143
- try:
144
- # MTCNN returns the cropped tensor directly
145
- face_tensor = MTCNN_DETECTOR(img)
146
- if face_tensor is not None:
147
- # Convert tensor back to PIL Image
148
- face_np = face_tensor.permute(1, 2, 0).numpy()
149
- face_np = ((face_np * 128) + 127.5).clip(0, 255).astype(np.uint8)
150
- return Image.fromarray(face_np)
151
- except Exception:
152
- pass
153
- return None
154
-
155
-
156
- @lru_cache(maxsize=1)
157
- def load_model() -> DeepfakeDetector:
158
- if not CHECKPOINT_PATH.exists():
159
- raise RuntimeError("best_model.pth not found. Upload it to this HF Space.")
160
-
161
- logger.info(f"Loading checkpoint on {DEVICE}...")
162
- ckpt = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
163
- state = ckpt.get("model_state_dict", ckpt)
164
-
165
- mlp_w = state.get("classifier.net.0.weight", None)
166
- dual = (mlp_w.shape[1] == 1536) if mlp_w is not None else True
167
 
168
- model = DeepfakeDetector(dual_input=dual).to(DEVICE)
169
- model.load_state_dict(state, strict=False)
170
- model.eval()
171
- logger.info(f"Model ready. dual_input={dual}, device={DEVICE}")
172
- return model
173
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
 
175
- def extract_frames(video_path: str, output_dir: str, num_frames: int = MAX_FRAMES) -> list:
176
- cap = cv2.VideoCapture(video_path)
177
- if not cap.isOpened():
178
- raise ValueError("Cannot open video file.")
179
-
180
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
181
- fps = cap.get(cv2.CAP_PROP_FPS) or 25
182
- duration = total_frames / fps if fps > 0 else 0
183
-
184
- if duration > MAX_DURATION_SEC:
185
- cap.release()
186
- raise ValueError(f"Video too long ({duration:.0f}s). Max: {MAX_DURATION_SEC}s.")
187
 
188
- if total_frames <= 0:
189
- total_frames = int(fps * MAX_DURATION_SEC)
190
 
191
- step = max(1, total_frames // num_frames)
192
- target_indices = set(range(0, total_frames, step))
193
- saved_paths = []
194
- frame_idx = 0
195
 
196
- while len(saved_paths) < num_frames:
 
 
 
 
 
 
 
 
197
  ret, frame = cap.read()
198
- if not ret:
199
- break
200
- if frame_idx in target_indices:
201
  rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
202
- path = os.path.join(output_dir, f"frame_{len(saved_paths):04d}.jpg")
203
- Image.fromarray(rgb).save(path, quality=90)
204
- saved_paths.append(path)
205
- frame_idx += 1
206
-
207
  cap.release()
208
- return saved_paths
209
 
210
 
211
- def run_inference(model: DeepfakeDetector, frame_paths: list) -> dict:
 
212
  fake_probs = []
 
213
  with torch.no_grad():
214
- for fpath in frame_paths:
215
  try:
216
- img = Image.open(fpath).convert("RGB")
217
- t_img = TRANSFORM(img).unsqueeze(0).to(DEVICE)
218
-
219
- # Try MTCNN face detection first (same as test_real.py)
220
- t_face = t_img # default fallback = full frame
221
- if model.dual_input:
222
- face_crop = detect_face_crop(img)
223
- if face_crop is not None:
224
- t_face = TRANSFORM(face_crop).unsqueeze(0).to(DEVICE)
225
- # else: fallback to full image (face not detected)
226
-
227
- logits = model(t_img, t_face if model.dual_input else None)
228
- prob = torch.softmax(logits, dim=1)[0, 1].item()
229
  fake_probs.append(prob)
230
  except Exception as e:
231
- logger.warning(f"Skipping frame {fpath}: {e}")
232
 
233
  if not fake_probs:
234
  raise ValueError("No frames could be processed.")
235
 
236
- # 1. Simple Aggregation (Mean) to match test_real.py
237
  video_fake_prob = float(np.mean(fake_probs))
238
 
239
  is_fake = video_fake_prob > 0.5
@@ -249,25 +197,26 @@ def run_inference(model: DeepfakeDetector, frame_paths: list) -> dict:
249
  }
250
 
251
 
252
- # ─────────────────────────────────────────────
253
- # API Routes (must be defined BEFORE static mount)
254
- # ─────────────────────────────────────────────
255
-
256
  @app.on_event("startup")
257
  async def startup_event():
258
  try:
259
- load_model()
260
  except Exception as e:
261
  logger.error(f"Startup model load failed: {e}")
262
 
263
 
264
  @app.get("/health")
265
  def health_check():
 
 
 
 
 
266
  return {
267
  "status": "ok",
268
  "model": "DINO-G50 Deepfake Detector",
269
  "device": str(DEVICE),
270
- "model_loaded": CHECKPOINT_PATH.exists(),
271
  }
272
 
273
 
@@ -295,7 +244,6 @@ async def predict(file: UploadFile = File(...)):
295
  f.write(content)
296
  del content
297
 
298
- model = load_model()
299
  logger.info(f"[{job_id}] Processing: {file.filename} ({size_mb:.1f} MB)")
300
 
301
  if ext in {".mp4", ".mov", ".avi", ".mkv"}:
@@ -307,7 +255,7 @@ async def predict(file: UploadFile = File(...)):
307
  shutil.copy(video_path, img_path)
308
  frame_paths = [str(img_path)]
309
 
310
- result = run_inference(model, frame_paths)
311
  result["filename"] = file.filename
312
  result["file_size_mb"] = round(size_mb, 2)
313
  result["job_id"] = job_id
@@ -326,8 +274,4 @@ async def predict(file: UploadFile = File(...)):
326
  shutil.rmtree(temp_dir, ignore_errors=True)
327
  logger.info(f"[{job_id}] Cleanup done.")
328
 
329
-
330
- # ──────────────────────────────────────���──────
331
- # Static Frontend (mounted LAST — serves index.html at /)
332
- # ─────────────────────────────────────────────
333
  app.mount("/", StaticFiles(directory="static", html=True), name="static")
 
5
  Routes:
6
  GET / → Serves index.html (the web UI)
7
  GET /health → JSON health check
8
+ POST /predict → Video/Photo upload → REAL/FAKE prediction
9
  """
10
 
11
  import os
 
15
  import logging
16
  import tempfile
17
  from pathlib import Path
 
18
 
19
  import cv2
20
  import torch
 
21
  import numpy as np
22
  from PIL import Image, ImageFile
 
23
  from fastapi import FastAPI, File, UploadFile, HTTPException
24
  from fastapi.middleware.cors import CORSMiddleware
25
  from fastapi.responses import JSONResponse, FileResponse
26
  from fastapi.staticfiles import StaticFiles
 
27
 
28
  ImageFile.LOAD_TRUNCATED_IMAGES = True
29
  logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
30
  logger = logging.getLogger(__name__)
31
 
32
+ # --- START OF IMPORT EXACTLY AS test_real.py ---
33
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
34
+ from src.utils.helpers import load_config, get_device
35
+ from src.data.transforms import get_val_transforms
36
+ from src.models.classifier import DeepfakeClassifier
37
+ from src.utils.face_detect import FaceDetector
38
+ # --- END OF IMPORT ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  app = FastAPI(
41
  title="DeepShield AI",
 
51
  allow_headers=["*"],
52
  )
53
 
54
+ DEVICE = get_device('cuda' if torch.cuda.is_available() else 'cpu')
55
  CHECKPOINT_PATH = Path("best_model.pth")
56
  MAX_FRAMES = 20
57
  MAX_FILE_MB = 30
58
  MAX_DURATION_SEC = 60
59
 
60
+ # Model caching
61
+ global_model = None
62
+ global_face_detector = None
63
+ global_transform = None
64
+ global_dual_input = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
+ def load_model_and_detector():
67
+ global global_model, global_face_detector, global_transform, global_dual_input
68
+ if global_model is not None:
69
+ return global_model, global_face_detector, global_transform, global_dual_input
 
70
 
71
+ ckpt_path_to_load = None
72
+ if not CHECKPOINT_PATH.exists():
73
+ base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
74
+ fallback_path = os.path.join(base_dir, 'models2/checkpoints/best_model.pth')
75
+ if not os.path.exists(fallback_path):
76
+ fallback_path = os.path.join(base_dir, 'models2/checkpoints/best_mlp.pth')
77
+
78
+ if os.path.exists(fallback_path):
79
+ ckpt_path_to_load = fallback_path
80
+ else:
81
+ raise RuntimeError("best_model.pth not found. Upload it to this HF Space.")
82
+ else:
83
+ ckpt_path_to_load = str(CHECKPOINT_PATH)
84
+
85
+ logger.info(f"Loading checkpoint on {DEVICE} from {ckpt_path_to_load} ...")
86
+ checkpoint = torch.load(ckpt_path_to_load, map_location=str(DEVICE))
87
+
88
+ is_fast_mlp = ckpt_path_to_load.endswith('best_mlp.pth')
89
+ dual_input = True
90
+ if is_fast_mlp and 'feat_dim' in checkpoint:
91
+ dual_input = (checkpoint['feat_dim'] == 1536)
92
+
93
+ base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
94
+ config_path = os.path.join(base_dir, 'configs/config2.yaml')
95
+ if os.path.exists(config_path):
96
+ config = load_config(config_path)
97
+ else:
98
+ # Fallback if config is missing in HF Spaces standalone
99
+ config = {
100
+ 'model': {'dino_variant': 'dinov2_vitb14', 'unfreeze_last_n_blocks': 0, 'dual_input': True},
101
+ 'face_detection': {'margin': 40, 'confidence_threshold': 0.9},
102
+ 'data': {'image_size': 224}
103
+ }
104
+
105
+ if not is_fast_mlp:
106
+ dual_input = config['model'].get('dual_input', True)
107
+
108
+ face_detector = FaceDetector(
109
+ margin=config['face_detection']['margin'],
110
+ confidence_threshold=config['face_detection']['confidence_threshold'],
111
+ image_size=config['data']['image_size'],
112
+ device=str(DEVICE)
113
+ ) if dual_input else None
114
+
115
+ model = DeepfakeClassifier(
116
+ dino_variant=config['model']['dino_variant'],
117
+ freeze_backbone=not is_fast_mlp,
118
+ unfreeze_last_n_blocks=config['model']['unfreeze_last_n_blocks'] if not is_fast_mlp else 0,
119
+ dual_input=dual_input
120
+ )
121
+
122
+ if is_fast_mlp:
123
+ model.classifier.load_state_dict(checkpoint['model_state_dict'])
124
+ else:
125
+ model.load_state_dict(checkpoint['model_state_dict'] if 'model_state_dict' in checkpoint else checkpoint)
126
+
127
+ model = model.to(DEVICE).eval()
128
+ transform = get_val_transforms(config['data']['image_size'])
129
 
130
+ global_model = model
131
+ global_face_detector = face_detector
132
+ global_transform = transform
133
+ global_dual_input = dual_input
 
 
 
 
 
 
 
 
134
 
135
+ logger.info(f"Model ready. dual_input={dual_input}, device={DEVICE}, is_fast_mlp={is_fast_mlp}")
136
+ return model, face_detector, transform, dual_input
137
 
 
 
 
 
138
 
139
+ def extract_frames(video_path: str, temp_dir: str, num_frames: int = MAX_FRAMES) -> list:
140
+ cap = cv2.VideoCapture(video_path)
141
+ if not cap.isOpened(): return []
142
+ total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
143
+ step = max(1, (total if total > 0 else 300) // num_frames)
144
+ indices = set(range(0, total if total > 0 else 300, step))
145
+
146
+ saved = []
147
+ for i in range(total if total > 0 else 300):
148
  ret, frame = cap.read()
149
+ if not ret: break
150
+ if i in indices:
 
151
  rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
152
+ path = os.path.join(temp_dir, f"frame_{len(saved):03d}.jpg")
153
+ Image.fromarray(rgb).save(path)
154
+ saved.append(path)
155
+ if len(saved) >= num_frames: break
 
156
  cap.release()
157
+ return saved
158
 
159
 
160
+ def run_inference(frame_paths: list) -> dict:
161
+ model, face_detector, transform, dual_input = load_model_and_detector()
162
  fake_probs = []
163
+
164
  with torch.no_grad():
165
+ for i, p in enumerate(frame_paths):
166
  try:
167
+ img = Image.open(p).convert('RGB')
168
+ t_img = transform(img).unsqueeze(0).to(DEVICE)
169
+ t_face = t_img
170
+
171
+ if dual_input:
172
+ face = face_detector.detect_and_crop(p)
173
+ if face is not None:
174
+ t_face = transform(face).unsqueeze(0).to(DEVICE)
175
+
176
+ probs = torch.softmax(model(t_img, t_face if dual_input else None), dim=1)
177
+ prob = probs[0, 1].item()
 
 
178
  fake_probs.append(prob)
179
  except Exception as e:
180
+ logger.warning(f"Skipping frame {p}: {e}")
181
 
182
  if not fake_probs:
183
  raise ValueError("No frames could be processed.")
184
 
 
185
  video_fake_prob = float(np.mean(fake_probs))
186
 
187
  is_fake = video_fake_prob > 0.5
 
197
  }
198
 
199
 
 
 
 
 
200
  @app.on_event("startup")
201
  async def startup_event():
202
  try:
203
+ load_model_and_detector()
204
  except Exception as e:
205
  logger.error(f"Startup model load failed: {e}")
206
 
207
 
208
  @app.get("/health")
209
  def health_check():
210
+ try:
211
+ model_loaded = CHECKPOINT_PATH.exists() or os.path.exists(os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'models2/checkpoints/best_model.pth'))
212
+ except:
213
+ model_loaded = False
214
+
215
  return {
216
  "status": "ok",
217
  "model": "DINO-G50 Deepfake Detector",
218
  "device": str(DEVICE),
219
+ "model_loaded": model_loaded,
220
  }
221
 
222
 
 
244
  f.write(content)
245
  del content
246
 
 
247
  logger.info(f"[{job_id}] Processing: {file.filename} ({size_mb:.1f} MB)")
248
 
249
  if ext in {".mp4", ".mov", ".avi", ".mkv"}:
 
255
  shutil.copy(video_path, img_path)
256
  frame_paths = [str(img_path)]
257
 
258
+ result = run_inference(frame_paths)
259
  result["filename"] = file.filename
260
  result["file_size_mb"] = round(size_mb, 2)
261
  result["job_id"] = job_id
 
274
  shutil.rmtree(temp_dir, ignore_errors=True)
275
  logger.info(f"[{job_id}] Cleanup done.")
276
 
 
 
 
 
277
  app.mount("/", StaticFiles(directory="static", html=True), name="static")