stephenebert commited on
Commit
a1a61d3
Β·
verified Β·
1 Parent(s): 9581c84

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -160
app.py CHANGED
@@ -1,199 +1,128 @@
1
  import time, faiss, gradio as gr, torch, numpy as np
 
2
  from PIL import Image
3
  from sentence_transformers import SentenceTransformer
4
  from transformers import BlipProcessor, BlipForConditionalGeneration, logging as hf_log
5
- hf_log.set_verbosity_error()
6
 
 
 
 
 
 
 
 
7
  print("🟒 fresh run", time.strftime("%H:%M:%S"))
8
 
9
- FAISS_INDEX = "scripts/coco_caption_clip.index"
10
- CAPTION_ARRAY = "scripts/coco_caption_texts.npy"
 
 
 
 
 
11
 
12
- # Test basic FAISS functionality first
13
- print("Testing basic FAISS functionality...")
14
  try:
15
  test_index = faiss.IndexFlatL2(512)
16
- test_vec = np.random.random((1, 512)).astype(np.float32)
17
- test_vec = np.ascontiguousarray(test_vec)
18
- test_index.add(test_vec)
19
- D, I = test_index.search(test_vec, 1)
20
- print(f"Basic FAISS test passed: D={D[0][0]:.3f}, I={I[0][0]}")
21
  FAISS_WORKING = True
22
  except Exception as e:
23
- print(f"Basic FAISS test failed: {e}")
24
  FAISS_WORKING = False
25
 
26
- device = "cuda" if torch.cuda.is_available() else "cpu"
27
- print(f"Using device: {device}")
28
 
29
- # Load models
 
30
  try:
31
  blip_proc = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
32
- blip_model = BlipForConditionalGeneration.from_pretrained(
33
- "Salesforce/blip-image-captioning-base").to(device).eval()
 
34
  clip_model = SentenceTransformer("clip-ViT-B-32")
35
- print("Models loaded successfully")
36
  except Exception as e:
37
- print(f"Error loading models: {e}")
38
- raise
39
 
40
- # Load FAISS index and captions
41
  try:
 
42
  if FAISS_WORKING:
43
- index = faiss.read_index(FAISS_INDEX)
44
- captions = np.load(CAPTION_ARRAY, allow_pickle=True)
45
- print(f"FAISS index loaded: {index.ntotal} vectors, dimension {index.d}")
46
  else:
47
- print("FAISS not working, will use fallback similarity search")
48
  index = None
49
- captions = np.load(CAPTION_ARRAY, allow_pickle=True)
50
- # Create embeddings for all captions for fallback
51
- print("Creating embeddings for fallback search...")
52
- caption_embeddings = clip_model.encode(captions.tolist(), normalize_embeddings=True, convert_to_numpy=True)
53
- caption_embeddings = np.array(caption_embeddings, dtype=np.float32)
54
- print(f"Created {len(caption_embeddings)} caption embeddings")
55
  except Exception as e:
56
- print(f"Error loading FAISS index or captions: {e}")
57
- raise
58
 
 
 
59
  def pil_to_tensor(img: Image.Image) -> torch.Tensor:
60
- """Convert PIL image to tensor for BLIP model"""
61
- # Convert to RGB and resize
62
- img_rgb = img.convert("RGB")
63
- img_resized = img_rgb.resize((384, 384), Image.Resampling.LANCZOS)
64
-
65
- # Convert to numpy array
66
- img_array = np.array(img_resized, dtype=np.float32) / 255.0
67
-
68
- # Apply BLIP normalization
69
  mean = np.array([0.48145466, 0.4578275, 0.40821073])
70
- std = np.array([0.26862954, 0.26130258, 0.27577711])
71
- img_normalized = (img_array - mean) / std
72
-
73
- # Convert to tensor format [1, 3, H, W]
74
- img_tensor = torch.from_numpy(img_normalized.transpose(2, 0, 1)).float()
75
- return img_tensor.unsqueeze(0).to(device)
76
-
77
- def fallback_similarity_search(query_vec, k=5):
78
- """Fallback similarity search using numpy when FAISS fails"""
79
- # Compute cosine similarity
80
- similarities = np.dot(caption_embeddings, query_vec.T).flatten()
81
-
82
- # Get top-k indices
83
- top_indices = np.argsort(similarities)[::-1][:k]
84
-
85
- # Return in FAISS format (distances, indices)
86
- distances = 1 - similarities[top_indices] # Convert similarity to distance
87
- return distances.reshape(1, -1), top_indices.reshape(1, -1)
88
 
89
  def safe_faiss_search(vec, k=5):
90
- """Safely perform FAISS search with multiple fallback methods"""
91
- if not FAISS_WORKING or index is None:
92
- return fallback_similarity_search(vec, k)
93
-
94
- # Try multiple vector preparation methods
95
- methods = [
96
- lambda v: v, # Use as-is
97
- lambda v: np.ascontiguousarray(v), # Ensure contiguous
98
- lambda v: np.array(v, dtype=np.float32, copy=True), # Force copy
99
- lambda v: np.array(v.tolist(), dtype=np.float32), # Convert via list
100
- ]
101
-
102
- for i, method in enumerate(methods):
103
- try:
104
- vec_processed = method(vec)
105
- if vec_processed.ndim == 1:
106
- vec_processed = vec_processed.reshape(1, -1)
107
-
108
- # Verify array properties
109
- if not vec_processed.flags.c_contiguous:
110
- vec_processed = np.ascontiguousarray(vec_processed)
111
-
112
- print(f"Method {i+1}: shape={vec_processed.shape}, dtype={vec_processed.dtype}, contiguous={vec_processed.flags.c_contiguous}")
113
-
114
- D, I = index.search(vec_processed, k)
115
- print(f"FAISS search successful with method {i+1}")
116
- return D, I
117
-
118
- except Exception as e:
119
- print(f"Method {i+1} failed: {e}")
120
- continue
121
-
122
- # If all FAISS methods fail, use fallback
123
- print("⚠️ All FAISS methods failed, using fallback similarity search")
124
- return fallback_similarity_search(vec, k)
125
 
 
126
  @torch.inference_mode()
127
  def retrieve(img: Image.Image, k: int = 5):
128
- """Main retrieval function"""
129
- try:
130
- if img is None:
131
- return "No image provided", "Please upload an image."
132
-
133
- # Ensure k is within bounds
134
- k = min(k, len(captions))
135
-
136
- print(f"Processing image with k={k}")
137
-
138
- # Generate caption with BLIP
139
- px = pil_to_tensor(img)
140
- ids = blip_model.generate(px, max_new_tokens=20)
141
- blip_cap = blip_proc.tokenizer.decode(ids[0], skip_special_tokens=True)
142
- print(f"BLIP caption: {blip_cap}")
143
-
144
- # Get embeddings from CLIP model
145
- embeddings = clip_model.encode([blip_cap], normalize_embeddings=True, convert_to_numpy=True)
146
-
147
- # Ensure proper numpy array format
148
- vec = np.array(embeddings, dtype=np.float32)
149
- if vec.ndim == 1:
150
- vec = vec.reshape(1, -1)
151
-
152
- print(f"Embedding shape: {vec.shape}, dtype: {vec.dtype}")
153
-
154
- # Perform similarity search
155
- D, I = safe_faiss_search(vec, k)
156
-
157
- # Format results
158
- if FAISS_WORKING and index is not None:
159
- neigh = [f"**{i+1}.** *distance {D[0][i]:.3f}*<br>{captions[I[0][i]]}"
160
- for i in range(k)]
161
- else:
162
- neigh = [f"**{i+1}.** *distance {D[0][i]:.3f}*<br>{captions[I[0][i]]}"
163
- for i in range(k)]
164
-
165
- return blip_cap, "<br><br>".join(neigh)
166
-
167
- except Exception as e:
168
- print(f"Error in retrieve: {str(e)}")
169
- import traceback
170
- traceback.print_exc()
171
- return f"Error: {str(e)}", "Please try again with a different image."
172
 
173
- # Create Gradio interface
174
  demo = gr.Interface(
175
  fn=retrieve,
176
- inputs=[
177
- gr.Image(type="pil", label="Upload Image"),
178
- gr.Slider(1, 10, 5, 1, label="Number of Similar Captions")
179
- ],
180
- outputs=[
181
- gr.Textbox(label="BLIP Generated Caption"),
182
- gr.HTML(label="Most Similar COCO Captions")
183
- ],
184
- title="Image-to-Text Retrieval Demo (BLIP + CLIP + FAISS)",
185
- description=("Upload an image β†’ AI generates caption (BLIP) β†’ finds embedding (CLIP) β†’ "
186
- "retrieves most similar captions from COCO dataset" +
187
- (" (FAISS)" if FAISS_WORKING else " (Fallback Search)"))
188
  )
189
 
190
  if __name__ == "__main__":
191
- print("Launching Gradio demo...")
192
- demo.launch(share = True) # add share=True if you need a public link
193
-
194
- """
195
- Usage:
196
- conda activate capstone-gradio-py310
197
- cd ~/Desktop/Springboard/Capstone/extra_credit
198
- python gradio_demo.py
199
- """
 
1
  import time, faiss, gradio as gr, torch, numpy as np
2
+ from pathlib import Path
3
  from PIL import Image
4
  from sentence_transformers import SentenceTransformer
5
  from transformers import BlipProcessor, BlipForConditionalGeneration, logging as hf_log
 
6
 
7
+ # Make sure the FAISS index + caption array exist
8
+
9
+ from scripts.get_assets import ensure_assets # helper you already have
10
+ ensure_assets() # download once, then cached
11
+
12
+ # House-keeping
13
+ hf_log.set_verbosity_error()
14
  print("🟒 fresh run", time.strftime("%H:%M:%S"))
15
 
16
+ FAISS_INDEX = Path("scripts/coco_caption_clip.index")
17
+ CAPTION_ARRAY = Path("scripts/coco_caption_texts.npy")
18
+
19
+ device = "cuda" if torch.cuda.is_available() else "cpu"
20
+ print(f"Using device: {device}")
21
+
22
+ # Quick FAISS smoke test
23
 
24
+ print("Testing basic FAISS functionality…")
 
25
  try:
26
  test_index = faiss.IndexFlatL2(512)
27
+ vec = np.random.rand(1, 512).astype("float32")
28
+ test_index.add(vec)
29
+ D, I = test_index.search(vec, 1)
30
+ print(f"βœ… FAISS ok (D={D[0][0]:.3f})")
 
31
  FAISS_WORKING = True
32
  except Exception as e:
33
+ print(f"⚠️ FAISS broken: {e}")
34
  FAISS_WORKING = False
35
 
 
 
36
 
37
+ # Load all models
38
+
39
  try:
40
  blip_proc = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
41
+ blip_model = (BlipForConditionalGeneration
42
+ .from_pretrained("Salesforce/blip-image-captioning-base")
43
+ .to(device).eval())
44
  clip_model = SentenceTransformer("clip-ViT-B-32")
45
+ print("βœ… Models loaded")
46
  except Exception as e:
47
+ raise RuntimeError(f"Model load failed: {e}")
 
48
 
49
+ # Load FAISS index + captions (or build fallback embeddings)
50
  try:
51
+ captions = np.load(CAPTION_ARRAY, allow_pickle=True)
52
  if FAISS_WORKING:
53
+ index = faiss.read_index(str(FAISS_INDEX))
54
+ print(f"βœ… FAISS index: {index.ntotal} vectors Γ— {index.d}")
55
+ caption_embeddings = None
56
  else:
 
57
  index = None
58
+ print("Building caption embeddings for fallback search…")
59
+ caption_embeddings = clip_model.encode(
60
+ captions.tolist(), convert_to_numpy=True,
61
+ normalize_embeddings=True, show_progress_bar=False
62
+ ).astype("float32")
 
63
  except Exception as e:
64
+ raise RuntimeError(f"Loading FAISS assets failed: {e}")
 
65
 
66
+ # Helpers
67
+ @torch.inference_mode()
68
  def pil_to_tensor(img: Image.Image) -> torch.Tensor:
69
+ img = img.convert("RGB").resize((384, 384), Image.Resampling.LANCZOS)
70
+ arr = np.asarray(img, dtype="float32") / 255.0
 
 
 
 
 
 
 
71
  mean = np.array([0.48145466, 0.4578275, 0.40821073])
72
+ std = np.array([0.26862954, 0.26130258, 0.27577711])
73
+ arr = (arr - mean) / std
74
+ return torch.from_numpy(arr.transpose(2, 0, 1)).unsqueeze(0).to(device)
75
+
76
+ def fallback_search(vec, k=5):
77
+ sims = caption_embeddings @ vec.T
78
+ idx = np.argsort(sims.ravel())[::-1][:k]
79
+ dist = 1 - sims[0, idx]
80
+ return dist.reshape(1, -1), idx.reshape(1, -1)
 
 
 
 
 
 
 
 
 
81
 
82
  def safe_faiss_search(vec, k=5):
83
+ if index is None:
84
+ return fallback_search(vec, k)
85
+ try:
86
+ D, I = index.search(np.ascontiguousarray(vec), k)
87
+ return D, I
88
+ except Exception as e:
89
+ print(f"FAISS search failed: {e} β†’ fallback")
90
+ return fallback_search(vec, k)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
+ # Main retrieval fn
93
  @torch.inference_mode()
94
  def retrieve(img: Image.Image, k: int = 5):
95
+ if img is None:
96
+ return "πŸ“· Please upload an image", ""
97
+ k = min(int(k), len(captions))
98
+
99
+ # BLIP caption
100
+ ids = blip_model.generate(pil_to_tensor(img), max_new_tokens=20)
101
+ blip_cap = blip_proc.tokenizer.decode(ids[0], skip_special_tokens=True)
102
+
103
+ # CLIP embedding
104
+ vec = clip_model.encode([blip_cap], normalize_embeddings=True,
105
+ convert_to_numpy=True).astype("float32")
106
+
107
+ # Similarity search
108
+ D, I = safe_faiss_search(vec, k)
109
+ lines = [f"**{i+1}.** *dist {D[0][i]:.3f}*<br>{captions[I[0][i]]}"
110
+ for i in range(k)]
111
+ return blip_cap, "<br><br>".join(lines)
112
+
113
+
114
+ # Gradio UI
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
 
116
  demo = gr.Interface(
117
  fn=retrieve,
118
+ inputs=[gr.Image(type="pil"), gr.Slider(1, 10, value=5, step=1,
119
+ label="# of similar captions")],
120
+ outputs=[gr.Textbox(label="BLIP caption"),
121
+ gr.HTML(label="Nearest COCO captions")],
122
+ title="Image-to-Text Retrieval (BLIP + CLIP + FAISS)",
123
+ description=("Upload an image β†’ BLIP generates a caption β†’ CLIP embeds it β†’ "
124
+ "FAISS retrieves the most similar human-written COCO captions.")
 
 
 
 
 
125
  )
126
 
127
  if __name__ == "__main__":
128
+ demo.launch()