|
|
""" |
|
|
Facial embedding extraction using FaceNet (InceptionResnetV1). |
|
|
|
|
|
This module wraps the FaceNet model from facenet_pytorch to produce |
|
|
512‑dimensional embeddings for detected faces. It relies on MTCNN |
|
|
for cropping the largest face in the image. If no face is detected, it |
|
|
returns ``None``. |
|
|
""" |
|
|
|
|
|
from typing import Optional |
|
|
|
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
import torch |
|
|
|
|
|
try: |
|
|
from facenet_pytorch import MTCNN, InceptionResnetV1 |
|
|
except ImportError as exc: |
|
|
raise ImportError( |
|
|
"facenet_pytorch is required for embedding extraction. Install it with `pip install facenet-pytorch`." |
|
|
) from exc |
|
|
|
|
|
_mtcnn: Optional[MTCNN] = None |
|
|
_resnet: Optional[InceptionResnetV1] = None |
|
|
|
|
|
|
|
|
def _get_models(device: str = "cpu") -> tuple[MTCNN, InceptionResnetV1]: |
|
|
"""Initialise and cache MTCNN and InceptionResnet models. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
device: str, optional |
|
|
Device on which to run the models. Defaults to ``"cpu"``. |
|
|
|
|
|
Returns |
|
|
------- |
|
|
tuple[MTCNN, InceptionResnetV1] |
|
|
The face detector and feature extractor. |
|
|
""" |
|
|
global _mtcnn, _resnet |
|
|
if _mtcnn is None: |
|
|
_mtcnn = MTCNN(image_size=160, margin=0, select_largest=True, device=device) |
|
|
if _resnet is None: |
|
|
_resnet = InceptionResnetV1(pretrained="vggface2").eval().to(device) |
|
|
return _mtcnn, _resnet |
|
|
|
|
|
|
|
|
def extract_embedding(image: Image.Image, device: str = "cpu") -> Optional[np.ndarray]: |
|
|
"""Extract a 512‑dimensional face embedding from an image. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
image: PIL.Image.Image |
|
|
The input image containing a face. |
|
|
device: str, optional |
|
|
Device on which to run the models. Defaults to ``"cpu"``. |
|
|
|
|
|
Returns |
|
|
------- |
|
|
np.ndarray or None |
|
|
A numpy array of shape (512,) containing the embedding. If no face |
|
|
is detected, returns ``None``. |
|
|
""" |
|
|
mtcnn, resnet = _get_models(device) |
|
|
|
|
|
face, prob = mtcnn(image, return_prob=True) |
|
|
if face is None: |
|
|
return None |
|
|
|
|
|
face = face.to(device).unsqueeze(0) |
|
|
|
|
|
with torch.no_grad(): |
|
|
emb = resnet(face) |
|
|
|
|
|
return emb.squeeze(0).cpu().numpy() |
|
|
|
|
|
|
|
|
__all__ = ["extract_embedding"] |