Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -39,25 +39,25 @@ def get_clip_embeddings(input_data, input_type='text'):
|
|
| 39 |
|
| 40 |
|
| 41 |
veggies = load_dataset('vojtam/vegetables')
|
|
|
|
|
|
|
| 42 |
|
| 43 |
-
|
| 44 |
-
text = gr.Textbox(label = "Enter the text")
|
| 45 |
image = gr.Gallery()
|
| 46 |
|
| 47 |
def get_similar_images(text, n = 4):
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
imgs
|
| 60 |
-
return imgs
|
| 61 |
|
| 62 |
|
| 63 |
|
|
|
|
| 39 |
|
| 40 |
|
| 41 |
veggies = load_dataset('vojtam/vegetables')
|
| 42 |
+
with open('img_embeddings.pkl', 'rb') as file:
|
| 43 |
+
img_embeddings = pickle.load(file)
|
| 44 |
|
| 45 |
+
text = gr.Textbox(label = "Enter the text", 'Your text goes here')
|
|
|
|
| 46 |
image = gr.Gallery()
|
| 47 |
|
| 48 |
def get_similar_images(text, n = 4):
|
| 49 |
+
if text:
|
| 50 |
+
text_embedding = get_clip_embeddings(text, input_type='text')
|
| 51 |
+
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
|
| 52 |
+
sims = cos(torch.tensor(text_embedding), torch.tensor(img_embeddings))
|
| 53 |
+
top_n = np.argsort(np.array(sims))[::-1][:n]
|
| 54 |
+
print(top_n)
|
| 55 |
+
print(img_embeddings)
|
| 56 |
+
imgs = []
|
| 57 |
+
|
| 58 |
+
for index in top_n:
|
| 59 |
+
imgs.append(veggies['train'][index.item()]['image'])
|
| 60 |
+
return imgs
|
|
|
|
| 61 |
|
| 62 |
|
| 63 |
|