Deepguard-api / inference.py
suyash-77's picture
Upload 9 files
a02f72f verified
"""
DeepGuard — ONNX ViT Inference Module
Loads the deepfake detection model once at startup.
All inference is stateless and in-memory.
Model: onnx-community/Deep-Fake-Detector-v2-Model-ONNX
- Architecture: google/vit-base-patch16-224
- Labels: {0: "Realism", 1: "Deepfake"}
- Input: pixel_values (1, 3, 224, 224) float32
- Output: logits (1, 2) float32
"""
import os
import io
import numpy as np
from PIL import Image
from typing import Optional, Tuple
import onnxruntime as ort
# ImageNet normalization constants (used during ViT pre-training)
IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32)
MODEL_PATH = os.path.join(os.path.dirname(__file__), "models", "deepfake_vit.onnx")
# Module-level singleton — loaded once, reused for every request
_session: Optional[ort.InferenceSession] = None
_input_name: str = ""
_output_names: list[str] = []
_has_attention_outputs: bool = False
def load_model() -> None:
"""
Load the ONNX model into a global session at startup.
Must be called once before any inference.
"""
global _session, _input_name, _output_names, _has_attention_outputs
if not os.path.exists(MODEL_PATH):
raise FileNotFoundError(
f"Model not found at {MODEL_PATH}. "
"Please run: python download_model.py"
)
opts = ort.SessionOptions()
opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
opts.inter_op_num_threads = 4
opts.intra_op_num_threads = 4
_session = ort.InferenceSession(
MODEL_PATH,
sess_options=opts,
providers=["CPUExecutionProvider"],
)
_input_name = _session.get_inputs()[0].name
_output_names = [o.name for o in _session.get_outputs()]
# Check whether model exposes attention weights (for attention rollout heatmap)
_has_attention_outputs = any(
"attn" in n.lower() or "attention" in n.lower()
for n in _output_names
)
print(f"[DeepGuard] Model loaded: {MODEL_PATH}")
print(f"[DeepGuard] Input: {_input_name}")
print(f"[DeepGuard] Outputs: {_output_names}")
print(f"[DeepGuard] Attention outputs available: {_has_attention_outputs}")
def get_session() -> ort.InferenceSession:
if _session is None:
raise RuntimeError("Model not loaded. Call load_model() first.")
return _session
def has_attention_outputs() -> bool:
return _has_attention_outputs
def get_attention_output_names() -> list[str]:
return [n for n in _output_names if "attn" in n.lower() or "attention" in n.lower()]
def preprocess(image: Image.Image) -> np.ndarray:
"""
Preprocess a PIL Image for ViT inference.
Returns: float32 NCHW tensor of shape (1, 3, 224, 224)
"""
img = image.convert("RGB").resize((224, 224), Image.BILINEAR)
arr = np.array(img, dtype=np.float32) / 255.0 # (224, 224, 3) [0, 1]
arr = (arr - IMAGENET_MEAN) / IMAGENET_STD # Normalize
arr = arr.transpose(2, 0, 1) # HWC → CHW
arr = np.expand_dims(arr, axis=0) # CHW → NCHW (1,3,224,224)
return arr
def softmax(logits: np.ndarray) -> np.ndarray:
"""Numerically stable softmax."""
e = np.exp(logits - np.max(logits))
return e / e.sum()
def run_inference(image: Image.Image) -> Tuple[float, dict]:
"""
Run the deepfake detection model on a PIL image.
Returns:
confidence_score (float): Probability of being AI-generated [0.0, 1.0]
raw_outputs (dict): Full ONNX output dict (for heatmap module)
"""
session = get_session()
tensor = preprocess(image)
# Run with all outputs (logits + any attention matrices)
raw_outputs = session.run(None, {_input_name: tensor})
output_dict = dict(zip(_output_names, raw_outputs))
# Find logits output (first non-attention output, or output named 'logits')
logits_key = next(
(n for n in _output_names if "logit" in n.lower()),
_output_names[0]
)
logits = output_dict[logits_key].squeeze() # shape (2,)
probs = softmax(logits)
# Label mapping: {0: "Realism", 1: "Deepfake"}
confidence_score = float(probs[1]) # probability of being Deepfake
return confidence_score, output_dict
def get_threat_level(score: float) -> str:
"""Map confidence score to threat level label."""
if score >= 0.90:
return "CRITICAL"
elif score >= 0.75:
return "HIGH"
elif score >= 0.50:
return "MEDIUM"
else:
return "LOW"
def get_model_reasoning(score: float, has_exif: bool, software: str) -> str:
"""Generate a human-readable model reasoning string."""
reasons = []
if score >= 0.90:
reasons.append("Very high-confidence AI artifact signatures detected across multiple image regions.")
elif score >= 0.75:
reasons.append("Significant statistical anomalies inconsistent with optical camera sensors detected.")
elif score >= 0.50:
reasons.append("Moderate AI artifact patterns detected; image may be partially manipulated.")
else:
reasons.append("Low probability of AI generation; image statistics consistent with real photography.")
if not has_exif:
reasons.append("Absence of EXIF metadata is a strong AI indicator.")
if software != "None":
reasons.append(f"Known AI software tag '{software}' detected in image metadata.")
reasons.append(
"ViT attention model flagged inconsistencies in background frequency, "
"texture uniformity, and facial boundary regions."
)
return " ".join(reasons)