Pragthedon's picture
Fix: Applied Physics-Majority logic to eliminate high-confidence false positives
eb14ec7
"""
models/ensemble.py
5-model ensemble: HF-Primary + HF-Secondary + CLIP + Frequency + CNN
with confidence-aware weighted voting.
"""
import numpy as np
from typing import Dict, Optional
from PIL import Image
import sys
import os
try:
_base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
except NameError:
_base_dir = os.path.abspath(os.getcwd())
sys.path.append(_base_dir)
from image_authenticity import config
from .clip_detector import CLIPDetector
from .cnn_detector import CNNDetector
from .frequency_detector import FrequencyDetector
from .hf_detector import DualHFDetector
class EnsembleDetector:
"""
Confidence-aware weighted ensemble of 5 detectors:
1. HF Primary β€” dima806/ai_vs_real_image_detection (98.25% acc)
2. HF Secondary β€” prithivMLmods/Deep-Fake-Detector-v2 (deepfake ViT)
3. CLIP ViT-L/14 β€” zero-shot semantic signal
4. Frequency β€” FFT/DCT/ELA/texture forensic analysis
5. CNN β€” EfficientNet-B4 (low weight; untrained head)
Confidence-aware voting: each model's weight is multiplied by its
confidence = |fake_prob - 0.5| * 2. A model outputting 0.5 contributes
0 weight; a model outputting 0.95 contributes at full weight.
This prevents uncertain models from dragging the score to 50%.
"""
def __init__(
self,
weights: Optional[Dict[str, float]] = None,
fake_threshold: Optional[float] = None,
device=None,
):
self.weights = weights or config.ENSEMBLE_WEIGHTS
self.fake_threshold = fake_threshold if fake_threshold is not None \
else config.FAKE_THRESHOLD
self.device = device or config.DEVICE
self.confidence_weighting = getattr(config, "CONFIDENCE_WEIGHTING", True)
# Initialise sub-detectors (all lazy-loaded on first predict)
self.hf_detector = DualHFDetector(device=self.device)
self.clip_detector = CLIPDetector(device=self.device)
self.cnn_detector = CNNDetector(device=self.device)
self.freq_detector = FrequencyDetector()
def predict(self, image: Image.Image) -> Dict:
if image.mode != "RGB":
image = image.convert("RGB")
# ── Run all detectors ─────────────────────────────────────────────────
hf_primary_result = self._safe_run(self.hf_detector.predict_primary, image, "HF-Primary")
hf_secondary_result = self._safe_run(self.hf_detector.predict_secondary, image, "HF-Secondary")
clip_result = self._safe_run(self.clip_detector.predict, image, "CLIP")
cnn_result = self._safe_run(self.cnn_detector.predict, image, "CNN")
freq_result = self._safe_run(self.freq_detector.predict, image, "Frequency")
# ── Collect raw fake probabilities ───────────────────────────────────
raw = {
"hf_primary": hf_primary_result.get("fake_prob", 0.5),
"hf_secondary": hf_secondary_result.get("fake_prob", 0.5),
"clip": clip_result.get("fake_prob", 0.5),
"frequency": freq_result.get("fake_prob", 0.5),
"cnn": cnn_result.get("fake_prob", 0.5),
}
# ── Confidence-aware weighted average ────────────────────────────────
total_w = 0.0
weighted_sum = 0.0
for model_key, fake_prob in raw.items():
base_w = self.weights.get(model_key, 0.0)
if self.confidence_weighting:
# If a model outputs 0.99, it is very confident. If it outputs 0.5, it is not.
# Transform to a scalar [0, 1] where 1 is absolute certainty (0.0 or 1.0) and 0 is complete uncertainty (0.5)
# But power it so highly confident models get an exponential boost over guessing models.
certainty = abs(fake_prob - 0.5) * 2.0
# Damping Field: Use legacy linear scaling but with a higher floor (0.20)
# This ensures that when specialist models strongly disagree, neutral
# models (CLIP/CNN) carry enough weight to act as a buffer.
eff_w = base_w * max(certainty, 0.20)
else:
eff_w = base_w
weighted_sum += eff_w * fake_prob
total_w += eff_w
ensemble_fake = float(weighted_sum / total_w) if total_w > 0 else 0.5
ensemble_real = 1.0 - ensemble_fake
# ── Decision ────────────────────────────────────────────────────────
is_fake = ensemble_fake >= self.fake_threshold
label = "FAKE" if is_fake else "REAL"
confidence = ensemble_fake if is_fake else ensemble_real
# ── Explanation ─────────────────────────────────────────────────────
explanation = self._generate_explanation(
label, ensemble_fake, raw, freq_result,
hf_primary_result, hf_secondary_result
)
return {
"label": label,
"confidence": float(confidence),
"fake_prob": float(ensemble_fake),
"real_prob": float(ensemble_real),
"hf_primary_result": hf_primary_result,
"hf_secondary_result": hf_secondary_result,
"clip_result": clip_result,
"cnn_result": cnn_result,
"freq_result": freq_result,
"scores": {
"hf_primary": float(raw["hf_primary"]),
"hf_secondary": float(raw["hf_secondary"]),
"clip": float(raw["clip"]),
"cnn": float(raw["cnn"]),
"frequency": float(raw["frequency"]),
"ensemble": float(ensemble_fake),
},
"weights": {k: self.weights.get(k, 0) for k in raw},
"explanation": explanation,
}
def _safe_run(self, fn, image, name):
try:
return fn(image)
except Exception as e:
print(f"[Ensemble] Warning: {name} detector failed β€” {e}")
return {"fake_prob": 0.5, "real_prob": 0.5, "error": str(e)}
def _generate_explanation(
self, label, ensemble_fake, raw, freq_result,
hf_primary_result, hf_secondary_result
) -> str:
pct = lambda p: f"{p*100:.1f}%"
lines = []
lines.append(
f"The 5-model ensemble classifies this image as **{label}** "
f"with {pct(ensemble_fake if label == 'FAKE' else 1 - ensemble_fake)} confidence."
)
# HF Primary
hf1 = raw["hf_primary"]
top1 = hf_primary_result.get("top_label", "?")
if hf1 > 0.60:
lines.append(
f"β€’ AI Detector (primary, {pct(hf1)}): Classified as '{top1}' β€” "
"consistent with AI-generated content."
)
else:
lines.append(
f"β€’ AI Detector (primary, {pct(hf1)}): Classified as '{top1}' β€” "
"consistent with authentic photographs."
)
# HF Secondary
hf2 = raw["hf_secondary"]
top2 = hf_secondary_result.get("top_label", "?")
if hf2 > 0.60:
lines.append(
f"β€’ Deepfake Detector (secondary, {pct(hf2)}): Detected '{top2}' patterns."
)
else:
lines.append(
f"β€’ Deepfake Detector (secondary, {pct(hf2)}): No deepfake signature detected (label: '{top2}')."
)
# CLIP
clip_f = raw["clip"]
if clip_f > 0.60:
lines.append(
f"β€’ CLIP ViT-L/14 ({pct(clip_f)}): Visual-semantic features resemble AI-generated content."
)
else:
lines.append(
f"β€’ CLIP ViT-L/14 ({pct(clip_f)}): Visual features consistent with real photographs."
)
# Frequency
freq_f = raw["frequency"]
alpha = freq_result.get("spectral_alpha", 1.8)
ela_m = freq_result.get("ela_mean", 0)
if freq_f > 0.55:
lines.append(
f"β€’ Frequency analysis ({pct(freq_f)}): Anomalous spectral profile "
f"(Ξ±={alpha:.2f}, ELA mean={ela_m:.1f}). "
f"AI generators exhibit unnatural frequency signatures."
)
else:
lines.append(
f"β€’ Frequency analysis ({pct(freq_f)}): Natural spectral profile "
f"(Ξ±={alpha:.2f}) consistent with a real camera sensor."
)
return "\n".join(lines)
def get_gradcam(self, image: Image.Image) -> np.ndarray:
return self.cnn_detector.get_gradcam(image)
def get_fft_spectrum(self, image: Image.Image) -> np.ndarray:
return self.freq_detector.get_fft_spectrum(image)
def __repr__(self):
return (
f"EnsembleDetector(5-model, "
f"confidence_weighting={self.confidence_weighting}, "
f"threshold={self.fake_threshold})"
)