|
|
from typing import List |
|
|
from urllib.request import urlopen |
|
|
from PIL import Image |
|
|
from .data.model_data import ModelData |
|
|
from .models.mobilenet_v3 import MobilenetV3 |
|
|
from .models.clip_vit import ClipVit |
|
|
from .data.classification_result import ClassificationResult |
|
|
|
|
|
class ClassificationModel: |
|
|
""" |
|
|
Base class for all classification models. |
|
|
""" |
|
|
|
|
|
def __init__(self): |
|
|
self.load_model() |
|
|
|
|
|
def get_model_names(self): |
|
|
return [model.name for model in self.models] |
|
|
|
|
|
def get_model_data(self, model_name): |
|
|
for model in self.models: |
|
|
if model.name == model_name: |
|
|
return model |
|
|
raise Exception(f'Model {model_name} not found') |
|
|
|
|
|
def load_model(self): |
|
|
self.models = [ |
|
|
ModelData('clip-vit-base-patch32', model_class=ClipVit()), |
|
|
ModelData('mobilenet_v3', model_class=MobilenetV3()) |
|
|
] |
|
|
|
|
|
def classify(self, model_name, image) -> List[ClassificationResult]: |
|
|
|
|
|
print('>> image type -->',type(image)) |
|
|
|
|
|
|
|
|
img = self.image_to_pil(image) |
|
|
|
|
|
model = self.get_model_data(model_name) |
|
|
return model.model_class.classify_image(img) |
|
|
|
|
|
def image_to_pil(self, image): |
|
|
|
|
|
if image.startswith('https'): |
|
|
return Image.open(urlopen(image)) |
|
|
return Image.open(image) |
|
|
|