MrTsp commited on
Commit
f9633ad
·
1 Parent(s): 8a71d57

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -284
app.py DELETED
@@ -1,284 +0,0 @@
1
- """
2
- DeepShield AI — Full-Stack FastAPI Backend
3
- Serves the frontend UI + deepfake detection API from one HF Space.
4
-
5
- Routes:
6
- GET / → Serves index.html (the web UI)
7
- GET /health → JSON health check
8
- POST /predict → Video upload → REAL/FAKE prediction
9
- """
10
-
11
- import os
12
- import sys
13
- import uuid
14
- import shutil
15
- import logging
16
- import tempfile
17
- from pathlib import Path
18
- from functools import lru_cache
19
-
20
- import cv2
21
- import torch
22
- import torch.nn as nn
23
- import numpy as np
24
- from PIL import Image, ImageFile
25
- from fastapi import FastAPI, File, UploadFile, HTTPException
26
- from fastapi.middleware.cors import CORSMiddleware
27
- from fastapi.responses import JSONResponse, FileResponse
28
- from fastapi.staticfiles import StaticFiles
29
- import torchvision.transforms as T
30
-
31
- ImageFile.LOAD_TRUNCATED_IMAGES = True
32
- logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
33
- logger = logging.getLogger(__name__)
34
-
35
- # ─────────────────────────────────────────────
36
- # Model Definition (self-contained)
37
- # ─────────────────────────────────────────────
38
-
39
- class DINOv2Extractor(nn.Module):
40
- def __init__(self, variant: str = "dinov2_vitb14"):
41
- super().__init__()
42
- logger.info(f"Loading {variant} from torch.hub...")
43
- self.backbone = torch.hub.load(
44
- "facebookresearch/dinov2", variant, pretrained=True
45
- )
46
- self.feature_dim = 768
47
- for p in self.backbone.parameters():
48
- p.requires_grad = False
49
- logger.info("DINOv2 backbone loaded (frozen).")
50
-
51
- def forward(self, x: torch.Tensor) -> torch.Tensor:
52
- return self.backbone(x)
53
-
54
-
55
- class MLPClassifier(nn.Module):
56
- def __init__(self, input_dim: int = 1536, num_classes: int = 2, dropout: float = 0.3):
57
- super().__init__()
58
- self.net = nn.Sequential(
59
- nn.Linear(input_dim, 512),
60
- nn.LayerNorm(512),
61
- nn.GELU(),
62
- nn.Dropout(dropout),
63
- nn.Linear(512, 256),
64
- nn.LayerNorm(256),
65
- nn.GELU(),
66
- nn.Dropout(dropout / 2),
67
- nn.Linear(256, num_classes),
68
- )
69
-
70
- def forward(self, x: torch.Tensor) -> torch.Tensor:
71
- return self.net(x)
72
-
73
-
74
- class DeepfakeDetector(nn.Module):
75
- def __init__(self, dual_input: bool = True):
76
- super().__init__()
77
- self.dual_input = dual_input
78
- self.extractor = DINOv2Extractor()
79
- feat_dim = 1536 if dual_input else 768
80
- self.classifier = MLPClassifier(input_dim=feat_dim)
81
-
82
- def forward(self, full_img: torch.Tensor, face_img: torch.Tensor = None) -> torch.Tensor:
83
- full_feat = self.extractor(full_img)
84
- if self.dual_input and face_img is not None:
85
- face_feat = self.extractor(face_img)
86
- feats = torch.cat([full_feat, face_feat], dim=1)
87
- else:
88
- feats = full_feat
89
- return self.classifier(feats)
90
-
91
-
92
- # ─────────────────────────────────────────────
93
- # App Setup
94
- # ─────────────────────────────────────────────
95
-
96
- app = FastAPI(
97
- title="DeepShield AI",
98
- description="DINO-G50 deepfake detector — full-stack web app",
99
- version="2.0.0",
100
- )
101
-
102
- app.add_middleware(
103
- CORSMiddleware,
104
- allow_origins=["*"],
105
- allow_credentials=True,
106
- allow_methods=["*"],
107
- allow_headers=["*"],
108
- )
109
-
110
- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
111
- CHECKPOINT_PATH = Path("best_model.pth")
112
- MAX_FRAMES = 20
113
- MAX_FILE_MB = 30
114
- MAX_DURATION_SEC = 60
115
-
116
- TRANSFORM = T.Compose([
117
- T.Resize((224, 224)),
118
- T.ToTensor(),
119
- T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
120
- ])
121
-
122
-
123
- @lru_cache(maxsize=1)
124
- def load_model() -> DeepfakeDetector:
125
- if not CHECKPOINT_PATH.exists():
126
- raise RuntimeError("best_model.pth not found. Upload it to this HF Space.")
127
-
128
- logger.info(f"Loading checkpoint on {DEVICE}...")
129
- ckpt = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
130
- state = ckpt.get("model_state_dict", ckpt)
131
-
132
- mlp_w = state.get("classifier.net.0.weight", None)
133
- dual = (mlp_w.shape[1] == 1536) if mlp_w is not None else True
134
-
135
- model = DeepfakeDetector(dual_input=dual).to(DEVICE)
136
- model.load_state_dict(state, strict=False)
137
- model.eval()
138
- logger.info(f"Model ready. dual_input={dual}, device={DEVICE}")
139
- return model
140
-
141
-
142
- def extract_frames(video_path: str, output_dir: str, num_frames: int = MAX_FRAMES) -> list:
143
- cap = cv2.VideoCapture(video_path)
144
- if not cap.isOpened():
145
- raise ValueError("Cannot open video file.")
146
-
147
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
148
- fps = cap.get(cv2.CAP_PROP_FPS) or 25
149
- duration = total_frames / fps if fps > 0 else 0
150
-
151
- if duration > MAX_DURATION_SEC:
152
- cap.release()
153
- raise ValueError(f"Video too long ({duration:.0f}s). Max: {MAX_DURATION_SEC}s.")
154
-
155
- if total_frames <= 0:
156
- total_frames = int(fps * MAX_DURATION_SEC)
157
-
158
- step = max(1, total_frames // num_frames)
159
- target_indices = set(range(0, total_frames, step))
160
- saved_paths = []
161
- frame_idx = 0
162
-
163
- while len(saved_paths) < num_frames:
164
- ret, frame = cap.read()
165
- if not ret:
166
- break
167
- if frame_idx in target_indices:
168
- rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
169
- path = os.path.join(output_dir, f"frame_{len(saved_paths):04d}.jpg")
170
- Image.fromarray(rgb).save(path, quality=90)
171
- saved_paths.append(path)
172
- frame_idx += 1
173
-
174
- cap.release()
175
- return saved_paths
176
-
177
-
178
- def run_inference(model: DeepfakeDetector, frame_paths: list) -> dict:
179
- fake_probs = []
180
- with torch.no_grad():
181
- for fpath in frame_paths:
182
- try:
183
- img = Image.open(fpath).convert("RGB")
184
- t_img = TRANSFORM(img).unsqueeze(0).to(DEVICE)
185
- t_face = t_img if model.dual_input else None
186
- logits = model(t_img, t_face)
187
- prob = torch.softmax(logits, dim=1)[0, 1].item()
188
- fake_probs.append(prob)
189
- except Exception as e:
190
- logger.warning(f"Skipping frame {fpath}: {e}")
191
-
192
- if not fake_probs:
193
- raise ValueError("No frames could be processed.")
194
-
195
- avg_fake = float(np.mean(fake_probs))
196
- avg_real = 1.0 - avg_fake
197
-
198
- return {
199
- "verdict": "FAKE" if avg_fake > 0.5 else "REAL",
200
- "fake_probability": round(avg_fake * 100, 1),
201
- "real_probability": round(avg_real * 100, 1),
202
- "frame_count": len(fake_probs),
203
- "confidence": round(max(avg_fake, avg_real) * 100, 1),
204
- "per_frame_scores": [round(p * 100, 1) for p in fake_probs],
205
- }
206
-
207
-
208
- # ─────────────────────────────────────────────
209
- # API Routes (must be defined BEFORE static mount)
210
- # ─────────────────────────────────────────────
211
-
212
- @app.on_event("startup")
213
- async def startup_event():
214
- try:
215
- load_model()
216
- except Exception as e:
217
- logger.error(f"Startup model load failed: {e}")
218
-
219
-
220
- @app.get("/health")
221
- def health_check():
222
- return {
223
- "status": "ok",
224
- "model": "DINO-G50 Deepfake Detector",
225
- "device": str(DEVICE),
226
- "model_loaded": CHECKPOINT_PATH.exists(),
227
- }
228
-
229
-
230
- @app.post("/predict")
231
- async def predict(file: UploadFile = File(...)):
232
- allowed_exts = {".mp4", ".mov", ".avi", ".mkv"}
233
- ext = Path(file.filename).suffix.lower() if file.filename else ""
234
-
235
- if ext not in allowed_exts:
236
- raise HTTPException(400, f"Unsupported type '{ext}'. Use: {allowed_exts}")
237
-
238
- content = await file.read()
239
- size_mb = len(content) / (1024 * 1024)
240
- if size_mb > MAX_FILE_MB:
241
- raise HTTPException(413, f"File too large ({size_mb:.1f} MB). Max: {MAX_FILE_MB} MB.")
242
-
243
- job_id = str(uuid.uuid4())[:8]
244
- temp_dir = Path(tempfile.gettempdir()) / f"deepshield_{job_id}"
245
- frames_dir = temp_dir / "frames"
246
- frames_dir.mkdir(parents=True, exist_ok=True)
247
- video_path = temp_dir / f"input{ext}"
248
-
249
- try:
250
- with open(video_path, "wb") as f:
251
- f.write(content)
252
- del content
253
-
254
- model = load_model()
255
- logger.info(f"[{job_id}] Processing: {file.filename} ({size_mb:.1f} MB)")
256
-
257
- frame_paths = extract_frames(str(video_path), str(frames_dir))
258
- if not frame_paths:
259
- raise HTTPException(422, "No frames could be extracted from video.")
260
-
261
- result = run_inference(model, frame_paths)
262
- result["filename"] = file.filename
263
- result["file_size_mb"] = round(size_mb, 2)
264
- result["job_id"] = job_id
265
-
266
- logger.info(f"[{job_id}] Result: {result['verdict']} ({result['fake_probability']}% fake)")
267
- return JSONResponse(content=result)
268
-
269
- except HTTPException:
270
- raise
271
- except ValueError as e:
272
- raise HTTPException(422, str(e))
273
- except Exception as e:
274
- logger.error(f"[{job_id}] Error: {e}", exc_info=True)
275
- raise HTTPException(500, f"Internal error: {str(e)}")
276
- finally:
277
- shutil.rmtree(temp_dir, ignore_errors=True)
278
- logger.info(f"[{job_id}] Cleanup done.")
279
-
280
-
281
- # ─────────────────────────────────────────────
282
- # Static Frontend (mounted LAST — serves index.html at /)
283
- # ─────────────────────────────────────────────
284
- app.mount("/", StaticFiles(directory="static", html=True), name="static")