Spaces:
Sleeping
Sleeping
| """ | |
| DeepGuard — ONNX ViT Inference Module | |
| Loads the deepfake detection model once at startup. | |
| All inference is stateless and in-memory. | |
| Model: onnx-community/Deep-Fake-Detector-v2-Model-ONNX | |
| - Architecture: google/vit-base-patch16-224 | |
| - Labels: {0: "Realism", 1: "Deepfake"} | |
| - Input: pixel_values (1, 3, 224, 224) float32 | |
| - Output: logits (1, 2) float32 | |
| """ | |
| import os | |
| import io | |
| import numpy as np | |
| from PIL import Image | |
| from typing import Optional, Tuple | |
| import onnxruntime as ort | |
| # ImageNet normalization constants (used during ViT pre-training) | |
| IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32) | |
| IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32) | |
| MODEL_PATH = os.path.join(os.path.dirname(__file__), "models", "deepfake_vit.onnx") | |
| # Module-level singleton — loaded once, reused for every request | |
| _session: Optional[ort.InferenceSession] = None | |
| _input_name: str = "" | |
| _output_names: list[str] = [] | |
| _has_attention_outputs: bool = False | |
| def load_model() -> None: | |
| """ | |
| Load the ONNX model into a global session at startup. | |
| Must be called once before any inference. | |
| """ | |
| global _session, _input_name, _output_names, _has_attention_outputs | |
| if not os.path.exists(MODEL_PATH): | |
| raise FileNotFoundError( | |
| f"Model not found at {MODEL_PATH}. " | |
| "Please run: python download_model.py" | |
| ) | |
| opts = ort.SessionOptions() | |
| opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL | |
| opts.inter_op_num_threads = 4 | |
| opts.intra_op_num_threads = 4 | |
| _session = ort.InferenceSession( | |
| MODEL_PATH, | |
| sess_options=opts, | |
| providers=["CPUExecutionProvider"], | |
| ) | |
| _input_name = _session.get_inputs()[0].name | |
| _output_names = [o.name for o in _session.get_outputs()] | |
| # Check whether model exposes attention weights (for attention rollout heatmap) | |
| _has_attention_outputs = any( | |
| "attn" in n.lower() or "attention" in n.lower() | |
| for n in _output_names | |
| ) | |
| print(f"[DeepGuard] Model loaded: {MODEL_PATH}") | |
| print(f"[DeepGuard] Input: {_input_name}") | |
| print(f"[DeepGuard] Outputs: {_output_names}") | |
| print(f"[DeepGuard] Attention outputs available: {_has_attention_outputs}") | |
| def get_session() -> ort.InferenceSession: | |
| if _session is None: | |
| raise RuntimeError("Model not loaded. Call load_model() first.") | |
| return _session | |
| def has_attention_outputs() -> bool: | |
| return _has_attention_outputs | |
| def get_attention_output_names() -> list[str]: | |
| return [n for n in _output_names if "attn" in n.lower() or "attention" in n.lower()] | |
| def preprocess(image: Image.Image) -> np.ndarray: | |
| """ | |
| Preprocess a PIL Image for ViT inference. | |
| Returns: float32 NCHW tensor of shape (1, 3, 224, 224) | |
| """ | |
| img = image.convert("RGB").resize((224, 224), Image.BILINEAR) | |
| arr = np.array(img, dtype=np.float32) / 255.0 # (224, 224, 3) [0, 1] | |
| arr = (arr - IMAGENET_MEAN) / IMAGENET_STD # Normalize | |
| arr = arr.transpose(2, 0, 1) # HWC → CHW | |
| arr = np.expand_dims(arr, axis=0) # CHW → NCHW (1,3,224,224) | |
| return arr | |
| def softmax(logits: np.ndarray) -> np.ndarray: | |
| """Numerically stable softmax.""" | |
| e = np.exp(logits - np.max(logits)) | |
| return e / e.sum() | |
| def run_inference(image: Image.Image) -> Tuple[float, dict]: | |
| """ | |
| Run the deepfake detection model on a PIL image. | |
| Returns: | |
| confidence_score (float): Probability of being AI-generated [0.0, 1.0] | |
| raw_outputs (dict): Full ONNX output dict (for heatmap module) | |
| """ | |
| session = get_session() | |
| tensor = preprocess(image) | |
| # Run with all outputs (logits + any attention matrices) | |
| raw_outputs = session.run(None, {_input_name: tensor}) | |
| output_dict = dict(zip(_output_names, raw_outputs)) | |
| # Find logits output (first non-attention output, or output named 'logits') | |
| logits_key = next( | |
| (n for n in _output_names if "logit" in n.lower()), | |
| _output_names[0] | |
| ) | |
| logits = output_dict[logits_key].squeeze() # shape (2,) | |
| probs = softmax(logits) | |
| # Label mapping: {0: "Realism", 1: "Deepfake"} | |
| confidence_score = float(probs[1]) # probability of being Deepfake | |
| return confidence_score, output_dict | |
| def get_threat_level(score: float) -> str: | |
| """Map confidence score to threat level label.""" | |
| if score >= 0.90: | |
| return "CRITICAL" | |
| elif score >= 0.75: | |
| return "HIGH" | |
| elif score >= 0.50: | |
| return "MEDIUM" | |
| else: | |
| return "LOW" | |
| def get_model_reasoning(score: float, has_exif: bool, software: str) -> str: | |
| """Generate a human-readable model reasoning string.""" | |
| reasons = [] | |
| if score >= 0.90: | |
| reasons.append("Very high-confidence AI artifact signatures detected across multiple image regions.") | |
| elif score >= 0.75: | |
| reasons.append("Significant statistical anomalies inconsistent with optical camera sensors detected.") | |
| elif score >= 0.50: | |
| reasons.append("Moderate AI artifact patterns detected; image may be partially manipulated.") | |
| else: | |
| reasons.append("Low probability of AI generation; image statistics consistent with real photography.") | |
| if not has_exif: | |
| reasons.append("Absence of EXIF metadata is a strong AI indicator.") | |
| if software != "None": | |
| reasons.append(f"Known AI software tag '{software}' detected in image metadata.") | |
| reasons.append( | |
| "ViT attention model flagged inconsistencies in background frequency, " | |
| "texture uniformity, and facial boundary regions." | |
| ) | |
| return " ".join(reasons) | |