proofly / image_authenticity /models /clip_detector.py
Pragthedon's picture
Initial backend API deployment
4f48a4e
"""
models/clip_detector.py
Zero-shot image authenticity detection using OpenAI CLIP.
Uses cosine similarity between image embeddings and descriptive text prompts
to estimate the probability that an image is real vs AI-generated/fake.
"""
import torch
import torch.nn.functional as F
import numpy as np
from typing import Dict
from PIL import Image
import sys
import os
try:
_base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
except NameError:
_base_dir = os.path.abspath(os.getcwd())
sys.path.append(_base_dir)
from image_authenticity import config
class CLIPDetector:
"""
Zero-shot real/fake image detector using OpenAI CLIP.
Compares image embeddings against two sets of text prompts:
- "real" prompts (e.g. "a real photograph taken by a camera")
- "fake" prompts (e.g. "an AI-generated image")
Returns the probability that the image is fake.
"""
def __init__(self, model_name: str = None, device: torch.device = None):
self.device = device or config.DEVICE
self.model_name = model_name or config.CLIP_MODEL_NAME
self.model = None
self.preprocess = None
self._loaded = False
self._failed = False
def _load(self):
"""Lazy-load CLIP. Tries primary model (ViT-L/14), falls back to ViT-B/16 on OOM."""
if self._loaded or self._failed:
return
try:
import clip
print(f"[CLIP] Loading model '{self.model_name}' on {self.device}...")
try:
self.model, self.preprocess = clip.load(
self.model_name,
device=self.device,
download_root=config.CACHE_DIR
)
except RuntimeError as oom:
# OOM or download failure — try fallback
fallback = getattr(config, "CLIP_MODEL_FALLBACK", "ViT-B/16")
print(f"[CLIP] {oom} — falling back to {fallback}")
self.model_name = fallback
self.model, self.preprocess = clip.load(
fallback,
device=self.device,
download_root=config.CACHE_DIR
)
self.model.eval()
# Pre-encode ALL prompt sets once
real_tokens = clip.tokenize(config.CLIP_REAL_PROMPTS).to(self.device)
fake_tokens = clip.tokenize(config.CLIP_FAKE_PROMPTS).to(self.device)
with torch.no_grad():
self.real_text_features = F.normalize(
self.model.encode_text(real_tokens), dim=-1)
self.fake_text_features = F.normalize(
self.model.encode_text(fake_tokens), dim=-1)
self._loaded = True
print(f"[CLIP] Model '{self.model_name}' loaded successfully.")
except ImportError:
raise ImportError(
"openai-clip is not installed. Run: pip install openai-clip"
)
except Exception as e:
print(f"[CLIP] WARNING: Failed to load CLIP: {e}")
self._failed = True
def predict(self, image: Image.Image) -> Dict[str, float]:
"""
Predict whether an image is real or fake using CLIP.
Returns dict with fake_prob, real_prob, clip_real_score, clip_fake_score.
Returns neutral 0.5/0.5 if model failed to load.
"""
self._load()
if self._failed or self.model is None:
return {"fake_prob": 0.5, "real_prob": 0.5,
"clip_real_score": 0.0, "clip_fake_score": 0.0}
if image.mode != "RGB":
image = image.convert("RGB")
img_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
with torch.no_grad():
image_features = self.model.encode_image(img_tensor)
image_features = F.normalize(image_features, dim=-1) # [1, D]
# Score against every individual prompt, then average
# This avoids the centroid averaging losing discriminative power
real_scores = (image_features @ self.real_text_features.T)[0] # [N_real]
fake_scores = (image_features @ self.fake_text_features.T)[0] # [N_fake]
real_sim = real_scores.mean().item()
fake_sim = fake_scores.mean().item()
# Max-pooled score (captures best-matching prompt)
real_max = real_scores.max().item()
fake_max = fake_scores.max().item()
# Blend mean (70%) + max (30%): mean is more stable across diverse prompts;
# max captures the best-matching individual prompt without dominating.
real_agg = 0.70 * real_sim + 0.30 * real_max
fake_agg = 0.70 * fake_sim + 0.30 * fake_max
# Temperature from config (default 40.0).
# At temperature 40, cosine differences of ~0.01 produce ~40% vs 60% splits
# — much more discriminative than temperature 20.
temperature = getattr(config, "CLIP_TEMPERATURE", 40.0)
logits = torch.tensor([real_agg, fake_agg]) * temperature
probs = torch.softmax(logits, dim=0).numpy()
real_prob = float(probs[0])
fake_prob = float(probs[1])
return {
"fake_prob": fake_prob,
"real_prob": real_prob,
"clip_real_score": float(real_agg),
"clip_fake_score": float(fake_agg),
}
def __repr__(self):
return f"CLIPDetector(model='{self.model_name}', device={self.device})"