Spaces:
Build error
Build error
| import gradio as gr | |
| import pickle | |
| from datasets import load_dataset | |
| from torch import nn | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from transformers import CLIPProcessor, CLIPModel | |
| from datasets import load_dataset | |
| model_checkpoint = "openai/clip-vit-base-patch32" | |
| def get_clip_embeddings(input_data, input_type='text'): | |
| # Load the CLIP model and processor | |
| model = CLIPModel.from_pretrained(model_checkpoint) | |
| processor = CLIPProcessor.from_pretrained(model_checkpoint) | |
| # Prepare the input based on the type | |
| if input_type == 'text': | |
| inputs = processor(text=input_data, return_tensors="pt", padding=True, truncation=True) | |
| elif input_type == 'image': | |
| if isinstance(input_data, str): | |
| image = Image.open(input_data) | |
| elif isinstance(input_data, Image.Image): | |
| image = input_data | |
| else: | |
| raise ValueError("For image input, provide either a file path or a PIL Image object") | |
| inputs = processor(images=image, return_tensors="pt") | |
| else: | |
| raise ValueError("Invalid input_type. Choose 'text' or 'image'") | |
| # Get the embeddings | |
| with torch.no_grad(): | |
| if input_type == 'text': | |
| embeddings = model.get_text_features(**inputs) | |
| else: | |
| embeddings = model.get_image_features(**inputs) | |
| return embeddings.numpy() | |
| veggies = load_dataset('vojtam/vegetables') | |
| with open('img_embeddings.pkl', 'rb') as file: | |
| img_embeddings = pickle.load(file) | |
| cos = nn.CosineSimilarity(dim=1, eps=1e-6) | |
| def get_similar_images(text, n=4): | |
| if text: | |
| text_embedding = get_clip_embeddings(text, input_type='text') | |
| sims = cos(torch.tensor(text_embedding), torch.tensor(img_embeddings)) | |
| top_n = np.argsort(np.array(sims))[::-1][:n] | |
| imgs = [] | |
| for index in top_n: | |
| imgs.append(veggies['train'][index.item()]['image']) | |
| return imgs | |
| return [] | |
| css = """ | |
| .full-height-gallery { | |
| height: calc(100vh - 250px); | |
| overflow-y: auto; | |
| } | |
| #submit-btn { | |
| background-color: #ff5b00; | |
| color: #ffffff; | |
| } | |
| """ | |
| with gr.Blocks(css=css) as intf: | |
| with gr.Row(): | |
| text_input = gr.Textbox(label="Enter the description of the images you want to search for", placeholder='Your text goes here') | |
| with gr.Row(): | |
| submit_btn = gr.Button("Submit", elem_id="submit-btn") | |
| clear_btn = gr.Button("Clear") | |
| with gr.Row(): | |
| gallery = gr.Gallery(label="Similar Images", show_label=False, elem_classes = ["full-height-gallery"]) | |
| submit_btn.click(fn=get_similar_images, inputs=text_input, outputs=gallery) | |
| clear_btn.click(fn=lambda: [None, []], inputs=None, outputs=[text_input, gallery]) | |
| intf.launch(share=True) |