File size: 767 Bytes
c49a9ad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
from .model_data import ModelData
class ClassificationModel:
"""
Base class for all classification models.
"""
def __init__(self):
self.models = self.initialize_models()
def get_model_names(self):
return [model.name for model in self.models]
def get_model_data(self, model_name):
for model in self.models:
if model.name == model_name:
return model
raise Exception(f'Model {model_name} not found')
def initialize_models(self):
return [
ModelData('clip-vit-base-patch32'),
ModelData('mobilenet_v3')
]
def load_model(self):
"""
Loads the model from the model path.
"""
raise NotImplementedError
|