flaviupop commited on
Commit
577e64d
·
1 Parent(s): 4bb11f7

Added genre prediction

Browse files
Files changed (3) hide show
  1. app.py +3 -1
  2. predict_genre.py +58 -0
  3. requirements.txt +4 -1
app.py CHANGED
@@ -1,8 +1,10 @@
1
  import gradio as gr
2
 
 
 
3
 
4
  def picture_analysis(input_image) -> str:
5
- return "This is a random genre"
6
 
7
 
8
  demo = gr.Interface(
 
1
  import gradio as gr
2
 
3
+ from predict_genre import image_analyzer
4
+
5
 
6
  def picture_analysis(input_image) -> str:
7
+ return image_analyzer.predict_genre(input_image)
8
 
9
 
10
  demo = gr.Interface(
predict_genre.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import CLIPModel, CLIPProcessor
3
+
4
+
5
+ def transform_genre_to_label(genre: int) -> str:
6
+ label = "Unknown Genre"
7
+ if genre == 0:
8
+ label = "abstract_painting"
9
+ elif genre == 1:
10
+ label = "cityscape"
11
+ elif genre == 2:
12
+ label = "enre_painting"
13
+ elif genre == 3:
14
+ label = "illustration"
15
+ elif genre == 4:
16
+ label = "landscape"
17
+ elif genre == 5:
18
+ label = "nude_painting"
19
+ elif genre == 6:
20
+ label = "portrait"
21
+ elif genre == 7:
22
+ label = "religious_painting"
23
+ elif genre == 8:
24
+ label = "sketch_and_study"
25
+ elif genre == 9:
26
+ label = "still_life"
27
+
28
+ return label
29
+
30
+
31
+ genres = set([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
32
+
33
+ label2id = {transform_genre_to_label(genre): i for i, genre in enumerate(genres)}
34
+ id2label = {i: label for label, i in label2id.items()}
35
+ labels = list(label2id)
36
+ label_prompt = [f"the genre of the painting is {transform_genre_to_label(genre)}" for genre in range(11)]
37
+
38
+ MODEL_NAME = "flaviupop/CLIP-Finetuned-Painting-Genre-Recognition"
39
+
40
+
41
+ class ImageAnalyzer:
42
+ def __init__(self):
43
+ self.model = CLIPModel.from_pretrained(MODEL_NAME)
44
+ self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
45
+
46
+ def predict_genre(self, input_image) -> str:
47
+ inputs = self.processor(text=label_prompt, images=input_image, return_tensors="pt", padding=True)
48
+
49
+ outputs = self.model(**inputs)
50
+ logits_per_image = outputs.logits_per_image # this is the image-text similarity score
51
+ probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
52
+
53
+ result = torch.argmax(probs)
54
+
55
+ return transform_genre_to_label(result)
56
+
57
+
58
+ image_analyzer = ImageAnalyzer()
requirements.txt CHANGED
@@ -1 +1,4 @@
1
- gradio
 
 
 
 
1
+ gradio
2
+ transformers
3
+ datasets
4
+ torch