vojtam commited on
Commit
be613af
·
verified ·
1 Parent(s): 7816cb5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -6
app.py CHANGED
@@ -9,10 +9,12 @@ from PIL import Image
9
  from transformers import CLIPProcessor, CLIPModel
10
  from datasets import load_dataset
11
 
 
 
12
  def get_clip_embeddings(input_data, input_type='text'):
13
  # Load the CLIP model and processor
14
- model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
15
- processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
16
 
17
  # Prepare the input based on the type
18
  if input_type == 'text':
@@ -42,17 +44,16 @@ veggies = load_dataset('vojtam/vegetables')
42
  with open('img_embeddings.pkl', 'rb') as file:
43
  img_embeddings = pickle.load(file)
44
 
45
- text = gr.Textbox(label = "Enter the text", 'Your text goes here')
 
 
46
  image = gr.Gallery()
47
 
48
  def get_similar_images(text, n = 4):
49
  if text:
50
  text_embedding = get_clip_embeddings(text, input_type='text')
51
- cos = nn.CosineSimilarity(dim=1, eps=1e-6)
52
  sims = cos(torch.tensor(text_embedding), torch.tensor(img_embeddings))
53
  top_n = np.argsort(np.array(sims))[::-1][:n]
54
- print(top_n)
55
- print(img_embeddings)
56
  imgs = []
57
 
58
  for index in top_n:
 
9
  from transformers import CLIPProcessor, CLIPModel
10
  from datasets import load_dataset
11
 
12
+ model_checkpoint = "openai/clip-vit-base-patch32"
13
+
14
  def get_clip_embeddings(input_data, input_type='text'):
15
  # Load the CLIP model and processor
16
+ model = CLIPModel.from_pretrained(model_checkpoint)
17
+ processor = CLIPProcessor.from_pretrained(model_checkpoint)
18
 
19
  # Prepare the input based on the type
20
  if input_type == 'text':
 
44
  with open('img_embeddings.pkl', 'rb') as file:
45
  img_embeddings = pickle.load(file)
46
 
47
+ cos = nn.CosineSimilarity(dim=1, eps=1e-6)
48
+
49
+ text = gr.Textbox(label = "Enter the description of the images you want to search for", placeholder='Your text goes here')
50
  image = gr.Gallery()
51
 
52
  def get_similar_images(text, n = 4):
53
  if text:
54
  text_embedding = get_clip_embeddings(text, input_type='text')
 
55
  sims = cos(torch.tensor(text_embedding), torch.tensor(img_embeddings))
56
  top_n = np.argsort(np.array(sims))[::-1][:n]
 
 
57
  imgs = []
58
 
59
  for index in top_n: