File size: 1,677 Bytes
cebad5c 9087ee6 cebad5c c49a9ad cebad5c c49a9ad cebad5c 9087ee6 c49a9ad cebad5c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
from typing import List
from urllib.request import urlopen
from PIL import Image
from .data.model_data import ModelData
from .models.mobilenet_v3 import MobilenetV3
from .models.clip_vit import ClipVit
from .models.google_vit import GoogleVit
from .models.resnet_50 import Resnet50
from .data.classification_result import ClassificationResult
class ClassificationModel:
"""
Base class for all classification models.
"""
def __init__(self):
self.load_model()
def get_model_names(self):
return [model.name for model in self.models]
def get_model_data(self, model_name):
for model in self.models:
if model.name == model_name:
return model
raise Exception(f'Model {model_name} not found')
def load_model(self):
self.models = [
ModelData('clip-vit-base-patch32', model_class=ClipVit()),
ModelData('mobilenet_v3', model_class=MobilenetV3()),
ModelData('google-vit-base-patch16-224', model_class=GoogleVit()),
ModelData('microsoft/resnet-50', model_class=Resnet50())
]
def classify(self, model_name, image) -> List[ClassificationResult]:
#print type of image
print('>> image type -->',type(image))
#convert image to pil
img = self.image_to_pil(image)
model = self.get_model_data(model_name)
return model.model_class.classify_image(img)
def image_to_pil(self, image):
#if image is starts with https (means url), then download it
if image.startswith('https'):
return Image.open(urlopen(image))
return Image.open(image)
|