""" DeepGuard — ONNX ViT Inference Module Loads the deepfake detection model once at startup. All inference is stateless and in-memory. Model: onnx-community/Deep-Fake-Detector-v2-Model-ONNX - Architecture: google/vit-base-patch16-224 - Labels: {0: "Realism", 1: "Deepfake"} - Input: pixel_values (1, 3, 224, 224) float32 - Output: logits (1, 2) float32 """ import os import io import numpy as np from PIL import Image from typing import Optional, Tuple import onnxruntime as ort # ImageNet normalization constants (used during ViT pre-training) IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32) IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32) MODEL_PATH = os.path.join(os.path.dirname(__file__), "models", "deepfake_vit.onnx") # Module-level singleton — loaded once, reused for every request _session: Optional[ort.InferenceSession] = None _input_name: str = "" _output_names: list[str] = [] _has_attention_outputs: bool = False def load_model() -> None: """ Load the ONNX model into a global session at startup. Must be called once before any inference. """ global _session, _input_name, _output_names, _has_attention_outputs if not os.path.exists(MODEL_PATH): raise FileNotFoundError( f"Model not found at {MODEL_PATH}. " "Please run: python download_model.py" ) opts = ort.SessionOptions() opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL opts.inter_op_num_threads = 4 opts.intra_op_num_threads = 4 _session = ort.InferenceSession( MODEL_PATH, sess_options=opts, providers=["CPUExecutionProvider"], ) _input_name = _session.get_inputs()[0].name _output_names = [o.name for o in _session.get_outputs()] # Check whether model exposes attention weights (for attention rollout heatmap) _has_attention_outputs = any( "attn" in n.lower() or "attention" in n.lower() for n in _output_names ) print(f"[DeepGuard] Model loaded: {MODEL_PATH}") print(f"[DeepGuard] Input: {_input_name}") print(f"[DeepGuard] Outputs: {_output_names}") print(f"[DeepGuard] Attention outputs available: {_has_attention_outputs}") def get_session() -> ort.InferenceSession: if _session is None: raise RuntimeError("Model not loaded. Call load_model() first.") return _session def has_attention_outputs() -> bool: return _has_attention_outputs def get_attention_output_names() -> list[str]: return [n for n in _output_names if "attn" in n.lower() or "attention" in n.lower()] def preprocess(image: Image.Image) -> np.ndarray: """ Preprocess a PIL Image for ViT inference. Returns: float32 NCHW tensor of shape (1, 3, 224, 224) """ img = image.convert("RGB").resize((224, 224), Image.BILINEAR) arr = np.array(img, dtype=np.float32) / 255.0 # (224, 224, 3) [0, 1] arr = (arr - IMAGENET_MEAN) / IMAGENET_STD # Normalize arr = arr.transpose(2, 0, 1) # HWC → CHW arr = np.expand_dims(arr, axis=0) # CHW → NCHW (1,3,224,224) return arr def softmax(logits: np.ndarray) -> np.ndarray: """Numerically stable softmax.""" e = np.exp(logits - np.max(logits)) return e / e.sum() def run_inference(image: Image.Image) -> Tuple[float, dict]: """ Run the deepfake detection model on a PIL image. Returns: confidence_score (float): Probability of being AI-generated [0.0, 1.0] raw_outputs (dict): Full ONNX output dict (for heatmap module) """ session = get_session() tensor = preprocess(image) # Run with all outputs (logits + any attention matrices) raw_outputs = session.run(None, {_input_name: tensor}) output_dict = dict(zip(_output_names, raw_outputs)) # Find logits output (first non-attention output, or output named 'logits') logits_key = next( (n for n in _output_names if "logit" in n.lower()), _output_names[0] ) logits = output_dict[logits_key].squeeze() # shape (2,) probs = softmax(logits) # Label mapping: {0: "Realism", 1: "Deepfake"} confidence_score = float(probs[1]) # probability of being Deepfake return confidence_score, output_dict def get_threat_level(score: float) -> str: """Map confidence score to threat level label.""" if score >= 0.90: return "CRITICAL" elif score >= 0.75: return "HIGH" elif score >= 0.50: return "MEDIUM" else: return "LOW" def get_model_reasoning(score: float, has_exif: bool, software: str) -> str: """Generate a human-readable model reasoning string.""" reasons = [] if score >= 0.90: reasons.append("Very high-confidence AI artifact signatures detected across multiple image regions.") elif score >= 0.75: reasons.append("Significant statistical anomalies inconsistent with optical camera sensors detected.") elif score >= 0.50: reasons.append("Moderate AI artifact patterns detected; image may be partially manipulated.") else: reasons.append("Low probability of AI generation; image statistics consistent with real photography.") if not has_exif: reasons.append("Absence of EXIF metadata is a strong AI indicator.") if software != "None": reasons.append(f"Known AI software tag '{software}' detected in image metadata.") reasons.append( "ViT attention model flagged inconsistencies in background frequency, " "texture uniformity, and facial boundary regions." ) return " ".join(reasons)