deepfakeguard / scripts /benchmark_image.py
noorullah1102's picture
Initial commit
a869ab1
"""Baseline benchmark for image deepfake detector against CIFAKE dataset.
Downloads a sample from CIFAKE (real CIFAR-10 vs Stable Diffusion fakes),
runs the full detection pipeline, and reports standard ML metrics.
Usage:
source venv/bin/activate
python scripts/benchmark_image.py [--sample N]
"""
import argparse
import io
import json
import sys
import time
from pathlib import Path
import numpy as np
from PIL import Image
from sklearn.metrics import (
accuracy_score,
classification_report,
confusion_matrix,
f1_score,
precision_score,
recall_score,
)
# Add project root to path so we can import app modules
ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(ROOT))
from app.services.image_detector import ImageDetector
def download_dataset(sample_per_class: int = 100) -> list[tuple[bytes, int, str]]:
"""Download deepfake test dataset and return (image_bytes, label, source) tuples.
Dataset: itsLeen/deepfake_vs_real_image_detection (512x512+ images)
Labels: 0 = real, 1 = fake.
"""
from datasets import load_dataset
print("Downloading deepfake dataset (itsLeen/deepfake_vs_real_image_detection)...")
dataset = load_dataset(
"itsLeen/deepfake_vs_real_image_detection", split="test", streaming=True
).shuffle(seed=42)
real_samples = []
fake_samples = []
for row in dataset:
if row["label"] == 0 and len(real_samples) < sample_per_class:
real_samples.append(row)
elif row["label"] == 1 and len(fake_samples) < sample_per_class:
fake_samples.append(row)
if len(real_samples) >= sample_per_class and len(fake_samples) >= sample_per_class:
break
print(f" Real images: {len(real_samples)}, Fake images: {len(fake_samples)}")
samples = []
for i, row in enumerate(real_samples):
img = row["image"].convert("RGB")
buf = io.BytesIO()
img.save(buf, format="JPEG")
samples.append((buf.getvalue(), 0, f"real_{i:04d}"))
for i, row in enumerate(fake_samples):
img = row["image"].convert("RGB")
buf = io.BytesIO()
img.save(buf, format="JPEG")
samples.append((buf.getvalue(), 1, f"fake_{i:04d}"))
return samples
def run_individual_model(detector: ImageDetector, image_bytes: bytes) -> dict | None:
"""Run each sub-model individually to get per-model predictions."""
import torch
results = {}
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
# CommunityForensics ViT
if detector._model_loaded:
try:
import torchvision.transforms as T
transform = T.Compose([
T.Resize((384, 384), interpolation=T.InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=[0.4815, 0.4578, 0.4082], std=[0.2686, 0.2613, 0.2758]),
])
pixel_values = transform(image).unsqueeze(0)
with torch.no_grad():
logits = detector.model(pixel_values=pixel_values).logits
probs = torch.softmax(logits, dim=-1)
results["communityforensics"] = {
"fake_prob": probs[0][1].item(),
"verdict": "manipulated" if probs[0][1].item() > 0.5 else "real",
}
except Exception:
results["communityforensics"] = None
# prithivMLmods SigLIP
if detector._ensemble_loaded:
try:
inputs = detector.processor_ensemble(images=image, return_tensors="pt")
with torch.no_grad():
logits = detector.model_ensemble(**inputs).logits
probs = torch.softmax(logits, dim=-1)
results["prithivmlmods"] = {
"fake_prob": probs[0][0].item(),
"verdict": "manipulated" if probs[0][0].item() > 0.5 else "real",
}
except Exception:
results["prithivmlmods"] = None
return results
def run_benchmark(sample_per_class: int = 100):
"""Run the full benchmark and print results."""
print("=" * 60)
print("DeepFakeGuard — Image Detector Baseline Benchmark")
print("=" * 60)
print()
# Download dataset
samples = download_dataset(sample_per_class)
print()
# Load detector
print("Loading image detection models...")
detector = ImageDetector()
detector.load_model()
if not detector.is_loaded:
print("ERROR: No models loaded. Cannot run benchmark.")
sys.exit(1)
print(f" Primary (CommunityForensics): {'loaded' if detector._model_loaded else 'FAILED'}")
print(f" Ensemble (prithivMLmods): {'loaded' if detector._ensemble_loaded else 'FAILED'}")
print()
# Run detection
print(f"Running detection on {len(samples)} images...")
y_true = []
y_pred = []
y_pred_cf = [] # CommunityForensics predictions
y_pred_pm = [] # prithivMLmods predictions
confidences = []
details = []
start = time.time()
for i, (img_bytes, label, source) in enumerate(samples):
result = detector.detect(img_bytes, filename=f"{source}.jpg")
pred_label = 1 if result["verdict"] == "manipulated" else 0
y_true.append(label)
y_pred.append(pred_label)
confidences.append(result["confidence"])
# Per-model breakdown
per_model = run_individual_model(detector, img_bytes)
if per_model.get("communityforensics"):
y_pred_cf.append(1 if per_model["communityforensics"]["verdict"] == "manipulated" else 0)
if per_model.get("prithivmlmods"):
y_pred_pm.append(1 if per_model["prithivmlmods"]["verdict"] == "manipulated" else 0)
details.append({
"source": source,
"label": "real" if label == 0 else "fake",
"predicted": result["verdict"],
"confidence": result["confidence"],
"severity": result["severity"],
"correct": pred_label == label,
"communityforensics": per_model.get("communityforensics"),
"prithivmlmods": per_model.get("prithivmlmods"),
})
if (i + 1) % 20 == 0:
print(f" Processed {i + 1}/{len(samples)}")
elapsed = time.time() - start
print(f" Done in {elapsed:.1f}s ({elapsed / len(samples):.2f}s/image)")
print()
# Compute metrics
y_true = np.array(y_true)
y_pred = np.array(y_pred)
# --- Full Pipeline ---
print("=" * 60)
print("FULL PIPELINE (ML Ensemble + Rule-Based)")
print("=" * 60)
acc = accuracy_score(y_true, y_pred)
prec = precision_score(y_true, y_pred, pos_label=1, zero_division=0)
rec = recall_score(y_true, y_pred, pos_label=1, zero_division=0)
f1 = f1_score(y_true, y_pred, pos_label=1, zero_division=0)
cm = confusion_matrix(y_true, y_pred)
print(f" Accuracy: {acc:.3f}")
print(f" Precision: {prec:.3f} (of predicted fake, how many actually fake)")
print(f" Recall: {rec:.3f} (of actual fake, how many caught)")
print(f" F1 Score: {f1:.3f}")
print()
print(f" Confusion Matrix:")
print(f" Predicted")
print(f" Real Fake")
print(f" Actual Real {cm[0][0]:>4} {cm[0][1]:>4}")
print(f" Actual Fake {cm[1][0]:>4} {cm[1][1]:>4}")
print()
real_mask = y_true == 0
fake_mask = y_true == 1
print(f" Real images correct: {np.sum((y_pred == 0) & real_mask)}/{np.sum(real_mask)}")
print(f" Fake images correct: {np.sum((y_pred == 1) & fake_mask)}/{np.sum(fake_mask)}")
print(f" Avg confidence: {np.mean(confidences):.3f}")
print()
# --- Per-Model Breakdown ---
if y_pred_cf:
print("=" * 60)
print("CommunityForensics ViT (Primary)")
print("=" * 60)
y_pred_cf = np.array(y_pred_cf)
acc_cf = accuracy_score(y_true, y_pred_cf)
f1_cf = f1_score(y_true, y_pred_cf, pos_label=1, zero_division=0)
print(f" Accuracy: {acc_cf:.3f} F1: {f1_cf:.3f}")
cm_cf = confusion_matrix(y_true, y_pred_cf)
print(f" Confusion Matrix:")
print(f" Predicted")
print(f" Real Fake")
print(f" Actual Real {cm_cf[0][0]:>4} {cm_cf[0][1]:>4}")
print(f" Actual Fake {cm_cf[1][0]:>4} {cm_cf[1][1]:>4}")
print()
if y_pred_pm:
print("=" * 60)
print("prithivMLmods SigLIP (Ensemble)")
print("=" * 60)
y_pred_pm = np.array(y_pred_pm)
acc_pm = accuracy_score(y_true, y_pred_pm)
f1_pm = f1_score(y_true, y_pred_pm, pos_label=1, zero_division=0)
print(f" Accuracy: {acc_pm:.3f} F1: {f1_pm:.3f}")
cm_pm = confusion_matrix(y_true, y_pred_pm)
print(f" Confusion Matrix:")
print(f" Predicted")
print(f" Real Fake")
print(f" Actual Real {cm_pm[0][0]:>4} {cm_pm[0][1]:>4}")
print(f" Actual Fake {cm_pm[1][0]:>4} {cm_pm[1][1]:>4}")
print()
# --- Summary comparison ---
print("=" * 60)
print("COMPARISON")
print("=" * 60)
print(f" {'Model':<35} {'Accuracy':>10} {'F1':>8}")
print(f" {'-'*35} {'-'*10} {'-'*8}")
print(f" {'Full Pipeline (Ensemble + Rules)':<35} {acc:>10.3f} {f1:>8.3f}")
if y_pred_cf is not None and len(y_pred_cf):
print(f" {'CommunityForensics ViT':<35} {acc_cf:>10.3f} {f1_cf:>8.3f}")
if y_pred_pm is not None and len(y_pred_pm):
print(f" {'prithivMLmods SigLIP':<35} {acc_pm:>10.3f} {f1_pm:>8.3f}")
print()
# Save results
results = {
"dataset": "itsLeen/deepfake_vs_real_image_detection",
"samples": {"real": int(np.sum(real_mask)), "fake": int(np.sum(fake_mask))},
"full_pipeline": {
"accuracy": round(float(acc), 4),
"precision": round(float(prec), 4),
"recall": round(float(rec), 4),
"f1": round(float(f1), 4),
"confusion_matrix": cm.tolist(),
"avg_confidence": round(float(np.mean(confidences)), 4),
},
"communityforensics_vit": {
"accuracy": round(float(acc_cf), 4),
"f1": round(float(f1_cf), 4),
} if y_pred_cf is not None and len(y_pred_cf) else None,
"prithivmlmods_siglip": {
"accuracy": round(float(acc_pm), 4),
"f1": round(float(f1_pm), 4),
} if y_pred_pm is not None and len(y_pred_pm) else None,
"per_image_details": details,
}
output_path = ROOT / "scripts" / "benchmark_results.json"
with open(output_path, "w") as f:
json.dump(results, f, indent=2)
print(f"Results saved to {output_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Benchmark image deepfake detector")
parser.add_argument(
"--sample",
type=int,
default=100,
help="Number of images per class (default: 100)",
)
args = parser.parse_args()
run_benchmark(sample_per_class=args.sample)