Spaces:
Sleeping
Sleeping
| import time, faiss, gradio as gr, torch, numpy as np | |
| from pathlib import Path | |
| from PIL import Image | |
| from sentence_transformers import SentenceTransformer | |
| from transformers import BlipProcessor, BlipForConditionalGeneration, logging as hf_log | |
| # Make sure the FAISS index + caption array exist | |
| from scripts.get_assets import ensure_assets # helper you already have | |
| ensure_assets() # download once, then cached | |
| # House-keeping | |
| hf_log.set_verbosity_error() | |
| print("🟢 fresh run", time.strftime("%H:%M:%S")) | |
| FAISS_INDEX = Path("scripts/coco_caption_clip.index") | |
| CAPTION_ARRAY = Path("scripts/coco_caption_texts.npy") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {device}") | |
| # Quick FAISS smoke test | |
| print("Testing basic FAISS functionality…") | |
| try: | |
| test_index = faiss.IndexFlatL2(512) | |
| vec = np.random.rand(1, 512).astype("float32") | |
| test_index.add(vec) | |
| D, I = test_index.search(vec, 1) | |
| print(f"✅ FAISS ok (D={D[0][0]:.3f})") | |
| FAISS_WORKING = True | |
| except Exception as e: | |
| print(f"⚠️ FAISS broken: {e}") | |
| FAISS_WORKING = False | |
| # Load all models | |
| try: | |
| blip_proc = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") | |
| blip_model = (BlipForConditionalGeneration | |
| .from_pretrained("Salesforce/blip-image-captioning-base") | |
| .to(device).eval()) | |
| clip_model = SentenceTransformer("clip-ViT-B-32") | |
| print("✅ Models loaded") | |
| except Exception as e: | |
| raise RuntimeError(f"Model load failed: {e}") | |
| # Load FAISS index + captions (or build fallback embeddings) | |
| try: | |
| captions = np.load(CAPTION_ARRAY, allow_pickle=True) | |
| if FAISS_WORKING: | |
| index = faiss.read_index(str(FAISS_INDEX)) | |
| print(f"✅ FAISS index: {index.ntotal} vectors × {index.d}") | |
| caption_embeddings = None | |
| else: | |
| index = None | |
| print("Building caption embeddings for fallback search…") | |
| caption_embeddings = clip_model.encode( | |
| captions.tolist(), convert_to_numpy=True, | |
| normalize_embeddings=True, show_progress_bar=False | |
| ).astype("float32") | |
| except Exception as e: | |
| raise RuntimeError(f"Loading FAISS assets failed: {e}") | |
| # Helpers | |
| def pil_to_tensor(img: Image.Image) -> torch.Tensor: | |
| img = img.convert("RGB").resize((384, 384), Image.Resampling.LANCZOS) | |
| arr = np.asarray(img, dtype="float32") / 255.0 | |
| mean = np.array([0.48145466, 0.4578275, 0.40821073]) | |
| std = np.array([0.26862954, 0.26130258, 0.27577711]) | |
| arr = (arr - mean) / std | |
| return torch.from_numpy(arr.transpose(2, 0, 1)).unsqueeze(0).to(device) | |
| def fallback_search(vec, k=5): | |
| sims = caption_embeddings @ vec.T | |
| idx = np.argsort(sims.ravel())[::-1][:k] | |
| dist = 1 - sims[0, idx] | |
| return dist.reshape(1, -1), idx.reshape(1, -1) | |
| def safe_faiss_search(vec, k=5): | |
| if index is None: | |
| return fallback_search(vec, k) | |
| try: | |
| D, I = index.search(np.ascontiguousarray(vec), k) | |
| return D, I | |
| except Exception as e: | |
| print(f"FAISS search failed: {e} → fallback") | |
| return fallback_search(vec, k) | |
| # Main retrieval fn | |
| def retrieve(img: Image.Image, k: int = 5): | |
| if img is None: | |
| return "📷 Please upload an image", "" | |
| k = min(int(k), len(captions)) | |
| # BLIP caption | |
| ids = blip_model.generate(pil_to_tensor(img), max_new_tokens=20) | |
| blip_cap = blip_proc.tokenizer.decode(ids[0], skip_special_tokens=True) | |
| # CLIP embedding | |
| vec = clip_model.encode([blip_cap], normalize_embeddings=True, | |
| convert_to_numpy=True).astype("float32") | |
| # Similarity search | |
| D, I = safe_faiss_search(vec, k) | |
| lines = [f"**{i+1}.** *dist {D[0][i]:.3f}*<br>{captions[I[0][i]]}" | |
| for i in range(k)] | |
| return blip_cap, "<br><br>".join(lines) | |
| # Gradio UI | |
| demo = gr.Interface( | |
| fn=retrieve, | |
| inputs=[gr.Image(type="pil"), gr.Slider(1, 10, value=5, step=1, | |
| label="# of similar captions")], | |
| outputs=[gr.Textbox(label="BLIP caption"), | |
| gr.HTML(label="Nearest COCO captions")], | |
| title="Image-to-Text Retrieval (BLIP + CLIP + FAISS)", | |
| description=("Upload an image → BLIP generates a caption → CLIP embeds it → " | |
| "FAISS retrieves the most similar human-written COCO captions.") | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |