DeepShield-Web2 / app.py
MrTsp's picture
Update app.py
5a4ffe0 verified
"""
DeepShield AI — Full-Stack FastAPI Backend
Serves the frontend UI + deepfake detection API from one HF Space.
Self-contained version with exact architectural parity to test_real.py
"""
import os
import sys
import uuid
import shutil
import logging
import tempfile
from pathlib import Path
from functools import lru_cache
import cv2
import torch
import torch.nn as nn
import numpy as np
from PIL import Image, ImageFile
from facenet_pytorch import MTCNN
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, FileResponse
from fastapi.staticfiles import StaticFiles
import torchvision.transforms as T
ImageFile.LOAD_TRUNCATED_IMAGES = True
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger(__name__)
# -------------------------------------------------------------
# EXACT PARITY MODEL DEFINITIONS (Copied from src/ to be standalone)
# -------------------------------------------------------------
class DINOv2Extractor(nn.Module):
def __init__(self, variant: str = 'dinov2_vitb14'):
super().__init__()
self.embed_dim = 768
logger.info(f"Loading {variant} from torch.hub ...")
self.backbone = torch.hub.load(
'facebookresearch/dinov2', variant, pretrained=True,
)
logger.info("DINOv2 loaded.")
for p in self.backbone.parameters():
p.requires_grad = False
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.backbone(x)
class MLPClassifier(nn.Module):
def __init__(self, input_dim: int = 1536, num_classes: int = 2, dropout: float = 0.4):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, 512),
nn.BatchNorm1d(512),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(512, 256),
nn.BatchNorm1d(256),
nn.GELU(),
nn.Dropout(dropout * 0.75),
nn.Linear(256, num_classes),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
class SupConDeepfakeClassifier(nn.Module):
def __init__(self, dual_input: bool = True, proj_dim: int = 128):
super().__init__()
self.dual_input = dual_input
self.extractor = DINOv2Extractor('dinov2_vitb14')
feat_dim = 768
classifier_input = feat_dim * 2 if dual_input else feat_dim
self.head = nn.Sequential(
nn.Linear(classifier_input, classifier_input),
nn.BatchNorm1d(classifier_input),
nn.ReLU(inplace=True),
nn.Linear(classifier_input, proj_dim)
)
self.classifier = MLPClassifier(classifier_input)
def forward(self, full_image: torch.Tensor, face_crop: torch.Tensor = None) -> torch.Tensor:
full_feat = self.extractor(full_image)
if self.dual_input:
face_feat = self.extractor(face_crop if face_crop is not None else full_image)
features = torch.cat([full_feat, face_feat], dim=1)
else:
features = full_feat
return self.classifier(features)
# -------------------------------------------------------------
# APP SETTINGS & SETUP
# -------------------------------------------------------------
app = FastAPI(
title="DeepShield AI",
description="DINO-G50 deepfake detector — full-stack web app",
version="2.0.0",
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
CHECKPOINT_PATH = Path("best_model.pth")
MAX_FRAMES = 50
MAX_FILE_MB = 30
MAX_DURATION_SEC = 90
# MTCNN face detector setup to mimic src/utils/face_detect.py precisely
try:
MTCNN_DETECTOR = MTCNN(
image_size=224,
margin=40,
keep_all=False,
post_process=False,
device='cpu'
)
logger.info("MTCNN face detector initialized.")
except Exception as e:
MTCNN_DETECTOR = None
logger.warning(f"MTCNN init failed (will use fallback): {e}")
# Exact transform replication
TRANSFORM = T.Compose([
T.Resize((224, 224)),
T.CenterCrop(224),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def detect_face_crop(img: Image.Image) -> Image.Image:
if MTCNN_DETECTOR is None:
return None
try:
boxes, probs = MTCNN_DETECTOR.detect(img)
if boxes is None or len(boxes) == 0:
return None
best_idx = np.argmax(probs)
best_prob = probs[best_idx]
if best_prob < 0.9:
return None
box = boxes[best_idx]
w, h = img.size
x1, y1, x2, y2 = [int(b) for b in box]
margin = 40
x1 = max(0, x1 - margin)
y1 = max(0, y1 - margin)
x2 = min(w, x2 + margin)
y2 = min(h, y2 + margin)
face = img.crop((x1, y1, x2, y2))
return face.resize((224, 224), Image.LANCZOS)
except Exception:
pass
return None
@lru_cache(maxsize=1)
def load_model() -> SupConDeepfakeClassifier:
# First check default path, then fallback if possible
ckpt_path_to_load = None
if not CHECKPOINT_PATH.exists():
fallback_path = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'models2/checkpoints/best_model.pth')
if os.path.exists(fallback_path):
ckpt_path_to_load = fallback_path
else:
raise RuntimeError("best_model.pth not found. Upload it to this HF Space.")
else:
ckpt_path_to_load = str(CHECKPOINT_PATH)
logger.info(f"Loading checkpoint on {DEVICE} from {ckpt_path_to_load} ...")
ckpt = torch.load(ckpt_path_to_load, map_location=DEVICE)
state = ckpt.get("model_state_dict", ckpt)
# Determine architecture
mlp_w = state.get("classifier.net.0.weight", None)
dual = (mlp_w.shape[1] == 1536) if mlp_w is not None else True
model = SupConDeepfakeClassifier(dual_input=dual).to(DEVICE)
model.load_state_dict(state, strict=False)
model.eval()
logger.info(f"SupCon Model ready. dual_input={dual}, device={DEVICE}")
return model
def extract_frames(video_path: str, temp_dir: str, num_frames: int = MAX_FRAMES) -> list:
cap = cv2.VideoCapture(video_path)
if not cap.isOpened(): return []
total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
step = max(1, (total if total > 0 else 300) // num_frames)
indices = set(range(0, total if total > 0 else 300, step))
saved = []
for i in range(total if total > 0 else 300):
ret, frame = cap.read()
if not ret: break
if i in indices:
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
path = os.path.join(temp_dir, f"frame_{len(saved):03d}.jpg")
Image.fromarray(rgb).save(path)
saved.append(path)
if len(saved) >= num_frames: break
cap.release()
return saved
def run_inference(model: SupConDeepfakeClassifier, frame_paths: list) -> dict:
fake_probs = []
with torch.no_grad():
for fpath in frame_paths:
try:
img = Image.open(fpath).convert("RGB")
t_img = TRANSFORM(img).unsqueeze(0).to(DEVICE)
t_face = t_img
if model.dual_input:
face_crop = detect_face_crop(img)
if face_crop is not None:
t_face = TRANSFORM(face_crop).unsqueeze(0).to(DEVICE)
logits = model(t_img, t_face if model.dual_input else None)
prob = torch.softmax(logits, dim=1)[0, 1].item()
fake_probs.append(prob)
except Exception as e:
logger.warning(f"Skipping frame {fpath}: {e}")
if not fake_probs:
raise ValueError("No frames could be processed.")
# Top-50% Aggregation (Advanced logic requested)
sorted_probs = sorted(fake_probs, reverse=True)
top_k = max(1, len(sorted_probs) // 2)
video_fake_prob = float(np.mean(sorted_probs[:top_k]))
is_fake = video_fake_prob > 0.5
avg_real = 1.0 - video_fake_prob
return {
"verdict": "FAKE" if is_fake else "REAL",
"fake_probability": round(video_fake_prob * 100, 1),
"real_probability": round(avg_real * 100, 1),
"frame_count": len(fake_probs),
"confidence": round(max(video_fake_prob, avg_real) * 100, 1),
"per_frame_scores": [round(p * 100, 1) for p in fake_probs],
}
# -------------------------------------------------------------
# API ROUTES
# -------------------------------------------------------------
@app.on_event("startup")
async def startup_event():
try:
load_model()
except Exception as e:
logger.error(f"Startup model load failed: {e}")
@app.get("/health")
def health_check():
try:
model_loaded = CHECKPOINT_PATH.exists() or os.path.exists(os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'models2/checkpoints/best_model.pth'))
except:
model_loaded = False
return {
"status": "ok",
"model": "DINO-G50 Deepfake Detector",
"device": str(DEVICE),
"model_loaded": model_loaded,
}
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
allowed_exts = {".mp4", ".mov", ".avi", ".mkv", ".jpg", ".jpeg", ".png", ".webp"}
ext = Path(file.filename).suffix.lower() if file.filename else ""
if ext not in allowed_exts:
raise HTTPException(400, f"Unsupported type '{ext}'. Use: {allowed_exts}")
content = await file.read()
size_mb = len(content) / (1024 * 1024)
if size_mb > MAX_FILE_MB:
raise HTTPException(413, f"File too large ({size_mb:.1f} MB). Max: {MAX_FILE_MB} MB.")
job_id = str(uuid.uuid4())[:8]
temp_dir = Path(tempfile.gettempdir()) / f"deepshield_{job_id}"
frames_dir = temp_dir / "frames"
frames_dir.mkdir(parents=True, exist_ok=True)
video_path = temp_dir / f"input{ext}"
try:
with open(video_path, "wb") as f:
f.write(content)
del content
model = load_model()
logger.info(f"[{job_id}] Processing: {file.filename} ({size_mb:.1f} MB)")
if ext in {".mp4", ".mov", ".avi", ".mkv"}:
frame_paths = extract_frames(str(video_path), str(frames_dir))
if not frame_paths:
raise HTTPException(422, "No frames could be extracted from video.")
else:
img_path = frames_dir / f"frame_0000{ext}"
shutil.copy(video_path, img_path)
frame_paths = [str(img_path)]
result = run_inference(model, frame_paths)
result["filename"] = file.filename
result["file_size_mb"] = round(size_mb, 2)
result["job_id"] = job_id
logger.info(f"[{job_id}] Result: {result['verdict']} ({result['fake_probability']}% fake)")
return JSONResponse(content=result)
except HTTPException:
raise
except ValueError as e:
raise HTTPException(422, str(e))
except Exception as e:
logger.error(f"[{job_id}] Error: {e}", exc_info=True)
raise HTTPException(500, f"Internal error: {str(e)}")
finally:
shutil.rmtree(temp_dir, ignore_errors=True)
logger.info(f"[{job_id}] Cleanup done.")
app.mount("/", StaticFiles(directory="static", html=True), name="static")