tooba248 commited on
Commit
444f65d
·
verified ·
1 Parent(s): 2698d32

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -15
app.py CHANGED
@@ -3,19 +3,21 @@ import torch
3
  import clip
4
  from PIL import Image
5
  import faiss
 
6
  from datasets import load_dataset
 
7
 
8
- # Device setup
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
 
11
- # Load fine-tuned model (from local file)
12
- model = torch.load("best_model.pt", map_location=device)
13
- model.eval()
14
-
15
- # Load base CLIP model for encoding
16
  model_clip, preprocess = clip.load("ViT-B/32", device=device)
17
 
18
- # Load Flickr30k test dataset
 
 
 
 
 
19
  dataset = load_dataset("nlphuji/flickr30k", split="test")
20
 
21
  captions = []
@@ -23,38 +25,42 @@ images = []
23
  image_embeddings = []
24
  text_embeddings = []
25
 
26
- # Prepare image and text embeddings
 
27
  for example in dataset:
28
  try:
 
29
  img = Image.open(requests.get(example["image"], stream=True).raw).convert("RGB")
30
  images.append(img)
31
  captions.append(example["sentence"])
32
 
 
33
  img_tensor = preprocess(img).unsqueeze(0).to(device)
34
  with torch.no_grad():
35
  img_feat = model_clip.encode_image(img_tensor)
36
  img_feat /= img_feat.norm(dim=-1, keepdim=True)
37
  image_embeddings.append(img_feat.cpu())
38
 
 
39
  txt_token = clip.tokenize([example["sentence"]]).to(device)
40
  txt_feat = model_clip.encode_text(txt_token)
41
  txt_feat /= txt_feat.norm(dim=-1, keepdim=True)
42
  text_embeddings.append(txt_feat.cpu())
43
- except:
 
44
  continue
45
 
46
- # Convert lists to tensors
47
  image_embeddings = torch.cat(image_embeddings, dim=0)
48
  text_embeddings = torch.cat(text_embeddings, dim=0)
49
 
50
- # Build FAISS indices
51
  image_index = faiss.IndexFlatIP(image_embeddings.shape[1])
52
  image_index.add(image_embeddings.numpy())
53
 
54
  text_index = faiss.IndexFlatIP(text_embeddings.shape[1])
55
  text_index.add(text_embeddings.numpy())
56
 
57
- # Image-to-Text search
58
  def image_to_text(image):
59
  image_input = preprocess(image).unsqueeze(0).to(device)
60
  with torch.no_grad():
@@ -64,7 +70,6 @@ def image_to_text(image):
64
  score = round(float(D[0][0]) * 100, 2)
65
  return f"{captions[I[0][0]]}\n(Match Score: {score}%)"
66
 
67
- # Text-to-Image search
68
  def text_to_image(text):
69
  text_input = clip.tokenize([text]).to(device)
70
  with torch.no_grad():
@@ -75,10 +80,9 @@ def text_to_image(text):
75
  img = images[I[0][0]]
76
  return img, f"Match Score: {score}%"
77
 
78
- # Gradio UI
79
  with gr.Blocks() as demo:
80
  gr.Markdown("## 🔄 Cross-Modal Retriever on Flickr30k (Image ↔ Text Matching)")
81
-
82
  with gr.Tab("🖼️ Image to Text"):
83
  img_input = gr.Image(type="pil", label="Upload Image")
84
  text_output = gr.Textbox(label="Most Similar Caption")
 
3
  import clip
4
  from PIL import Image
5
  import faiss
6
+ import requests
7
  from datasets import load_dataset
8
+ from io import BytesIO
9
 
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
 
12
+ # Load base CLIP model and preprocess
 
 
 
 
13
  model_clip, preprocess = clip.load("ViT-B/32", device=device)
14
 
15
+ # Load fine-tuned weights (state_dict) and apply to CLIP model
16
+ state_dict = torch.load("best_model.pt", map_location=device)
17
+ model_clip.load_state_dict(state_dict)
18
+ model_clip.eval()
19
+
20
+ # Load Flickr30k test split dataset
21
  dataset = load_dataset("nlphuji/flickr30k", split="test")
22
 
23
  captions = []
 
25
  image_embeddings = []
26
  text_embeddings = []
27
 
28
+ print("Preparing embeddings for retrieval pool...")
29
+
30
  for example in dataset:
31
  try:
32
+ # Load image from URL
33
  img = Image.open(requests.get(example["image"], stream=True).raw).convert("RGB")
34
  images.append(img)
35
  captions.append(example["sentence"])
36
 
37
+ # Preprocess and encode image
38
  img_tensor = preprocess(img).unsqueeze(0).to(device)
39
  with torch.no_grad():
40
  img_feat = model_clip.encode_image(img_tensor)
41
  img_feat /= img_feat.norm(dim=-1, keepdim=True)
42
  image_embeddings.append(img_feat.cpu())
43
 
44
+ # Tokenize and encode text
45
  txt_token = clip.tokenize([example["sentence"]]).to(device)
46
  txt_feat = model_clip.encode_text(txt_token)
47
  txt_feat /= txt_feat.norm(dim=-1, keepdim=True)
48
  text_embeddings.append(txt_feat.cpu())
49
+ except Exception as e:
50
+ print(f"Skipping one example due to error: {e}")
51
  continue
52
 
53
+ # Convert lists of embeddings to tensors
54
  image_embeddings = torch.cat(image_embeddings, dim=0)
55
  text_embeddings = torch.cat(text_embeddings, dim=0)
56
 
57
+ # Create FAISS indices for fast similarity search (Inner Product = cosine similarity)
58
  image_index = faiss.IndexFlatIP(image_embeddings.shape[1])
59
  image_index.add(image_embeddings.numpy())
60
 
61
  text_index = faiss.IndexFlatIP(text_embeddings.shape[1])
62
  text_index.add(text_embeddings.numpy())
63
 
 
64
  def image_to_text(image):
65
  image_input = preprocess(image).unsqueeze(0).to(device)
66
  with torch.no_grad():
 
70
  score = round(float(D[0][0]) * 100, 2)
71
  return f"{captions[I[0][0]]}\n(Match Score: {score}%)"
72
 
 
73
  def text_to_image(text):
74
  text_input = clip.tokenize([text]).to(device)
75
  with torch.no_grad():
 
80
  img = images[I[0][0]]
81
  return img, f"Match Score: {score}%"
82
 
 
83
  with gr.Blocks() as demo:
84
  gr.Markdown("## 🔄 Cross-Modal Retriever on Flickr30k (Image ↔ Text Matching)")
85
+
86
  with gr.Tab("🖼️ Image to Text"):
87
  img_input = gr.Image(type="pil", label="Upload Image")
88
  text_output = gr.Textbox(label="Most Similar Caption")