Spaces:
Sleeping
Sleeping
| """ | |
| detector.py β Unified ImageAuthenticityDetector public interface. | |
| Usage: | |
| from detector import ImageAuthenticityDetector | |
| detector = ImageAuthenticityDetector() | |
| result = detector.predict(image) # PIL Image or path/URL | |
| """ | |
| from PIL import Image | |
| from typing import Union, Dict | |
| import numpy as np | |
| from image_authenticity.models.ensemble import EnsembleDetector | |
| from image_authenticity.utils.preprocessing import load_image, preprocess_for_display | |
| from image_authenticity.utils.visualization import ( | |
| overlay_gradcam, | |
| visualize_fft_spectrum, | |
| create_result_card, | |
| ) | |
| from image_authenticity import config | |
| class ImageAuthenticityDetector: | |
| """ | |
| High-level interface for real/fake image detection. | |
| Wraps EnsembleDetector (CLIP + CNN + Frequency) and provides | |
| prediction + visualisation in a single call. | |
| Example:: | |
| detector = ImageAuthenticityDetector() | |
| # From PIL image | |
| from PIL import Image | |
| img = Image.open("photo.jpg") | |
| result = detector.predict(img) | |
| print(result["label"], result["confidence"]) | |
| # From path or URL | |
| result = detector.predict("https://example.com/image.jpg") | |
| # With visualisations | |
| result, visuals = detector.predict_with_visuals(img) | |
| """ | |
| def __init__( | |
| self, | |
| ensemble_weights: Dict[str, float] = None, | |
| fake_threshold: float = None, | |
| device=None, | |
| ): | |
| self.ensemble = EnsembleDetector( | |
| weights = ensemble_weights, | |
| fake_threshold = fake_threshold, | |
| device = device, | |
| ) | |
| print(f"[Detector] Initialised: {self.ensemble}") | |
| def predict(self, source: Union[str, Image.Image, np.ndarray]) -> Dict: | |
| """ | |
| Predict whether an image is real or fake. | |
| Args: | |
| source: PIL Image, file path string, URL string, or numpy array | |
| Returns: | |
| dict with: | |
| - label : "REAL" or "FAKE" | |
| - confidence : float [0,1] | |
| - fake_prob : float [0,1] | |
| - real_prob : float [0,1] | |
| - scores : dict of per-model fake_prob scores | |
| - explanation : human-readable string | |
| - clip_result : raw CLIP detector output dict | |
| - cnn_result : raw CNN detector output dict | |
| - freq_result : raw Frequency detector output dict | |
| """ | |
| image = load_image(source) | |
| return self.ensemble.predict(image) | |
| def predict_with_visuals( | |
| self, | |
| source: Union[str, Image.Image, np.ndarray], | |
| include_gradcam: bool = True, | |
| include_fft: bool = True, | |
| include_result_card: bool = True, | |
| ) -> tuple[Dict, Dict[str, Image.Image]]: | |
| """ | |
| Predict + generate visualisations. | |
| Args: | |
| source: image source | |
| include_gradcam: generate GradCAM heatmap overlay | |
| include_fft: generate FFT spectrum image | |
| include_result_card: generate results card figure | |
| Returns: | |
| (result_dict, visuals_dict) | |
| visuals_dict keys: "original", "gradcam", "fft_spectrum", "result_card" | |
| """ | |
| image = load_image(source) | |
| result = self.ensemble.predict(image) | |
| disp = preprocess_for_display(image) | |
| visuals = { | |
| "original": disp, | |
| } | |
| if include_gradcam: | |
| try: | |
| heatmap = self.ensemble.get_gradcam(image) | |
| visuals["gradcam"] = overlay_gradcam( | |
| disp, | |
| _resize_heatmap(heatmap, disp.size), | |
| alpha=config.GRADCAM_ALPHA, | |
| ) | |
| except Exception as e: | |
| print(f"[Detector] GradCAM failed: {e}") | |
| visuals["gradcam"] = disp | |
| if include_fft: | |
| try: | |
| spectrum = self.ensemble.get_fft_spectrum(image) | |
| visuals["fft_spectrum"] = visualize_fft_spectrum(spectrum) | |
| except Exception as e: | |
| print(f"[Detector] FFT vis failed: {e}") | |
| if include_result_card: | |
| try: | |
| visuals["result_card"] = create_result_card(result) | |
| except Exception as e: | |
| print(f"[Detector] Result card failed: {e}") | |
| return result, visuals | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Helpers | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _resize_heatmap(heatmap: np.ndarray, size: tuple) -> np.ndarray: | |
| """Resize a 2D float heatmap to match PIL image size (w, h).""" | |
| from PIL import Image as PILImage | |
| h_pil = PILImage.fromarray((heatmap * 255).astype(np.uint8)) | |
| h_pil = h_pil.resize(size, PILImage.BILINEAR) | |
| return np.array(h_pil) / 255.0 | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Quick CLI usage | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if __name__ == "__main__": | |
| import sys | |
| import json | |
| if len(sys.argv) < 2: | |
| print("Usage: python detector.py <image_path_or_url>") | |
| sys.exit(1) | |
| source = sys.argv[1] | |
| print(f"\n[CLI] Analysing: {source}\n") | |
| detector = ImageAuthenticityDetector() | |
| result = detector.predict(source) | |
| print("=" * 50) | |
| print(f" RESULT : {result['label']}") | |
| print(f" CONFIDENCE: {result['confidence']*100:.1f}%") | |
| print(f" FAKE PROB : {result['fake_prob']*100:.1f}%") | |
| print("-" * 50) | |
| for model, score in result["scores"].items(): | |
| print(f" {model.upper():<12}: {score*100:.1f}%") | |
| print("=" * 50) | |
| print(f"\n{result['explanation']}\n") | |