| import gradio as gr |
| import os |
| from datasets import load_dataset |
| import torch |
| import clip |
| from PIL import Image |
| import pyarrow as pa |
| import lancedb |
|
|
| ds = load_dataset("vipulmaheshwari/GTA-Image-Captioning-Dataset") |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| model, preprocess = clip.load("ViT-L/14", device=device) |
|
|
|
|
| |
| def embedding_image(image): |
| processed_image = preprocess(image) |
| unsqueezed_image = processed_image.unsqueeze(0).to(device) |
| embed_image = model.encode_image(unsqueezed_image) |
|
|
| |
| result = embed_image.detach().cpu().numpy()[0].tolist() |
| return result |
|
|
| data = [] |
| for i in range(len(ds["train"])): |
| img = ds["train"][i]['image'] |
| text = ds["train"][i]['text'] |
|
|
| |
| encoded_img = embedding_image(img) |
| data.append({"vector": encoded_img, "text": text, "id" : i}) |
|
|
|
|
| db = lancedb.connect('./data/tables') |
| schema = pa.schema( |
| [pa.field("vector", pa.list_(pa.float32(),768)), |
| pa.field("text", pa.string()), |
| pa.field("id", pa.int32()) ]) |
|
|
| tabel = db.create_table("GTA Image Embedding Data", schema=schema, mode="overwrite") |
|
|
| tabel.add(data) |
| tabel.to_pandas() |
|
|
| import gradio as gr |
|
|
| |
| def search_text(text): |
| query = tabel.search(embedding_text(text)).limit(4).to_pandas() |
| images = [] |
| for i in range(len(query)): |
| data_id = int(query['id'][i]) |
| image_path = ds["train"][data_id]['image'] |
| images.append(image_path) |
| return images |
|
|
| |
| with gr.Blocks() as demo: |
| with gr.Row(): |
| with gr.Tab("Image search"): |
| vector_query = gr.Textbox(value="Input Text to search; Car on traffic light ", show_label=False) |
| b1 = gr.Button("Submit") |
|
|
| with gr.Row(): |
| gallery = gr.Gallery( |
| label="Found images", show_label=False, elem_id="gallery" |
| ).style(columns=[2], object_fit="contain", height="auto") |
|
|
| b1.click(search_text, inputs=vector_query, outputs=gallery) |
|
|
| demo.launch() |
|
|
|
|
|
|
|
|
|
|
|
|