""" models/ensemble.py 5-model ensemble: HF-Primary + HF-Secondary + CLIP + Frequency + CNN with confidence-aware weighted voting. """ import numpy as np from typing import Dict, Optional from PIL import Image import sys import os try: _base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) except NameError: _base_dir = os.path.abspath(os.getcwd()) sys.path.append(_base_dir) from image_authenticity import config from .clip_detector import CLIPDetector from .cnn_detector import CNNDetector from .frequency_detector import FrequencyDetector from .hf_detector import DualHFDetector class EnsembleDetector: """ Confidence-aware weighted ensemble of 5 detectors: 1. HF Primary — dima806/ai_vs_real_image_detection (98.25% acc) 2. HF Secondary — prithivMLmods/Deep-Fake-Detector-v2 (deepfake ViT) 3. CLIP ViT-L/14 — zero-shot semantic signal 4. Frequency — FFT/DCT/ELA/texture forensic analysis 5. CNN — EfficientNet-B4 (low weight; untrained head) Confidence-aware voting: each model's weight is multiplied by its confidence = |fake_prob - 0.5| * 2. A model outputting 0.5 contributes 0 weight; a model outputting 0.95 contributes at full weight. This prevents uncertain models from dragging the score to 50%. """ def __init__( self, weights: Optional[Dict[str, float]] = None, fake_threshold: Optional[float] = None, device=None, ): self.weights = weights or config.ENSEMBLE_WEIGHTS self.fake_threshold = fake_threshold if fake_threshold is not None \ else config.FAKE_THRESHOLD self.device = device or config.DEVICE self.confidence_weighting = getattr(config, "CONFIDENCE_WEIGHTING", True) # Initialise sub-detectors (all lazy-loaded on first predict) self.hf_detector = DualHFDetector(device=self.device) self.clip_detector = CLIPDetector(device=self.device) self.cnn_detector = CNNDetector(device=self.device) self.freq_detector = FrequencyDetector() def predict(self, image: Image.Image) -> Dict: if image.mode != "RGB": image = image.convert("RGB") # ── Run all detectors ───────────────────────────────────────────────── hf_primary_result = self._safe_run(self.hf_detector.predict_primary, image, "HF-Primary") hf_secondary_result = self._safe_run(self.hf_detector.predict_secondary, image, "HF-Secondary") clip_result = self._safe_run(self.clip_detector.predict, image, "CLIP") cnn_result = self._safe_run(self.cnn_detector.predict, image, "CNN") freq_result = self._safe_run(self.freq_detector.predict, image, "Frequency") # ── Collect raw fake probabilities ─────────────────────────────────── raw = { "hf_primary": hf_primary_result.get("fake_prob", 0.5), "hf_secondary": hf_secondary_result.get("fake_prob", 0.5), "clip": clip_result.get("fake_prob", 0.5), "frequency": freq_result.get("fake_prob", 0.5), "cnn": cnn_result.get("fake_prob", 0.5), } # ── Confidence-aware weighted average ──────────────────────────────── total_w = 0.0 weighted_sum = 0.0 for model_key, fake_prob in raw.items(): base_w = self.weights.get(model_key, 0.0) if self.confidence_weighting: # If a model outputs 0.99, it is very confident. If it outputs 0.5, it is not. # Transform to a scalar [0, 1] where 1 is absolute certainty (0.0 or 1.0) and 0 is complete uncertainty (0.5) # But power it so highly confident models get an exponential boost over guessing models. certainty = abs(fake_prob - 0.5) * 2.0 # Damping Field: Use legacy linear scaling but with a higher floor (0.20) # This ensures that when specialist models strongly disagree, neutral # models (CLIP/CNN) carry enough weight to act as a buffer. eff_w = base_w * max(certainty, 0.20) else: eff_w = base_w weighted_sum += eff_w * fake_prob total_w += eff_w ensemble_fake = float(weighted_sum / total_w) if total_w > 0 else 0.5 ensemble_real = 1.0 - ensemble_fake # ── Decision ──────────────────────────────────────────────────────── is_fake = ensemble_fake >= self.fake_threshold label = "FAKE" if is_fake else "REAL" confidence = ensemble_fake if is_fake else ensemble_real # ── Explanation ───────────────────────────────────────────────────── explanation = self._generate_explanation( label, ensemble_fake, raw, freq_result, hf_primary_result, hf_secondary_result ) return { "label": label, "confidence": float(confidence), "fake_prob": float(ensemble_fake), "real_prob": float(ensemble_real), "hf_primary_result": hf_primary_result, "hf_secondary_result": hf_secondary_result, "clip_result": clip_result, "cnn_result": cnn_result, "freq_result": freq_result, "scores": { "hf_primary": float(raw["hf_primary"]), "hf_secondary": float(raw["hf_secondary"]), "clip": float(raw["clip"]), "cnn": float(raw["cnn"]), "frequency": float(raw["frequency"]), "ensemble": float(ensemble_fake), }, "weights": {k: self.weights.get(k, 0) for k in raw}, "explanation": explanation, } def _safe_run(self, fn, image, name): try: return fn(image) except Exception as e: print(f"[Ensemble] Warning: {name} detector failed — {e}") return {"fake_prob": 0.5, "real_prob": 0.5, "error": str(e)} def _generate_explanation( self, label, ensemble_fake, raw, freq_result, hf_primary_result, hf_secondary_result ) -> str: pct = lambda p: f"{p*100:.1f}%" lines = [] lines.append( f"The 5-model ensemble classifies this image as **{label}** " f"with {pct(ensemble_fake if label == 'FAKE' else 1 - ensemble_fake)} confidence." ) # HF Primary hf1 = raw["hf_primary"] top1 = hf_primary_result.get("top_label", "?") if hf1 > 0.60: lines.append( f"• AI Detector (primary, {pct(hf1)}): Classified as '{top1}' — " "consistent with AI-generated content." ) else: lines.append( f"• AI Detector (primary, {pct(hf1)}): Classified as '{top1}' — " "consistent with authentic photographs." ) # HF Secondary hf2 = raw["hf_secondary"] top2 = hf_secondary_result.get("top_label", "?") if hf2 > 0.60: lines.append( f"• Deepfake Detector (secondary, {pct(hf2)}): Detected '{top2}' patterns." ) else: lines.append( f"• Deepfake Detector (secondary, {pct(hf2)}): No deepfake signature detected (label: '{top2}')." ) # CLIP clip_f = raw["clip"] if clip_f > 0.60: lines.append( f"• CLIP ViT-L/14 ({pct(clip_f)}): Visual-semantic features resemble AI-generated content." ) else: lines.append( f"• CLIP ViT-L/14 ({pct(clip_f)}): Visual features consistent with real photographs." ) # Frequency freq_f = raw["frequency"] alpha = freq_result.get("spectral_alpha", 1.8) ela_m = freq_result.get("ela_mean", 0) if freq_f > 0.55: lines.append( f"• Frequency analysis ({pct(freq_f)}): Anomalous spectral profile " f"(α={alpha:.2f}, ELA mean={ela_m:.1f}). " f"AI generators exhibit unnatural frequency signatures." ) else: lines.append( f"• Frequency analysis ({pct(freq_f)}): Natural spectral profile " f"(α={alpha:.2f}) consistent with a real camera sensor." ) return "\n".join(lines) def get_gradcam(self, image: Image.Image) -> np.ndarray: return self.cnn_detector.get_gradcam(image) def get_fft_spectrum(self, image: Image.Image) -> np.ndarray: return self.freq_detector.get_fft_spectrum(image) def __repr__(self): return ( f"EnsembleDetector(5-model, " f"confidence_weighting={self.confidence_weighting}, " f"threshold={self.fake_threshold})" )