sol9x-sagar's picture
initial setup
2979822
"""ONNX inference for face anti-spoofing."""
import numpy as np
import onnxruntime as ort
import sys
from typing import List, Dict
from src.inference.preprocess import preprocess_batch
def process_with_logits(raw_logits: np.ndarray, threshold: float) -> Dict:
"""Convert raw logits to real/spoof classification."""
real_logit = float(raw_logits[0])
spoof_logit = float(raw_logits[1])
logit_diff = real_logit - spoof_logit
# Probability-like score in 0..1 (sigmoid of logit margin)
p_real = 1.0 / (1.0 + np.exp(-logit_diff))
is_real = logit_diff >= threshold
confidence = abs(logit_diff)
return {
"is_real": bool(is_real),
"status": "real" if is_real else "spoof",
"logit_diff": round(float(logit_diff), 2),
"real_logit": round(float(real_logit), 2),
"spoof_logit": round(float(spoof_logit), 2),
"confidence": round(float(confidence), 2),
"realness_score": round(float(p_real), 2),
}
def infer(
face_crops: List[np.ndarray],
ort_session: ort.InferenceSession,
input_name: str,
model_img_size: int,
) -> List[np.ndarray]:
"""Run batch inference on cropped face images. Return list of logits per face."""
if not face_crops or ort_session is None:
return []
try:
batch_input = preprocess_batch(face_crops, model_img_size)
logits = ort_session.run([], {input_name: batch_input})[0]
if logits.shape != (len(face_crops), 2):
raise ValueError("Model output shape mismatch")
return [logits[i] for i in range(len(face_crops))]
except Exception as e:
print(f"Inference error: {e}", file=sys.stderr)
return []