"""ONNX inference for face anti-spoofing.""" import numpy as np import onnxruntime as ort import sys from typing import List, Dict from src.inference.preprocess import preprocess_batch def process_with_logits(raw_logits: np.ndarray, threshold: float) -> Dict: """Convert raw logits to real/spoof classification.""" real_logit = float(raw_logits[0]) spoof_logit = float(raw_logits[1]) logit_diff = real_logit - spoof_logit # Probability-like score in 0..1 (sigmoid of logit margin) p_real = 1.0 / (1.0 + np.exp(-logit_diff)) is_real = logit_diff >= threshold confidence = abs(logit_diff) return { "is_real": bool(is_real), "status": "real" if is_real else "spoof", "logit_diff": round(float(logit_diff), 2), "real_logit": round(float(real_logit), 2), "spoof_logit": round(float(spoof_logit), 2), "confidence": round(float(confidence), 2), "realness_score": round(float(p_real), 2), } def infer( face_crops: List[np.ndarray], ort_session: ort.InferenceSession, input_name: str, model_img_size: int, ) -> List[np.ndarray]: """Run batch inference on cropped face images. Return list of logits per face.""" if not face_crops or ort_session is None: return [] try: batch_input = preprocess_batch(face_crops, model_img_size) logits = ort_session.run([], {input_name: batch_input})[0] if logits.shape != (len(face_crops), 2): raise ValueError("Model output shape mismatch") return [logits[i] for i in range(len(face_crops))] except Exception as e: print(f"Inference error: {e}", file=sys.stderr) return []