MrTsp commited on
Commit
f7c45de
Β·
1 Parent(s): 4e349b7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +325 -0
app.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DeepShield AI β€” Full-Stack FastAPI Backend
3
+ Serves the frontend UI + deepfake detection API from one HF Space.
4
+
5
+ Routes:
6
+ GET / β†’ Serves index.html (the web UI)
7
+ GET /health β†’ JSON health check
8
+ POST /predict β†’ Video upload β†’ REAL/FAKE prediction
9
+ """
10
+
11
+ import os
12
+ import sys
13
+ import uuid
14
+ import shutil
15
+ import logging
16
+ import tempfile
17
+ from pathlib import Path
18
+ from functools import lru_cache
19
+
20
+ import cv2
21
+ import torch
22
+ import torch.nn as nn
23
+ import numpy as np
24
+ from PIL import Image, ImageFile
25
+ from facenet_pytorch import MTCNN
26
+ from fastapi import FastAPI, File, UploadFile, HTTPException
27
+ from fastapi.middleware.cors import CORSMiddleware
28
+ from fastapi.responses import JSONResponse, FileResponse
29
+ from fastapi.staticfiles import StaticFiles
30
+ import torchvision.transforms as T
31
+
32
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
33
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
34
+ logger = logging.getLogger(__name__)
35
+
36
+ # ─────────────────────────────────────────────
37
+ # Model Definition (self-contained)
38
+ # ─────────────────────────────────────────────
39
+
40
+ class DINOv2Extractor(nn.Module):
41
+ def __init__(self, variant: str = "dinov2_vitb14"):
42
+ super().__init__()
43
+ logger.info(f"Loading {variant} from torch.hub...")
44
+ self.backbone = torch.hub.load(
45
+ "facebookresearch/dinov2", variant, pretrained=True
46
+ )
47
+ self.feature_dim = 768
48
+ for p in self.backbone.parameters():
49
+ p.requires_grad = False
50
+ logger.info("DINOv2 backbone loaded (frozen).")
51
+
52
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
53
+ return self.backbone(x)
54
+
55
+
56
+ class MLPClassifier(nn.Module):
57
+ def __init__(self, input_dim: int = 1536, num_classes: int = 2, dropout: float = 0.3):
58
+ super().__init__()
59
+ self.net = nn.Sequential(
60
+ nn.Linear(input_dim, 512),
61
+ nn.LayerNorm(512),
62
+ nn.GELU(),
63
+ nn.Dropout(dropout),
64
+ nn.Linear(512, 256),
65
+ nn.LayerNorm(256),
66
+ nn.GELU(),
67
+ nn.Dropout(dropout / 2),
68
+ nn.Linear(256, num_classes),
69
+ )
70
+
71
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
72
+ return self.net(x)
73
+
74
+
75
+ class DeepfakeDetector(nn.Module):
76
+ def __init__(self, dual_input: bool = True):
77
+ super().__init__()
78
+ self.dual_input = dual_input
79
+ self.extractor = DINOv2Extractor()
80
+ feat_dim = 1536 if dual_input else 768
81
+ self.classifier = MLPClassifier(input_dim=feat_dim)
82
+
83
+ def forward(self, full_img: torch.Tensor, face_img: torch.Tensor = None) -> torch.Tensor:
84
+ full_feat = self.extractor(full_img)
85
+ if self.dual_input and face_img is not None:
86
+ face_feat = self.extractor(face_img)
87
+ feats = torch.cat([full_feat, face_feat], dim=1)
88
+ else:
89
+ feats = full_feat
90
+ return self.classifier(feats)
91
+
92
+
93
+ # ─────────────────────────────────────────────
94
+ # App Setup
95
+ # ─────────────────────────────────────────────
96
+
97
+ app = FastAPI(
98
+ title="DeepShield AI",
99
+ description="DINO-G50 deepfake detector β€” full-stack web app",
100
+ version="2.0.0",
101
+ )
102
+
103
+ app.add_middleware(
104
+ CORSMiddleware,
105
+ allow_origins=["*"],
106
+ allow_credentials=True,
107
+ allow_methods=["*"],
108
+ allow_headers=["*"],
109
+ )
110
+
111
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
112
+ CHECKPOINT_PATH = Path("best_model.pth")
113
+ MAX_FRAMES = 20
114
+ MAX_FILE_MB = 30
115
+ MAX_DURATION_SEC = 60
116
+
117
+ # MTCNN face detector (initialized once, CPU is fine for detection)
118
+ try:
119
+ MTCNN_DETECTOR = MTCNN(
120
+ image_size=224,
121
+ margin=40,
122
+ min_face_size=20,
123
+ thresholds=[0.6, 0.7, 0.9],
124
+ keep_all=False,
125
+ device='cpu'
126
+ )
127
+ logger.info("MTCNN face detector initialized.")
128
+ except Exception as e:
129
+ MTCNN_DETECTOR = None
130
+ logger.warning(f"MTCNN init failed (will use full frame fallback): {e}")
131
+
132
+ TRANSFORM = T.Compose([
133
+ T.Resize((224, 224)),
134
+ T.ToTensor(),
135
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
136
+ ])
137
+
138
+
139
+ def detect_face_crop(img: Image.Image) -> Image.Image:
140
+ """Detect face with MTCNN and return cropped face, or None if not found."""
141
+ if MTCNN_DETECTOR is None:
142
+ return None
143
+ try:
144
+ # MTCNN returns the cropped tensor directly
145
+ face_tensor = MTCNN_DETECTOR(img)
146
+ if face_tensor is not None:
147
+ # Convert tensor back to PIL Image
148
+ face_np = face_tensor.permute(1, 2, 0).numpy()
149
+ face_np = ((face_np * 128) + 127.5).clip(0, 255).astype(np.uint8)
150
+ return Image.fromarray(face_np)
151
+ except Exception:
152
+ pass
153
+ return None
154
+
155
+
156
+ @lru_cache(maxsize=1)
157
+ def load_model() -> DeepfakeDetector:
158
+ if not CHECKPOINT_PATH.exists():
159
+ raise RuntimeError("best_model.pth not found. Upload it to this HF Space.")
160
+
161
+ logger.info(f"Loading checkpoint on {DEVICE}...")
162
+ ckpt = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
163
+ state = ckpt.get("model_state_dict", ckpt)
164
+
165
+ mlp_w = state.get("classifier.net.0.weight", None)
166
+ dual = (mlp_w.shape[1] == 1536) if mlp_w is not None else True
167
+
168
+ model = DeepfakeDetector(dual_input=dual).to(DEVICE)
169
+ model.load_state_dict(state, strict=False)
170
+ model.eval()
171
+ logger.info(f"Model ready. dual_input={dual}, device={DEVICE}")
172
+ return model
173
+
174
+
175
+ def extract_frames(video_path: str, output_dir: str, num_frames: int = MAX_FRAMES) -> list:
176
+ cap = cv2.VideoCapture(video_path)
177
+ if not cap.isOpened():
178
+ raise ValueError("Cannot open video file.")
179
+
180
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
181
+ fps = cap.get(cv2.CAP_PROP_FPS) or 25
182
+ duration = total_frames / fps if fps > 0 else 0
183
+
184
+ if duration > MAX_DURATION_SEC:
185
+ cap.release()
186
+ raise ValueError(f"Video too long ({duration:.0f}s). Max: {MAX_DURATION_SEC}s.")
187
+
188
+ if total_frames <= 0:
189
+ total_frames = int(fps * MAX_DURATION_SEC)
190
+
191
+ step = max(1, total_frames // num_frames)
192
+ target_indices = set(range(0, total_frames, step))
193
+ saved_paths = []
194
+ frame_idx = 0
195
+
196
+ while len(saved_paths) < num_frames:
197
+ ret, frame = cap.read()
198
+ if not ret:
199
+ break
200
+ if frame_idx in target_indices:
201
+ rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
202
+ path = os.path.join(output_dir, f"frame_{len(saved_paths):04d}.jpg")
203
+ Image.fromarray(rgb).save(path, quality=90)
204
+ saved_paths.append(path)
205
+ frame_idx += 1
206
+
207
+ cap.release()
208
+ return saved_paths
209
+
210
+
211
+ def run_inference(model: DeepfakeDetector, frame_paths: list) -> dict:
212
+ fake_probs = []
213
+ with torch.no_grad():
214
+ for fpath in frame_paths:
215
+ try:
216
+ img = Image.open(fpath).convert("RGB")
217
+ t_img = TRANSFORM(img).unsqueeze(0).to(DEVICE)
218
+
219
+ # Try MTCNN face detection first (same as test_real.py)
220
+ t_face = t_img # default fallback = full frame
221
+ if model.dual_input:
222
+ face_crop = detect_face_crop(img)
223
+ if face_crop is not None:
224
+ t_face = TRANSFORM(face_crop).unsqueeze(0).to(DEVICE)
225
+ # else: fallback to full image (face not detected)
226
+
227
+ logits = model(t_img, t_face if model.dual_input else None)
228
+ prob = torch.softmax(logits, dim=1)[0, 1].item()
229
+ fake_probs.append(prob)
230
+ except Exception as e:
231
+ logger.warning(f"Skipping frame {fpath}: {e}")
232
+
233
+ if not fake_probs:
234
+ raise ValueError("No frames could be processed.")
235
+
236
+ avg_fake = float(np.mean(fake_probs))
237
+ avg_real = 1.0 - avg_fake
238
+
239
+ return {
240
+ "verdict": "FAKE" if avg_fake > 0.5 else "REAL",
241
+ "fake_probability": round(avg_fake * 100, 1),
242
+ "real_probability": round(avg_real * 100, 1),
243
+ "frame_count": len(fake_probs),
244
+ "confidence": round(max(avg_fake, avg_real) * 100, 1),
245
+ "per_frame_scores": [round(p * 100, 1) for p in fake_probs],
246
+ }
247
+
248
+
249
+ # ─────────────────────────────────────────────
250
+ # API Routes (must be defined BEFORE static mount)
251
+ # ─────────────────────────────────────────────
252
+
253
+ @app.on_event("startup")
254
+ async def startup_event():
255
+ try:
256
+ load_model()
257
+ except Exception as e:
258
+ logger.error(f"Startup model load failed: {e}")
259
+
260
+
261
+ @app.get("/health")
262
+ def health_check():
263
+ return {
264
+ "status": "ok",
265
+ "model": "DINO-G50 Deepfake Detector",
266
+ "device": str(DEVICE),
267
+ "model_loaded": CHECKPOINT_PATH.exists(),
268
+ }
269
+
270
+
271
+ @app.post("/predict")
272
+ async def predict(file: UploadFile = File(...)):
273
+ allowed_exts = {".mp4", ".mov", ".avi", ".mkv"}
274
+ ext = Path(file.filename).suffix.lower() if file.filename else ""
275
+
276
+ if ext not in allowed_exts:
277
+ raise HTTPException(400, f"Unsupported type '{ext}'. Use: {allowed_exts}")
278
+
279
+ content = await file.read()
280
+ size_mb = len(content) / (1024 * 1024)
281
+ if size_mb > MAX_FILE_MB:
282
+ raise HTTPException(413, f"File too large ({size_mb:.1f} MB). Max: {MAX_FILE_MB} MB.")
283
+
284
+ job_id = str(uuid.uuid4())[:8]
285
+ temp_dir = Path(tempfile.gettempdir()) / f"deepshield_{job_id}"
286
+ frames_dir = temp_dir / "frames"
287
+ frames_dir.mkdir(parents=True, exist_ok=True)
288
+ video_path = temp_dir / f"input{ext}"
289
+
290
+ try:
291
+ with open(video_path, "wb") as f:
292
+ f.write(content)
293
+ del content
294
+
295
+ model = load_model()
296
+ logger.info(f"[{job_id}] Processing: {file.filename} ({size_mb:.1f} MB)")
297
+
298
+ frame_paths = extract_frames(str(video_path), str(frames_dir))
299
+ if not frame_paths:
300
+ raise HTTPException(422, "No frames could be extracted from video.")
301
+
302
+ result = run_inference(model, frame_paths)
303
+ result["filename"] = file.filename
304
+ result["file_size_mb"] = round(size_mb, 2)
305
+ result["job_id"] = job_id
306
+
307
+ logger.info(f"[{job_id}] Result: {result['verdict']} ({result['fake_probability']}% fake)")
308
+ return JSONResponse(content=result)
309
+
310
+ except HTTPException:
311
+ raise
312
+ except ValueError as e:
313
+ raise HTTPException(422, str(e))
314
+ except Exception as e:
315
+ logger.error(f"[{job_id}] Error: {e}", exc_info=True)
316
+ raise HTTPException(500, f"Internal error: {str(e)}")
317
+ finally:
318
+ shutil.rmtree(temp_dir, ignore_errors=True)
319
+ logger.info(f"[{job_id}] Cleanup done.")
320
+
321
+
322
+ # ─────────────────────────────────────────────
323
+ # Static Frontend (mounted LAST β€” serves index.html at /)
324
+ # ─────────────────────────────────────────────
325
+ app.mount("/", StaticFiles(directory="static", html=True), name="static")