tooba248 commited on
Commit
10b2979
·
verified ·
1 Parent(s): e54509f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -0
app.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import clip
4
+ from PIL import Image
5
+ from torchvision import transforms
6
+ import faiss
7
+ import requests
8
+ from io import BytesIO
9
+ from datasets import load_dataset
10
+
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
+
13
+ # Load fine-tuned model
14
+ model_url = "best_model.pt"
15
+ model_bytes = requests.get(model_url).content
16
+ model = torch.load(BytesIO(model_bytes), map_location=device)
17
+ model.eval()
18
+
19
+ # Load CLIP for preprocessing
20
+ model_clip, preprocess = clip.load("ViT-B/32", device=device)
21
+
22
+ # Load full test split from Flickr30k
23
+ dataset = load_dataset("nlphuji/flickr30k", split="test")
24
+
25
+ captions = []
26
+ images = []
27
+ image_embeddings = []
28
+ text_embeddings = []
29
+
30
+ for example in dataset:
31
+ try:
32
+ img = Image.open(requests.get(example["image"], stream=True).raw).convert("RGB")
33
+ images.append(img)
34
+ captions.append(example["sentence"])
35
+
36
+ img_tensor = preprocess(img).unsqueeze(0).to(device)
37
+ with torch.no_grad():
38
+ img_feat = model_clip.encode_image(img_tensor)
39
+ img_feat /= img_feat.norm(dim=-1, keepdim=True)
40
+ image_embeddings.append(img_feat.cpu())
41
+
42
+ txt_token = clip.tokenize([example["sentence"]]).to(device)
43
+ txt_feat = model_clip.encode_text(txt_token)
44
+ txt_feat /= txt_feat.norm(dim=-1, keepdim=True)
45
+ text_embeddings.append(txt_feat.cpu())
46
+ except:
47
+ continue
48
+
49
+ image_embeddings = torch.cat(image_embeddings, dim=0)
50
+ text_embeddings = torch.cat(text_embeddings, dim=0)
51
+
52
+ # Create FAISS indexes
53
+ image_index = faiss.IndexFlatIP(image_embeddings.shape[1])
54
+ image_index.add(image_embeddings.numpy())
55
+
56
+ text_index = faiss.IndexFlatIP(text_embeddings.shape[1])
57
+ text_index.add(text_embeddings.numpy())
58
+
59
+ # Define functions
60
+ def image_to_text(image):
61
+ image_input = preprocess(image).unsqueeze(0).to(device)
62
+ with torch.no_grad():
63
+ image_feature = model_clip.encode_image(image_input)
64
+ image_feature /= image_feature.norm(dim=-1, keepdim=True)
65
+ D, I = text_index.search(image_feature.cpu().numpy(), 1)
66
+ score = round(float(D[0][0]) * 100, 2)
67
+ return f"{captions[I[0][0]]}\n(Match Score: {score}%)"
68
+
69
+ image_input = preprocess(image).unsqueeze(0).to(device)
70
+ with torch.no_grad():
71
+ image_feature = model_clip.encode_image(image_input)
72
+ image_feature /= image_feature.norm(dim=-1, keepdim=True)
73
+ _, I = text_index.search(image_feature.cpu().numpy(), 1)
74
+ return captions[I[0][0]]
75
+
76
+ def text_to_image(text):
77
+ text_input = clip.tokenize([text]).to(device)
78
+ with torch.no_grad():
79
+ text_feature = model_clip.encode_text(text_input)
80
+ text_feature /= text_feature.norm(dim=-1, keepdim=True)
81
+ D, I = image_index.search(text_feature.cpu().numpy(), 1)
82
+ score = round(float(D[0][0]) * 100, 2)
83
+ img = images[I[0][0]]
84
+ return img, f"Match Score: {score}%"
85
+
86
+ text_input = clip.tokenize([text]).to(device)
87
+ with torch.no_grad():
88
+ text_feature = model_clip.encode_text(text_input)
89
+ text_feature /= text_feature.norm(dim=-1, keepdim=True)
90
+ _, I = image_index.search(text_feature.cpu().numpy(), 1)
91
+ return images[I[0][0]]
92
+
93
+ # Gradio UI
94
+ with gr.Blocks() as demo:
95
+ gr.Markdown("## 🖼️📝 Cross-Modal Retriever on Flickr30k Test Split")
96
+ with gr.Tab("Image to Text"):
97
+ img_input = gr.Image(type="pil")
98
+ text_output = gr.Textbox(label="Retrieved Caption")
99
+ btn1 = gr.Button("Search Caption")
100
+ btn1.click(image_to_text, inputs=img_input, outputs=text_output)
101
+
102
+ with gr.Tab("Text to Image"):
103
+ text_input = gr.Textbox(label="Enter Text Prompt")
104
+ img_output = gr.Image(label="Most Similar Image")
105
+ score_output = gr.Textbox(label="Similarity Score")
106
+ btn2 = gr.Button("Search Image")
107
+ btn2.click(text_to_image, inputs=text_input, outputs=[img_output, score_output])
108
+
109
+ demo.launch()