MrTsp commited on
Commit
6b3c4f1
·
verified ·
1 Parent(s): 4dbf321

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +311 -0
app.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DeepShield AI — Full-Stack FastAPI Backend (SupCon Version)
3
+ Serves the frontend UI + deepfake detection API from one HF Space.
4
+ 98.3% Accuracy — Supervised Contrastive Learning Model
5
+ """
6
+
7
+ import os
8
+ import sys
9
+ import uuid
10
+ import shutil
11
+ import logging
12
+ import tempfile
13
+ from pathlib import Path
14
+ from functools import lru_cache
15
+
16
+ import cv2
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+ import numpy as np
21
+ from PIL import Image, ImageFile
22
+ from facenet_pytorch import MTCNN
23
+ from fastapi import FastAPI, File, UploadFile, HTTPException
24
+ from fastapi.middleware.cors import CORSMiddleware
25
+ from fastapi.responses import JSONResponse, FileResponse
26
+ from fastapi.staticfiles import StaticFiles
27
+ import torchvision.transforms as T
28
+
29
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
30
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
31
+ logger = logging.getLogger(__name__)
32
+
33
+ # ─────────────────────────────────────────────
34
+ # Model Definition (Self-Contained SupCon Architecture)
35
+ # ─────────────────────────────────────────────
36
+
37
+ class DINOv2Extractor(nn.Module):
38
+ def __init__(self, variant: str = "dinov2_vitb14"):
39
+ super().__init__()
40
+ logger.info(f"Loading {variant} from torch.hub...")
41
+ self.backbone = torch.hub.load(
42
+ "facebookresearch/dinov2", variant, pretrained=True
43
+ )
44
+ self.feature_dim = 768
45
+ for p in self.backbone.parameters():
46
+ p.requires_grad = False
47
+
48
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
49
+ return self.backbone(x)
50
+
51
+ class MLPClassifier(nn.Module):
52
+ def __init__(self, input_dim: int, num_classes: int = 2, dropout: float = 0.4):
53
+ super().__init__()
54
+ self.net = nn.Sequential(
55
+ nn.Linear(input_dim, 512),
56
+ nn.BatchNorm1d(512),
57
+ nn.GELU(),
58
+ nn.Dropout(dropout),
59
+ nn.Linear(512, 256),
60
+ nn.BatchNorm1d(256),
61
+ nn.GELU(),
62
+ nn.Dropout(dropout * 0.75),
63
+ nn.Linear(256, num_classes),
64
+ )
65
+
66
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
67
+ return self.net(x)
68
+
69
+ class SupConDeepfakeClassifier(nn.Module):
70
+ """
71
+ Supervised Contrastive Version of the DINOv2 Deepfake Detector.
72
+ Matches the architecture used in scripts3.
73
+ """
74
+ def __init__(self, dual_input: bool = True, proj_dim: int = 128):
75
+ super().__init__()
76
+ self.dual_input = dual_input
77
+ self.extractor = DINOv2Extractor()
78
+
79
+ feat_dim = 768
80
+ classifier_input = feat_dim * 2 if dual_input else feat_dim
81
+
82
+ # Projection Head for SupCon (needed for weight loading, even if not used in inference)
83
+ self.head = nn.Sequential(
84
+ nn.Linear(classifier_input, classifier_input),
85
+ nn.BatchNorm1d(classifier_input),
86
+ nn.ReLU(inplace=True),
87
+ nn.Linear(classifier_input, proj_dim)
88
+ )
89
+
90
+ self.classifier = MLPClassifier(classifier_input)
91
+
92
+ def forward(self, full_image: torch.Tensor, face_crop: torch.Tensor = None):
93
+ full_feat = self.extractor(full_image)
94
+ if self.dual_input:
95
+ face_feat = self.extractor(face_crop if face_crop is not None else full_image)
96
+ features = torch.cat([full_feat, face_feat], dim=1)
97
+ else:
98
+ features = full_feat
99
+
100
+ logits = self.classifier(features)
101
+ # We don't need 'proj' for inference
102
+ return logits
103
+
104
+ # ─────────────────────────────────────────────
105
+ # App Setup
106
+ # ─────────────────────────────────────────────
107
+
108
+ app = FastAPI(
109
+ title="DeepShield AI",
110
+ description="DINO-G50 deepfake detector — SupCon SOTA version",
111
+ version="3.0.0",
112
+ )
113
+
114
+ app.add_middleware(
115
+ CORSMiddleware,
116
+ allow_origins=["*"],
117
+ allow_credentials=True,
118
+ allow_methods=["*"],
119
+ allow_headers=["*"],
120
+ )
121
+
122
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
123
+ CHECKPOINT_PATH = Path("best_model.pth")
124
+ MAX_FRAMES = 20
125
+ MAX_FILE_MB = 30
126
+ MAX_DURATION_SEC = 60
127
+
128
+ # MTCNN face detector
129
+ try:
130
+ MTCNN_DETECTOR = MTCNN(
131
+ image_size=224,
132
+ margin=40,
133
+ keep_all=False,
134
+ post_process=False,
135
+ device='cpu'
136
+ )
137
+ logger.info("MTCNN face detector initialized.")
138
+ except Exception as e:
139
+ MTCNN_DETECTOR = None
140
+ logger.warning(f"MTCNN init failed: {e}")
141
+
142
+ TRANSFORM = T.Compose([
143
+ T.Resize((224, 224)),
144
+ T.CenterCrop(224),
145
+ T.ToTensor(),
146
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
147
+ ])
148
+
149
+ def detect_face_crop(img: Image.Image) -> Image.Image:
150
+ if MTCNN_DETECTOR is None:
151
+ return None
152
+ try:
153
+ boxes, probs = MTCNN_DETECTOR.detect(img)
154
+ if boxes is None or len(boxes) == 0:
155
+ return None
156
+
157
+ best_idx = np.argmax(probs)
158
+ if probs[best_idx] < 0.9:
159
+ return None
160
+
161
+ box = boxes[best_idx]
162
+ w, h = img.size
163
+ x1, y1, x2, y2 = [int(b) for b in box]
164
+ margin = 40
165
+ x1, y1 = max(0, x1-margin), max(0, y1-margin)
166
+ x2, y2 = min(w, x2+margin), min(h, y2+margin)
167
+
168
+ face = img.crop((x1, y1, x2, y2))
169
+ return face.resize((224, 224), Image.LANCZOS)
170
+ except Exception:
171
+ pass
172
+ return None
173
+
174
+ @lru_cache(maxsize=1)
175
+ def load_model() -> SupConDeepfakeClassifier:
176
+ if not CHECKPOINT_PATH.exists():
177
+ fallback = Path("models3/checkpoints/best_model.pth")
178
+ if fallback.exists():
179
+ shutil.copy(fallback, CHECKPOINT_PATH)
180
+ else:
181
+ raise RuntimeError("best_model.pth not found. Please upload the model from models3/.")
182
+
183
+ logger.info(f"Loading SupCon checkpoint on {DEVICE}...")
184
+ ckpt = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
185
+ state = ckpt.get("model_state_dict", ckpt)
186
+
187
+ # Auto-detect dual input from weights
188
+ mlp_w = state.get("classifier.net.0.weight", None)
189
+ dual = (mlp_w.shape[1] == 1536) if mlp_w is not None else True
190
+
191
+ model = SupConDeepfakeClassifier(dual_input=dual).to(DEVICE)
192
+ model.load_state_dict(state, strict=False)
193
+ model.eval()
194
+ logger.info(f"SupCon Model ready. dual_input={dual}, device={DEVICE}")
195
+ return model
196
+
197
+ def extract_frames(video_path: str, output_dir: str, num_frames: int = MAX_FRAMES) -> list:
198
+ cap = cv2.VideoCapture(video_path)
199
+ if not cap.isOpened():
200
+ raise ValueError("Cannot open video file.")
201
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
202
+ if total_frames <= 0: total_frames = 300
203
+ step = max(1, total_frames // num_frames)
204
+ target_indices = set(range(0, total_frames, step))
205
+ saved_paths = []
206
+ frame_idx = 0
207
+ while len(saved_paths) < num_frames:
208
+ ret, frame = cap.read()
209
+ if not ret: break
210
+ if frame_idx in target_indices:
211
+ rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
212
+ path = os.path.join(output_dir, f"frame_{len(saved_paths):04d}.jpg")
213
+ Image.fromarray(rgb).save(path, quality=90)
214
+ saved_paths.append(path)
215
+ frame_idx += 1
216
+ cap.release()
217
+ return saved_paths
218
+
219
+ def run_inference(model: SupConDeepfakeClassifier, frame_paths: list) -> dict:
220
+ fake_probs = []
221
+ with torch.no_grad():
222
+ for fpath in frame_paths:
223
+ try:
224
+ img = Image.open(fpath).convert("RGB")
225
+ t_img = TRANSFORM(img).unsqueeze(0).to(DEVICE)
226
+ t_face = t_img
227
+ if model.dual_input:
228
+ face_crop = detect_face_crop(img)
229
+ if face_crop is not None:
230
+ t_face = TRANSFORM(face_crop).unsqueeze(0).to(DEVICE)
231
+
232
+ logits = model(t_img, t_face if model.dual_input else None)
233
+ prob = torch.softmax(logits, dim=1)[0, 1].item()
234
+ fake_probs.append(prob)
235
+ except Exception as e:
236
+ logger.warning(f"Error on {fpath}: {e}")
237
+
238
+ if not fake_probs: raise ValueError("No frames processed.")
239
+
240
+ # Matching test_real.py simple mean logic for consistency
241
+ video_fake_prob = float(np.mean(fake_probs))
242
+ is_fake = video_fake_prob > 0.5
243
+ avg_real = 1.0 - video_fake_prob
244
+
245
+ return {
246
+ "verdict": "FAKE" if is_fake else "REAL",
247
+ "fake_probability": round(video_fake_prob * 100, 1),
248
+ "real_probability": round(avg_real * 100, 1),
249
+ "frame_count": len(fake_probs),
250
+ "confidence": round(max(video_fake_prob, avg_real) * 100, 1),
251
+ "per_frame_scores": [round(p * 100, 1) for p in fake_probs],
252
+ }
253
+
254
+ @app.on_event("startup")
255
+ async def startup_event():
256
+ try:
257
+ load_model()
258
+ except Exception as e:
259
+ logger.error(f"Startup model load failed: {e}")
260
+
261
+ @app.get("/health")
262
+ def health_check():
263
+ return {
264
+ "status": "ok",
265
+ "model": "DINO-G50 SupCon Detector",
266
+ "model_loaded": CHECKPOINT_PATH.exists(),
267
+ }
268
+
269
+ @app.post("/predict")
270
+ async def predict(file: UploadFile = File(...)):
271
+ allowed_exts = {".mp4", ".mov", ".avi", ".mkv", ".jpg", ".jpeg", ".png", ".webp"}
272
+ ext = Path(file.filename).suffix.lower() if file.filename else ""
273
+ if ext not in allowed_exts:
274
+ raise HTTPException(400, f"Unsupported file type '{ext}'.")
275
+
276
+ content = await file.read()
277
+ size_mb = len(content) / (1024 * 1024)
278
+ if size_mb > MAX_FILE_MB:
279
+ raise HTTPException(413, f"File too large ({size_mb:.1f} MB). Max: {MAX_FILE_MB} MB.")
280
+
281
+ job_id = str(uuid.uuid4())[:8]
282
+ temp_dir = Path(tempfile.gettempdir()) / f"deepshield_{job_id}"
283
+ frames_dir = temp_dir / "frames"
284
+ frames_dir.mkdir(parents=True, exist_ok=True)
285
+ file_path = temp_dir / f"input{ext}"
286
+
287
+ try:
288
+ with open(file_path, "wb") as f:
289
+ f.write(content)
290
+ del content
291
+ model = load_model()
292
+
293
+ if ext in {".mp4", ".mov", ".avi", ".mkv"}:
294
+ frame_paths = extract_frames(str(file_path), str(frames_dir))
295
+ else:
296
+ img_path = frames_dir / f"frame_0000{ext}"
297
+ shutil.copy(file_path, img_path)
298
+ frame_paths = [str(img_path)]
299
+
300
+ if not frame_paths: raise HTTPException(422, "Failed to extract frames.")
301
+
302
+ result = run_inference(model, frame_paths)
303
+ result.update({"filename": file.filename, "file_size_mb": round(size_mb, 2)})
304
+ return JSONResponse(content=result)
305
+ except Exception as e:
306
+ logger.error(f"Error: {e}", exc_info=True)
307
+ raise HTTPException(500, str(e))
308
+ finally:
309
+ shutil.rmtree(temp_dir, ignore_errors=True)
310
+
311
+ app.mount("/", StaticFiles(directory="static", html=True), name="static")