File size: 2,103 Bytes
da54daf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import gradio as gr
import os
from datasets import load_dataset
import torch
import clip
from PIL import Image
import pyarrow as pa
import lancedb

ds = load_dataset("vipulmaheshwari/GTA-Image-Captioning-Dataset")

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-L/14", device=device)


#Embedding Image Function
def embedding_image(image):
  processed_image = preprocess(image)
  unsqueezed_image = processed_image.unsqueeze(0).to(device)
  embed_image = model.encode_image(unsqueezed_image)

  # Detach, move to CPU, convert to numpy array, and extract the first element as a list
  result = embed_image.detach().cpu().numpy()[0].tolist()
  return result

data = []
for i in range(len(ds["train"])):
    img = ds["train"][i]['image']
    text = ds["train"][i]['text']

    # Encode the image
    encoded_img = embedding_image(img)
    data.append({"vector": encoded_img, "text": text, "id" : i})


db = lancedb.connect('./data/tables')
schema = pa.schema(
  [pa.field("vector", pa.list_(pa.float32(),768)),
   pa.field("text", pa.string()),
   pa.field("id", pa.int32()) ])

tabel = db.create_table("GTA Image Embedding Data", schema=schema, mode="overwrite")

tabel.add(data)
tabel.to_pandas()

import gradio as gr

# Define your search_text function
def search_text(text):
    query = tabel.search(embedding_text(text)).limit(4).to_pandas()
    images = []
    for i in range(len(query)):
        data_id = int(query['id'][i])
        image_path = ds["train"][data_id]['image']
        images.append(image_path)
    return images

# Create Gradio interface
with gr.Blocks() as demo:
    with gr.Row():
        with gr.Tab("Image search"):
            vector_query = gr.Textbox(value="Input Text to search; Car on traffic light ", show_label=False)
            b1 = gr.Button("Submit")

    with gr.Row():
        gallery = gr.Gallery(
            label="Found images", show_label=False, elem_id="gallery"
        ).style(columns=[2], object_fit="contain", height="auto")

    b1.click(search_text, inputs=vector_query, outputs=gallery)

demo.launch()