Spaces:
Sleeping
Sleeping
| """ | |
| models/clip_detector.py | |
| Zero-shot image authenticity detection using OpenAI CLIP. | |
| Uses cosine similarity between image embeddings and descriptive text prompts | |
| to estimate the probability that an image is real vs AI-generated/fake. | |
| """ | |
| import torch | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from typing import Dict | |
| from PIL import Image | |
| import sys | |
| import os | |
| try: | |
| _base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
| except NameError: | |
| _base_dir = os.path.abspath(os.getcwd()) | |
| sys.path.append(_base_dir) | |
| from image_authenticity import config | |
| class CLIPDetector: | |
| """ | |
| Zero-shot real/fake image detector using OpenAI CLIP. | |
| Compares image embeddings against two sets of text prompts: | |
| - "real" prompts (e.g. "a real photograph taken by a camera") | |
| - "fake" prompts (e.g. "an AI-generated image") | |
| Returns the probability that the image is fake. | |
| """ | |
| def __init__(self, model_name: str = None, device: torch.device = None): | |
| self.device = device or config.DEVICE | |
| self.model_name = model_name or config.CLIP_MODEL_NAME | |
| self.model = None | |
| self.preprocess = None | |
| self._loaded = False | |
| self._failed = False | |
| def _load(self): | |
| """Lazy-load CLIP. Tries primary model (ViT-L/14), falls back to ViT-B/16 on OOM.""" | |
| if self._loaded or self._failed: | |
| return | |
| try: | |
| import clip | |
| print(f"[CLIP] Loading model '{self.model_name}' on {self.device}...") | |
| try: | |
| self.model, self.preprocess = clip.load( | |
| self.model_name, | |
| device=self.device, | |
| download_root=config.CACHE_DIR | |
| ) | |
| except RuntimeError as oom: | |
| # OOM or download failure — try fallback | |
| fallback = getattr(config, "CLIP_MODEL_FALLBACK", "ViT-B/16") | |
| print(f"[CLIP] {oom} — falling back to {fallback}") | |
| self.model_name = fallback | |
| self.model, self.preprocess = clip.load( | |
| fallback, | |
| device=self.device, | |
| download_root=config.CACHE_DIR | |
| ) | |
| self.model.eval() | |
| # Pre-encode ALL prompt sets once | |
| real_tokens = clip.tokenize(config.CLIP_REAL_PROMPTS).to(self.device) | |
| fake_tokens = clip.tokenize(config.CLIP_FAKE_PROMPTS).to(self.device) | |
| with torch.no_grad(): | |
| self.real_text_features = F.normalize( | |
| self.model.encode_text(real_tokens), dim=-1) | |
| self.fake_text_features = F.normalize( | |
| self.model.encode_text(fake_tokens), dim=-1) | |
| self._loaded = True | |
| print(f"[CLIP] Model '{self.model_name}' loaded successfully.") | |
| except ImportError: | |
| raise ImportError( | |
| "openai-clip is not installed. Run: pip install openai-clip" | |
| ) | |
| except Exception as e: | |
| print(f"[CLIP] WARNING: Failed to load CLIP: {e}") | |
| self._failed = True | |
| def predict(self, image: Image.Image) -> Dict[str, float]: | |
| """ | |
| Predict whether an image is real or fake using CLIP. | |
| Returns dict with fake_prob, real_prob, clip_real_score, clip_fake_score. | |
| Returns neutral 0.5/0.5 if model failed to load. | |
| """ | |
| self._load() | |
| if self._failed or self.model is None: | |
| return {"fake_prob": 0.5, "real_prob": 0.5, | |
| "clip_real_score": 0.0, "clip_fake_score": 0.0} | |
| if image.mode != "RGB": | |
| image = image.convert("RGB") | |
| img_tensor = self.preprocess(image).unsqueeze(0).to(self.device) | |
| with torch.no_grad(): | |
| image_features = self.model.encode_image(img_tensor) | |
| image_features = F.normalize(image_features, dim=-1) # [1, D] | |
| # Score against every individual prompt, then average | |
| # This avoids the centroid averaging losing discriminative power | |
| real_scores = (image_features @ self.real_text_features.T)[0] # [N_real] | |
| fake_scores = (image_features @ self.fake_text_features.T)[0] # [N_fake] | |
| real_sim = real_scores.mean().item() | |
| fake_sim = fake_scores.mean().item() | |
| # Max-pooled score (captures best-matching prompt) | |
| real_max = real_scores.max().item() | |
| fake_max = fake_scores.max().item() | |
| # Blend mean (70%) + max (30%): mean is more stable across diverse prompts; | |
| # max captures the best-matching individual prompt without dominating. | |
| real_agg = 0.70 * real_sim + 0.30 * real_max | |
| fake_agg = 0.70 * fake_sim + 0.30 * fake_max | |
| # Temperature from config (default 40.0). | |
| # At temperature 40, cosine differences of ~0.01 produce ~40% vs 60% splits | |
| # — much more discriminative than temperature 20. | |
| temperature = getattr(config, "CLIP_TEMPERATURE", 40.0) | |
| logits = torch.tensor([real_agg, fake_agg]) * temperature | |
| probs = torch.softmax(logits, dim=0).numpy() | |
| real_prob = float(probs[0]) | |
| fake_prob = float(probs[1]) | |
| return { | |
| "fake_prob": fake_prob, | |
| "real_prob": real_prob, | |
| "clip_real_score": float(real_agg), | |
| "clip_fake_score": float(fake_agg), | |
| } | |
| def __repr__(self): | |
| return f"CLIPDetector(model='{self.model_name}', device={self.device})" | |