| from .model_data import ModelData | |
| class ClassificationModel: | |
| """ | |
| Base class for all classification models. | |
| """ | |
| def __init__(self): | |
| self.models = self.initialize_models() | |
| def get_model_names(self): | |
| return [model.name for model in self.models] | |
| def get_model_data(self, model_name): | |
| for model in self.models: | |
| if model.name == model_name: | |
| return model | |
| raise Exception(f'Model {model_name} not found') | |
| def initialize_models(self): | |
| return [ | |
| ModelData('clip-vit-base-patch32'), | |
| ModelData('mobilenet_v3') | |
| ] | |
| def load_model(self): | |
| """ | |
| Loads the model from the model path. | |
| """ | |
| raise NotImplementedError | |