VERIDEX.V1 / backend /models /deepfake_detector.py
shadow55gh
fix: real ELA heatmap, real face boxes, real DCT freq, fix HF URLs
c5ec583
"""
VERIDEX β€” Deepfake Detector (models/deepfake_detector.py)
===========================================================
Loads 6 ML models at startup:
1. MTCNN β€” face detection (facenet-pytorch)
2. FaceNet β€” face embeddings (facenet-pytorch / vggface2)
3. ViT β€” deepfake detection (prithivMLmods/Deep-Fake-Detector-v2-Model)
4. CLIP β€” AI-image scoring (openai/clip-vit-base-patch32)
5. SDXL Detector β€” Stable Diffusion check (Organika/sdxl-detector)
6. GAN Detector β€” GAN artifact check (saltacc/anime-ai-detect)
NOTE: transformers 4.47+ deprecates AutoFeatureExtractor.
We use AutoImageProcessor everywhere instead.
"""
import io
from pathlib import Path
import numpy as np
from PIL import Image
from loguru import logger
# ── Lazy-loaded globals ──────────────────────────────────────────
_mtcnn = None
_facenet = None
_vit_processor = None # AutoImageProcessor
_vit_model = None # AutoModelForImageClassification
_clip_processor = None # CLIPProcessor
_clip_model = None # CLIPModel
_sd_processor = None # AutoImageProcessor
_sd_model = None # AutoModelForImageClassification
_gan_processor = None # AutoImageProcessor
_gan_model = None # AutoModelForImageClassification
_efficientnet = None # Custom trained EfficientNet-B4
_effnet_cfg = {} # img_size, fake_label from meta.json
_device: str | None = None
def _get_device() -> str:
global _device
if _device is None:
import torch
_device = "cuda" if torch.cuda.is_available() else "cpu"
return _device
# ════════════════════════════════════════════════════════════════
# Startup loader β€” called from main.py lifespan()
# ════════════════════════════════════════════════════════════════
async def load_all_models(models_dir: Path) -> None:
"""
Pre-load all 6 models.
Each model failure is logged as a warning β€” server keeps running.
"""
global _mtcnn, _facenet
global _vit_processor, _vit_model
global _clip_processor, _clip_model
global _sd_processor, _sd_model
global _gan_processor, _gan_model
device = _get_device()
cache = str(models_dir)
# ── 1. MTCNN ─────────────────────────────────────────────
logger.info("1/6 MTCNN Face Detector...")
try:
from facenet_pytorch import MTCNN
_mtcnn = MTCNN(keep_all=True, device=device)
logger.success(" βœ… MTCNN ready")
except Exception as e:
logger.warning(f" ⚠ MTCNN unavailable: {e}")
# ── 2. FaceNet ───────────────────────────────────────────
logger.info("2/6 FaceNet InceptionResnetV1...")
try:
from facenet_pytorch import InceptionResnetV1
_facenet = InceptionResnetV1(pretrained="vggface2").eval()
logger.success(" βœ… FaceNet ready")
except Exception as e:
logger.warning(f" ⚠ FaceNet unavailable: {e}")
# ── 3. ViT Deepfake Detector ─────────────────────────────
logger.info("3/6 ViT Deepfake Detector (Hugging Face)...")
try:
from transformers import AutoImageProcessor, AutoModelForImageClassification
_vit_processor = AutoImageProcessor.from_pretrained(
"prithivMLmods/Deep-Fake-Detector-v2-Model",
cache_dir=cache,
)
_vit_model = AutoModelForImageClassification.from_pretrained(
"prithivMLmods/Deep-Fake-Detector-v2-Model",
cache_dir=cache,
)
logger.success(" βœ… ViT Deepfake Detector ready")
except Exception as e:
logger.warning(f" ⚠ ViT model unavailable: {e}")
# ── 4. CLIP ──────────────────────────────────────────────
logger.info("4/6 CLIP Model (OpenAI)...")
try:
from transformers import CLIPProcessor, CLIPModel
_clip_processor = CLIPProcessor.from_pretrained(
"openai/clip-vit-base-patch32",
cache_dir=cache,
)
_clip_model = CLIPModel.from_pretrained(
"openai/clip-vit-base-patch32",
cache_dir=cache,
)
logger.success(" βœ… CLIP ready")
except Exception as e:
logger.warning(f" ⚠ CLIP unavailable: {e}")
# ── 5. Stable Diffusion Detector ─────────────────────────
logger.info("5/6 AI Image Detector (SDXL)...")
try:
from transformers import AutoImageProcessor, AutoModelForImageClassification
_sd_processor = AutoImageProcessor.from_pretrained(
"Organika/sdxl-detector",
cache_dir=cache,
)
_sd_model = AutoModelForImageClassification.from_pretrained(
"Organika/sdxl-detector",
cache_dir=cache,
)
logger.success(" βœ… SD Detector ready")
except Exception as e:
logger.warning(f" ⚠ SD Detector unavailable: {e}")
# ── 6. GAN Detector ──────────────────────────────────────
logger.info("6/6 GAN Artifact Detector...")
try:
from transformers import AutoImageProcessor, AutoModelForImageClassification
_gan_processor = AutoImageProcessor.from_pretrained(
"saltacc/anime-ai-detect",
cache_dir=cache,
)
_gan_model = AutoModelForImageClassification.from_pretrained(
"saltacc/anime-ai-detect",
cache_dir=cache,
)
logger.success(" βœ… GAN Detector ready")
except Exception as e:
logger.warning(f" ⚠ GAN Detector unavailable: {e}")
# ── 7. Custom EfficientNet-B4 (trained weights) ───────────
logger.info("7/7 Custom EfficientNet-B4 (local weights)...")
try:
import json, torch, timm
import torch.nn as nn
weights_path = Path("weights/efficientnet_deepfake.pth")
meta_path = Path("weights/efficientnet_b4_meta.json")
if weights_path.exists():
meta = {}
if meta_path.exists():
with open(meta_path) as f:
meta = json.load(f)
img_size = meta.get("img_size", 256)
fake_label = meta.get("fake_label", 0)
model = timm.create_model("efficientnet_b4", pretrained=False)
model.classifier = nn.Sequential(
nn.Dropout(0.4),
nn.Linear(model.num_features, 512),
nn.GELU(),
nn.BatchNorm1d(512),
nn.Dropout(0.3),
nn.Linear(512, 2),
)
state = torch.load(weights_path, map_location=device)
model.load_state_dict(state)
model.eval()
model = model.to(device)
_efficientnet = model
_effnet_cfg = {"img_size": img_size, "fake_label": fake_label}
logger.success(f" βœ… EfficientNet-B4 ready (img={img_size}, fake_label={fake_label})")
else:
logger.warning(" ⚠ weights/efficientnet_deepfake.pth not found β€” skipping")
except Exception as e:
logger.warning(f" ⚠ EfficientNet unavailable: {e}")
# ════════════════════════════════════════════════════════════════
# Internal helpers
# ════════════════════════════════════════════════════════════════
def _heuristic_fake_score(arr: np.ndarray) -> float:
"""Lightweight signal-processing fallback when ML models are absent."""
try:
from scipy.ndimage import laplace
lap = laplace(arr.mean(axis=2))
noise = float(np.std(lap) / (np.mean(np.abs(lap)) + 1e-6))
ch_std = float(np.std([arr[:, :, c].mean() for c in range(3)]))
return float(np.clip((noise / 20.0 + ch_std / 30.0) / 2.0, 0.0, 1.0))
except Exception:
return 0.5
def _run_classifier(
processor,
model,
img: Image.Image,
fake_keywords: tuple = ("fake", "ai", "artificial", "generated"),
) -> float:
"""
Generic HuggingFace image classifier β†’ fake probability.
Works with AutoImageProcessor (transformers 4.47+).
"""
try:
import torch
inputs = processor(images=img, return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits
probs = torch.softmax(logits, dim=-1)[0]
id2label = model.config.id2label
# Find the "fake" class by label name
for idx, label in id2label.items():
if any(kw in label.lower() for kw in fake_keywords):
return float(probs[idx])
# Fallback: class index 1
return float(probs[1]) if len(probs) > 1 else float(probs[0])
except Exception as e:
logger.debug(f"[Classifier] inference failed: {e}")
return 0.5
def _clip_ai_score(img: Image.Image) -> float:
"""
CLIP zero-shot: probability that this is an AI-generated image.
Prompts: 'a real photograph' vs 'an AI generated image'
"""
if _clip_processor is None or _clip_model is None:
return 0.5
try:
import torch
prompts = ["a real photograph", "an AI generated image"]
inputs = _clip_processor(
text=prompts, images=img,
return_tensors="pt", padding=True,
)
with torch.no_grad():
out = _clip_model(**inputs)
probs = torch.softmax(out.logits_per_image, dim=1)[0]
return float(probs[1]) # probability of "AI generated"
except Exception as e:
logger.debug(f"[CLIP] inference failed: {e}")
return 0.5
# ════════════════════════════════════════════════════════════════
# Public API
# ════════════════════════════════════════════════════════════════
async def analyze_image(content: bytes) -> dict:
"""
Run all available ML models on an image.
Returns dict with keys:
fake_prob float β€” ViT verdict (0 = real, 1 = fake)
authentic_prob float
clip_score float β€” CLIP AI score (1 = AI-generated)
sd_score float β€” SDXL detector (1 = Stable Diffusion)
gan_score float β€” GAN detector (1 = GAN-generated)
face_count int β€” faces detected by MTCNN
method str β€” which models were used
"""
try:
img = Image.open(io.BytesIO(content)).convert("RGB")
except Exception as e:
logger.error(f"[DeepfakeDetector] Cannot open image: {e}")
return {
"fake_prob": 0.5, "authentic_prob": 0.5,
"clip_score": 0.5, "sd_score": 0.5, "gan_score": 0.5,
"face_count": 0, "method": "error",
}
arr = np.array(img)
methods = []
# ── ViT deepfake score ───────────────────────────────────
if _vit_model is not None and _vit_processor is not None:
fake_prob = _run_classifier(
_vit_processor, _vit_model, img,
fake_keywords=("fake", "deepfake", "manipulated"),
)
methods.append("vit")
else:
fake_prob = _heuristic_fake_score(arr)
methods.append("heuristic")
# ── CLIP AI score ────────────────────────────────────────
clip_score = _clip_ai_score(img)
if _clip_model is not None:
methods.append("clip")
# ── SD detector score ────────────────────────────────────
if _sd_model is not None and _sd_processor is not None:
sd_score = _run_classifier(
_sd_processor, _sd_model, img,
fake_keywords=("artificial", "fake", "generated", "ai"),
)
methods.append("sdxl")
else:
sd_score = 0.5
# ── GAN detector score ───────────────────────────────────
if _gan_model is not None and _gan_processor is not None:
gan_score = _run_classifier(
_gan_processor, _gan_model, img,
fake_keywords=("ai", "artificial", "generated", "anime_ai"),
)
methods.append("gan")
else:
gan_score = 0.5
# ── EfficientNet-B4 (custom trained) ─────────────────────
effnet_score = 0.5
if _efficientnet is not None:
try:
import torch
import torchvision.transforms as T
sz = _effnet_cfg.get("img_size", 256)
fake_label = _effnet_cfg.get("fake_label", 0)
tf = T.Compose([
T.Resize((sz, sz)),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
tensor = tf(img).unsqueeze(0).to(_get_device())
with torch.no_grad():
logits = _efficientnet(tensor)
probs = torch.softmax(logits, dim=-1)[0]
effnet_score = float(probs[fake_label])
methods.append("effnet")
except Exception as e:
logger.debug(f"[EfficientNet] inference failed: {e}")
# ── Combine all scores (weighted average) ────────────────
scores_list = [fake_prob, effnet_score]
weights_list = [0.35, 0.45] # effnet gets more weight (trained on our data)
if clip_score != 0.5:
scores_list.append(clip_score); weights_list.append(0.10)
if sd_score != 0.5:
scores_list.append(sd_score); weights_list.append(0.05)
if gan_score != 0.5:
scores_list.append(gan_score); weights_list.append(0.05)
total_w = sum(weights_list)
fake_prob = sum(s * w for s, w in zip(scores_list, weights_list)) / total_w
# ── Face detection via MTCNN β€” return boxes + count ─────
face_count = 0
face_boxes = [] # list of [x, y, w, h] in pixel coords
if _mtcnn is not None:
try:
boxes, probs = _mtcnn.detect(img)
if boxes is not None:
face_count = len(boxes)
for box in boxes:
x1, y1, x2, y2 = [float(v) for v in box]
face_boxes.append([x1, y1, x2 - x1, y2 - y1]) # [x, y, w, h]
except Exception as e:
logger.debug(f"[MTCNN] detection failed: {e}")
# ── Real ELA (Error Level Analysis) heatmap data ─────────
ela_data = _compute_ela(arr)
# ── Frequency domain anomaly score ───────────────────────
freq_data = _compute_freq_anomaly(arr)
logger.debug(
f"[DeepfakeDetector] fake={fake_prob:.3f} effnet={effnet_score:.3f} "
f"clip={clip_score:.3f} sd={sd_score:.3f} gan={gan_score:.3f} "
f"faces={face_count} method={'+'.join(methods)}"
)
return {
"fake_prob": round(fake_prob, 4),
"authentic_prob": round(1.0 - fake_prob, 4),
"clip_score": round(clip_score, 4),
"sd_score": round(sd_score, 4),
"gan_score": round(gan_score, 4),
"effnet_score": round(effnet_score, 4),
"face_count": face_count,
"face_boxes": face_boxes, # [[x,y,w,h], ...] pixel coords
"ela_data": ela_data, # base64 PNG of ELA map
"freq_data": freq_data, # {bins: [...], magnitudes: [...]}
"method": "+".join(methods),
}
def _compute_ela(arr: np.ndarray) -> str:
"""
Real Error Level Analysis β€” compresses image at quality 75,
subtracts from original, amplifies difference, returns base64 PNG.
ELA highlights regions that have been re-compressed (edited/generated).
"""
try:
import io, base64
from PIL import Image
img_orig = Image.fromarray(arr)
# Save at reduced quality
buf = io.BytesIO()
img_orig.save(buf, format='JPEG', quality=75)
buf.seek(0)
img_compressed = Image.open(buf).convert('RGB')
orig_arr = np.array(img_orig, dtype=np.float32)
comp_arr = np.array(img_compressed, dtype=np.float32)
# Difference amplified
diff = np.abs(orig_arr - comp_arr)
diff_scaled = np.clip(diff * 15, 0, 255).astype(np.uint8)
ela_img = Image.fromarray(diff_scaled)
# Resize to reasonable size for frontend
ela_img = ela_img.resize((320, 240), Image.LANCZOS)
out = io.BytesIO()
ela_img.save(out, format='PNG')
return base64.b64encode(out.getvalue()).decode('utf-8')
except Exception as e:
logger.debug(f"[ELA] computation failed: {e}")
return ""
def _compute_freq_anomaly(arr: np.ndarray) -> dict:
"""
DCT frequency domain analysis.
Returns histogram of frequency magnitudes β€” AI images have characteristic
patterns in high-frequency bands.
"""
try:
from scipy.fftpack import dct
gray = arr.mean(axis=2).astype(np.float32)
# Sample a center patch
h, w = gray.shape
patch_size = min(256, h, w)
py, px = (h - patch_size) // 2, (w - patch_size) // 2
patch = gray[py:py+patch_size, px:px+patch_size]
# 2D DCT
dct_2d = dct(dct(patch, axis=0, norm='ortho'), axis=1, norm='ortho')
mag = np.abs(dct_2d).flatten()
# Log-scale histogram with 32 bins
mag_log = np.log1p(mag)
hist, _ = np.histogram(mag_log, bins=32)
hist_norm = (hist / (hist.max() + 1e-6)).tolist()
return {"bins": list(range(32)), "magnitudes": hist_norm}
except Exception as e:
logger.debug(f"[FreqAnalysis] computation failed: {e}")
return {"bins": list(range(32)), "magnitudes": [0.5] * 32}