akash4552 commited on
Commit
ea1afb6
·
verified ·
1 Parent(s): 6514a05

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -0
app.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import clip
4
+ import faiss
5
+ import numpy as np
6
+ from PIL import Image
7
+ import os
8
+
9
+ # Load CLIP model
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
11
+ model, preprocess = clip.load("ViT-B/32", device=device)
12
+
13
+ # Global storage
14
+ image_paths = []
15
+ image_embeddings = None
16
+ faiss_index = None
17
+
18
+ def build_faiss_index(images):
19
+ """Build FAISS index from uploaded images"""
20
+ global image_paths, image_embeddings, faiss_index
21
+ image_paths = []
22
+ embeddings = []
23
+
24
+ for img in images:
25
+ image_paths.append(img.name)
26
+ pil_img = Image.open(img.name).convert("RGB")
27
+ tensor_img = preprocess(pil_img).unsqueeze(0).to(device)
28
+
29
+ with torch.no_grad():
30
+ emb = model.encode_image(tensor_img)
31
+ emb /= emb.norm(dim=-1, keepdim=True)
32
+ embeddings.append(emb.cpu().numpy())
33
+
34
+ image_embeddings = np.vstack(embeddings).astype("float32")
35
+
36
+ # Build FAISS index
37
+ d = image_embeddings.shape[1] # embedding dimension
38
+ faiss_index = faiss.IndexFlatIP(d) # cosine similarity (inner product)
39
+ faiss_index.add(image_embeddings)
40
+
41
+ return f"Indexed {len(image_paths)} images."
42
+
43
+ def search(query, top_k=5):
44
+ """Search top-k most similar images given a text query"""
45
+ global image_paths, faiss_index, image_embeddings
46
+ if faiss_index is None:
47
+ return "Please upload and index images first.", []
48
+
49
+ # Encode query
50
+ text = clip.tokenize([query]).to(device)
51
+ with torch.no_grad():
52
+ text_emb = model.encode_text(text)
53
+ text_emb /= text_emb.norm(dim=-1, keepdim=True)
54
+
55
+ text_emb = text_emb.cpu().numpy().astype("float32")
56
+
57
+ # Search FAISS
58
+ scores, indices = faiss_index.search(text_emb, top_k)
59
+ results = []
60
+ for idx, score in zip(indices[0], scores[0]):
61
+ img = image_paths[idx]
62
+ results.append((img, float(score)))
63
+
64
+ return f"Top {top_k} results for '{query}'", results
65
+
66
+ def display_results(query, top_k=5):
67
+ message, results = search(query, top_k)
68
+ images, scores = [], []
69
+ for img, score in results:
70
+ images.append(img)
71
+ scores.append(f"{score:.3f}")
72
+ return message, images, scores
73
+
74
+ with gr.Blocks() as demo:
75
+ gr.Markdown("## Image Search with CLIP + FAISS 🚀")
76
+
77
+ with gr.Row():
78
+ img_upload = gr.File(file_types=[".png", ".jpg", ".jpeg"], file_count="multiple")
79
+ build_btn = gr.Button("Build Index")
80
+
81
+ status = gr.Textbox(label="Status")
82
+
83
+ with gr.Row():
84
+ query = gr.Textbox(label="Search Query")
85
+ top_k = gr.Slider(1, 20, value=5, step=1, label="Top K Results")
86
+ search_btn = gr.Button("Search")
87
+
88
+ output_text = gr.Textbox(label="Results")
89
+ output_gallery = gr.Gallery(label="Ranked Images").style(grid=[5], height="auto")
90
+ output_scores = gr.Textbox(label="Similarity Scores")
91
+
92
+ build_btn.click(fn=build_faiss_index, inputs=[img_upload], outputs=[status])
93
+ search_btn.click(fn=display_results, inputs=[query, top_k], outputs=[output_text, output_gallery, output_scores])
94
+
95
+ demo.launch()