mohamedkh001
Deploy AEFRS complete system with models and services
ea93121
"""Embedding service with ArcFace ONNX runtime and deterministic fallback."""
from __future__ import annotations
import hashlib
import logging
import os
from pathlib import Path
import numpy as np
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from services.common.logging_config import setup_logging
from services.common.runtime import decode_image_b64, l2_normalize, maybe_load_onnx
setup_logging("embedding")
logger = logging.getLogger(__name__)
app = FastAPI(title="AEFRS Embedding Service", version="1.1.0")
MODEL_PATH = Path(os.getenv("EMBEDDING_MODEL_PATH", "artifacts/models/arcface_iresnet100.onnx"))
SESSION = maybe_load_onnx(MODEL_PATH)
class EmbedRequest(BaseModel):
"""Face crop payload for embedding extraction."""
aligned_face_b64: str
def _fallback_embedding(raw_bytes: bytes) -> np.ndarray:
"""Stable 512-D embedding fallback based on SHA-512 expansion."""
digest = hashlib.sha512(raw_bytes).digest()
vec = np.frombuffer(digest * 8, dtype=np.uint8)[:512].astype(np.float32) / 255.0
return l2_normalize(vec)
def _onnx_embedding(img: np.ndarray) -> np.ndarray:
"""Run ArcFace ONNX embedding extraction."""
assert SESSION is not None
input_name = SESSION.get_inputs()[0].name
x = np.transpose(img, (2, 0, 1))[None, :, :, :].astype(np.float32)
out = SESSION.run(None, {input_name: x})
emb = np.array(out[0]).reshape(-1).astype(np.float32)
if emb.size < 512:
emb = np.pad(emb, (0, 512 - emb.size), mode="constant")
emb = emb[:512]
return l2_normalize(emb)
@app.get("/healthz")
def healthz() -> dict:
"""Health endpoint with runtime mode."""
return {"status": "ok", "runtime": "onnx" if SESSION else "fallback"}
@app.post("/embed")
def embed(req: EmbedRequest) -> dict:
"""Generate embedding vector from aligned face image."""
try:
raw = req.aligned_face_b64.encode("utf-8")
img = decode_image_b64(req.aligned_face_b64, size=112)
emb = _onnx_embedding(img) if SESSION else _fallback_embedding(raw)
return {"embedding": emb.tolist(), "dim": 512}
except Exception as exc:
logger.exception("Embedding failed")
raise HTTPException(status_code=400, detail=f"embedding failed: {exc}") from exc