""" Facial embedding extraction using FaceNet (InceptionResnetV1). This module wraps the FaceNet model from facenet_pytorch to produce 512‑dimensional embeddings for detected faces. It relies on MTCNN for cropping the largest face in the image. If no face is detected, it returns ``None``. """ from typing import Optional import numpy as np from PIL import Image import torch try: from facenet_pytorch import MTCNN, InceptionResnetV1 except ImportError as exc: raise ImportError( "facenet_pytorch is required for embedding extraction. Install it with `pip install facenet-pytorch`." ) from exc _mtcnn: Optional[MTCNN] = None _resnet: Optional[InceptionResnetV1] = None def _get_models(device: str = "cpu") -> tuple[MTCNN, InceptionResnetV1]: """Initialise and cache MTCNN and InceptionResnet models. Parameters ---------- device: str, optional Device on which to run the models. Defaults to ``"cpu"``. Returns ------- tuple[MTCNN, InceptionResnetV1] The face detector and feature extractor. """ global _mtcnn, _resnet if _mtcnn is None: _mtcnn = MTCNN(image_size=160, margin=0, select_largest=True, device=device) if _resnet is None: _resnet = InceptionResnetV1(pretrained="vggface2").eval().to(device) return _mtcnn, _resnet def extract_embedding(image: Image.Image, device: str = "cpu") -> Optional[np.ndarray]: """Extract a 512‑dimensional face embedding from an image. Parameters ---------- image: PIL.Image.Image The input image containing a face. device: str, optional Device on which to run the models. Defaults to ``"cpu"``. Returns ------- np.ndarray or None A numpy array of shape (512,) containing the embedding. If no face is detected, returns ``None``. """ mtcnn, resnet = _get_models(device) # Detect face and crop to 160x160. MTCNN returns a tensor of shape (3, 160, 160). face, prob = mtcnn(image, return_prob=True) if face is None: return None # Add batch dimension and send to device. face = face.to(device).unsqueeze(0) # Extract embedding. with torch.no_grad(): emb = resnet(face) # Return as 1D numpy array on CPU. return emb.squeeze(0).cpu().numpy() __all__ = ["extract_embedding"]