import gradio as gr import pickle from datasets import load_dataset from torch import nn import numpy as np import torch from PIL import Image from transformers import CLIPProcessor, CLIPModel from datasets import load_dataset model_checkpoint = "openai/clip-vit-base-patch32" def get_clip_embeddings(input_data, input_type='text'): # Load the CLIP model and processor model = CLIPModel.from_pretrained(model_checkpoint) processor = CLIPProcessor.from_pretrained(model_checkpoint) # Prepare the input based on the type if input_type == 'text': inputs = processor(text=input_data, return_tensors="pt", padding=True, truncation=True) elif input_type == 'image': if isinstance(input_data, str): image = Image.open(input_data) elif isinstance(input_data, Image.Image): image = input_data else: raise ValueError("For image input, provide either a file path or a PIL Image object") inputs = processor(images=image, return_tensors="pt") else: raise ValueError("Invalid input_type. Choose 'text' or 'image'") # Get the embeddings with torch.no_grad(): if input_type == 'text': embeddings = model.get_text_features(**inputs) else: embeddings = model.get_image_features(**inputs) return embeddings.numpy() veggies = load_dataset('vojtam/vegetables') with open('img_embeddings.pkl', 'rb') as file: img_embeddings = pickle.load(file) cos = nn.CosineSimilarity(dim=1, eps=1e-6) def get_similar_images(text, n=4): if text: text_embedding = get_clip_embeddings(text, input_type='text') sims = cos(torch.tensor(text_embedding), torch.tensor(img_embeddings)) top_n = np.argsort(np.array(sims))[::-1][:n] imgs = [] for index in top_n: imgs.append(veggies['train'][index.item()]['image']) return imgs return [] css = """ .full-height-gallery { height: calc(100vh - 250px); overflow-y: auto; } #submit-btn { background-color: #ff5b00; color: #ffffff; } """ with gr.Blocks(css=css) as intf: with gr.Row(): text_input = gr.Textbox(label="Enter the description of the images you want to search for", placeholder='Your text goes here') with gr.Row(): submit_btn = gr.Button("Submit", elem_id="submit-btn") clear_btn = gr.Button("Clear") with gr.Row(): gallery = gr.Gallery(label="Similar Images", show_label=False, elem_classes = ["full-height-gallery"]) submit_btn.click(fn=get_similar_images, inputs=text_input, outputs=gallery) clear_btn.click(fn=lambda: [None, []], inputs=None, outputs=[text_input, gallery]) intf.launch(share=True)