File size: 3,677 Bytes
b710267
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from pathlib import Path
import glob
import os

from .face_detection_yunet.yunet import YuNet
from .text_recognition_crnn.crnn import CRNN
from .face_recognition_sface.sface import SFace
from .image_classification_ppresnet.ppresnet import PPResNet
from .human_segmentation_pphumanseg.pphumanseg import PPHumanSeg
from .person_detection_mediapipe.mp_persondet import MPPersonDet
from .pose_estimation_mediapipe.mp_pose import MPPose
from .qrcode_wechatqrcode.wechatqrcode import WeChatQRCode
from .person_reid_youtureid.youtureid import YoutuReID
from .image_classification_mobilenet.mobilenet import MobileNet
from .palm_detection_mediapipe.mp_palmdet import MPPalmDet
from .handpose_estimation_mediapipe.mp_handpose import MPHandPose
from .license_plate_detection_yunet.lpd_yunet import LPD_YuNet
from .object_detection_nanodet.nanodet import NanoDet
from .object_detection_yolox.yolox import YoloX
from .facial_expression_recognition.facial_fer_model import FacialExpressionRecog
from .object_tracking_vittrack.vittrack import VitTrack
from .text_detection_ppocr.ppocr_det import PPOCRDet
from .image_segmentation_efficientsam.efficientSAM import EfficientSAM

class ModuleRegistery:
    def __init__(self, name):
        self._name = name
        self._dict = dict()

        self._base_path = Path(__file__).parent

    def get(self, key):
        '''
        Returns a tuple with:
        - a module handler,
        - a list of model file paths
        '''
        return self._dict[key]

    def register(self, item):
        '''
        Registers given module handler along with paths of model files
        '''
        # search for model files
        model_dir = str(self._base_path / item.__module__.split(".")[1])
        fp32_model_paths = []
        fp16_model_paths = []
        int8_model_paths = []
        int8bq_model_paths = []
        # onnx
        ret_onnx = sorted(glob.glob(os.path.join(model_dir, "*.onnx")))
        if "object_tracking" in item.__module__:
            # object tracking models usually have multiple parts
            fp32_model_paths = [ret_onnx]
        else:
            for r in ret_onnx:
                if "int8" in r:
                    int8_model_paths.append([r])
                elif "fp16" in r: # exclude fp16 for now
                    fp16_model_paths.append([r])
                elif "blocked" in r:
                    int8bq_model_paths.append([r])
                else:
                    fp32_model_paths.append([r])
        # caffe
        ret_caffemodel = sorted(glob.glob(os.path.join(model_dir, "*.caffemodel")))
        ret_prototxt = sorted(glob.glob(os.path.join(model_dir, "*.prototxt")))
        caffe_models = []
        for caffemodel, prototxt in zip(ret_caffemodel, ret_prototxt):
            caffe_models += [prototxt, caffemodel]
        if caffe_models:
            fp32_model_paths.append(caffe_models)

        all_model_paths = dict(
            fp32=fp32_model_paths,
            fp16=fp16_model_paths,
            int8=int8_model_paths,
            int8bq=int8bq_model_paths
        )

        self._dict[item.__name__] = (item, all_model_paths)

MODELS = ModuleRegistery('Models')
MODELS.register(YuNet)
MODELS.register(CRNN)
MODELS.register(SFace)
MODELS.register(PPResNet)
MODELS.register(PPHumanSeg)
MODELS.register(MPPersonDet)
MODELS.register(MPPose)
MODELS.register(WeChatQRCode)
MODELS.register(YoutuReID)
MODELS.register(MobileNet)
MODELS.register(MPPalmDet)
MODELS.register(MPHandPose)
MODELS.register(LPD_YuNet)
MODELS.register(NanoDet)
MODELS.register(YoloX)
MODELS.register(FacialExpressionRecog)
MODELS.register(VitTrack)
MODELS.register(PPOCRDet)
MODELS.register(EfficientSAM)