dev-odeh commited on
Commit
da54daf
·
verified ·
1 Parent(s): 0d33d3e

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +78 -0
  2. requirements.txt +10 -0
app.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ from datasets import load_dataset
4
+ import torch
5
+ import clip
6
+ from PIL import Image
7
+ import pyarrow as pa
8
+ import lancedb
9
+
10
+ ds = load_dataset("vipulmaheshwari/GTA-Image-Captioning-Dataset")
11
+
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
+ model, preprocess = clip.load("ViT-L/14", device=device)
14
+
15
+
16
+ #Embedding Image Function
17
+ def embedding_image(image):
18
+ processed_image = preprocess(image)
19
+ unsqueezed_image = processed_image.unsqueeze(0).to(device)
20
+ embed_image = model.encode_image(unsqueezed_image)
21
+
22
+ # Detach, move to CPU, convert to numpy array, and extract the first element as a list
23
+ result = embed_image.detach().cpu().numpy()[0].tolist()
24
+ return result
25
+
26
+ data = []
27
+ for i in range(len(ds["train"])):
28
+ img = ds["train"][i]['image']
29
+ text = ds["train"][i]['text']
30
+
31
+ # Encode the image
32
+ encoded_img = embedding_image(img)
33
+ data.append({"vector": encoded_img, "text": text, "id" : i})
34
+
35
+
36
+ db = lancedb.connect('./data/tables')
37
+ schema = pa.schema(
38
+ [pa.field("vector", pa.list_(pa.float32(),768)),
39
+ pa.field("text", pa.string()),
40
+ pa.field("id", pa.int32()) ])
41
+
42
+ tabel = db.create_table("GTA Image Embedding Data", schema=schema, mode="overwrite")
43
+
44
+ tabel.add(data)
45
+ tabel.to_pandas()
46
+
47
+ import gradio as gr
48
+
49
+ # Define your search_text function
50
+ def search_text(text):
51
+ query = tabel.search(embedding_text(text)).limit(4).to_pandas()
52
+ images = []
53
+ for i in range(len(query)):
54
+ data_id = int(query['id'][i])
55
+ image_path = ds["train"][data_id]['image']
56
+ images.append(image_path)
57
+ return images
58
+
59
+ # Create Gradio interface
60
+ with gr.Blocks() as demo:
61
+ with gr.Row():
62
+ with gr.Tab("Image search"):
63
+ vector_query = gr.Textbox(value="Input Text to search; Car on traffic light ", show_label=False)
64
+ b1 = gr.Button("Submit")
65
+
66
+ with gr.Row():
67
+ gallery = gr.Gallery(
68
+ label="Found images", show_label=False, elem_id="gallery"
69
+ ).style(columns=[2], object_fit="contain", height="auto")
70
+
71
+ b1.click(search_text, inputs=vector_query, outputs=gallery)
72
+
73
+ demo.launch()
74
+
75
+
76
+
77
+
78
+
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy
2
+ pandas
3
+ matplotlib
4
+ gradio
5
+ datasets
6
+ torch
7
+ clip
8
+ PIL
9
+ pyarrow
10
+ lancedb