import time, faiss, gradio as gr, torch, numpy as np from pathlib import Path from PIL import Image from sentence_transformers import SentenceTransformer from transformers import BlipProcessor, BlipForConditionalGeneration, logging as hf_log # Make sure the FAISS index + caption array exist from scripts.get_assets import ensure_assets # helper you already have ensure_assets() # download once, then cached # House-keeping hf_log.set_verbosity_error() print("🟢 fresh run", time.strftime("%H:%M:%S")) FAISS_INDEX = Path("scripts/coco_caption_clip.index") CAPTION_ARRAY = Path("scripts/coco_caption_texts.npy") device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") # Quick FAISS smoke test print("Testing basic FAISS functionality…") try: test_index = faiss.IndexFlatL2(512) vec = np.random.rand(1, 512).astype("float32") test_index.add(vec) D, I = test_index.search(vec, 1) print(f"✅ FAISS ok (D={D[0][0]:.3f})") FAISS_WORKING = True except Exception as e: print(f"⚠️ FAISS broken: {e}") FAISS_WORKING = False # Load all models try: blip_proc = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") blip_model = (BlipForConditionalGeneration .from_pretrained("Salesforce/blip-image-captioning-base") .to(device).eval()) clip_model = SentenceTransformer("clip-ViT-B-32") print("✅ Models loaded") except Exception as e: raise RuntimeError(f"Model load failed: {e}") # Load FAISS index + captions (or build fallback embeddings) try: captions = np.load(CAPTION_ARRAY, allow_pickle=True) if FAISS_WORKING: index = faiss.read_index(str(FAISS_INDEX)) print(f"✅ FAISS index: {index.ntotal} vectors × {index.d}") caption_embeddings = None else: index = None print("Building caption embeddings for fallback search…") caption_embeddings = clip_model.encode( captions.tolist(), convert_to_numpy=True, normalize_embeddings=True, show_progress_bar=False ).astype("float32") except Exception as e: raise RuntimeError(f"Loading FAISS assets failed: {e}") # Helpers @torch.inference_mode() def pil_to_tensor(img: Image.Image) -> torch.Tensor: img = img.convert("RGB").resize((384, 384), Image.Resampling.LANCZOS) arr = np.asarray(img, dtype="float32") / 255.0 mean = np.array([0.48145466, 0.4578275, 0.40821073]) std = np.array([0.26862954, 0.26130258, 0.27577711]) arr = (arr - mean) / std return torch.from_numpy(arr.transpose(2, 0, 1)).unsqueeze(0).to(device) def fallback_search(vec, k=5): sims = caption_embeddings @ vec.T idx = np.argsort(sims.ravel())[::-1][:k] dist = 1 - sims[0, idx] return dist.reshape(1, -1), idx.reshape(1, -1) def safe_faiss_search(vec, k=5): if index is None: return fallback_search(vec, k) try: D, I = index.search(np.ascontiguousarray(vec), k) return D, I except Exception as e: print(f"FAISS search failed: {e} → fallback") return fallback_search(vec, k) # Main retrieval fn @torch.inference_mode() def retrieve(img: Image.Image, k: int = 5): if img is None: return "📷 Please upload an image", "" k = min(int(k), len(captions)) # BLIP caption ids = blip_model.generate(pil_to_tensor(img), max_new_tokens=20) blip_cap = blip_proc.tokenizer.decode(ids[0], skip_special_tokens=True) # CLIP embedding vec = clip_model.encode([blip_cap], normalize_embeddings=True, convert_to_numpy=True).astype("float32") # Similarity search D, I = safe_faiss_search(vec, k) lines = [f"**{i+1}.** *dist {D[0][i]:.3f}*
{captions[I[0][i]]}" for i in range(k)] return blip_cap, "

".join(lines) # Gradio UI demo = gr.Interface( fn=retrieve, inputs=[gr.Image(type="pil"), gr.Slider(1, 10, value=5, step=1, label="# of similar captions")], outputs=[gr.Textbox(label="BLIP caption"), gr.HTML(label="Nearest COCO captions")], title="Image-to-Text Retrieval (BLIP + CLIP + FAISS)", description=("Upload an image → BLIP generates a caption → CLIP embeds it → " "FAISS retrieves the most similar human-written COCO captions.") ) if __name__ == "__main__": demo.launch()