""" DeepShield AI — Full-Stack FastAPI Backend (SupCon Version) Serves the frontend UI + deepfake detection API from one HF Space. 98.3% Accuracy — Supervised Contrastive Learning Model """ import os import sys import uuid import shutil import logging import tempfile from pathlib import Path from functools import lru_cache import cv2 import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from PIL import Image, ImageFile from facenet_pytorch import MTCNN from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, FileResponse from fastapi.staticfiles import StaticFiles import torchvision.transforms as T ImageFile.LOAD_TRUNCATED_IMAGES = True logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") logger = logging.getLogger(__name__) # ───────────────────────────────────────────── # Model Definition (Self-Contained SupCon Architecture) # ───────────────────────────────────────────── class DINOv2Extractor(nn.Module): def __init__(self, variant: str = "dinov2_vitb14"): super().__init__() logger.info(f"Loading {variant} from torch.hub...") self.backbone = torch.hub.load( "facebookresearch/dinov2", variant, pretrained=True ) self.feature_dim = 768 for p in self.backbone.parameters(): p.requires_grad = False def forward(self, x: torch.Tensor) -> torch.Tensor: return self.backbone(x) class MLPClassifier(nn.Module): def __init__(self, input_dim: int, num_classes: int = 2, dropout: float = 0.4): super().__init__() self.net = nn.Sequential( nn.Linear(input_dim, 512), nn.BatchNorm1d(512), nn.GELU(), nn.Dropout(dropout), nn.Linear(512, 256), nn.BatchNorm1d(256), nn.GELU(), nn.Dropout(dropout * 0.75), nn.Linear(256, num_classes), ) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.net(x) class SupConDeepfakeClassifier(nn.Module): """ Supervised Contrastive Version of the DINOv2 Deepfake Detector. Matches the architecture used in scripts3. """ def __init__(self, dual_input: bool = True, proj_dim: int = 128): super().__init__() self.dual_input = dual_input self.extractor = DINOv2Extractor() feat_dim = 768 classifier_input = feat_dim * 2 if dual_input else feat_dim # Projection Head for SupCon (needed for weight loading, even if not used in inference) self.head = nn.Sequential( nn.Linear(classifier_input, classifier_input), nn.BatchNorm1d(classifier_input), nn.ReLU(inplace=True), nn.Linear(classifier_input, proj_dim) ) self.classifier = MLPClassifier(classifier_input) def forward(self, full_image: torch.Tensor, face_crop: torch.Tensor = None): full_feat = self.extractor(full_image) if self.dual_input: face_feat = self.extractor(face_crop if face_crop is not None else full_image) features = torch.cat([full_feat, face_feat], dim=1) else: features = full_feat logits = self.classifier(features) # We don't need 'proj' for inference return logits # ───────────────────────────────────────────── # App Setup # ───────────────────────────────────────────── app = FastAPI( title="DeepShield AI", description="DINO-G50 deepfake detector — SupCon SOTA version", version="3.0.0", ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") CHECKPOINT_PATH = Path("best_model.pth") MAX_FRAMES = 20 MAX_FILE_MB = 30 MAX_DURATION_SEC = 60 # MTCNN face detector try: MTCNN_DETECTOR = MTCNN( image_size=224, margin=40, keep_all=False, post_process=False, device='cpu' ) logger.info("MTCNN face detector initialized.") except Exception as e: MTCNN_DETECTOR = None logger.warning(f"MTCNN init failed: {e}") TRANSFORM = T.Compose([ T.Resize((224, 224)), T.CenterCrop(224), T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) def detect_face_crop(img: Image.Image) -> Image.Image: if MTCNN_DETECTOR is None: return None try: boxes, probs = MTCNN_DETECTOR.detect(img) if boxes is None or len(boxes) == 0: return None best_idx = np.argmax(probs) if probs[best_idx] < 0.9: return None box = boxes[best_idx] w, h = img.size x1, y1, x2, y2 = [int(b) for b in box] margin = 40 x1, y1 = max(0, x1-margin), max(0, y1-margin) x2, y2 = min(w, x2+margin), min(h, y2+margin) face = img.crop((x1, y1, x2, y2)) return face.resize((224, 224), Image.LANCZOS) except Exception: pass return None @lru_cache(maxsize=1) def load_model() -> SupConDeepfakeClassifier: if not CHECKPOINT_PATH.exists(): fallback = Path("models3/checkpoints/best_model.pth") if fallback.exists(): shutil.copy(fallback, CHECKPOINT_PATH) else: raise RuntimeError("best_model.pth not found. Please upload the model from models3/.") logger.info(f"Loading SupCon checkpoint on {DEVICE}...") ckpt = torch.load(CHECKPOINT_PATH, map_location=DEVICE) state = ckpt.get("model_state_dict", ckpt) # Auto-detect dual input from weights mlp_w = state.get("classifier.net.0.weight", None) dual = (mlp_w.shape[1] == 1536) if mlp_w is not None else True model = SupConDeepfakeClassifier(dual_input=dual).to(DEVICE) model.load_state_dict(state, strict=False) model.eval() logger.info(f"SupCon Model ready. dual_input={dual}, device={DEVICE}") return model def extract_frames(video_path: str, output_dir: str, num_frames: int = MAX_FRAMES) -> list: cap = cv2.VideoCapture(video_path) if not cap.isOpened(): raise ValueError("Cannot open video file.") total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) if total_frames <= 0: total_frames = 300 step = max(1, total_frames // num_frames) target_indices = set(range(0, total_frames, step)) saved_paths = [] frame_idx = 0 while len(saved_paths) < num_frames: ret, frame = cap.read() if not ret: break if frame_idx in target_indices: rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) path = os.path.join(output_dir, f"frame_{len(saved_paths):04d}.jpg") Image.fromarray(rgb).save(path, quality=90) saved_paths.append(path) frame_idx += 1 cap.release() return saved_paths def run_inference(model: SupConDeepfakeClassifier, frame_paths: list) -> dict: fake_probs = [] with torch.no_grad(): for fpath in frame_paths: try: img = Image.open(fpath).convert("RGB") t_img = TRANSFORM(img).unsqueeze(0).to(DEVICE) t_face = t_img if model.dual_input: face_crop = detect_face_crop(img) if face_crop is not None: t_face = TRANSFORM(face_crop).unsqueeze(0).to(DEVICE) logits = model(t_img, t_face if model.dual_input else None) prob = torch.softmax(logits, dim=1)[0, 1].item() fake_probs.append(prob) except Exception as e: logger.warning(f"Error on {fpath}: {e}") if not fake_probs: raise ValueError("No frames processed.") # Matching test_real.py simple mean logic for consistency video_fake_prob = float(np.mean(fake_probs)) is_fake = video_fake_prob > 0.5 avg_real = 1.0 - video_fake_prob return { "verdict": "FAKE" if is_fake else "REAL", "fake_probability": round(video_fake_prob * 100, 1), "real_probability": round(avg_real * 100, 1), "frame_count": len(fake_probs), "confidence": round(max(video_fake_prob, avg_real) * 100, 1), "per_frame_scores": [round(p * 100, 1) for p in fake_probs], } @app.on_event("startup") async def startup_event(): try: load_model() except Exception as e: logger.error(f"Startup model load failed: {e}") @app.get("/health") def health_check(): return { "status": "ok", "model": "DINO-G50 SupCon Detector", "model_loaded": CHECKPOINT_PATH.exists(), } @app.post("/predict") async def predict(file: UploadFile = File(...)): allowed_exts = {".mp4", ".mov", ".avi", ".mkv", ".jpg", ".jpeg", ".png", ".webp"} ext = Path(file.filename).suffix.lower() if file.filename else "" if ext not in allowed_exts: raise HTTPException(400, f"Unsupported file type '{ext}'.") content = await file.read() size_mb = len(content) / (1024 * 1024) if size_mb > MAX_FILE_MB: raise HTTPException(413, f"File too large ({size_mb:.1f} MB). Max: {MAX_FILE_MB} MB.") job_id = str(uuid.uuid4())[:8] temp_dir = Path(tempfile.gettempdir()) / f"deepshield_{job_id}" frames_dir = temp_dir / "frames" frames_dir.mkdir(parents=True, exist_ok=True) file_path = temp_dir / f"input{ext}" try: with open(file_path, "wb") as f: f.write(content) del content model = load_model() if ext in {".mp4", ".mov", ".avi", ".mkv"}: frame_paths = extract_frames(str(file_path), str(frames_dir)) else: img_path = frames_dir / f"frame_0000{ext}" shutil.copy(file_path, img_path) frame_paths = [str(img_path)] if not frame_paths: raise HTTPException(422, "Failed to extract frames.") result = run_inference(model, frame_paths) result.update({"filename": file.filename, "file_size_mb": round(size_mb, 2)}) return JSONResponse(content=result) except Exception as e: logger.error(f"Error: {e}", exc_info=True) raise HTTPException(500, str(e)) finally: shutil.rmtree(temp_dir, ignore_errors=True) app.mount("/", StaticFiles(directory="static", html=True), name="static")