deepfake-detection / inference.py
hardiksharma6555's picture
Upload 7 files
6bc65b8 verified
import io
import os
from functools import lru_cache
from typing import Any, Dict, List, Tuple
import numpy as np
import requests
from PIL import Image
FAKE_LABEL = "FAKE"
REAL_LABEL = "REAL"
FAKE_HINTS = ("fake", "deepfake", "ai", "generated", "synthetic", "manipulated")
REAL_HINTS = ("real", "human", "authentic", "genuine", "original")
def _resolve_default_path(env_var: str, primary: str, fallback: str) -> str:
configured = os.getenv(env_var, "").strip()
if configured:
return configured
if os.path.isfile(primary):
return primary
return fallback
def normalize_scores(fake_score: float, real_score: float) -> Tuple[float, float]:
fake = float(max(fake_score, 0.0))
real = float(max(real_score, 0.0))
total = fake + real
if total <= 0:
return 0.5, 0.5
return fake / total, real / total
def _ensure_rgb_uint8(image: np.ndarray) -> np.ndarray:
if image is None:
raise ValueError("Image is required")
arr = np.asarray(image)
if arr.ndim == 2:
arr = np.stack([arr] * 3, axis=-1)
if arr.ndim != 3:
raise ValueError(f"Expected image with 2 or 3 dimensions, got shape {arr.shape}")
if arr.shape[-1] == 4:
arr = arr[..., :3]
if arr.shape[-1] != 3:
raise ValueError(f"Expected 3 channels, got shape {arr.shape}")
if arr.dtype != np.uint8:
arr = np.clip(arr, 0, 255).astype(np.uint8)
return arr
def _to_png_bytes(image: np.ndarray) -> bytes:
rgb = _ensure_rgb_uint8(image)
pil_img = Image.fromarray(rgb).convert("RGB")
buffer = io.BytesIO()
pil_img.save(buffer, format="PNG")
return buffer.getvalue()
def _extract_state_dict(checkpoint: Any) -> Any:
if not isinstance(checkpoint, dict):
return checkpoint
for key in ("state_dict", "model_state_dict", "model"):
maybe = checkpoint.get(key)
if isinstance(maybe, dict):
return maybe
return checkpoint
def _predict_payload(fake_score: float, real_score: float) -> Dict[str, Any]:
fake_score, real_score = normalize_scores(fake_score, real_score)
label = FAKE_LABEL if fake_score >= real_score else REAL_LABEL
confidence = fake_score if label == FAKE_LABEL else real_score
return {
"label": label,
"confidence": confidence,
"scores": {
FAKE_LABEL: round(fake_score, 4),
REAL_LABEL: round(real_score, 4),
},
}
@lru_cache(maxsize=4)
def _load_torch_model(model_path: str, arch: str):
import torch
from timm import create_model
from torchvision import transforms
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint = torch.load(model_path, map_location=device)
if isinstance(checkpoint, dict) and "arch" in checkpoint and not arch:
arch = checkpoint["arch"]
arch = arch or "efficientnet_b4"
model = create_model(arch, pretrained=False, num_classes=2)
state_dict = _extract_state_dict(checkpoint)
try:
model.load_state_dict(state_dict)
except RuntimeError:
# Common fix for DataParallel checkpoints with "module." prefixes.
remapped = {}
for key, value in state_dict.items():
new_key = key[7:] if key.startswith("module.") else key
remapped[new_key] = value
model.load_state_dict(remapped)
model.to(device).eval()
image_size = int(checkpoint.get("image_size", 380)) if isinstance(checkpoint, dict) else 380
transform = transforms.Compose(
[
transforms.Resize(image_size + 20),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
return model, transform, device
def predict_local_torch(image: np.ndarray) -> Dict[str, Any]:
import torch
model_path = _resolve_default_path(
"LOCAL_MODEL_PATH",
"models/pytorch_model.pth",
"models/reference/pytorch_model.pth",
)
arch = os.getenv("MODEL_ARCH", "efficientnet_b4")
if not os.path.isfile(model_path):
raise FileNotFoundError(
f"Missing local model: {model_path}. "
"Train one with train.py or set LOCAL_MODEL_PATH to an existing checkpoint."
)
model, transform, device = _load_torch_model(model_path, arch)
rgb = _ensure_rgb_uint8(image)
pil_img = Image.fromarray(rgb).convert("RGB")
tensor = transform(pil_img).unsqueeze(0).to(device)
with torch.no_grad():
logits = model(tensor)
probs = torch.softmax(logits, dim=1).squeeze(0).cpu().numpy().tolist()
fake_score = float(probs[0])
real_score = float(probs[1])
return _predict_payload(fake_score, real_score)
@lru_cache(maxsize=2)
def _load_onnx_session(onnx_path: str):
import onnxruntime as ort
providers = ["CPUExecutionProvider"]
return ort.InferenceSession(onnx_path, providers=providers)
def predict_local_onnx(image: np.ndarray) -> Dict[str, Any]:
onnx_path = _resolve_default_path(
"LOCAL_ONNX_PATH",
"models/model.onnx",
"models/reference/efficientnet.onnx",
)
default_image_size = int(os.getenv("IMAGE_SIZE", "380"))
if not os.path.isfile(onnx_path):
raise FileNotFoundError(
f"Missing ONNX model: {onnx_path}. Export one with export_onnx.py or set LOCAL_ONNX_PATH."
)
from PIL import ImageOps
session = _load_onnx_session(onnx_path)
input_meta = session.get_inputs()[0]
input_name = input_meta.name
input_shape = list(input_meta.shape)
is_nchw = len(input_shape) == 4 and input_shape[1] == 3
is_nhwc = len(input_shape) == 4 and input_shape[-1] == 3
if is_nchw:
height = input_shape[2] if isinstance(input_shape[2], int) else default_image_size
width = input_shape[3] if isinstance(input_shape[3], int) else default_image_size
elif is_nhwc:
height = input_shape[1] if isinstance(input_shape[1], int) else default_image_size
width = input_shape[2] if isinstance(input_shape[2], int) else default_image_size
else:
raise RuntimeError(f"Unsupported ONNX input shape: {input_shape}")
rgb = _ensure_rgb_uint8(image)
pil_img = Image.fromarray(rgb).convert("RGB")
resized = ImageOps.fit(pil_img, (width, height), method=Image.Resampling.BILINEAR)
arr = np.asarray(resized).astype(np.float32) / 255.0
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
arr = (arr - mean) / std
if is_nchw:
arr = np.transpose(arr, (2, 0, 1))[None, ...]
else:
arr = arr[None, ...]
outputs = session.run(None, {input_name: arr})
logits = np.asarray(outputs[0])
logits = logits[0] if logits.ndim > 1 else logits
if np.all(logits >= 0.0) and np.all(logits <= 1.0) and np.isclose(np.sum(logits), 1.0, atol=1e-3):
probs = logits
else:
logits = logits - np.max(logits)
exps = np.exp(logits)
probs = exps / np.sum(exps)
fake_score = float(probs[0])
real_score = float(probs[1])
return _predict_payload(fake_score, real_score)
def _flatten_hf_output(payload: Any) -> List[Dict[str, Any]]:
if isinstance(payload, list):
if payload and isinstance(payload[0], list):
payload = payload[0]
return [x for x in payload if isinstance(x, dict)]
if isinstance(payload, dict):
if "label" in payload and "score" in payload:
return [payload]
if "error" in payload:
raise RuntimeError(str(payload["error"]))
return []
def _scores_from_hf(entries: List[Dict[str, Any]]) -> Tuple[float, float]:
fake = 0.0
real = 0.0
for entry in entries:
label = str(entry.get("label", "")).lower().strip()
score = float(entry.get("score", 0.0))
if any(h in label for h in FAKE_HINTS):
fake = max(fake, score)
if any(h in label for h in REAL_HINTS):
real = max(real, score)
if fake == 0.0 and real == 0.0 and entries:
if len(entries) >= 2:
fake = float(entries[0].get("score", 0.0))
real = float(entries[1].get("score", 0.0))
else:
top = entries[0]
top_score = float(top.get("score", 0.0))
top_label = str(top.get("label", "")).lower()
if any(h in top_label for h in REAL_HINTS):
real = top_score
else:
fake = top_score
return normalize_scores(fake, real)
def predict_hf_api(image: np.ndarray) -> Dict[str, Any]:
token = os.getenv("HF_TOKEN", "").strip()
endpoint = os.getenv("HF_INFERENCE_ENDPOINT", "").strip()
model_id = os.getenv("DEEPFAKE_MODEL_ID", "").strip()
if not token:
raise RuntimeError("Missing HF_TOKEN secret.")
if not endpoint and not model_id:
raise RuntimeError("Set either HF_INFERENCE_ENDPOINT or DEEPFAKE_MODEL_ID.")
url = endpoint or f"https://api-inference.huggingface.co/models/{model_id}"
headers = {"Authorization": f"Bearer {token}"}
response = requests.post(url, headers=headers, data=_to_png_bytes(image), timeout=180)
if response.status_code >= 400:
raise RuntimeError(f"HF API error {response.status_code}: {response.text[:240]}")
payload = response.json()
entries = _flatten_hf_output(payload)
if not entries:
raise RuntimeError(f"Unexpected response: {payload}")
fake_score, real_score = _scores_from_hf(entries)
return _predict_payload(fake_score, real_score)