File size: 4,461 Bytes
b5d3a91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
from __future__ import annotations

# built-in dependencies
from typing import TYPE_CHECKING, Any, Final, TypedDict, Dict

# project dependencies
from deepface.models.facial_recognition import (
    VGGFace,
    OpenFace,
    FbDeepFace,
    DeepID,
    ArcFace,
    SFace,
    Dlib,
    Facenet,
    GhostFaceNet,
    Buffalo_L,
)
from deepface.models.face_detection import (
    FastMtCnn,
    MediaPipe,
    MtCnn,
    OpenCv,
    Dlib as DlibDetector,
    RetinaFace,
    Ssd,
    Yolo as YoloFaceDetector,
    YuNet,
    CenterFace,
)
from deepface.models.demography import Age, Gender, Race, Emotion
from deepface.models.spoofing import FasNet
from deepface.modules.exceptions import UnimplementedError

if TYPE_CHECKING:
    from deepface.models.Demography import Demography
    from deepface.models.Detector import Detector
    from deepface.models.FacialRecognition import FacialRecognition

    cached_models: Dict[str, Dict[str, Any]] = {}


class AvailableModels(TypedDict):
    facial_recognition: dict[str, type[FacialRecognition]]
    spoofing: dict[str, type[FasNet.Fasnet]]
    facial_attribute: dict[str, type[Demography]]
    face_detector: dict[str, type[Detector]]


AVAILABLE_MODELS: Final[AvailableModels] = {
    "facial_recognition": {
        "VGG-Face": VGGFace.VggFaceClient,
        "OpenFace": OpenFace.OpenFaceClient,
        "Facenet": Facenet.FaceNet128dClient,
        "Facenet512": Facenet.FaceNet512dClient,
        "DeepFace": FbDeepFace.DeepFaceClient,
        "DeepID": DeepID.DeepIdClient,
        "Dlib": Dlib.DlibClient,
        "ArcFace": ArcFace.ArcFaceClient,
        "SFace": SFace.SFaceClient,
        "GhostFaceNet": GhostFaceNet.GhostFaceNetClient,
        "Buffalo_L": Buffalo_L.Buffalo_L,
    },
    "spoofing": {
        "Fasnet": FasNet.Fasnet,
    },
    "facial_attribute": {
        "Emotion": Emotion.EmotionClient,
        "Age": Age.ApparentAgeClient,
        "Gender": Gender.GenderClient,
        "Race": Race.RaceClient,
    },
    "face_detector": {
        "opencv": OpenCv.OpenCvClient,
        "mtcnn": MtCnn.MtCnnClient,
        "ssd": Ssd.SsdClient,
        "dlib": DlibDetector.DlibClient,
        "retinaface": RetinaFace.RetinaFaceClient,
        "mediapipe": MediaPipe.MediaPipeClient,
        "yolov8n": YoloFaceDetector.YoloDetectorClientV8n,
        "yolov8m": YoloFaceDetector.YoloDetectorClientV8m,
        "yolov8l": YoloFaceDetector.YoloDetectorClientV8l,
        "yolov11n": YoloFaceDetector.YoloDetectorClientV11n,
        "yolov11s": YoloFaceDetector.YoloDetectorClientV11s,
        "yolov11m": YoloFaceDetector.YoloDetectorClientV11m,
        "yolov11l": YoloFaceDetector.YoloDetectorClientV11l,
        "yolov12n": YoloFaceDetector.YoloDetectorClientV12n,
        "yolov12s": YoloFaceDetector.YoloDetectorClientV12s,
        "yolov12m": YoloFaceDetector.YoloDetectorClientV12m,
        "yolov12l": YoloFaceDetector.YoloDetectorClientV12l,
        "yunet": YuNet.YuNetClient,
        "fastmtcnn": FastMtCnn.FastMtCnnClient,
        "centerface": CenterFace.CenterFaceClient,
    },
}


def build_model(task: str, model_name: str) -> Any:
    """
    This function loads a pre-trained models as singletonish way
    Parameters:
        task (str): facial_recognition, facial_attribute, face_detector, spoofing
        model_name (str): model identifier
            - VGG-Face, Facenet, Facenet512, OpenFace, DeepFace, DeepID, Dlib,
                ArcFace, SFace and GhostFaceNet for face recognition
            - Age, Gender, Emotion, Race for facial attributes
            - opencv, mtcnn, ssd, dlib, retinaface, mediapipe, yolov8, 'yolov11n',
                'yolov11s', 'yolov11m', yunet, fastmtcnn or centerface for face detectors
            - Fasnet for spoofing
    Returns:
            built model class
    """

    # singleton design pattern
    global cached_models

    if task not in AVAILABLE_MODELS.keys():
        raise UnimplementedError(f"unimplemented task - {task}")

    if "cached_models" not in globals():
        cached_models = {current_task: {} for current_task in AVAILABLE_MODELS.keys()}

    if cached_models[task].get(model_name) is None:
        model = AVAILABLE_MODELS[task].get(model_name)  # type: ignore[literal-required]
        if model:
            cached_models[task][model_name] = model()
        else:
            raise UnimplementedError(f"Invalid model_name passed - {task}/{model_name}")

    return cached_models[task][model_name]