tooba248 commited on
Commit
e403cae
Β·
verified Β·
1 Parent(s): 444f65d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -62
app.py CHANGED
@@ -1,99 +1,96 @@
1
  import gradio as gr
2
  import torch
3
  import clip
 
4
  from PIL import Image
5
  import faiss
6
  import requests
7
- from datasets import load_dataset
8
  from io import BytesIO
9
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
 
12
- # Load base CLIP model and preprocess
13
  model_clip, preprocess = clip.load("ViT-B/32", device=device)
14
 
15
- # Load fine-tuned weights (state_dict) and apply to CLIP model
16
  state_dict = torch.load("best_model.pt", map_location=device)
17
- model_clip.load_state_dict(state_dict)
 
18
  model_clip.eval()
19
 
20
- # Load Flickr30k test split dataset
21
  dataset = load_dataset("nlphuji/flickr30k", split="test")
22
 
23
- captions = []
24
- images = []
25
- image_embeddings = []
26
- text_embeddings = []
27
-
28
- print("Preparing embeddings for retrieval pool...")
29
 
 
30
  for example in dataset:
31
  try:
32
- # Load image from URL
33
  img = Image.open(requests.get(example["image"], stream=True).raw).convert("RGB")
34
  images.append(img)
35
  captions.append(example["sentence"])
36
 
37
- # Preprocess and encode image
38
- img_tensor = preprocess(img).unsqueeze(0).to(device)
39
  with torch.no_grad():
40
- img_feat = model_clip.encode_image(img_tensor)
41
- img_feat /= img_feat.norm(dim=-1, keepdim=True)
42
- image_embeddings.append(img_feat.cpu())
43
 
44
- # Tokenize and encode text
45
- txt_token = clip.tokenize([example["sentence"]]).to(device)
46
- txt_feat = model_clip.encode_text(txt_token)
47
- txt_feat /= txt_feat.norm(dim=-1, keepdim=True)
48
- text_embeddings.append(txt_feat.cpu())
49
- except Exception as e:
50
- print(f"Skipping one example due to error: {e}")
51
  continue
52
 
53
- # Convert lists of embeddings to tensors
54
- image_embeddings = torch.cat(image_embeddings, dim=0)
55
- text_embeddings = torch.cat(text_embeddings, dim=0)
56
 
57
- # Create FAISS indices for fast similarity search (Inner Product = cosine similarity)
58
- image_index = faiss.IndexFlatIP(image_embeddings.shape[1])
59
- image_index.add(image_embeddings.numpy())
60
 
61
- text_index = faiss.IndexFlatIP(text_embeddings.shape[1])
62
- text_index.add(text_embeddings.numpy())
63
 
64
- def image_to_text(image):
65
- image_input = preprocess(image).unsqueeze(0).to(device)
 
66
  with torch.no_grad():
67
- image_feature = model_clip.encode_image(image_input)
68
- image_feature /= image_feature.norm(dim=-1, keepdim=True)
69
- D, I = text_index.search(image_feature.cpu().numpy(), 1)
70
- score = round(float(D[0][0]) * 100, 2)
71
- return f"{captions[I[0][0]]}\n(Match Score: {score}%)"
72
-
73
- def text_to_image(text):
74
- text_input = clip.tokenize([text]).to(device)
75
  with torch.no_grad():
76
- text_feature = model_clip.encode_text(text_input)
77
- text_feature /= text_feature.norm(dim=-1, keepdim=True)
78
- D, I = image_index.search(text_feature.cpu().numpy(), 1)
79
- score = round(float(D[0][0]) * 100, 2)
80
- img = images[I[0][0]]
81
- return img, f"Match Score: {score}%"
82
 
 
83
  with gr.Blocks() as demo:
84
- gr.Markdown("## πŸ”„ Cross-Modal Retriever on Flickr30k (Image ↔ Text Matching)")
85
-
86
- with gr.Tab("πŸ–ΌοΈ Image to Text"):
87
- img_input = gr.Image(type="pil", label="Upload Image")
88
- text_output = gr.Textbox(label="Most Similar Caption")
89
- btn1 = gr.Button("Find Caption")
90
- btn1.click(image_to_text, inputs=img_input, outputs=text_output)
91
-
92
- with gr.Tab("πŸ“ Text to Image"):
93
- text_input = gr.Textbox(label="Enter a Caption")
94
- img_output = gr.Image(label="Most Similar Image")
95
- score_output = gr.Textbox(label="Similarity Score")
96
- btn2 = gr.Button("Find Image")
97
- btn2.click(text_to_image, inputs=text_input, outputs=[img_output, score_output])
98
 
99
  demo.launch()
 
1
  import gradio as gr
2
  import torch
3
  import clip
4
+ from datasets import load_dataset
5
  from PIL import Image
6
  import faiss
7
  import requests
 
8
  from io import BytesIO
9
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
 
12
+ # 1) Load base CLIP model + preprocess
13
  model_clip, preprocess = clip.load("ViT-B/32", device=device)
14
 
15
+ # 2) Load your fine‐tuned weights (state_dict) into model_clip
16
  state_dict = torch.load("best_model.pt", map_location=device)
17
+ missing, unexpected = model_clip.load_state_dict(state_dict, strict=False)
18
+ print(f"⚠️ Missing keys: {missing}\n⚠️ Unexpected keys: {unexpected}")
19
  model_clip.eval()
20
 
21
+ # 3) Build retrieval pool from Flickr30k test split
22
  dataset = load_dataset("nlphuji/flickr30k", split="test")
23
 
24
+ images, captions = [], []
25
+ img_embs, txt_embs = [], []
 
 
 
 
26
 
27
+ print("πŸ”„ Preparing retrieval pool embeddings...")
28
  for example in dataset:
29
  try:
30
+ # load & store raw image + caption
31
  img = Image.open(requests.get(example["image"], stream=True).raw).convert("RGB")
32
  images.append(img)
33
  captions.append(example["sentence"])
34
 
35
+ # encode image
36
+ img_t = preprocess(img).unsqueeze(0).to(device)
37
  with torch.no_grad():
38
+ v = model_clip.encode_image(img_t)
39
+ v /= v.norm(dim=-1, keepdim=True)
40
+ img_embs.append(v.cpu())
41
 
42
+ # encode text
43
+ t = clip.tokenize([example["sentence"]]).to(device)
44
+ with torch.no_grad():
45
+ tfeat = model_clip.encode_text(t)
46
+ tfeat /= tfeat.norm(dim=-1, keepdim=True)
47
+ txt_embs.append(tfeat.cpu())
48
+ except:
49
  continue
50
 
51
+ # cat into tensors
52
+ img_embs = torch.cat(img_embs, dim=0)
53
+ txt_embs = torch.cat(txt_embs, dim=0)
54
 
55
+ # build FAISS indices (Inner‐Product = cosine)
56
+ img_index = faiss.IndexFlatIP(img_embs.shape[1])
57
+ img_index.add(img_embs.numpy())
58
 
59
+ txt_index = faiss.IndexFlatIP(txt_embs.shape[1])
60
+ txt_index.add(txt_embs.numpy())
61
 
62
+ # 4) Gradio callbacks
63
+ def image_to_text(inp_img):
64
+ im = preprocess(inp_img).unsqueeze(0).to(device)
65
  with torch.no_grad():
66
+ v = model_clip.encode_image(im)
67
+ v /= v.norm(dim=-1, keepdim=True)
68
+ D, I = txt_index.search(v.cpu().numpy(), 1)
69
+ score = D[0][0] * 100
70
+ return f"{captions[I[0][0]]}\n(Match Score: {score:.2f}%)"
71
+
72
+ def text_to_image(inp_txt):
73
+ tok = clip.tokenize([inp_txt]).to(device)
74
  with torch.no_grad():
75
+ t = model_clip.encode_text(tok)
76
+ t /= t.norm(dim=-1, keepdim=True)
77
+ D, I = img_index.search(t.cpu().numpy(), 1)
78
+ score = D[0][0] * 100
79
+ return images[I[0][0]], f"Match Score: {score:.2f}%"
 
80
 
81
+ # 5) Gradio UI
82
  with gr.Blocks() as demo:
83
+ gr.Markdown("## πŸ”„ Cross-Modal Retriever (Flickr30k Test Split)\nUpload an image or enter text to retrieve the best match.")
84
+
85
+ with gr.Tab("πŸ–ΌοΈ Image β†’ Text"):
86
+ img_in = gr.Image(type="pil", label="Upload Image")
87
+ txt_out = gr.Textbox(label="Retrieved Caption")
88
+ gr.Button("Search Caption").click(image_to_text, img_in, txt_out)
89
+
90
+ with gr.Tab("πŸ“ Text β†’ Image"):
91
+ txt_in = gr.Textbox(label="Enter Text")
92
+ img_out = gr.Image(label="Retrieved Image")
93
+ score_out = gr.Textbox(label="Score")
94
+ gr.Button("Search Image").click(text_to_image, txt_in, [img_out, score_out])
 
 
95
 
96
  demo.launch()