proofly / image_authenticity /detector.py
Pragthedon's picture
Initial backend API deployment
4f48a4e
"""
detector.py β€” Unified ImageAuthenticityDetector public interface.
Usage:
from detector import ImageAuthenticityDetector
detector = ImageAuthenticityDetector()
result = detector.predict(image) # PIL Image or path/URL
"""
from PIL import Image
from typing import Union, Dict
import numpy as np
from image_authenticity.models.ensemble import EnsembleDetector
from image_authenticity.utils.preprocessing import load_image, preprocess_for_display
from image_authenticity.utils.visualization import (
overlay_gradcam,
visualize_fft_spectrum,
create_result_card,
)
from image_authenticity import config
class ImageAuthenticityDetector:
"""
High-level interface for real/fake image detection.
Wraps EnsembleDetector (CLIP + CNN + Frequency) and provides
prediction + visualisation in a single call.
Example::
detector = ImageAuthenticityDetector()
# From PIL image
from PIL import Image
img = Image.open("photo.jpg")
result = detector.predict(img)
print(result["label"], result["confidence"])
# From path or URL
result = detector.predict("https://example.com/image.jpg")
# With visualisations
result, visuals = detector.predict_with_visuals(img)
"""
def __init__(
self,
ensemble_weights: Dict[str, float] = None,
fake_threshold: float = None,
device=None,
):
self.ensemble = EnsembleDetector(
weights = ensemble_weights,
fake_threshold = fake_threshold,
device = device,
)
print(f"[Detector] Initialised: {self.ensemble}")
def predict(self, source: Union[str, Image.Image, np.ndarray]) -> Dict:
"""
Predict whether an image is real or fake.
Args:
source: PIL Image, file path string, URL string, or numpy array
Returns:
dict with:
- label : "REAL" or "FAKE"
- confidence : float [0,1]
- fake_prob : float [0,1]
- real_prob : float [0,1]
- scores : dict of per-model fake_prob scores
- explanation : human-readable string
- clip_result : raw CLIP detector output dict
- cnn_result : raw CNN detector output dict
- freq_result : raw Frequency detector output dict
"""
image = load_image(source)
return self.ensemble.predict(image)
def predict_with_visuals(
self,
source: Union[str, Image.Image, np.ndarray],
include_gradcam: bool = True,
include_fft: bool = True,
include_result_card: bool = True,
) -> tuple[Dict, Dict[str, Image.Image]]:
"""
Predict + generate visualisations.
Args:
source: image source
include_gradcam: generate GradCAM heatmap overlay
include_fft: generate FFT spectrum image
include_result_card: generate results card figure
Returns:
(result_dict, visuals_dict)
visuals_dict keys: "original", "gradcam", "fft_spectrum", "result_card"
"""
image = load_image(source)
result = self.ensemble.predict(image)
disp = preprocess_for_display(image)
visuals = {
"original": disp,
}
if include_gradcam:
try:
heatmap = self.ensemble.get_gradcam(image)
visuals["gradcam"] = overlay_gradcam(
disp,
_resize_heatmap(heatmap, disp.size),
alpha=config.GRADCAM_ALPHA,
)
except Exception as e:
print(f"[Detector] GradCAM failed: {e}")
visuals["gradcam"] = disp
if include_fft:
try:
spectrum = self.ensemble.get_fft_spectrum(image)
visuals["fft_spectrum"] = visualize_fft_spectrum(spectrum)
except Exception as e:
print(f"[Detector] FFT vis failed: {e}")
if include_result_card:
try:
visuals["result_card"] = create_result_card(result)
except Exception as e:
print(f"[Detector] Result card failed: {e}")
return result, visuals
# ─────────────────────────────────────────────────────────
# Helpers
# ─────────────────────────────────────────────────────────
def _resize_heatmap(heatmap: np.ndarray, size: tuple) -> np.ndarray:
"""Resize a 2D float heatmap to match PIL image size (w, h)."""
from PIL import Image as PILImage
h_pil = PILImage.fromarray((heatmap * 255).astype(np.uint8))
h_pil = h_pil.resize(size, PILImage.BILINEAR)
return np.array(h_pil) / 255.0
# ─────────────────────────────────────────────────────────
# Quick CLI usage
# ─────────────────────────────────────────────────────────
if __name__ == "__main__":
import sys
import json
if len(sys.argv) < 2:
print("Usage: python detector.py <image_path_or_url>")
sys.exit(1)
source = sys.argv[1]
print(f"\n[CLI] Analysing: {source}\n")
detector = ImageAuthenticityDetector()
result = detector.predict(source)
print("=" * 50)
print(f" RESULT : {result['label']}")
print(f" CONFIDENCE: {result['confidence']*100:.1f}%")
print(f" FAKE PROB : {result['fake_prob']*100:.1f}%")
print("-" * 50)
for model, score in result["scores"].items():
print(f" {model.upper():<12}: {score*100:.1f}%")
print("=" * 50)
print(f"\n{result['explanation']}\n")