""" detector.py — Unified ImageAuthenticityDetector public interface. Usage: from detector import ImageAuthenticityDetector detector = ImageAuthenticityDetector() result = detector.predict(image) # PIL Image or path/URL """ from PIL import Image from typing import Union, Dict import numpy as np from image_authenticity.models.ensemble import EnsembleDetector from image_authenticity.utils.preprocessing import load_image, preprocess_for_display from image_authenticity.utils.visualization import ( overlay_gradcam, visualize_fft_spectrum, create_result_card, ) from image_authenticity import config class ImageAuthenticityDetector: """ High-level interface for real/fake image detection. Wraps EnsembleDetector (CLIP + CNN + Frequency) and provides prediction + visualisation in a single call. Example:: detector = ImageAuthenticityDetector() # From PIL image from PIL import Image img = Image.open("photo.jpg") result = detector.predict(img) print(result["label"], result["confidence"]) # From path or URL result = detector.predict("https://example.com/image.jpg") # With visualisations result, visuals = detector.predict_with_visuals(img) """ def __init__( self, ensemble_weights: Dict[str, float] = None, fake_threshold: float = None, device=None, ): self.ensemble = EnsembleDetector( weights = ensemble_weights, fake_threshold = fake_threshold, device = device, ) print(f"[Detector] Initialised: {self.ensemble}") def predict(self, source: Union[str, Image.Image, np.ndarray]) -> Dict: """ Predict whether an image is real or fake. Args: source: PIL Image, file path string, URL string, or numpy array Returns: dict with: - label : "REAL" or "FAKE" - confidence : float [0,1] - fake_prob : float [0,1] - real_prob : float [0,1] - scores : dict of per-model fake_prob scores - explanation : human-readable string - clip_result : raw CLIP detector output dict - cnn_result : raw CNN detector output dict - freq_result : raw Frequency detector output dict """ image = load_image(source) return self.ensemble.predict(image) def predict_with_visuals( self, source: Union[str, Image.Image, np.ndarray], include_gradcam: bool = True, include_fft: bool = True, include_result_card: bool = True, ) -> tuple[Dict, Dict[str, Image.Image]]: """ Predict + generate visualisations. Args: source: image source include_gradcam: generate GradCAM heatmap overlay include_fft: generate FFT spectrum image include_result_card: generate results card figure Returns: (result_dict, visuals_dict) visuals_dict keys: "original", "gradcam", "fft_spectrum", "result_card" """ image = load_image(source) result = self.ensemble.predict(image) disp = preprocess_for_display(image) visuals = { "original": disp, } if include_gradcam: try: heatmap = self.ensemble.get_gradcam(image) visuals["gradcam"] = overlay_gradcam( disp, _resize_heatmap(heatmap, disp.size), alpha=config.GRADCAM_ALPHA, ) except Exception as e: print(f"[Detector] GradCAM failed: {e}") visuals["gradcam"] = disp if include_fft: try: spectrum = self.ensemble.get_fft_spectrum(image) visuals["fft_spectrum"] = visualize_fft_spectrum(spectrum) except Exception as e: print(f"[Detector] FFT vis failed: {e}") if include_result_card: try: visuals["result_card"] = create_result_card(result) except Exception as e: print(f"[Detector] Result card failed: {e}") return result, visuals # ───────────────────────────────────────────────────────── # Helpers # ───────────────────────────────────────────────────────── def _resize_heatmap(heatmap: np.ndarray, size: tuple) -> np.ndarray: """Resize a 2D float heatmap to match PIL image size (w, h).""" from PIL import Image as PILImage h_pil = PILImage.fromarray((heatmap * 255).astype(np.uint8)) h_pil = h_pil.resize(size, PILImage.BILINEAR) return np.array(h_pil) / 255.0 # ───────────────────────────────────────────────────────── # Quick CLI usage # ───────────────────────────────────────────────────────── if __name__ == "__main__": import sys import json if len(sys.argv) < 2: print("Usage: python detector.py ") sys.exit(1) source = sys.argv[1] print(f"\n[CLI] Analysing: {source}\n") detector = ImageAuthenticityDetector() result = detector.predict(source) print("=" * 50) print(f" RESULT : {result['label']}") print(f" CONFIDENCE: {result['confidence']*100:.1f}%") print(f" FAKE PROB : {result['fake_prob']*100:.1f}%") print("-" * 50) for model, score in result["scores"].items(): print(f" {model.upper():<12}: {score*100:.1f}%") print("=" * 50) print(f"\n{result['explanation']}\n")