Spaces:
Sleeping
Sleeping
| """ | |
| DeepShield AI β Full-Stack FastAPI Backend (SupCon Version) | |
| Serves the frontend UI + deepfake detection API from one HF Space. | |
| 98.3% Accuracy β Supervised Contrastive Learning Model | |
| """ | |
| import os | |
| import sys | |
| import uuid | |
| import shutil | |
| import logging | |
| import tempfile | |
| from pathlib import Path | |
| from functools import lru_cache | |
| import cv2 | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from PIL import Image, ImageFile | |
| from facenet_pytorch import MTCNN | |
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse, FileResponse | |
| from fastapi.staticfiles import StaticFiles | |
| import torchvision.transforms as T | |
| ImageFile.LOAD_TRUNCATED_IMAGES = True | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") | |
| logger = logging.getLogger(__name__) | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # Model Definition (Self-Contained SupCon Architecture) | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| class DINOv2Extractor(nn.Module): | |
| def __init__(self, variant: str = "dinov2_vitb14"): | |
| super().__init__() | |
| logger.info(f"Loading {variant} from torch.hub...") | |
| self.backbone = torch.hub.load( | |
| "facebookresearch/dinov2", variant, pretrained=True | |
| ) | |
| self.feature_dim = 768 | |
| for p in self.backbone.parameters(): | |
| p.requires_grad = False | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.backbone(x) | |
| class MLPClassifier(nn.Module): | |
| def __init__(self, input_dim: int, num_classes: int = 2, dropout: float = 0.4): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Linear(input_dim, 512), | |
| nn.BatchNorm1d(512), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(512, 256), | |
| nn.BatchNorm1d(256), | |
| nn.GELU(), | |
| nn.Dropout(dropout * 0.75), | |
| nn.Linear(256, num_classes), | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.net(x) | |
| class SupConDeepfakeClassifier(nn.Module): | |
| """ | |
| Supervised Contrastive Version of the DINOv2 Deepfake Detector. | |
| Matches the architecture used in scripts3. | |
| """ | |
| def __init__(self, dual_input: bool = True, proj_dim: int = 128): | |
| super().__init__() | |
| self.dual_input = dual_input | |
| self.extractor = DINOv2Extractor() | |
| feat_dim = 768 | |
| classifier_input = feat_dim * 2 if dual_input else feat_dim | |
| # Projection Head for SupCon (needed for weight loading, even if not used in inference) | |
| self.head = nn.Sequential( | |
| nn.Linear(classifier_input, classifier_input), | |
| nn.BatchNorm1d(classifier_input), | |
| nn.ReLU(inplace=True), | |
| nn.Linear(classifier_input, proj_dim) | |
| ) | |
| self.classifier = MLPClassifier(classifier_input) | |
| def forward(self, full_image: torch.Tensor, face_crop: torch.Tensor = None): | |
| full_feat = self.extractor(full_image) | |
| if self.dual_input: | |
| face_feat = self.extractor(face_crop if face_crop is not None else full_image) | |
| features = torch.cat([full_feat, face_feat], dim=1) | |
| else: | |
| features = full_feat | |
| logits = self.classifier(features) | |
| # We don't need 'proj' for inference | |
| return logits | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # App Setup | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| app = FastAPI( | |
| title="DeepShield AI", | |
| description="DINO-G50 deepfake detector β SupCon SOTA version", | |
| version="3.0.0", | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| CHECKPOINT_PATH = Path("best_model.pth") | |
| MAX_FRAMES = 20 | |
| MAX_FILE_MB = 30 | |
| MAX_DURATION_SEC = 60 | |
| # MTCNN face detector | |
| try: | |
| MTCNN_DETECTOR = MTCNN( | |
| image_size=224, | |
| margin=40, | |
| keep_all=False, | |
| post_process=False, | |
| device='cpu' | |
| ) | |
| logger.info("MTCNN face detector initialized.") | |
| except Exception as e: | |
| MTCNN_DETECTOR = None | |
| logger.warning(f"MTCNN init failed: {e}") | |
| TRANSFORM = T.Compose([ | |
| T.Resize((224, 224)), | |
| T.CenterCrop(224), | |
| T.ToTensor(), | |
| T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| def detect_face_crop(img: Image.Image) -> Image.Image: | |
| if MTCNN_DETECTOR is None: | |
| return None | |
| try: | |
| boxes, probs = MTCNN_DETECTOR.detect(img) | |
| if boxes is None or len(boxes) == 0: | |
| return None | |
| best_idx = np.argmax(probs) | |
| if probs[best_idx] < 0.9: | |
| return None | |
| box = boxes[best_idx] | |
| w, h = img.size | |
| x1, y1, x2, y2 = [int(b) for b in box] | |
| margin = 40 | |
| x1, y1 = max(0, x1-margin), max(0, y1-margin) | |
| x2, y2 = min(w, x2+margin), min(h, y2+margin) | |
| face = img.crop((x1, y1, x2, y2)) | |
| return face.resize((224, 224), Image.LANCZOS) | |
| except Exception: | |
| pass | |
| return None | |
| def load_model() -> SupConDeepfakeClassifier: | |
| if not CHECKPOINT_PATH.exists(): | |
| fallback = Path("models3/checkpoints/best_model.pth") | |
| if fallback.exists(): | |
| shutil.copy(fallback, CHECKPOINT_PATH) | |
| else: | |
| raise RuntimeError("best_model.pth not found. Please upload the model from models3/.") | |
| logger.info(f"Loading SupCon checkpoint on {DEVICE}...") | |
| ckpt = torch.load(CHECKPOINT_PATH, map_location=DEVICE) | |
| state = ckpt.get("model_state_dict", ckpt) | |
| # Auto-detect dual input from weights | |
| mlp_w = state.get("classifier.net.0.weight", None) | |
| dual = (mlp_w.shape[1] == 1536) if mlp_w is not None else True | |
| model = SupConDeepfakeClassifier(dual_input=dual).to(DEVICE) | |
| model.load_state_dict(state, strict=False) | |
| model.eval() | |
| logger.info(f"SupCon Model ready. dual_input={dual}, device={DEVICE}") | |
| return model | |
| def extract_frames(video_path: str, output_dir: str, num_frames: int = MAX_FRAMES) -> list: | |
| cap = cv2.VideoCapture(video_path) | |
| if not cap.isOpened(): | |
| raise ValueError("Cannot open video file.") | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| if total_frames <= 0: total_frames = 300 | |
| step = max(1, total_frames // num_frames) | |
| target_indices = set(range(0, total_frames, step)) | |
| saved_paths = [] | |
| frame_idx = 0 | |
| while len(saved_paths) < num_frames: | |
| ret, frame = cap.read() | |
| if not ret: break | |
| if frame_idx in target_indices: | |
| rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| path = os.path.join(output_dir, f"frame_{len(saved_paths):04d}.jpg") | |
| Image.fromarray(rgb).save(path, quality=90) | |
| saved_paths.append(path) | |
| frame_idx += 1 | |
| cap.release() | |
| return saved_paths | |
| def run_inference(model: SupConDeepfakeClassifier, frame_paths: list) -> dict: | |
| fake_probs = [] | |
| with torch.no_grad(): | |
| for fpath in frame_paths: | |
| try: | |
| img = Image.open(fpath).convert("RGB") | |
| t_img = TRANSFORM(img).unsqueeze(0).to(DEVICE) | |
| t_face = t_img | |
| if model.dual_input: | |
| face_crop = detect_face_crop(img) | |
| if face_crop is not None: | |
| t_face = TRANSFORM(face_crop).unsqueeze(0).to(DEVICE) | |
| logits = model(t_img, t_face if model.dual_input else None) | |
| prob = torch.softmax(logits, dim=1)[0, 1].item() | |
| fake_probs.append(prob) | |
| except Exception as e: | |
| logger.warning(f"Error on {fpath}: {e}") | |
| if not fake_probs: raise ValueError("No frames processed.") | |
| # Matching test_real.py simple mean logic for consistency | |
| video_fake_prob = float(np.mean(fake_probs)) | |
| is_fake = video_fake_prob > 0.5 | |
| avg_real = 1.0 - video_fake_prob | |
| return { | |
| "verdict": "FAKE" if is_fake else "REAL", | |
| "fake_probability": round(video_fake_prob * 100, 1), | |
| "real_probability": round(avg_real * 100, 1), | |
| "frame_count": len(fake_probs), | |
| "confidence": round(max(video_fake_prob, avg_real) * 100, 1), | |
| "per_frame_scores": [round(p * 100, 1) for p in fake_probs], | |
| } | |
| async def startup_event(): | |
| try: | |
| load_model() | |
| except Exception as e: | |
| logger.error(f"Startup model load failed: {e}") | |
| def health_check(): | |
| return { | |
| "status": "ok", | |
| "model": "DINO-G50 SupCon Detector", | |
| "model_loaded": CHECKPOINT_PATH.exists(), | |
| } | |
| async def predict(file: UploadFile = File(...)): | |
| allowed_exts = {".mp4", ".mov", ".avi", ".mkv", ".jpg", ".jpeg", ".png", ".webp"} | |
| ext = Path(file.filename).suffix.lower() if file.filename else "" | |
| if ext not in allowed_exts: | |
| raise HTTPException(400, f"Unsupported file type '{ext}'.") | |
| content = await file.read() | |
| size_mb = len(content) / (1024 * 1024) | |
| if size_mb > MAX_FILE_MB: | |
| raise HTTPException(413, f"File too large ({size_mb:.1f} MB). Max: {MAX_FILE_MB} MB.") | |
| job_id = str(uuid.uuid4())[:8] | |
| temp_dir = Path(tempfile.gettempdir()) / f"deepshield_{job_id}" | |
| frames_dir = temp_dir / "frames" | |
| frames_dir.mkdir(parents=True, exist_ok=True) | |
| file_path = temp_dir / f"input{ext}" | |
| try: | |
| with open(file_path, "wb") as f: | |
| f.write(content) | |
| del content | |
| model = load_model() | |
| if ext in {".mp4", ".mov", ".avi", ".mkv"}: | |
| frame_paths = extract_frames(str(file_path), str(frames_dir)) | |
| else: | |
| img_path = frames_dir / f"frame_0000{ext}" | |
| shutil.copy(file_path, img_path) | |
| frame_paths = [str(img_path)] | |
| if not frame_paths: raise HTTPException(422, "Failed to extract frames.") | |
| result = run_inference(model, frame_paths) | |
| result.update({"filename": file.filename, "file_size_mb": round(size_mb, 2)}) | |
| return JSONResponse(content=result) | |
| except Exception as e: | |
| logger.error(f"Error: {e}", exc_info=True) | |
| raise HTTPException(500, str(e)) | |
| finally: | |
| shutil.rmtree(temp_dir, ignore_errors=True) | |
| app.mount("/", StaticFiles(directory="static", html=True), name="static") | |