| | """ONNX inference for face anti-spoofing.""" |
| |
|
| | import numpy as np |
| | import onnxruntime as ort |
| | import sys |
| | from typing import List, Dict |
| | from src.inference.preprocess import preprocess_batch |
| |
|
| |
|
| | def process_with_logits(raw_logits: np.ndarray, threshold: float) -> Dict: |
| | """Convert raw logits to real/spoof classification.""" |
| | real_logit = float(raw_logits[0]) |
| | spoof_logit = float(raw_logits[1]) |
| | logit_diff = real_logit - spoof_logit |
| |
|
| | |
| | p_real = 1.0 / (1.0 + np.exp(-logit_diff)) |
| |
|
| | is_real = logit_diff >= threshold |
| | confidence = abs(logit_diff) |
| |
|
| | return { |
| | "is_real": bool(is_real), |
| | "status": "real" if is_real else "spoof", |
| | "logit_diff": round(float(logit_diff), 2), |
| | "real_logit": round(float(real_logit), 2), |
| | "spoof_logit": round(float(spoof_logit), 2), |
| | "confidence": round(float(confidence), 2), |
| | "realness_score": round(float(p_real), 2), |
| | } |
| |
|
| |
|
| | def infer( |
| | face_crops: List[np.ndarray], |
| | ort_session: ort.InferenceSession, |
| | input_name: str, |
| | model_img_size: int, |
| | ) -> List[np.ndarray]: |
| | """Run batch inference on cropped face images. Return list of logits per face.""" |
| | if not face_crops or ort_session is None: |
| | return [] |
| |
|
| | try: |
| | batch_input = preprocess_batch(face_crops, model_img_size) |
| | logits = ort_session.run([], {input_name: batch_input})[0] |
| |
|
| | if logits.shape != (len(face_crops), 2): |
| | raise ValueError("Model output shape mismatch") |
| |
|
| | return [logits[i] for i in range(len(face_crops))] |
| | except Exception as e: |
| | print(f"Inference error: {e}", file=sys.stderr) |
| | return [] |
| |
|