Image-Classification-Benchmark / src /classification_model.py
MEYTI BECI BAGUNDA
Update 4 files
9087ee6
from typing import List
from urllib.request import urlopen
from PIL import Image
from .data.model_data import ModelData
from .models.mobilenet_v3 import MobilenetV3
from .models.clip_vit import ClipVit
from .models.google_vit import GoogleVit
from .models.resnet_50 import Resnet50
from .data.classification_result import ClassificationResult
class ClassificationModel:
"""
Base class for all classification models.
"""
def __init__(self):
self.load_model()
def get_model_names(self):
return [model.name for model in self.models]
def get_model_data(self, model_name):
for model in self.models:
if model.name == model_name:
return model
raise Exception(f'Model {model_name} not found')
def load_model(self):
self.models = [
ModelData('clip-vit-base-patch32', model_class=ClipVit()),
ModelData('mobilenet_v3', model_class=MobilenetV3()),
ModelData('google-vit-base-patch16-224', model_class=GoogleVit()),
ModelData('microsoft/resnet-50', model_class=Resnet50())
]
def classify(self, model_name, image) -> List[ClassificationResult]:
#print type of image
print('>> image type -->',type(image))
#convert image to pil
img = self.image_to_pil(image)
model = self.get_model_data(model_name)
return model.model_class.classify_image(img)
def image_to_pil(self, image):
#if image is starts with https (means url), then download it
if image.startswith('https'):
return Image.open(urlopen(image))
return Image.open(image)