Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -9,10 +9,12 @@ from PIL import Image
|
|
| 9 |
from transformers import CLIPProcessor, CLIPModel
|
| 10 |
from datasets import load_dataset
|
| 11 |
|
|
|
|
|
|
|
| 12 |
def get_clip_embeddings(input_data, input_type='text'):
|
| 13 |
# Load the CLIP model and processor
|
| 14 |
-
model = CLIPModel.from_pretrained(
|
| 15 |
-
processor = CLIPProcessor.from_pretrained(
|
| 16 |
|
| 17 |
# Prepare the input based on the type
|
| 18 |
if input_type == 'text':
|
|
@@ -42,17 +44,16 @@ veggies = load_dataset('vojtam/vegetables')
|
|
| 42 |
with open('img_embeddings.pkl', 'rb') as file:
|
| 43 |
img_embeddings = pickle.load(file)
|
| 44 |
|
| 45 |
-
|
|
|
|
|
|
|
| 46 |
image = gr.Gallery()
|
| 47 |
|
| 48 |
def get_similar_images(text, n = 4):
|
| 49 |
if text:
|
| 50 |
text_embedding = get_clip_embeddings(text, input_type='text')
|
| 51 |
-
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
|
| 52 |
sims = cos(torch.tensor(text_embedding), torch.tensor(img_embeddings))
|
| 53 |
top_n = np.argsort(np.array(sims))[::-1][:n]
|
| 54 |
-
print(top_n)
|
| 55 |
-
print(img_embeddings)
|
| 56 |
imgs = []
|
| 57 |
|
| 58 |
for index in top_n:
|
|
|
|
| 9 |
from transformers import CLIPProcessor, CLIPModel
|
| 10 |
from datasets import load_dataset
|
| 11 |
|
| 12 |
+
model_checkpoint = "openai/clip-vit-base-patch32"
|
| 13 |
+
|
| 14 |
def get_clip_embeddings(input_data, input_type='text'):
|
| 15 |
# Load the CLIP model and processor
|
| 16 |
+
model = CLIPModel.from_pretrained(model_checkpoint)
|
| 17 |
+
processor = CLIPProcessor.from_pretrained(model_checkpoint)
|
| 18 |
|
| 19 |
# Prepare the input based on the type
|
| 20 |
if input_type == 'text':
|
|
|
|
| 44 |
with open('img_embeddings.pkl', 'rb') as file:
|
| 45 |
img_embeddings = pickle.load(file)
|
| 46 |
|
| 47 |
+
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
|
| 48 |
+
|
| 49 |
+
text = gr.Textbox(label = "Enter the description of the images you want to search for", placeholder='Your text goes here')
|
| 50 |
image = gr.Gallery()
|
| 51 |
|
| 52 |
def get_similar_images(text, n = 4):
|
| 53 |
if text:
|
| 54 |
text_embedding = get_clip_embeddings(text, input_type='text')
|
|
|
|
| 55 |
sims = cos(torch.tensor(text_embedding), torch.tensor(img_embeddings))
|
| 56 |
top_n = np.argsort(np.array(sims))[::-1][:n]
|
|
|
|
|
|
|
| 57 |
imgs = []
|
| 58 |
|
| 59 |
for index in top_n:
|