import time, faiss, gradio as gr, torch, numpy as np
from pathlib import Path
from PIL import Image
from sentence_transformers import SentenceTransformer
from transformers import BlipProcessor, BlipForConditionalGeneration, logging as hf_log
# Make sure the FAISS index + caption array exist
from scripts.get_assets import ensure_assets # helper you already have
ensure_assets() # download once, then cached
# House-keeping
hf_log.set_verbosity_error()
print("🟢 fresh run", time.strftime("%H:%M:%S"))
FAISS_INDEX = Path("scripts/coco_caption_clip.index")
CAPTION_ARRAY = Path("scripts/coco_caption_texts.npy")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Quick FAISS smoke test
print("Testing basic FAISS functionality…")
try:
test_index = faiss.IndexFlatL2(512)
vec = np.random.rand(1, 512).astype("float32")
test_index.add(vec)
D, I = test_index.search(vec, 1)
print(f"âś… FAISS ok (D={D[0][0]:.3f})")
FAISS_WORKING = True
except Exception as e:
print(f"⚠️ FAISS broken: {e}")
FAISS_WORKING = False
# Load all models
try:
blip_proc = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model = (BlipForConditionalGeneration
.from_pretrained("Salesforce/blip-image-captioning-base")
.to(device).eval())
clip_model = SentenceTransformer("clip-ViT-B-32")
print("âś… Models loaded")
except Exception as e:
raise RuntimeError(f"Model load failed: {e}")
# Load FAISS index + captions (or build fallback embeddings)
try:
captions = np.load(CAPTION_ARRAY, allow_pickle=True)
if FAISS_WORKING:
index = faiss.read_index(str(FAISS_INDEX))
print(f"âś… FAISS index: {index.ntotal} vectors Ă— {index.d}")
caption_embeddings = None
else:
index = None
print("Building caption embeddings for fallback search…")
caption_embeddings = clip_model.encode(
captions.tolist(), convert_to_numpy=True,
normalize_embeddings=True, show_progress_bar=False
).astype("float32")
except Exception as e:
raise RuntimeError(f"Loading FAISS assets failed: {e}")
# Helpers
@torch.inference_mode()
def pil_to_tensor(img: Image.Image) -> torch.Tensor:
img = img.convert("RGB").resize((384, 384), Image.Resampling.LANCZOS)
arr = np.asarray(img, dtype="float32") / 255.0
mean = np.array([0.48145466, 0.4578275, 0.40821073])
std = np.array([0.26862954, 0.26130258, 0.27577711])
arr = (arr - mean) / std
return torch.from_numpy(arr.transpose(2, 0, 1)).unsqueeze(0).to(device)
def fallback_search(vec, k=5):
sims = caption_embeddings @ vec.T
idx = np.argsort(sims.ravel())[::-1][:k]
dist = 1 - sims[0, idx]
return dist.reshape(1, -1), idx.reshape(1, -1)
def safe_faiss_search(vec, k=5):
if index is None:
return fallback_search(vec, k)
try:
D, I = index.search(np.ascontiguousarray(vec), k)
return D, I
except Exception as e:
print(f"FAISS search failed: {e} → fallback")
return fallback_search(vec, k)
# Main retrieval fn
@torch.inference_mode()
def retrieve(img: Image.Image, k: int = 5):
if img is None:
return "đź“· Please upload an image", ""
k = min(int(k), len(captions))
# BLIP caption
ids = blip_model.generate(pil_to_tensor(img), max_new_tokens=20)
blip_cap = blip_proc.tokenizer.decode(ids[0], skip_special_tokens=True)
# CLIP embedding
vec = clip_model.encode([blip_cap], normalize_embeddings=True,
convert_to_numpy=True).astype("float32")
# Similarity search
D, I = safe_faiss_search(vec, k)
lines = [f"**{i+1}.** *dist {D[0][i]:.3f}*
{captions[I[0][i]]}"
for i in range(k)]
return blip_cap, "
".join(lines)
# Gradio UI
demo = gr.Interface(
fn=retrieve,
inputs=[gr.Image(type="pil"), gr.Slider(1, 10, value=5, step=1,
label="# of similar captions")],
outputs=[gr.Textbox(label="BLIP caption"),
gr.HTML(label="Nearest COCO captions")],
title="Image-to-Text Retrieval (BLIP + CLIP + FAISS)",
description=("Upload an image → BLIP generates a caption → CLIP embeds it → "
"FAISS retrieves the most similar human-written COCO captions.")
)
if __name__ == "__main__":
demo.launch()