Image-Classification-Benchmark / src /classification_model.py
AnnasBlackHat's picture
basic ui
c49a9ad
raw
history blame
767 Bytes
from .model_data import ModelData
class ClassificationModel:
"""
Base class for all classification models.
"""
def __init__(self):
self.models = self.initialize_models()
def get_model_names(self):
return [model.name for model in self.models]
def get_model_data(self, model_name):
for model in self.models:
if model.name == model_name:
return model
raise Exception(f'Model {model_name} not found')
def initialize_models(self):
return [
ModelData('clip-vit-base-patch32'),
ModelData('mobilenet_v3')
]
def load_model(self):
"""
Loads the model from the model path.
"""
raise NotImplementedError