MrTsp commited on
Commit
e8f9f2d
·
verified ·
1 Parent(s): 0226ac6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +146 -98
app.py CHANGED
@@ -1,11 +1,7 @@
1
  """
2
  DeepShield AI — Full-Stack FastAPI Backend
3
  Serves the frontend UI + deepfake detection API from one HF Space.
4
-
5
- Routes:
6
- GET / → Serves index.html (the web UI)
7
- GET /health → JSON health check
8
- POST /predict → Video/Photo upload → REAL/FAKE prediction
9
  """
10
 
11
  import os
@@ -15,27 +11,81 @@ import shutil
15
  import logging
16
  import tempfile
17
  from pathlib import Path
 
18
 
19
  import cv2
20
  import torch
 
21
  import numpy as np
22
  from PIL import Image, ImageFile
 
23
  from fastapi import FastAPI, File, UploadFile, HTTPException
24
  from fastapi.middleware.cors import CORSMiddleware
25
  from fastapi.responses import JSONResponse, FileResponse
26
  from fastapi.staticfiles import StaticFiles
 
27
 
28
  ImageFile.LOAD_TRUNCATED_IMAGES = True
29
  logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
30
  logger = logging.getLogger(__name__)
31
 
32
- # --- START OF IMPORT EXACTLY AS test_real.py ---
33
- sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
34
- from src.utils.helpers import load_config, get_device
35
- from src.data.transforms import get_val_transforms
36
- from src.models.classifier import DeepfakeClassifier
37
- from src.utils.face_detect import FaceDetector
38
- # --- END OF IMPORT ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  app = FastAPI(
41
  title="DeepShield AI",
@@ -51,30 +101,69 @@ app.add_middleware(
51
  allow_headers=["*"],
52
  )
53
 
54
- DEVICE = get_device('cuda' if torch.cuda.is_available() else 'cpu')
55
  CHECKPOINT_PATH = Path("best_model.pth")
56
  MAX_FRAMES = 20
57
  MAX_FILE_MB = 30
58
  MAX_DURATION_SEC = 60
59
 
60
- # Model caching
61
- global_model = None
62
- global_face_detector = None
63
- global_transform = None
64
- global_dual_input = True
65
-
66
- def load_model_and_detector():
67
- global global_model, global_face_detector, global_transform, global_dual_input
68
- if global_model is not None:
69
- return global_model, global_face_detector, global_transform, global_dual_input
70
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  ckpt_path_to_load = None
72
  if not CHECKPOINT_PATH.exists():
73
- base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
74
- fallback_path = os.path.join(base_dir, 'models2/checkpoints/best_model.pth')
75
- if not os.path.exists(fallback_path):
76
- fallback_path = os.path.join(base_dir, 'models2/checkpoints/best_mlp.pth')
77
-
78
  if os.path.exists(fallback_path):
79
  ckpt_path_to_load = fallback_path
80
  else:
@@ -83,58 +172,18 @@ def load_model_and_detector():
83
  ckpt_path_to_load = str(CHECKPOINT_PATH)
84
 
85
  logger.info(f"Loading checkpoint on {DEVICE} from {ckpt_path_to_load} ...")
86
- checkpoint = torch.load(ckpt_path_to_load, map_location=str(DEVICE))
87
-
88
- is_fast_mlp = ckpt_path_to_load.endswith('best_mlp.pth')
89
- dual_input = True
90
- if is_fast_mlp and 'feat_dim' in checkpoint:
91
- dual_input = (checkpoint['feat_dim'] == 1536)
92
-
93
- base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
94
- config_path = os.path.join(base_dir, 'configs/config2.yaml')
95
- if os.path.exists(config_path):
96
- config = load_config(config_path)
97
- else:
98
- # Fallback if config is missing in HF Spaces standalone
99
- config = {
100
- 'model': {'dino_variant': 'dinov2_vitb14', 'unfreeze_last_n_blocks': 0, 'dual_input': True},
101
- 'face_detection': {'margin': 40, 'confidence_threshold': 0.9},
102
- 'data': {'image_size': 224}
103
- }
104
-
105
- if not is_fast_mlp:
106
- dual_input = config['model'].get('dual_input', True)
107
-
108
- face_detector = FaceDetector(
109
- margin=config['face_detection']['margin'],
110
- confidence_threshold=config['face_detection']['confidence_threshold'],
111
- image_size=config['data']['image_size'],
112
- device=str(DEVICE)
113
- ) if dual_input else None
114
-
115
- model = DeepfakeClassifier(
116
- dino_variant=config['model']['dino_variant'],
117
- freeze_backbone=not is_fast_mlp,
118
- unfreeze_last_n_blocks=config['model']['unfreeze_last_n_blocks'] if not is_fast_mlp else 0,
119
- dual_input=dual_input
120
- )
121
-
122
- if is_fast_mlp:
123
- model.classifier.load_state_dict(checkpoint['model_state_dict'])
124
- else:
125
- model.load_state_dict(checkpoint['model_state_dict'] if 'model_state_dict' in checkpoint else checkpoint)
126
-
127
- model = model.to(DEVICE).eval()
128
- transform = get_val_transforms(config['data']['image_size'])
129
 
130
- global_model = model
131
- global_face_detector = face_detector
132
- global_transform = transform
133
- global_dual_input = dual_input
134
-
135
- logger.info(f"Model ready. dual_input={dual_input}, device={DEVICE}, is_fast_mlp={is_fast_mlp}")
136
- return model, face_detector, transform, dual_input
137
 
 
 
 
 
 
138
 
139
  def extract_frames(video_path: str, temp_dir: str, num_frames: int = MAX_FRAMES) -> list:
140
  cap = cv2.VideoCapture(video_path)
@@ -157,33 +206,30 @@ def extract_frames(video_path: str, temp_dir: str, num_frames: int = MAX_FRAMES)
157
  return saved
158
 
159
 
160
- def run_inference(frame_paths: list) -> dict:
161
- model, face_detector, transform, dual_input = load_model_and_detector()
162
  fake_probs = []
163
-
164
  with torch.no_grad():
165
- for i, p in enumerate(frame_paths):
166
  try:
167
- img = Image.open(p).convert('RGB')
168
- t_img = transform(img).unsqueeze(0).to(DEVICE)
169
  t_face = t_img
170
 
171
- if dual_input:
172
- face = face_detector.detect_and_crop(p)
173
- if face is not None:
174
- t_face = transform(face).unsqueeze(0).to(DEVICE)
175
-
176
- probs = torch.softmax(model(t_img, t_face if dual_input else None), dim=1)
177
- prob = probs[0, 1].item()
178
  fake_probs.append(prob)
179
  except Exception as e:
180
- logger.warning(f"Skipping frame {p}: {e}")
181
 
182
  if not fake_probs:
183
  raise ValueError("No frames could be processed.")
184
 
185
  video_fake_prob = float(np.mean(fake_probs))
186
-
187
  is_fake = video_fake_prob > 0.5
188
  avg_real = 1.0 - video_fake_prob
189
 
@@ -196,15 +242,17 @@ def run_inference(frame_paths: list) -> dict:
196
  "per_frame_scores": [round(p * 100, 1) for p in fake_probs],
197
  }
198
 
 
 
 
199
 
200
  @app.on_event("startup")
201
  async def startup_event():
202
  try:
203
- load_model_and_detector()
204
  except Exception as e:
205
  logger.error(f"Startup model load failed: {e}")
206
 
207
-
208
  @app.get("/health")
209
  def health_check():
210
  try:
@@ -219,7 +267,6 @@ def health_check():
219
  "model_loaded": model_loaded,
220
  }
221
 
222
-
223
  @app.post("/predict")
224
  async def predict(file: UploadFile = File(...)):
225
  allowed_exts = {".mp4", ".mov", ".avi", ".mkv", ".jpg", ".jpeg", ".png", ".webp"}
@@ -244,6 +291,7 @@ async def predict(file: UploadFile = File(...)):
244
  f.write(content)
245
  del content
246
 
 
247
  logger.info(f"[{job_id}] Processing: {file.filename} ({size_mb:.1f} MB)")
248
 
249
  if ext in {".mp4", ".mov", ".avi", ".mkv"}:
@@ -255,7 +303,7 @@ async def predict(file: UploadFile = File(...)):
255
  shutil.copy(video_path, img_path)
256
  frame_paths = [str(img_path)]
257
 
258
- result = run_inference(frame_paths)
259
  result["filename"] = file.filename
260
  result["file_size_mb"] = round(size_mb, 2)
261
  result["job_id"] = job_id
 
1
  """
2
  DeepShield AI — Full-Stack FastAPI Backend
3
  Serves the frontend UI + deepfake detection API from one HF Space.
4
+ Self-contained version with exact architectural parity to test_real.py
 
 
 
 
5
  """
6
 
7
  import os
 
11
  import logging
12
  import tempfile
13
  from pathlib import Path
14
+ from functools import lru_cache
15
 
16
  import cv2
17
  import torch
18
+ import torch.nn as nn
19
  import numpy as np
20
  from PIL import Image, ImageFile
21
+ from facenet_pytorch import MTCNN
22
  from fastapi import FastAPI, File, UploadFile, HTTPException
23
  from fastapi.middleware.cors import CORSMiddleware
24
  from fastapi.responses import JSONResponse, FileResponse
25
  from fastapi.staticfiles import StaticFiles
26
+ import torchvision.transforms as T
27
 
28
  ImageFile.LOAD_TRUNCATED_IMAGES = True
29
  logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
30
  logger = logging.getLogger(__name__)
31
 
32
+ # -------------------------------------------------------------
33
+ # EXACT PARITY MODEL DEFINITIONS (Copied from src/ to be standalone)
34
+ # -------------------------------------------------------------
35
+
36
+ class DINOv2Extractor(nn.Module):
37
+ def __init__(self, variant: str = 'dinov2_vitb14'):
38
+ super().__init__()
39
+ self.embed_dim = 768
40
+ logger.info(f"Loading {variant} from torch.hub ...")
41
+ self.backbone = torch.hub.load(
42
+ 'facebookresearch/dinov2', variant, pretrained=True,
43
+ )
44
+ logger.info("DINOv2 loaded.")
45
+ for p in self.backbone.parameters():
46
+ p.requires_grad = False
47
+
48
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
49
+ return self.backbone(x)
50
+
51
+ class MLPClassifier(nn.Module):
52
+ def __init__(self, input_dim: int = 1536, num_classes: int = 2, dropout: float = 0.4):
53
+ super().__init__()
54
+ self.net = nn.Sequential(
55
+ nn.Linear(input_dim, 512),
56
+ nn.BatchNorm1d(512),
57
+ nn.GELU(),
58
+ nn.Dropout(dropout),
59
+ nn.Linear(512, 256),
60
+ nn.BatchNorm1d(256),
61
+ nn.GELU(),
62
+ nn.Dropout(dropout * 0.75),
63
+ nn.Linear(256, num_classes),
64
+ )
65
+
66
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
67
+ return self.net(x)
68
+
69
+ class DeepfakeDetector(nn.Module):
70
+ def __init__(self, dual_input: bool = True):
71
+ super().__init__()
72
+ self.dual_input = dual_input
73
+ self.extractor = DINOv2Extractor('dinov2_vitb14')
74
+ feat_dim = 1536 if dual_input else 768
75
+ self.classifier = MLPClassifier(feat_dim)
76
+
77
+ def forward(self, full_image: torch.Tensor, face_crop: torch.Tensor = None) -> torch.Tensor:
78
+ full_feat = self.extractor(full_image)
79
+ if self.dual_input:
80
+ face_feat = self.extractor(face_crop if face_crop is not None else full_image)
81
+ features = torch.cat([full_feat, face_feat], dim=1)
82
+ else:
83
+ features = full_feat
84
+ return self.classifier(features)
85
+
86
+ # -------------------------------------------------------------
87
+ # APP SETTINGS & SETUP
88
+ # -------------------------------------------------------------
89
 
90
  app = FastAPI(
91
  title="DeepShield AI",
 
101
  allow_headers=["*"],
102
  )
103
 
104
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
105
  CHECKPOINT_PATH = Path("best_model.pth")
106
  MAX_FRAMES = 20
107
  MAX_FILE_MB = 30
108
  MAX_DURATION_SEC = 60
109
 
110
+ # MTCNN face detector setup to mimic src/utils/face_detect.py precisely
111
+ try:
112
+ MTCNN_DETECTOR = MTCNN(
113
+ image_size=224,
114
+ margin=40,
115
+ keep_all=False,
116
+ post_process=False,
117
+ device='cpu'
118
+ )
119
+ logger.info("MTCNN face detector initialized.")
120
+ except Exception as e:
121
+ MTCNN_DETECTOR = None
122
+ logger.warning(f"MTCNN init failed (will use fallback): {e}")
123
+
124
+ # Exact transform replication
125
+ TRANSFORM = T.Compose([
126
+ T.Resize((224, 224)),
127
+ T.CenterCrop(224),
128
+ T.ToTensor(),
129
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
130
+ ])
131
+
132
+ def detect_face_crop(img: Image.Image) -> Image.Image:
133
+ if MTCNN_DETECTOR is None:
134
+ return None
135
+ try:
136
+ boxes, probs = MTCNN_DETECTOR.detect(img)
137
+ if boxes is None or len(boxes) == 0:
138
+ return None
139
+
140
+ best_idx = np.argmax(probs)
141
+ best_prob = probs[best_idx]
142
+ if best_prob < 0.9:
143
+ return None
144
+
145
+ box = boxes[best_idx]
146
+ w, h = img.size
147
+ x1, y1, x2, y2 = [int(b) for b in box]
148
+ margin = 40
149
+
150
+ x1 = max(0, x1 - margin)
151
+ y1 = max(0, y1 - margin)
152
+ x2 = min(w, x2 + margin)
153
+ y2 = min(h, y2 + margin)
154
+
155
+ face = img.crop((x1, y1, x2, y2))
156
+ return face.resize((224, 224), Image.LANCZOS)
157
+ except Exception:
158
+ pass
159
+ return None
160
+
161
+ @lru_cache(maxsize=1)
162
+ def load_model() -> DeepfakeDetector:
163
+ # First check default path, then fallback if possible
164
  ckpt_path_to_load = None
165
  if not CHECKPOINT_PATH.exists():
166
+ fallback_path = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'models2/checkpoints/best_model.pth')
 
 
 
 
167
  if os.path.exists(fallback_path):
168
  ckpt_path_to_load = fallback_path
169
  else:
 
172
  ckpt_path_to_load = str(CHECKPOINT_PATH)
173
 
174
  logger.info(f"Loading checkpoint on {DEVICE} from {ckpt_path_to_load} ...")
175
+ ckpt = torch.load(ckpt_path_to_load, map_location=DEVICE)
176
+ state = ckpt.get("model_state_dict", ckpt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
+ # Determine architecture
179
+ mlp_w = state.get("classifier.net.0.weight", None)
180
+ dual = (mlp_w.shape[1] == 1536) if mlp_w is not None else True
 
 
 
 
181
 
182
+ model = DeepfakeDetector(dual_input=dual).to(DEVICE)
183
+ model.load_state_dict(state, strict=False)
184
+ model.eval()
185
+ logger.info(f"Model ready. dual_input={dual}, device={DEVICE}")
186
+ return model
187
 
188
  def extract_frames(video_path: str, temp_dir: str, num_frames: int = MAX_FRAMES) -> list:
189
  cap = cv2.VideoCapture(video_path)
 
206
  return saved
207
 
208
 
209
+ def run_inference(model: DeepfakeDetector, frame_paths: list) -> dict:
 
210
  fake_probs = []
 
211
  with torch.no_grad():
212
+ for fpath in frame_paths:
213
  try:
214
+ img = Image.open(fpath).convert("RGB")
215
+ t_img = TRANSFORM(img).unsqueeze(0).to(DEVICE)
216
  t_face = t_img
217
 
218
+ if model.dual_input:
219
+ face_crop = detect_face_crop(img)
220
+ if face_crop is not None:
221
+ t_face = TRANSFORM(face_crop).unsqueeze(0).to(DEVICE)
222
+
223
+ logits = model(t_img, t_face if model.dual_input else None)
224
+ prob = torch.softmax(logits, dim=1)[0, 1].item()
225
  fake_probs.append(prob)
226
  except Exception as e:
227
+ logger.warning(f"Skipping frame {fpath}: {e}")
228
 
229
  if not fake_probs:
230
  raise ValueError("No frames could be processed.")
231
 
232
  video_fake_prob = float(np.mean(fake_probs))
 
233
  is_fake = video_fake_prob > 0.5
234
  avg_real = 1.0 - video_fake_prob
235
 
 
242
  "per_frame_scores": [round(p * 100, 1) for p in fake_probs],
243
  }
244
 
245
+ # -------------------------------------------------------------
246
+ # API ROUTES
247
+ # -------------------------------------------------------------
248
 
249
  @app.on_event("startup")
250
  async def startup_event():
251
  try:
252
+ load_model()
253
  except Exception as e:
254
  logger.error(f"Startup model load failed: {e}")
255
 
 
256
  @app.get("/health")
257
  def health_check():
258
  try:
 
267
  "model_loaded": model_loaded,
268
  }
269
 
 
270
  @app.post("/predict")
271
  async def predict(file: UploadFile = File(...)):
272
  allowed_exts = {".mp4", ".mov", ".avi", ".mkv", ".jpg", ".jpeg", ".png", ".webp"}
 
291
  f.write(content)
292
  del content
293
 
294
+ model = load_model()
295
  logger.info(f"[{job_id}] Processing: {file.filename} ({size_mb:.1f} MB)")
296
 
297
  if ext in {".mp4", ".mov", ".avi", ".mkv"}:
 
303
  shutil.copy(video_path, img_path)
304
  frame_paths = [str(img_path)]
305
 
306
+ result = run_inference(model, frame_paths)
307
  result["filename"] = file.filename
308
  result["file_size_mb"] = round(size_mb, 2)
309
  result["job_id"] = job_id