Facial-Recognition-Verification / src /extract_embeddings.py
martinbadrous's picture
Upload 11 files
8a4d3a7 verified
"""
Facial embedding extraction using FaceNet (InceptionResnetV1).
This module wraps the FaceNet model from facenet_pytorch to produce
512‑dimensional embeddings for detected faces. It relies on MTCNN
for cropping the largest face in the image. If no face is detected, it
returns ``None``.
"""
from typing import Optional
import numpy as np
from PIL import Image
import torch
try:
from facenet_pytorch import MTCNN, InceptionResnetV1
except ImportError as exc:
raise ImportError(
"facenet_pytorch is required for embedding extraction. Install it with `pip install facenet-pytorch`."
) from exc
_mtcnn: Optional[MTCNN] = None
_resnet: Optional[InceptionResnetV1] = None
def _get_models(device: str = "cpu") -> tuple[MTCNN, InceptionResnetV1]:
"""Initialise and cache MTCNN and InceptionResnet models.
Parameters
----------
device: str, optional
Device on which to run the models. Defaults to ``"cpu"``.
Returns
-------
tuple[MTCNN, InceptionResnetV1]
The face detector and feature extractor.
"""
global _mtcnn, _resnet
if _mtcnn is None:
_mtcnn = MTCNN(image_size=160, margin=0, select_largest=True, device=device)
if _resnet is None:
_resnet = InceptionResnetV1(pretrained="vggface2").eval().to(device)
return _mtcnn, _resnet
def extract_embedding(image: Image.Image, device: str = "cpu") -> Optional[np.ndarray]:
"""Extract a 512‑dimensional face embedding from an image.
Parameters
----------
image: PIL.Image.Image
The input image containing a face.
device: str, optional
Device on which to run the models. Defaults to ``"cpu"``.
Returns
-------
np.ndarray or None
A numpy array of shape (512,) containing the embedding. If no face
is detected, returns ``None``.
"""
mtcnn, resnet = _get_models(device)
# Detect face and crop to 160x160. MTCNN returns a tensor of shape (3, 160, 160).
face, prob = mtcnn(image, return_prob=True)
if face is None:
return None
# Add batch dimension and send to device.
face = face.to(device).unsqueeze(0)
# Extract embedding.
with torch.no_grad():
emb = resnet(face)
# Return as 1D numpy array on CPU.
return emb.squeeze(0).cpu().numpy()
__all__ = ["extract_embedding"]