|
|
import gradio as gr |
|
|
import os |
|
|
from datasets import load_dataset |
|
|
import torch |
|
|
import clip |
|
|
from PIL import Image |
|
|
import pyarrow as pa |
|
|
import lancedb |
|
|
|
|
|
ds = load_dataset("vipulmaheshwari/GTA-Image-Captioning-Dataset") |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
model, preprocess = clip.load("ViT-L/14", device=device) |
|
|
|
|
|
|
|
|
|
|
|
def embedding_image(image): |
|
|
processed_image = preprocess(image) |
|
|
unsqueezed_image = processed_image.unsqueeze(0).to(device) |
|
|
embed_image = model.encode_image(unsqueezed_image) |
|
|
|
|
|
|
|
|
result = embed_image.detach().cpu().numpy()[0].tolist() |
|
|
return result |
|
|
|
|
|
data = [] |
|
|
for i in range(len(ds["train"])): |
|
|
img = ds["train"][i]['image'] |
|
|
text = ds["train"][i]['text'] |
|
|
|
|
|
|
|
|
encoded_img = embedding_image(img) |
|
|
data.append({"vector": encoded_img, "text": text, "id" : i}) |
|
|
|
|
|
|
|
|
db = lancedb.connect('./data/tables') |
|
|
schema = pa.schema( |
|
|
[pa.field("vector", pa.list_(pa.float32(),768)), |
|
|
pa.field("text", pa.string()), |
|
|
pa.field("id", pa.int32()) ]) |
|
|
|
|
|
tabel = db.create_table("GTA Image Embedding Data", schema=schema, mode="overwrite") |
|
|
|
|
|
tabel.add(data) |
|
|
tabel.to_pandas() |
|
|
|
|
|
import gradio as gr |
|
|
|
|
|
|
|
|
def search_text(text): |
|
|
query = tabel.search(embedding_text(text)).limit(4).to_pandas() |
|
|
images = [] |
|
|
for i in range(len(query)): |
|
|
data_id = int(query['id'][i]) |
|
|
image_path = ds["train"][data_id]['image'] |
|
|
images.append(image_path) |
|
|
return images |
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
with gr.Row(): |
|
|
with gr.Tab("Image search"): |
|
|
vector_query = gr.Textbox(value="Input Text to search; Car on traffic light ", show_label=False) |
|
|
b1 = gr.Button("Submit") |
|
|
|
|
|
with gr.Row(): |
|
|
gallery = gr.Gallery( |
|
|
label="Found images", show_label=False, elem_id="gallery" |
|
|
).style(columns=[2], object_fit="contain", height="auto") |
|
|
|
|
|
b1.click(search_text, inputs=vector_query, outputs=gallery) |
|
|
|
|
|
demo.launch() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|