File size: 767 Bytes
c49a9ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from .model_data import ModelData

class ClassificationModel:
    """
    Base class for all classification models.
    """

    def __init__(self):
        self.models = self.initialize_models()

    def get_model_names(self):
        return [model.name for model in self.models]

    def get_model_data(self, model_name):
        for model in self.models:
            if model.name == model_name:
                return model
        raise Exception(f'Model {model_name} not found')

    def initialize_models(self):
        return [
            ModelData('clip-vit-base-patch32'), 
            ModelData('mobilenet_v3')
            ]

    def load_model(self):
        """
        Loads the model from the model path.
        """
        raise NotImplementedError