vojtam commited on
Commit
7816cb5
·
verified ·
1 Parent(s): 8dc3e2c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -15
app.py CHANGED
@@ -39,25 +39,25 @@ def get_clip_embeddings(input_data, input_type='text'):
39
 
40
 
41
  veggies = load_dataset('vojtam/vegetables')
 
 
42
 
43
-
44
- text = gr.Textbox(label = "Enter the text")
45
  image = gr.Gallery()
46
 
47
  def get_similar_images(text, n = 4):
48
- with open('img_embeddings.pkl', 'rb') as file:
49
- img_embeddings = pickle.load(file)
50
- text_embedding = get_clip_embeddings(text, input_type='text')
51
- cos = nn.CosineSimilarity(dim=1, eps=1e-6)
52
- sims = cos(torch.tensor(text_embedding), torch.tensor(img_embeddings))
53
- top_n = np.argsort(np.array(sims))[::-1][:4]
54
- print(top_n)
55
- print(img_embeddings)
56
- imgs = []
57
-
58
- for index in top_n:
59
- imgs.append(veggies['train'][index.item()]['image'])
60
- return imgs
61
 
62
 
63
 
 
39
 
40
 
41
  veggies = load_dataset('vojtam/vegetables')
42
+ with open('img_embeddings.pkl', 'rb') as file:
43
+ img_embeddings = pickle.load(file)
44
 
45
+ text = gr.Textbox(label = "Enter the text", 'Your text goes here')
 
46
  image = gr.Gallery()
47
 
48
  def get_similar_images(text, n = 4):
49
+ if text:
50
+ text_embedding = get_clip_embeddings(text, input_type='text')
51
+ cos = nn.CosineSimilarity(dim=1, eps=1e-6)
52
+ sims = cos(torch.tensor(text_embedding), torch.tensor(img_embeddings))
53
+ top_n = np.argsort(np.array(sims))[::-1][:n]
54
+ print(top_n)
55
+ print(img_embeddings)
56
+ imgs = []
57
+
58
+ for index in top_n:
59
+ imgs.append(veggies['train'][index.item()]['image'])
60
+ return imgs
 
61
 
62
 
63