proofly / image_authenticity.py
Pragthedon's picture
Initial backend API deployment
4f48a4e
"""
Image Authenticity Detector — Single-File Standalone POC
Combines ViT, CLIP, CNN, and Frequency analysis into one script.
"""
from __future__ import annotations
import os
import sys
import io
import argparse
import numpy as np
import requests
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from PIL import Image
from typing import Dict, Optional
from scipy.fft import fft2, fftshift, dct
# =============================================================================
# CONFIGURATION
# =============================================================================
class Config:
def __init__(self):
if torch.cuda.is_available():
self.DEVICE = torch.device("cuda")
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
self.DEVICE = torch.device("mps")
else:
self.DEVICE = torch.device("cpu")
try:
self.BASE_DIR = os.path.dirname(os.path.abspath(__file__))
except NameError:
# Fallback for interactive shells / Jupyter notebooks
self.BASE_DIR = os.path.abspath(os.getcwd())
self.CACHE_DIR = os.path.join(self.BASE_DIR, ".cache")
os.makedirs(self.CACHE_DIR, exist_ok=True)
self.CNN_WEIGHTS_PATH = os.path.join(self.BASE_DIR, "weights", "cnn_detector.pth")
self.CLIP_MODEL_NAME = "ViT-L/14"
self.CLIP_TEMPERATURE = 40.0
self.CLIP_MODEL_FALLBACK = "ViT-B/16"
self.CLIP_REAL_PROMPTS = [
"a real photograph taken by a camera with natural sensor noise and grain",
"an authentic candid photo with natural lighting, depth of field and bokeh",
"a genuine photo with natural lens distortion, chromatic aberration and shadows",
"a real photograph with authentic skin texture, pores and natural imperfections",
"a documentary-style photo of real people or objects in real environments",
"a photo taken with a smartphone or DSLR camera showing real-world scenes",
"a real image with natural color grading, occasional overexposure and noise",
]
self.CLIP_FAKE_PROMPTS = [
"an AI-generated image produced by a neural network or diffusion model",
"a synthetic image created by Stable Diffusion, Midjourney or DALL-E",
"a GAN-generated photorealistic face that does not exist in reality",
"a deepfake with artificially smooth skin and unnatural eye reflections",
"a computer-generated image with unnaturally perfect symmetry and sharp edges",
"an AI artwork with artificial textures, glitchy backgrounds and fantasy lighting",
"a digitally manipulated or composited image with inconsistent lighting",
]
self.HF_PRIMARY_MODEL = "dima806/ai_vs_real_image_detection"
self.HF_SECONDARY_MODEL = "prithivMLmods/Deep-Fake-Detector-v2-Model"
self.HF_CACHE_DIR = self.CACHE_DIR
self.CNN_BACKBONE = "efficientnet_b4"
self.CNN_IMAGE_SIZE = 224
self.CNN_PRETRAINED = True
self.FREQ_IMAGE_SIZE = 256
self.FREQ_DCT_PATCH_SIZE = 8
self.FREQ_HIGH_FREQ_THRESH = 0.4
self.ENSEMBLE_WEIGHTS = {
"hf_primary": 0.35, "hf_secondary": 0.25,
"clip": 0.20, "frequency": 0.15, "cnn": 0.05,
}
self.CONFIDENCE_WEIGHTING = True
self.FAKE_THRESHOLD = 0.46
config = Config()
# =============================================================================
# CLIP DETECTOR
# =============================================================================
class CLIPDetector:
def __init__(self, model_name: str = None, device: torch.device = None):
self.device = device or config.DEVICE
self.model_name = model_name or config.CLIP_MODEL_NAME
self.model = None
self.preprocess = None
self._loaded = False
self._failed = False
def _load(self):
if self._loaded or self._failed: return
try:
import clip
print(f"[CLIP] Loading model '{self.model_name}' on {self.device}...")
try:
self.model, self.preprocess = clip.load(self.model_name, device=self.device, download_root=config.CACHE_DIR)
except RuntimeError as oom:
fallback = getattr(config, "CLIP_MODEL_FALLBACK", "ViT-B/16")
print(f"[CLIP] {oom} — falling back to {fallback}")
self.model_name = fallback
self.model, self.preprocess = clip.load(fallback, device=self.device, download_root=config.CACHE_DIR)
self.model.eval()
real_tokens = clip.tokenize(config.CLIP_REAL_PROMPTS).to(self.device)
fake_tokens = clip.tokenize(config.CLIP_FAKE_PROMPTS).to(self.device)
with torch.no_grad():
self.real_text_features = F.normalize(self.model.encode_text(real_tokens), dim=-1)
self.fake_text_features = F.normalize(self.model.encode_text(fake_tokens), dim=-1)
self._loaded = True
print(f"[CLIP] Model '{self.model_name}' loaded successfully.")
except Exception as e:
print(f"[CLIP] WARNING: Failed to load CLIP: {e}")
self._failed = True
def predict(self, image: Image.Image) -> Dict[str, float]:
self._load()
if self._failed or self.model is None:
return {"fake_prob": 0.5, "real_prob": 0.5, "clip_real_score": 0.0, "clip_fake_score": 0.0}
if image.mode != "RGB": image = image.convert("RGB")
img_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
with torch.no_grad():
image_features = F.normalize(self.model.encode_image(img_tensor), dim=-1)
real_scores = (image_features @ self.real_text_features.T)[0]
fake_scores = (image_features @ self.fake_text_features.T)[0]
real_sim, fake_sim = real_scores.mean().item(), fake_scores.mean().item()
real_max, fake_max = real_scores.max().item(), fake_scores.max().item()
real_agg = 0.70 * real_sim + 0.30 * real_max
fake_agg = 0.70 * fake_sim + 0.30 * fake_max
temperature = config.CLIP_TEMPERATURE
logits = torch.tensor([real_agg, fake_agg]) * temperature
probs = torch.softmax(logits, dim=0).numpy()
return {"fake_prob": float(probs[1]), "real_prob": float(probs[0]), "clip_real_score": float(real_agg), "clip_fake_score": float(fake_agg)}
# =============================================================================
# CNN DETECTOR
# =============================================================================
class _EfficientNetWithHead(nn.Module):
def __init__(self, backbone: nn.Module, head: nn.Sequential):
super().__init__()
self.backbone = backbone
self.head = head
def forward(self, x: torch.Tensor) -> torch.Tensor:
features = self.backbone.forward_features(x)
return self.head(features)
class CNNDetector:
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]
def __init__(self, weights_path: str = None, device: torch.device = None):
self.device = device or config.DEVICE
self.weights_path = weights_path or config.CNN_WEIGHTS_PATH
self.model = None
self._loaded = False
def _build_model(self):
try: import timm
except ImportError: raise ImportError("timm is not installed. Run: pip install timm")
model = timm.create_model(config.CNN_BACKBONE, pretrained=config.CNN_PRETRAINED, num_classes=0)
head = nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(model.num_features, 256), nn.ReLU(inplace=True), nn.Dropout(0.3), nn.Linear(256, 2))
return _EfficientNetWithHead(model, head)
def _load(self):
if self._loaded: return
print(f"[CNN] Building {config.CNN_BACKBONE} backbone on {self.device}...")
self.model = self._build_model().to(self.device)
self.model.eval()
if os.path.isfile(self.weights_path):
print(f"[CNN] Loading fine-tuned weights from {self.weights_path}")
self.model.load_state_dict(torch.load(self.weights_path, map_location=self.device), strict=False)
else:
print(f"[CNN] No fine-tuned weights found at {self.weights_path}.\n[CNN] Using ImageNet pre-trained backbone + untrained head.")
self._loaded = True
print("[CNN] Model ready.")
def _preprocess(self, image: Image.Image) -> torch.Tensor:
if image.mode != "RGB": image = image.convert("RGB")
transform = T.Compose([T.Resize((config.CNN_IMAGE_SIZE, config.CNN_IMAGE_SIZE)), T.ToTensor(), T.Normalize(mean=self.MEAN, std=self.STD)])
return transform(image).unsqueeze(0).to(self.device)
def predict(self, image: Image.Image) -> Dict[str, float]:
self._load()
with torch.no_grad():
logits = self.model(self._preprocess(image))
probs = F.softmax(logits, dim=1)[0]
return {"fake_prob": float(probs[1].item()), "real_prob": float(probs[0].item())}
# =============================================================================
# FREQUENCY DETECTOR (Math/Forensics - Non-AI)
# =============================================================================
class FrequencyDetector:
def __init__(self):
self.image_size = config.FREQ_IMAGE_SIZE
self.patch_size = config.FREQ_DCT_PATCH_SIZE
def _preprocess(self, image: Image.Image) -> np.ndarray:
if image.mode != "RGB": image = image.convert("RGB")
img = image.resize((self.image_size, self.image_size), Image.LANCZOS)
return np.array(img.convert("L"), dtype=np.float32)
def _fft_analysis(self, gray: np.ndarray) -> Dict:
f = fftshift(fft2(gray))
power = np.abs(f) ** 2
h, w = power.shape
cy, cx = h // 2, w // 2
Y, X = np.ogrid[:h, :w]
R = np.sqrt((X - cx)**2 + (Y - cy)**2)
max_r = np.sqrt(cx**2 + cy**2)
radial_bins = 48
bin_edges = np.linspace(1.0, max_r, radial_bins + 1)
psd_raw, freq_vals = [], []
for i in range(radial_bins):
mask = (R >= bin_edges[i]) & (R < bin_edges[i+1])
if mask.sum() > 0:
psd_raw.append(power[mask].mean())
freq_vals.append((bin_edges[i] + bin_edges[i+1]) / 2.0 / max_r)
psd_raw = np.array(psd_raw, dtype=np.float64)
freq_vals = np.array(freq_vals, dtype=np.float64)
log_freq = np.log(freq_vals + 1e-12)
log_power = np.log(psd_raw + 1e-12)
alpha = float(-np.polyfit(log_freq, log_power, 1)[0]) if len(log_freq) > 3 and log_power.std() > 0 else 1.8
mid_point = len(psd_raw) // 2
hf_ratio = float(psd_raw[mid_point:].sum() / (psd_raw[:mid_point].sum() + psd_raw[mid_point:].sum() + 1e-12))
log_mag_psd = np.log1p(np.sqrt(psd_raw))
fitted = np.poly1d(np.polyfit(np.arange(len(log_mag_psd)), log_mag_psd, 2))(np.arange(len(log_mag_psd)))
periodic_score = float(np.abs(log_mag_psd - fitted).std() / (log_mag_psd.mean() + 1e-8))
return {"spectral_alpha": alpha, "hf_ratio": hf_ratio, "periodic_score": periodic_score}
def _dct_analysis(self, gray: np.ndarray) -> Dict:
h, w = gray.shape
p = self.patch_size
gray = gray[:h - h % p, :w - w % p]
ac_energies, dc_ac_ratios, cross_block_vars = [], [], []
for i in range(0, gray.shape[0], p):
for j in range(0, gray.shape[1], p):
dct_block = dct(dct(gray[i:i+p, j:j+p], axis=0, norm='ortho'), axis=1, norm='ortho')
dc, ac = float(dct_block[0, 0]**2), float((dct_block**2).sum() - dct_block[0, 0]**2)
ac_energies.append(ac)
dc_ac_ratios.append(dc / (ac + 1e-8))
cross_block_vars.append(dct_block[1:, 1:].std())
return {"dct_ac_mean": float(np.mean(ac_energies)), "dct_ac_std": float(np.std(ac_energies))}
def _ela_analysis(self, image: Image.Image) -> Dict:
if image.mode != "RGB": image = image.convert("RGB")
buf = io.BytesIO()
image.save(buf, format="JPEG", quality=95)
buf.seek(0)
comp = np.array(Image.open(buf).convert("RGB"), dtype=np.float32)
orig = np.array(image, dtype=np.float32)
ela_map = np.abs(orig - comp)
ela_mean = float(ela_map.mean())
return {"ela_mean": ela_mean, "ela_std": float(ela_map.std()), "ela_weight_scale": 0.25 if ela_mean < 1.5 else 1.0}
def _texture_analysis(self, gray: np.ndarray) -> Dict:
gray_u8 = np.clip(gray, 0, 255).astype(np.uint8)
tile = 32
tile_vars = [float(cv2.Laplacian(gray_u8[i:i+tile, j:j+tile], cv2.CV_64F).var())
for i in range(0, gray_u8.shape[0] - tile + 1, tile)
for j in range(0, gray_u8.shape[1] - tile + 1, tile)]
tile_vars = np.array(tile_vars, dtype=np.float64)
local_mean = float(tile_vars.mean()) if len(tile_vars) > 0 else 1.0
return {"texture_local_cv": float(tile_vars.std() / (local_mean + 1e-8))}
def _benford_analysis(self, gray: np.ndarray) -> Dict:
h, w = gray.shape
p = self.patch_size
gray_c = gray[:h - h % p, :w - w % p]
ac_coeffs = []
for i in range(0, gray_c.shape[0], p):
for j in range(0, gray_c.shape[1], p):
dct_block = dct(dct(gray_c[i:i+p, j:j+p], axis=0, norm='ortho'), axis=1, norm='ortho')
ac = dct_block.flatten()[1:]
ac_coeffs.extend(np.abs(ac[ac > 1.0]).tolist())
if len(ac_coeffs) < 100: return {"benford_mse": 0.0}
coeffs = np.array(ac_coeffs)
first_digits = np.floor(coeffs / 10.0 ** np.floor(np.log10(coeffs + 1e-12))).astype(int)
first_digits = first_digits[(first_digits >= 1) & (first_digits <= 9)]
if len(first_digits) < 50: return {"benford_mse": 0.0}
observed = np.bincount(first_digits, minlength=10)[1:10] / len(first_digits)
ideal = np.array([np.log10(1 + 1/d) for d in range(1, 10)])
return {"benford_mse": float(np.mean((observed - ideal) ** 2))}
def _noise_analysis(self, gray: np.ndarray) -> Dict:
laplacian = cv2.Laplacian(np.clip(gray, 0, 255).astype(np.uint8), cv2.CV_64F)
flat = laplacian.flatten()
kurt = float(np.mean(((flat - flat.mean()) / flat.std())**4)) if flat.std() > 0 else 3.0
return {"noise_kurtosis": kurt}
def predict(self, image: Image.Image) -> Dict:
gray = self._preprocess(image)
fft_s = self._fft_analysis(gray)
dct_s = self._dct_analysis(gray)
noise_s = self._noise_analysis(gray)
ela_s = self._ela_analysis(image)
texture_s = self._texture_analysis(gray)
ben_s = self._benford_analysis(gray)
# Computing the combined score
score, weight = 0.0, 0.0
alpha_score = float(np.clip(abs(fft_s["spectral_alpha"] - 1.8) / 1.0 - 1.0, 0.0, 1.0))
score += 1.5 * alpha_score; weight += 1.5
score += 1.0 * float(np.clip((fft_s["hf_ratio"] - 0.30) / 0.25, 0.0, 1.0)); weight += 1.0
score += 0.75 * float(np.clip(fft_s["periodic_score"] / 0.5, 0.0, 1.0)); weight += 0.75
score += 0.75 * float(np.clip(abs(noise_s["noise_kurtosis"] - 3.0) / 15.0, 0.0, 1.0)); weight += 0.75
score += 0.75 * float(1.0 - np.clip((dct_s["dct_ac_std"] / (dct_s["dct_ac_mean"] + 1e-8)) / 1.5, 0.0, 1.0)); weight += 0.75
es = ela_s.get("ela_weight_scale", 1.0)
score += (1.25 * es) * float(np.clip((8.0 - ela_s["ela_std"]) / 7.0, 0.0, 1.0)); weight += (1.25 * es)
score += (1.0 * es) * float(np.clip((ela_s["ela_mean"] - 8.0) / 12.0, 0.0, 1.0)); weight += (1.0 * es)
score += 1.0 * float(np.clip((1.0 - texture_s["texture_local_cv"]) / 0.8, 0.0, 1.0)); weight += 1.0
score += 1.0 * float(np.clip(ben_s["benford_mse"] / 0.005, 0.0, 1.0)); weight += 1.0
fake_prob = float(np.clip(score / weight, 0.0, 1.0))
return {
"fake_prob": fake_prob, "real_prob": 1.0 - fake_prob,
"hf_ratio": fft_s["hf_ratio"], "spectral_alpha": fft_s["spectral_alpha"],
"texture_local_cv": texture_s["texture_local_cv"], "benford_mse": ben_s["benford_mse"]
}
# =============================================================================
# HUGGINGFACE DETECTOR
# =============================================================================
_LABEL_FAKE_MAP = {
"ai": 1.0, "artificial": 1.0, "fake": 1.0, "ai_generated": 1.0, "ai-generated": 1.0,
"generated": 1.0, "synthetic": 1.0, "deepfake": 1.0,
"real": 0.0, "authentic": 0.0, "genuine": 0.0, "natural": 0.0,
"photo": 0.0, "human": 0.0, "realism": 0.0,
}
def _label_to_fake_prob(label: str, score: float) -> float:
key = label.lower().strip()
if key in _LABEL_FAKE_MAP:
d = _LABEL_FAKE_MAP[key]
return d * score + (1.0 - d) * (1.0 - score)
for k, v in _LABEL_FAKE_MAP.items():
if k in key: return v * score + (1.0 - v) * (1.0 - score)
return 0.5
class HFDetector:
def __init__(self, model_name: str, label_name: str = "HF", device=None):
self.model_name = model_name
self.label_name = label_name
self.device = device or config.DEVICE
self._pipe = None
self._loaded = False
self._failed = False
def _load(self):
if self._loaded or self._failed: return
try:
from transformers import pipeline
print(f"[{self.label_name}] Loading {self.model_name} ...")
device_id = 0 if (self.device.type == "cuda") else -1
self._pipe = pipeline("image-classification", model=self.model_name, device=device_id,
model_kwargs={"cache_dir": config.HF_CACHE_DIR}, trust_remote_code=False)
self._loaded = True
print(f"[{self.label_name}] Loaded successfully.")
except Exception as e:
print(f"[{self.label_name}] WARNING: Could not load {self.model_name}: {e}\n[{self.label_name}] Falling back to neutral.")
self._failed = True
def predict(self, image: Image.Image) -> Dict:
self._load()
if self._failed or self._pipe is None: return {"fake_prob": 0.5, "top_label": "unknown"}
if image.mode != "RGB": image = image.convert("RGB")
try:
results = self._pipe(image)
top = results[0]
fake_prob = _label_to_fake_prob(top["label"], top["score"])
return {"fake_prob": float(fake_prob), "top_label": top["label"]}
except Exception as e:
print(f"[{self.label_name}] error: {e}")
return {"fake_prob": 0.5, "top_label": "error"}
class DualHFDetector:
def __init__(self, device=None):
self.primary = HFDetector(config.HF_PRIMARY_MODEL, "HF-Primary", device)
self.secondary = HFDetector(config.HF_SECONDARY_MODEL, "HF-Secondary", device)
# =============================================================================
# ENSEMBLE DETECTOR
# =============================================================================
class EnsembleDetector:
def __init__(self, weights=None, threshold=None, device=None):
self.weights = weights or config.ENSEMBLE_WEIGHTS
self.threshold = threshold if threshold is not None else config.FAKE_THRESHOLD
self.device = device or config.DEVICE
self.hf_detector = DualHFDetector(device=self.device)
self.clip_detector = CLIPDetector(device=self.device)
self.cnn_detector = CNNDetector(device=self.device)
self.freq_detector = FrequencyDetector()
def _safe_run(self, fn, image, name):
try: return fn(image)
except Exception as e:
print(f"[Ensemble] Warning: {name} failed — {e}")
return {"fake_prob": 0.5}
def predict(self, image: Image.Image) -> Dict:
if image.mode != "RGB": image = image.convert("RGB")
hf1 = self._safe_run(self.hf_detector.primary.predict, image, "HF-Primary")
hf2 = self._safe_run(self.hf_detector.secondary.predict, image, "HF-Secondary")
clip = self._safe_run(self.clip_detector.predict, image, "CLIP")
cnn = self._safe_run(self.cnn_detector.predict, image, "CNN")
freq = self._safe_run(self.freq_detector.predict, image, "Frequency")
raw = {
"hf_primary": hf1.get("fake_prob", 0.5),
"hf_secondary": hf2.get("fake_prob", 0.5),
"clip": clip.get("fake_prob", 0.5),
"cnn": cnn.get("fake_prob", 0.5),
"frequency": freq.get("fake_prob", 0.5),
}
total_w, weighted_sum = 0.0, 0.0
for key, p in raw.items():
base_w = self.weights.get(key, 0.0)
eff_w = base_w * max(abs(p - 0.5) * 2.0, 0.10) if config.CONFIDENCE_WEIGHTING else base_w
weighted_sum += eff_w * p
total_w += eff_w
ensemble_fake = float(weighted_sum / total_w) if total_w > 0 else 0.5
label = "FAKE" if ensemble_fake >= self.threshold else "REAL"
# Generate simple explanation
lines = [f"The 5-model ensemble classifies this image as {label}."]
lines.append(f" - AI Detector (primary): {hf1.get('top_label','?')} ({raw['hf_primary']*100:.1f}% fake)")
lines.append(f" - Deepfake Detector (secondary): {hf2.get('top_label','?')} ({raw['hf_secondary']*100:.1f}% fake)")
lines.append(f" - CLIP semantics: {raw['clip']*100:.1f}% fake")
lines.append(f" - Frequency spectrum alpha: {freq.get('spectral_alpha', 1.8):.2f}")
return {
"label": label, "fake_prob": ensemble_fake, "real_prob": 1.0 - ensemble_fake,
"scores": raw, "freq_result": freq, "explanation": "\n".join(lines)
}
# =============================================================================
# CLI / ENTRY POINT
# =============================================================================
def main():
parser = argparse.ArgumentParser(description="Standalone Single-File Image Authenticity Detector")
parser.add_argument("input", nargs="?", help="Path to local image or URL (optional, will prompt if omitted)")
parser.add_argument("--threshold", type=float, default=config.FAKE_THRESHOLD, help="Threshold for fake verdict")
# parse_known_args() is used instead of parse_args() to prevent crashes in Jupyter/Colab
# environments, which often inject a '-f' argument containing the kernel connection file.
args, unknown = parser.parse_known_args()
image_input = args.input
if not image_input:
# Prompt interactively if no argument was provided
print("=== Image Authenticity Detector ===")
# Use simple input() to prompt for the image inside Colab/Interactive sessions
image_input = input("Enter path to a local image or paste a URL: ").strip()
if not image_input:
print("[!] No input provided. Exiting.")
sys.exit(1)
print(f"\n[*] Loading Image from: {image_input}")
try:
if image_input.startswith("http://") or image_input.startswith("https://"):
r = requests.get(image_input, timeout=15, headers={"User-Agent": "Mozilla/5.0"})
r.raise_for_status()
pil_image = Image.open(io.BytesIO(r.content)).convert("RGB")
else:
pil_image = Image.open(image_input).convert("RGB")
except Exception as e:
print(f"[!] Error loading image: {e}")
sys.exit(1)
print("[*] Initializing Ensemble Detector (lazy model loading)...")
detector = EnsembleDetector(threshold=args.threshold)
print("[*] Running Analysis (downloads models on first run)...")
result = detector.predict(pil_image)
fp = result['fake_prob']
if 0.45 <= fp <= 0.55: verdict = "UNCERTAIN (Models disagree / Low confidence)"
elif fp > args.threshold: verdict = "FAKE (AI-Generated / Manipulated)"
else: verdict = "REAL (Authentic Photograph)"
print("\n" + "="*50)
print(" "*15 + "ANALYSIS RESULT")
print("="*50)
print(f"VERDICT: {verdict}")
print(f"Score: {result['real_prob']*100:.1f}% Real | {fp*100:.1f}% Fake")
print("-" * 50)
print("Detailed Model Scores:")
for m, s in result['scores'].items():
print(f" - {m.upper():<12} : {s*100:5.1f}% Fake")
print("-" * 50)
print("\n" + result["explanation"])
freq = result.get("freq_result", {})
if freq:
print("\nKey Frequency Stats:")
print(f" - High-Freq Ratio : {freq.get('hf_ratio',0)*100:.1f}%")
print(f" - Texture CV : {freq.get('texture_local_cv',0):.3f}")
print(f" - Benford MSE : {freq.get('benford_mse',0):.5f}")
print("="*50)
if __name__ == "__main__":
main()