File size: 4,485 Bytes
ce53f55
a1a61d3
ce53f55
 
 
 
a1a61d3
 
 
 
 
 
 
ce53f55
 
a1a61d3
 
 
 
 
 
 
ce53f55
a1a61d3
ce53f55
 
a1a61d3
 
 
 
ce53f55
 
a1a61d3
ce53f55
 
 
a1a61d3
 
ce53f55
 
a1a61d3
 
 
ce53f55
a1a61d3
ce53f55
a1a61d3
ce53f55
a1a61d3
ce53f55
a1a61d3
ce53f55
a1a61d3
 
 
ce53f55
 
a1a61d3
 
 
 
 
ce53f55
a1a61d3
ce53f55
a1a61d3
 
ce53f55
a1a61d3
 
ce53f55
a1a61d3
 
 
 
 
 
 
 
 
ce53f55
 
a1a61d3
 
 
 
 
 
 
 
ce53f55
a1a61d3
ce53f55
 
a1a61d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce53f55
 
 
a1a61d3
 
 
 
 
 
 
ce53f55
 
 
a1a61d3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import time, faiss, gradio as gr, torch, numpy as np
from pathlib import Path
from PIL import Image
from sentence_transformers import SentenceTransformer
from transformers import BlipProcessor, BlipForConditionalGeneration, logging as hf_log

# Make sure the FAISS index + caption array exist 

from scripts.get_assets import ensure_assets  # helper you already have
ensure_assets()                               # download once, then cached

# House-keeping 
hf_log.set_verbosity_error()
print("🟢 fresh run", time.strftime("%H:%M:%S"))

FAISS_INDEX   = Path("scripts/coco_caption_clip.index")
CAPTION_ARRAY = Path("scripts/coco_caption_texts.npy")

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Quick FAISS smoke test

print("Testing basic FAISS functionality…")
try:
    test_index = faiss.IndexFlatL2(512)
    vec        = np.random.rand(1, 512).astype("float32")
    test_index.add(vec)
    D, I = test_index.search(vec, 1)
    print(f"✅ FAISS ok (D={D[0][0]:.3f})")
    FAISS_WORKING = True
except Exception as e:
    print(f"⚠️  FAISS broken: {e}")
    FAISS_WORKING = False


# Load all models

try:
    blip_proc  = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
    blip_model = (BlipForConditionalGeneration
                  .from_pretrained("Salesforce/blip-image-captioning-base")
                  .to(device).eval())
    clip_model = SentenceTransformer("clip-ViT-B-32")
    print("✅ Models loaded")
except Exception as e:
    raise RuntimeError(f"Model load failed: {e}")

# Load FAISS index + captions (or build fallback embeddings)
try:
    captions = np.load(CAPTION_ARRAY, allow_pickle=True)
    if FAISS_WORKING:
        index = faiss.read_index(str(FAISS_INDEX))
        print(f"✅ FAISS index: {index.ntotal} vectors × {index.d}")
        caption_embeddings = None
    else:
        index = None
        print("Building caption embeddings for fallback search…")
        caption_embeddings = clip_model.encode(
            captions.tolist(), convert_to_numpy=True,
            normalize_embeddings=True, show_progress_bar=False
        ).astype("float32")
except Exception as e:
    raise RuntimeError(f"Loading FAISS assets failed: {e}")

# Helpers
@torch.inference_mode()
def pil_to_tensor(img: Image.Image) -> torch.Tensor:
    img = img.convert("RGB").resize((384, 384), Image.Resampling.LANCZOS)
    arr = np.asarray(img, dtype="float32") / 255.0
    mean = np.array([0.48145466, 0.4578275, 0.40821073])
    std  = np.array([0.26862954, 0.26130258, 0.27577711])
    arr  = (arr - mean) / std
    return torch.from_numpy(arr.transpose(2, 0, 1)).unsqueeze(0).to(device)

def fallback_search(vec, k=5):
    sims = caption_embeddings @ vec.T
    idx  = np.argsort(sims.ravel())[::-1][:k]
    dist = 1 - sims[0, idx]
    return dist.reshape(1, -1), idx.reshape(1, -1)

def safe_faiss_search(vec, k=5):
    if index is None:
        return fallback_search(vec, k)
    try:
        D, I = index.search(np.ascontiguousarray(vec), k)
        return D, I
    except Exception as e:
        print(f"FAISS search failed: {e} → fallback")
        return fallback_search(vec, k)

# Main retrieval fn 
@torch.inference_mode()
def retrieve(img: Image.Image, k: int = 5):
    if img is None:
        return "📷 Please upload an image", ""
    k = min(int(k), len(captions))

    # BLIP caption
    ids = blip_model.generate(pil_to_tensor(img), max_new_tokens=20)
    blip_cap = blip_proc.tokenizer.decode(ids[0], skip_special_tokens=True)

    # CLIP embedding
    vec = clip_model.encode([blip_cap], normalize_embeddings=True,
                            convert_to_numpy=True).astype("float32")

    # Similarity search
    D, I = safe_faiss_search(vec, k)
    lines = [f"**{i+1}.** *dist {D[0][i]:.3f}*<br>{captions[I[0][i]]}"
             for i in range(k)]
    return blip_cap, "<br><br>".join(lines)


#  Gradio UI 

demo = gr.Interface(
    fn=retrieve,
    inputs=[gr.Image(type="pil"), gr.Slider(1, 10, value=5, step=1,
                                            label="# of similar captions")],
    outputs=[gr.Textbox(label="BLIP caption"),
             gr.HTML(label="Nearest COCO captions")],
    title="Image-to-Text Retrieval (BLIP + CLIP + FAISS)",
    description=("Upload an image → BLIP generates a caption → CLIP embeds it → "
                 "FAISS retrieves the most similar human-written COCO captions.")
)

if __name__ == "__main__":
    demo.launch()