Spaces:
Runtime error
Runtime error
| #Acknowledgments: | |
| #This project is inspired by: | |
| #1. https://github.com/haltakov/natural-language-image-search by Vladimir Haltakov | |
| #2. DrishtiSharma/Text-to-Image-search-using-CLIP | |
| import torch | |
| import requests | |
| import numpy as np | |
| import pandas as pd | |
| import gradio as gr | |
| from io import BytesIO | |
| from PIL import Image as PILIMAGE | |
| from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer | |
| #Selecting device based on availability of GPUs | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| #Defining model, processor and tokenizer | |
| model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device) | |
| processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") | |
| #Loading the data | |
| photos = pd.read_csv("./items_data.csv") | |
| photo_features = np.load("./features.npy") | |
| photo_ids = pd.read_csv("./photo_ids.csv") | |
| photo_ids = list(photo_ids['photo_id']) | |
| def find_best_matches(text): | |
| #Inference | |
| with torch.no_grad(): | |
| # Encode and normalize the description using CLIP | |
| inputs = tokenizer([text], padding=True, return_tensors="pt") | |
| inputs = processor(text=[text], images=None, return_tensors="pt", padding=True) | |
| text_encoded = model.get_text_features(**inputs).detach().numpy() | |
| # Finding Cosine similarity | |
| similarities = list((text_encoded @ photo_features.T).squeeze(0)) | |
| #Block of code for displaying top 3 best matches (images) | |
| matched_images = [] | |
| for i in range(3): | |
| idx = sorted(zip(similarities, range(photo_features.shape[0])), key=lambda x: x[0], reverse=True)[i][1] | |
| photo_id = photo_ids[idx] | |
| photo_data = photos[photos["Uniq Id"] == photo_id].iloc[0] | |
| response = requests.get(photo_data["Image"] + "?w=640") | |
| img = PILIMAGE.open(BytesIO(response.content)) | |
| matched_images.append(img) | |
| return matched_images | |
| #Gradio app | |
| with gr.Blocks() as demo: | |
| with gr.Column(variant="panel"): | |
| with gr.Row(variant="compact"): | |
| text = gr.Textbox( | |
| label="Search product", | |
| show_label=False, | |
| max_lines=1, | |
| placeholder="Type product", | |
| ).style( | |
| container=False, | |
| ) | |
| btn = gr.Button("Search").style(full_width=False) | |
| gallery = gr.Gallery( | |
| label="Products", show_label=False, elem_id="gallery" | |
| ).style(grid=[3], height="auto") | |
| btn.click(find_best_matches, inputs = text, outputs = gallery) | |
| demo.launch(show_api=False) | |