| | import torch |
| | from transformers import CLIPModel, CLIPProcessor |
| |
|
| |
|
| | def transform_genre_to_label(genre: int) -> str: |
| | label = "Unknown Genre" |
| | if genre == 0: |
| | label = "abstract_painting" |
| | elif genre == 1: |
| | label = "cityscape" |
| | elif genre == 2: |
| | label = "enre_painting" |
| | elif genre == 3: |
| | label = "illustration" |
| | elif genre == 4: |
| | label = "landscape" |
| | elif genre == 5: |
| | label = "nude_painting" |
| | elif genre == 6: |
| | label = "portrait" |
| | elif genre == 7: |
| | label = "religious_painting" |
| | elif genre == 8: |
| | label = "sketch_and_study" |
| | elif genre == 9: |
| | label = "still_life" |
| |
|
| | return label |
| |
|
| |
|
| | genres = set([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) |
| |
|
| | label2id = {transform_genre_to_label(genre): i for i, genre in enumerate(genres)} |
| | id2label = {i: label for label, i in label2id.items()} |
| | labels = list(label2id) |
| | label_prompt = [f"the genre of the painting is {transform_genre_to_label(genre)}" for genre in range(11)] |
| |
|
| | MODEL_NAME = "flaviupop/CLIP-Finetuned-Painting-Genre-Recognition" |
| |
|
| |
|
| | class ImageAnalyzer: |
| | def __init__(self): |
| | self.model = CLIPModel.from_pretrained(MODEL_NAME) |
| | self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") |
| |
|
| | def predict_genre(self, input_image) -> str: |
| | inputs = self.processor(text=label_prompt, images=input_image, return_tensors="pt", padding=True) |
| |
|
| | outputs = self.model(**inputs) |
| | logits_per_image = outputs.logits_per_image |
| | probs = logits_per_image.softmax(dim=1) |
| |
|
| | result = torch.argmax(probs) |
| |
|
| | return transform_genre_to_label(result) |
| |
|
| |
|
| | image_analyzer = ImageAnalyzer() |
| |
|