File size: 2,103 Bytes
da54daf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
import gradio as gr
import os
from datasets import load_dataset
import torch
import clip
from PIL import Image
import pyarrow as pa
import lancedb
ds = load_dataset("vipulmaheshwari/GTA-Image-Captioning-Dataset")
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-L/14", device=device)
#Embedding Image Function
def embedding_image(image):
processed_image = preprocess(image)
unsqueezed_image = processed_image.unsqueeze(0).to(device)
embed_image = model.encode_image(unsqueezed_image)
# Detach, move to CPU, convert to numpy array, and extract the first element as a list
result = embed_image.detach().cpu().numpy()[0].tolist()
return result
data = []
for i in range(len(ds["train"])):
img = ds["train"][i]['image']
text = ds["train"][i]['text']
# Encode the image
encoded_img = embedding_image(img)
data.append({"vector": encoded_img, "text": text, "id" : i})
db = lancedb.connect('./data/tables')
schema = pa.schema(
[pa.field("vector", pa.list_(pa.float32(),768)),
pa.field("text", pa.string()),
pa.field("id", pa.int32()) ])
tabel = db.create_table("GTA Image Embedding Data", schema=schema, mode="overwrite")
tabel.add(data)
tabel.to_pandas()
import gradio as gr
# Define your search_text function
def search_text(text):
query = tabel.search(embedding_text(text)).limit(4).to_pandas()
images = []
for i in range(len(query)):
data_id = int(query['id'][i])
image_path = ds["train"][data_id]['image']
images.append(image_path)
return images
# Create Gradio interface
with gr.Blocks() as demo:
with gr.Row():
with gr.Tab("Image search"):
vector_query = gr.Textbox(value="Input Text to search; Car on traffic light ", show_label=False)
b1 = gr.Button("Submit")
with gr.Row():
gallery = gr.Gallery(
label="Found images", show_label=False, elem_id="gallery"
).style(columns=[2], object_fit="contain", height="auto")
b1.click(search_text, inputs=vector_query, outputs=gallery)
demo.launch()
|