File size: 1,677 Bytes
cebad5c
 
 
 
 
 
9087ee6
 
 
cebad5c
c49a9ad
 
 
 
 
 
 
cebad5c
c49a9ad
 
 
 
 
 
 
 
 
 
cebad5c
 
 
9087ee6
 
 
c49a9ad
 
cebad5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from typing import List
from urllib.request import urlopen
from PIL import Image
from .data.model_data import ModelData
from .models.mobilenet_v3 import MobilenetV3
from .models.clip_vit import ClipVit
from .models.google_vit import GoogleVit
from .models.resnet_50 import Resnet50

from .data.classification_result import ClassificationResult

class ClassificationModel:
    """
    Base class for all classification models.
    """

    def __init__(self):
        self.load_model()

    def get_model_names(self):
        return [model.name for model in self.models]

    def get_model_data(self, model_name):
        for model in self.models:
            if model.name == model_name:
                return model
        raise Exception(f'Model {model_name} not found')

    def load_model(self):
        self.models = [
            ModelData('clip-vit-base-patch32', model_class=ClipVit()), 
            ModelData('mobilenet_v3', model_class=MobilenetV3()),
            ModelData('google-vit-base-patch16-224', model_class=GoogleVit()),
            ModelData('microsoft/resnet-50', model_class=Resnet50())
            ]

    def classify(self, model_name, image) -> List[ClassificationResult]:
        #print type of image 
        print('>> image type -->',type(image))

        #convert image to pil
        img = self.image_to_pil(image)

        model = self.get_model_data(model_name)
        return model.model_class.classify_image(img)      

    def image_to_pil(self, image):
        #if image is starts with https (means url), then download it 
        if image.startswith('https'):
            return Image.open(urlopen(image))
        return Image.open(image)