Sectumsempra1402's picture
Initial deploy: FastAPI deepfake detection API
f827a8b
"""
Deepfake detector built on EfficientNet-B4 fine-tuned on DFDC / FF++
(HuggingFace: dima806/deepfake_vs_real_image_detection).
Pipeline:
1. Images β†’ OpenCV Haar-cascade face detection β†’ crop largest face β†’ classifier
(falls back to full image when no face is found)
2. Videos β†’ sample up to MAX_VIDEO_FRAMES evenly-spaced frames β†’
run image pipeline on each β†’ return mean score + frame count
"""
import logging
import os
import tempfile
from io import BytesIO
from typing import Tuple
import cv2
import numpy as np
import requests
import torch
from PIL import Image
from transformers import pipeline
logger = logging.getLogger(__name__)
MODEL_ID = "dima806/deepfake_vs_real_image_detection"
MAX_VIDEO_FRAMES = 12
REQUEST_TIMEOUT = 60
DEEPFAKE_LABEL_KEYWORDS = {"fake", "deepfake", "ai"}
FACE_PADDING = 30
MIN_FACE_SIZE = 40
class DeepfakeDetector:
def __init__(self) -> None:
hf_device = 0 if torch.cuda.is_available() else -1
logger.info("Loading face detector (OpenCV Haar cascade)…")
cascade_path = cv2.data.haarcascades + "haarcascade_frontalface_default.xml"
self.face_cascade = cv2.CascadeClassifier(cascade_path)
logger.info("Loading deepfake classifier (%s)…", MODEL_ID)
self.classifier = pipeline(
"image-classification",
model=MODEL_ID,
device=hf_device,
)
logger.info("Models ready (device=%s)", "cuda" if torch.cuda.is_available() else "cpu")
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def analyze_image(self, url: str) -> Tuple[float, int]:
"""Return (deepfake_probability, operations=1)."""
image = self._fetch_image(url)
score = self._score_frame(image)
return score, 1
def analyze_video(self, url: str) -> Tuple[float, int]:
"""Return (mean_deepfake_probability, frames_analyzed)."""
frames = self._sample_video_frames(url)
if not frames:
return 0.0, 0
scores = [self._score_frame(f) for f in frames]
return float(np.mean(scores)), len(scores)
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _score_frame(self, image: Image.Image) -> float:
target = self._crop_face(image)
results = self.classifier(target, top_k=2)
for r in results:
if any(kw in r["label"].lower() for kw in DEEPFAKE_LABEL_KEYWORDS):
return float(r["score"])
return 0.0
def _crop_face(self, image: Image.Image) -> Image.Image:
"""Detect faces with Haar cascade and return the largest crop; fall back to full image."""
try:
img_array = np.array(image)
gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
faces = self.face_cascade.detectMultiScale(
gray,
scaleFactor=1.1,
minNeighbors=5,
minSize=(MIN_FACE_SIZE, MIN_FACE_SIZE),
)
if len(faces) == 0:
return image
x, y, w, h = max(faces, key=lambda f: f[2] * f[3])
x1 = max(0, x - FACE_PADDING)
y1 = max(0, y - FACE_PADDING)
x2 = min(image.width, x + w + FACE_PADDING)
y2 = min(image.height, y + h + FACE_PADDING)
return image.crop((x1, y1, x2, y2))
except Exception:
logger.debug("Face detection failed; using full image", exc_info=True)
return image
def _fetch_image(self, url: str) -> Image.Image:
resp = requests.get(url, timeout=REQUEST_TIMEOUT)
resp.raise_for_status()
return Image.open(BytesIO(resp.content)).convert("RGB")
def _sample_video_frames(self, url: str) -> list:
resp = requests.get(url, timeout=REQUEST_TIMEOUT, stream=True)
resp.raise_for_status()
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
for chunk in resp.iter_content(chunk_size=65536):
tmp.write(chunk)
tmp_path = tmp.name
frames: list = []
try:
cap = cv2.VideoCapture(tmp_path)
total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
if total <= 0:
return frames
indices = np.linspace(0, total - 1, min(MAX_VIDEO_FRAMES, total), dtype=int)
for idx in indices:
cap.set(cv2.CAP_PROP_POS_FRAMES, int(idx))
ok, frame = cap.read()
if not ok:
continue
frames.append(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))
cap.release()
finally:
os.unlink(tmp_path)
return frames