File size: 2,014 Bytes
8a4d3a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
"""
Face detection utility using MTCNN from facenet_pytorch.

This module exposes a simple function to detect faces in a PIL Image.  It
returns bounding boxes for all detected faces.  The detection model is
constructed lazily on the first call to avoid unnecessary GPU/CPU
initialisation when the module is imported.
"""

from typing import List, Tuple, Optional

import numpy as np
from PIL import Image

try:
    from facenet_pytorch import MTCNN
except ImportError as exc:
    raise ImportError(
        "facenet_pytorch is required for face detection. Install it with `pip install facenet-pytorch`."
    ) from exc

_mtcnn: Optional[MTCNN] = None


def _get_mtcnn(device: str = "cpu") -> MTCNN:
    """Return a singleton MTCNN detector instance.

    Parameters
    ----------
    device: str, optional
        PyTorch device on which to run the detector.  Defaults to ``"cpu"``.

    Returns
    -------
    MTCNN
        The configured multi-task cascaded CNN detector.
    """
    global _mtcnn
    if _mtcnn is None:
        _mtcnn = MTCNN(image_size=160, margin=0, keep_all=True, device=device)
    return _mtcnn


def detect_faces(image: Image.Image, device: str = "cpu") -> List[Tuple[float, float, float, float]]:
    """Detect faces in a PIL image.

    Parameters
    ----------
    image: PIL.Image.Image
        The input image in which to detect faces.
    device: str, optional
        Device on which to run the detector (``"cpu"`` or ``"cuda"``).  Defaults to ``"cpu"``.

    Returns
    -------
    List[Tuple[float, float, float, float]]
        A list of bounding boxes (x1, y1, x2, y2) for each detected face.  If
        no faces are found, returns an empty list.
    """
    mtcnn = _get_mtcnn(device)
    # MTCNN returns (boxes, probs).  We only need boxes.
    boxes, _ = mtcnn.detect(image)
    if boxes is None:
        return []
    # Convert numpy array of shape (n, 4) into list of tuples.
    return [tuple(map(float, box)) for box in np.array(boxes)]


__all__ = ["detect_faces"]