| from typing import List | |
| from transformers import pipeline | |
| from internals.util.commons import download_image | |
| class ImageClassifier: | |
| __loaded = False | |
| def __init__(self, candidates: List[str] = ["realistic", "anime", "comic"]): | |
| self.__candidates = candidates | |
| def load(self): | |
| if self.__loaded: | |
| return | |
| self.pipe = pipeline( | |
| "zero-shot-image-classification", | |
| model="philschmid/clip-zero-shot-image-classification", | |
| ) | |
| self.__loaded = True | |
| def classify(self, image_url: str, width: int, height: int) -> str: | |
| self.load() | |
| image = download_image(image_url).resize((width, height)) | |
| results = self.pipe.__call__([image], candidate_labels=self.__candidates) | |
| results = results[0] | |
| if len(results) > 0: | |
| return results[0]["label"] | |
| return "" | |