Spaces:
Sleeping
Sleeping
| """ | |
| models/ensemble.py | |
| 5-model ensemble: HF-Primary + HF-Secondary + CLIP + Frequency + CNN | |
| with confidence-aware weighted voting. | |
| """ | |
| import numpy as np | |
| from typing import Dict, Optional | |
| from PIL import Image | |
| import sys | |
| import os | |
| try: | |
| _base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
| except NameError: | |
| _base_dir = os.path.abspath(os.getcwd()) | |
| sys.path.append(_base_dir) | |
| from image_authenticity import config | |
| from .clip_detector import CLIPDetector | |
| from .cnn_detector import CNNDetector | |
| from .frequency_detector import FrequencyDetector | |
| from .hf_detector import DualHFDetector | |
| class EnsembleDetector: | |
| """ | |
| Confidence-aware weighted ensemble of 5 detectors: | |
| 1. HF Primary β dima806/ai_vs_real_image_detection (98.25% acc) | |
| 2. HF Secondary β prithivMLmods/Deep-Fake-Detector-v2 (deepfake ViT) | |
| 3. CLIP ViT-L/14 β zero-shot semantic signal | |
| 4. Frequency β FFT/DCT/ELA/texture forensic analysis | |
| 5. CNN β EfficientNet-B4 (low weight; untrained head) | |
| Confidence-aware voting: each model's weight is multiplied by its | |
| confidence = |fake_prob - 0.5| * 2. A model outputting 0.5 contributes | |
| 0 weight; a model outputting 0.95 contributes at full weight. | |
| This prevents uncertain models from dragging the score to 50%. | |
| """ | |
| def __init__( | |
| self, | |
| weights: Optional[Dict[str, float]] = None, | |
| fake_threshold: Optional[float] = None, | |
| device=None, | |
| ): | |
| self.weights = weights or config.ENSEMBLE_WEIGHTS | |
| self.fake_threshold = fake_threshold if fake_threshold is not None \ | |
| else config.FAKE_THRESHOLD | |
| self.device = device or config.DEVICE | |
| self.confidence_weighting = getattr(config, "CONFIDENCE_WEIGHTING", True) | |
| # Initialise sub-detectors (all lazy-loaded on first predict) | |
| self.hf_detector = DualHFDetector(device=self.device) | |
| self.clip_detector = CLIPDetector(device=self.device) | |
| self.cnn_detector = CNNDetector(device=self.device) | |
| self.freq_detector = FrequencyDetector() | |
| def predict(self, image: Image.Image) -> Dict: | |
| if image.mode != "RGB": | |
| image = image.convert("RGB") | |
| # ββ Run all detectors βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| hf_primary_result = self._safe_run(self.hf_detector.predict_primary, image, "HF-Primary") | |
| hf_secondary_result = self._safe_run(self.hf_detector.predict_secondary, image, "HF-Secondary") | |
| clip_result = self._safe_run(self.clip_detector.predict, image, "CLIP") | |
| cnn_result = self._safe_run(self.cnn_detector.predict, image, "CNN") | |
| freq_result = self._safe_run(self.freq_detector.predict, image, "Frequency") | |
| # ββ Collect raw fake probabilities βββββββββββββββββββββββββββββββββββ | |
| raw = { | |
| "hf_primary": hf_primary_result.get("fake_prob", 0.5), | |
| "hf_secondary": hf_secondary_result.get("fake_prob", 0.5), | |
| "clip": clip_result.get("fake_prob", 0.5), | |
| "frequency": freq_result.get("fake_prob", 0.5), | |
| "cnn": cnn_result.get("fake_prob", 0.5), | |
| } | |
| # ββ Confidence-aware weighted average ββββββββββββββββββββββββββββββββ | |
| total_w = 0.0 | |
| weighted_sum = 0.0 | |
| for model_key, fake_prob in raw.items(): | |
| base_w = self.weights.get(model_key, 0.0) | |
| if self.confidence_weighting: | |
| # If a model outputs 0.99, it is very confident. If it outputs 0.5, it is not. | |
| # Transform to a scalar [0, 1] where 1 is absolute certainty (0.0 or 1.0) and 0 is complete uncertainty (0.5) | |
| # But power it so highly confident models get an exponential boost over guessing models. | |
| certainty = abs(fake_prob - 0.5) * 2.0 | |
| # Damping Field: Use legacy linear scaling but with a higher floor (0.20) | |
| # This ensures that when specialist models strongly disagree, neutral | |
| # models (CLIP/CNN) carry enough weight to act as a buffer. | |
| eff_w = base_w * max(certainty, 0.20) | |
| else: | |
| eff_w = base_w | |
| weighted_sum += eff_w * fake_prob | |
| total_w += eff_w | |
| ensemble_fake = float(weighted_sum / total_w) if total_w > 0 else 0.5 | |
| ensemble_real = 1.0 - ensemble_fake | |
| # ββ Decision ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| is_fake = ensemble_fake >= self.fake_threshold | |
| label = "FAKE" if is_fake else "REAL" | |
| confidence = ensemble_fake if is_fake else ensemble_real | |
| # ββ Explanation βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| explanation = self._generate_explanation( | |
| label, ensemble_fake, raw, freq_result, | |
| hf_primary_result, hf_secondary_result | |
| ) | |
| return { | |
| "label": label, | |
| "confidence": float(confidence), | |
| "fake_prob": float(ensemble_fake), | |
| "real_prob": float(ensemble_real), | |
| "hf_primary_result": hf_primary_result, | |
| "hf_secondary_result": hf_secondary_result, | |
| "clip_result": clip_result, | |
| "cnn_result": cnn_result, | |
| "freq_result": freq_result, | |
| "scores": { | |
| "hf_primary": float(raw["hf_primary"]), | |
| "hf_secondary": float(raw["hf_secondary"]), | |
| "clip": float(raw["clip"]), | |
| "cnn": float(raw["cnn"]), | |
| "frequency": float(raw["frequency"]), | |
| "ensemble": float(ensemble_fake), | |
| }, | |
| "weights": {k: self.weights.get(k, 0) for k in raw}, | |
| "explanation": explanation, | |
| } | |
| def _safe_run(self, fn, image, name): | |
| try: | |
| return fn(image) | |
| except Exception as e: | |
| print(f"[Ensemble] Warning: {name} detector failed β {e}") | |
| return {"fake_prob": 0.5, "real_prob": 0.5, "error": str(e)} | |
| def _generate_explanation( | |
| self, label, ensemble_fake, raw, freq_result, | |
| hf_primary_result, hf_secondary_result | |
| ) -> str: | |
| pct = lambda p: f"{p*100:.1f}%" | |
| lines = [] | |
| lines.append( | |
| f"The 5-model ensemble classifies this image as **{label}** " | |
| f"with {pct(ensemble_fake if label == 'FAKE' else 1 - ensemble_fake)} confidence." | |
| ) | |
| # HF Primary | |
| hf1 = raw["hf_primary"] | |
| top1 = hf_primary_result.get("top_label", "?") | |
| if hf1 > 0.60: | |
| lines.append( | |
| f"β’ AI Detector (primary, {pct(hf1)}): Classified as '{top1}' β " | |
| "consistent with AI-generated content." | |
| ) | |
| else: | |
| lines.append( | |
| f"β’ AI Detector (primary, {pct(hf1)}): Classified as '{top1}' β " | |
| "consistent with authentic photographs." | |
| ) | |
| # HF Secondary | |
| hf2 = raw["hf_secondary"] | |
| top2 = hf_secondary_result.get("top_label", "?") | |
| if hf2 > 0.60: | |
| lines.append( | |
| f"β’ Deepfake Detector (secondary, {pct(hf2)}): Detected '{top2}' patterns." | |
| ) | |
| else: | |
| lines.append( | |
| f"β’ Deepfake Detector (secondary, {pct(hf2)}): No deepfake signature detected (label: '{top2}')." | |
| ) | |
| # CLIP | |
| clip_f = raw["clip"] | |
| if clip_f > 0.60: | |
| lines.append( | |
| f"β’ CLIP ViT-L/14 ({pct(clip_f)}): Visual-semantic features resemble AI-generated content." | |
| ) | |
| else: | |
| lines.append( | |
| f"β’ CLIP ViT-L/14 ({pct(clip_f)}): Visual features consistent with real photographs." | |
| ) | |
| # Frequency | |
| freq_f = raw["frequency"] | |
| alpha = freq_result.get("spectral_alpha", 1.8) | |
| ela_m = freq_result.get("ela_mean", 0) | |
| if freq_f > 0.55: | |
| lines.append( | |
| f"β’ Frequency analysis ({pct(freq_f)}): Anomalous spectral profile " | |
| f"(Ξ±={alpha:.2f}, ELA mean={ela_m:.1f}). " | |
| f"AI generators exhibit unnatural frequency signatures." | |
| ) | |
| else: | |
| lines.append( | |
| f"β’ Frequency analysis ({pct(freq_f)}): Natural spectral profile " | |
| f"(Ξ±={alpha:.2f}) consistent with a real camera sensor." | |
| ) | |
| return "\n".join(lines) | |
| def get_gradcam(self, image: Image.Image) -> np.ndarray: | |
| return self.cnn_detector.get_gradcam(image) | |
| def get_fft_spectrum(self, image: Image.Image) -> np.ndarray: | |
| return self.freq_detector.get_fft_spectrum(image) | |
| def __repr__(self): | |
| return ( | |
| f"EnsembleDetector(5-model, " | |
| f"confidence_weighting={self.confidence_weighting}, " | |
| f"threshold={self.fake_threshold})" | |
| ) | |