vojtam commited on
Commit
f7621f0
·
verified ·
1 Parent(s): 09558d9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -0
app.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pickle
3
+ from datasets import load_dataset
4
+ from torch import nn
5
+ import numpy as np
6
+
7
+ import torch
8
+ from PIL import Image
9
+ from transformers import CLIPProcessor, CLIPModel
10
+ from datasets import load_dataset
11
+
12
+ def get_clip_embeddings(input_data, input_type='text'):
13
+ # Load the CLIP model and processor
14
+ model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
15
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
16
+
17
+ # Prepare the input based on the type
18
+ if input_type == 'text':
19
+ inputs = processor(text=input_data, return_tensors="pt", padding=True, truncation=True)
20
+ elif input_type == 'image':
21
+ if isinstance(input_data, str):
22
+ image = Image.open(input_data)
23
+ elif isinstance(input_data, Image.Image):
24
+ image = input_data
25
+ else:
26
+ raise ValueError("For image input, provide either a file path or a PIL Image object")
27
+ inputs = processor(images=image, return_tensors="pt")
28
+ else:
29
+ raise ValueError("Invalid input_type. Choose 'text' or 'image'")
30
+
31
+ # Get the embeddings
32
+ with torch.no_grad():
33
+ if input_type == 'text':
34
+ embeddings = model.get_text_features(**inputs)
35
+ else:
36
+ embeddings = model.get_image_features(**inputs)
37
+
38
+ return embeddings.numpy()
39
+
40
+
41
+ veggies = load_dataset('vojtam/vegetables')
42
+
43
+
44
+ text = gr.Textbox(label = "Enter the text")
45
+ image = gr.Gallery()
46
+
47
+ def get_similar_images(text, n = 4):
48
+ with open('img_embeddings.pkl', 'rb') as file:
49
+ img_embeddings = pickle.load(file)
50
+ text_embedding = get_clip_embeddings(text, input_type='text')
51
+ cos = nn.CosineSimilarity(dim=1, eps=1e-6)
52
+ sims = cos(torch.tensor(text_embedding), torch.tensor(img_embeddings))
53
+ top_n = np.argsort(np.array(sims))[::-1][:4]
54
+ print(top_n)
55
+ print(img_embeddings)
56
+ imgs = []
57
+
58
+ for index in top_n:
59
+ imgs.append(veggies['train'][index.item()]['image'])
60
+ return imgs
61
+
62
+
63
+
64
+ intf = gr.Interface(fn = get_similar_images, inputs = text, outputs = image)
65
+ intf.launch(share=True)