import gradio as gr import os from datasets import load_dataset import torch import clip from PIL import Image import pyarrow as pa import lancedb ds = load_dataset("vipulmaheshwari/GTA-Image-Captioning-Dataset") device = "cuda" if torch.cuda.is_available() else "cpu" model, preprocess = clip.load("ViT-L/14", device=device) #Embedding Image Function def embedding_image(image): processed_image = preprocess(image) unsqueezed_image = processed_image.unsqueeze(0).to(device) embed_image = model.encode_image(unsqueezed_image) # Detach, move to CPU, convert to numpy array, and extract the first element as a list result = embed_image.detach().cpu().numpy()[0].tolist() return result data = [] for i in range(len(ds["train"])): img = ds["train"][i]['image'] text = ds["train"][i]['text'] # Encode the image encoded_img = embedding_image(img) data.append({"vector": encoded_img, "text": text, "id" : i}) db = lancedb.connect('./data/tables') schema = pa.schema( [pa.field("vector", pa.list_(pa.float32(),768)), pa.field("text", pa.string()), pa.field("id", pa.int32()) ]) tabel = db.create_table("GTA Image Embedding Data", schema=schema, mode="overwrite") tabel.add(data) tabel.to_pandas() import gradio as gr # Define your search_text function def search_text(text): query = tabel.search(embedding_text(text)).limit(4).to_pandas() images = [] for i in range(len(query)): data_id = int(query['id'][i]) image_path = ds["train"][data_id]['image'] images.append(image_path) return images # Create Gradio interface with gr.Blocks() as demo: with gr.Row(): with gr.Tab("Image search"): vector_query = gr.Textbox(value="Input Text to search; Car on traffic light ", show_label=False) b1 = gr.Button("Submit") with gr.Row(): gallery = gr.Gallery( label="Found images", show_label=False, elem_id="gallery" ).style(columns=[2], object_fit="contain", height="auto") b1.click(search_text, inputs=vector_query, outputs=gallery) demo.launch()