Spaces:
Sleeping
Sleeping
File size: 4,485 Bytes
ce53f55 a1a61d3 ce53f55 a1a61d3 ce53f55 a1a61d3 ce53f55 a1a61d3 ce53f55 a1a61d3 ce53f55 a1a61d3 ce53f55 a1a61d3 ce53f55 a1a61d3 ce53f55 a1a61d3 ce53f55 a1a61d3 ce53f55 a1a61d3 ce53f55 a1a61d3 ce53f55 a1a61d3 ce53f55 a1a61d3 ce53f55 a1a61d3 ce53f55 a1a61d3 ce53f55 a1a61d3 ce53f55 a1a61d3 ce53f55 a1a61d3 ce53f55 a1a61d3 ce53f55 a1a61d3 ce53f55 a1a61d3 ce53f55 a1a61d3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 | import time, faiss, gradio as gr, torch, numpy as np
from pathlib import Path
from PIL import Image
from sentence_transformers import SentenceTransformer
from transformers import BlipProcessor, BlipForConditionalGeneration, logging as hf_log
# Make sure the FAISS index + caption array exist
from scripts.get_assets import ensure_assets # helper you already have
ensure_assets() # download once, then cached
# House-keeping
hf_log.set_verbosity_error()
print("🟢 fresh run", time.strftime("%H:%M:%S"))
FAISS_INDEX = Path("scripts/coco_caption_clip.index")
CAPTION_ARRAY = Path("scripts/coco_caption_texts.npy")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Quick FAISS smoke test
print("Testing basic FAISS functionality…")
try:
test_index = faiss.IndexFlatL2(512)
vec = np.random.rand(1, 512).astype("float32")
test_index.add(vec)
D, I = test_index.search(vec, 1)
print(f"✅ FAISS ok (D={D[0][0]:.3f})")
FAISS_WORKING = True
except Exception as e:
print(f"⚠️ FAISS broken: {e}")
FAISS_WORKING = False
# Load all models
try:
blip_proc = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model = (BlipForConditionalGeneration
.from_pretrained("Salesforce/blip-image-captioning-base")
.to(device).eval())
clip_model = SentenceTransformer("clip-ViT-B-32")
print("✅ Models loaded")
except Exception as e:
raise RuntimeError(f"Model load failed: {e}")
# Load FAISS index + captions (or build fallback embeddings)
try:
captions = np.load(CAPTION_ARRAY, allow_pickle=True)
if FAISS_WORKING:
index = faiss.read_index(str(FAISS_INDEX))
print(f"✅ FAISS index: {index.ntotal} vectors × {index.d}")
caption_embeddings = None
else:
index = None
print("Building caption embeddings for fallback search…")
caption_embeddings = clip_model.encode(
captions.tolist(), convert_to_numpy=True,
normalize_embeddings=True, show_progress_bar=False
).astype("float32")
except Exception as e:
raise RuntimeError(f"Loading FAISS assets failed: {e}")
# Helpers
@torch.inference_mode()
def pil_to_tensor(img: Image.Image) -> torch.Tensor:
img = img.convert("RGB").resize((384, 384), Image.Resampling.LANCZOS)
arr = np.asarray(img, dtype="float32") / 255.0
mean = np.array([0.48145466, 0.4578275, 0.40821073])
std = np.array([0.26862954, 0.26130258, 0.27577711])
arr = (arr - mean) / std
return torch.from_numpy(arr.transpose(2, 0, 1)).unsqueeze(0).to(device)
def fallback_search(vec, k=5):
sims = caption_embeddings @ vec.T
idx = np.argsort(sims.ravel())[::-1][:k]
dist = 1 - sims[0, idx]
return dist.reshape(1, -1), idx.reshape(1, -1)
def safe_faiss_search(vec, k=5):
if index is None:
return fallback_search(vec, k)
try:
D, I = index.search(np.ascontiguousarray(vec), k)
return D, I
except Exception as e:
print(f"FAISS search failed: {e} → fallback")
return fallback_search(vec, k)
# Main retrieval fn
@torch.inference_mode()
def retrieve(img: Image.Image, k: int = 5):
if img is None:
return "📷 Please upload an image", ""
k = min(int(k), len(captions))
# BLIP caption
ids = blip_model.generate(pil_to_tensor(img), max_new_tokens=20)
blip_cap = blip_proc.tokenizer.decode(ids[0], skip_special_tokens=True)
# CLIP embedding
vec = clip_model.encode([blip_cap], normalize_embeddings=True,
convert_to_numpy=True).astype("float32")
# Similarity search
D, I = safe_faiss_search(vec, k)
lines = [f"**{i+1}.** *dist {D[0][i]:.3f}*<br>{captions[I[0][i]]}"
for i in range(k)]
return blip_cap, "<br><br>".join(lines)
# Gradio UI
demo = gr.Interface(
fn=retrieve,
inputs=[gr.Image(type="pil"), gr.Slider(1, 10, value=5, step=1,
label="# of similar captions")],
outputs=[gr.Textbox(label="BLIP caption"),
gr.HTML(label="Nearest COCO captions")],
title="Image-to-Text Retrieval (BLIP + CLIP + FAISS)",
description=("Upload an image → BLIP generates a caption → CLIP embeds it → "
"FAISS retrieves the most similar human-written COCO captions.")
)
if __name__ == "__main__":
demo.launch()
|