from .model_data import ModelData class ClassificationModel: """ Base class for all classification models. """ def __init__(self): self.models = self.initialize_models() def get_model_names(self): return [model.name for model in self.models] def get_model_data(self, model_name): for model in self.models: if model.name == model_name: return model raise Exception(f'Model {model_name} not found') def initialize_models(self): return [ ModelData('clip-vit-base-patch32'), ModelData('mobilenet_v3') ] def load_model(self): """ Loads the model from the model path. """ raise NotImplementedError