akash4552 commited on
Commit
9a5ac5b
·
verified ·
1 Parent(s): 7e4d943

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +349 -82
app.py CHANGED
@@ -1,95 +1,362 @@
1
- import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import torch
3
- import clip
4
- import faiss
5
  import numpy as np
6
- from PIL import Image
7
- import os
8
 
9
- # Load CLIP model
10
- device = "cuda" if torch.cuda.is_available() else "cpu"
11
- model, preprocess = clip.load("ViT-B/32", device=device)
 
 
 
 
 
 
 
 
 
 
12
 
13
- # Global storage
14
- image_paths = []
15
- image_embeddings = None
16
- faiss_index = None
17
 
18
- def build_faiss_index(images):
19
- """Build FAISS index from uploaded images"""
20
- global image_paths, image_embeddings, faiss_index
21
- image_paths = []
22
- embeddings = []
23
 
24
- for img in images:
25
- image_paths.append(img.name)
26
- pil_img = Image.open(img.name).convert("RGB")
27
- tensor_img = preprocess(pil_img).unsqueeze(0).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
- with torch.no_grad():
30
- emb = model.encode_image(tensor_img)
31
- emb /= emb.norm(dim=-1, keepdim=True)
32
- embeddings.append(emb.cpu().numpy())
33
 
34
- image_embeddings = np.vstack(embeddings).astype("float32")
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- # Build FAISS index
37
- d = image_embeddings.shape[1] # embedding dimension
38
- faiss_index = faiss.IndexFlatIP(d) # cosine similarity (inner product)
39
- faiss_index.add(image_embeddings)
40
 
41
- return f"Indexed {len(image_paths)} images."
 
 
 
 
42
 
43
- def search(query, top_k=5):
44
- """Search top-k most similar images given a text query"""
45
- global image_paths, faiss_index, image_embeddings
46
- if faiss_index is None:
47
- return "Please upload and index images first.", []
48
 
49
- # Encode query
50
- text = clip.tokenize([query]).to(device)
 
 
 
51
  with torch.no_grad():
52
- text_emb = model.encode_text(text)
53
- text_emb /= text_emb.norm(dim=-1, keepdim=True)
54
-
55
- text_emb = text_emb.cpu().numpy().astype("float32")
56
-
57
- # Search FAISS
58
- scores, indices = faiss_index.search(text_emb, top_k)
59
- results = []
60
- for idx, score in zip(indices[0], scores[0]):
61
- img = image_paths[idx]
62
- results.append((img, float(score)))
63
-
64
- return f"Top {top_k} results for '{query}'", results
65
-
66
- def display_results(query, top_k=5):
67
- message, results = search(query, top_k)
68
- images, scores = [], []
69
- for img, score in results:
70
- images.append(img)
71
- scores.append(f"{score:.3f}")
72
- return message, images, scores
73
-
74
- with gr.Blocks() as demo:
75
- gr.Markdown("## Image Search with CLIP + FAISS 🚀")
76
-
77
- with gr.Row():
78
- img_upload = gr.File(file_types=[".png", ".jpg", ".jpeg"], file_count="multiple")
79
- build_btn = gr.Button("Build Index")
80
-
81
- status = gr.Textbox(label="Status")
82
-
83
- with gr.Row():
84
- query = gr.Textbox(label="Search Query")
85
- top_k = gr.Slider(1, 20, value=5, step=1, label="Top K Results")
86
- search_btn = gr.Button("Search")
87
-
88
- output_text = gr.Textbox(label="Results")
89
- output_gallery = gr.Gallery(label="Ranked Images").style(grid=[5], height="auto")
90
- output_scores = gr.Textbox(label="Similarity Scores")
91
-
92
- build_btn.click(fn=build_faiss_index, inputs=[img_upload], outputs=[status])
93
- search_btn.click(fn=display_results, inputs=[query, top_k], outputs=[output_text, output_gallery, output_scores])
94
-
95
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio app: Text-to-Image ranking using OpenCLIP (open-source)
3
+ Features:
4
+ - Accepts a text query and multiple images (100+).
5
+ - Encodes text and images with OpenCLIP (ViT-B-32 by default).
6
+ - Computes cosine similarity, normalizes scores to 0-100.
7
+ - Returns a ranked CSV and a visual grid image annotated with scores.
8
+ - GPU optional (will use CUDA if available).
9
+ """
10
+
11
+ import os
12
+ import io
13
+ import math
14
+ import time
15
+ from typing import List, Tuple, Optional
16
+
17
  import torch
18
+ import open_clip
19
+ from PIL import Image, ImageDraw, ImageFont
20
  import numpy as np
21
+ import pandas as pd
22
+ import gradio as gr
23
 
24
+ # -------------------------
25
+ # Configuration / Globals
26
+ # -------------------------
27
+ MODEL_NAME = "ViT-B-32" # OpenCLIP model backbone
28
+ # MODEL_PRETRAIN = "laion2b_s32b_b79k"
29
+ MODEL_PRETRAIN = "openai" # pretraining dataset variant (open weights)
30
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
31
+ BATCH_SIZE = 64 # image encoding batch size (tune by your GPU/CPU memory)
32
+ TOP_K_DEFAULT = 20 # how many top results to show visually
33
+ THUMB_SIZE = (256, 256) # thumbnail size for visual grid
34
+ FONT_PATH = None # if you want a custom TTF, set path, else default PIL font used
35
+ NORMALIZE_SCORE_TO = 100 # final scores in 0..NORMALIZE_SCORE_TO
36
+ # -------------------------
37
 
38
+ # Load model once at startup (lazy load wrapped in function)
39
+ _model_data = {"loaded": False}
 
 
40
 
 
 
 
 
 
41
 
42
+ def load_model(device: str = DEVICE):
43
+ """
44
+ Loads OpenCLIP model and transforms. Cached on first call.
45
+ Returns model, preprocess function, tokenizer, and embedding dimension.
46
+ """
47
+ if _model_data.get("loaded", False):
48
+ return _model_data["model"], _model_data["preprocess"], _model_data["tokenizer"], _model_data["dim"]
49
+
50
+ print(f"Loading OpenCLIP {MODEL_NAME} ({MODEL_PRETRAIN}) to {device} ...")
51
+ model, _, preprocess = open_clip.create_model_and_transforms(MODEL_NAME, MODEL_PRETRAIN)
52
+ tokenizer = open_clip.get_tokenizer(MODEL_NAME)
53
+ model.to(device)
54
+ model.eval()
55
+ # store
56
+ dim = model.text_projection.shape[1] if hasattr(model, "text_projection") else model.projection.shape[1]
57
+ _model_data.update({
58
+ "loaded": True,
59
+ "model": model,
60
+ "preprocess": preprocess,
61
+ "tokenizer": tokenizer,
62
+ "dim": dim
63
+ })
64
+ print("Model loaded.")
65
+ return model, preprocess, tokenizer, dim
66
 
 
 
 
 
67
 
68
+ # -------------------------
69
+ # Utilities
70
+ # -------------------------
71
+ def load_pil_image(file_obj) -> Image.Image:
72
+ """
73
+ Given a file-like object from Gradio (or path), return a PIL image in RGB.
74
+ """
75
+ if isinstance(file_obj, str):
76
+ img = Image.open(file_obj)
77
+ else:
78
+ file_obj.seek(0)
79
+ img = Image.open(io.BytesIO(file_obj.read()))
80
+ return img.convert("RGB")
81
 
 
 
 
 
82
 
83
+ def batchify(iterable, batch_size):
84
+ """Yield successive batches from iterable"""
85
+ it = list(iterable)
86
+ for i in range(0, len(it), batch_size):
87
+ yield it[i:i + batch_size]
88
 
 
 
 
 
 
89
 
90
+ def encode_text(text: str, model, tokenizer, device: str = DEVICE) -> torch.Tensor:
91
+ """
92
+ Encode text to a normalized embedding tensor (1 x dim)
93
+ """
94
+ texts_tokenized = tokenizer([text])
95
  with torch.no_grad():
96
+ text_tokens = texts_tokenized.to(device)
97
+ text_feats = model.encode_text(text_tokens) # (1, dim)
98
+ text_feats = text_feats / text_feats.norm(dim=-1, keepdim=True)
99
+ return text_feats
100
+
101
+
102
+ def encode_images(images: List[Image.Image], model, preprocess, device: str = DEVICE, batch_size: int = BATCH_SIZE) -> torch.Tensor:
103
+ """
104
+ Encode a list of PIL images into normalized embeddings (N x dim).
105
+ Uses batching to avoid memory blowups. Returns CPU tensor.
106
+ """
107
+ all_feats = []
108
+ model_device = next(model.parameters()).device
109
+ for batch in batchify(images, batch_size):
110
+ # preprocess and stack
111
+ batch_tensors = torch.stack([preprocess(img) for img in batch]).to(device)
112
+ with torch.no_grad():
113
+ feats = model.encode_image(batch_tensors)
114
+ feats = feats / feats.norm(dim=-1, keepdim=True)
115
+ all_feats.append(feats.cpu())
116
+ all_feats = torch.cat(all_feats, dim=0)
117
+ return all_feats # on CPU
118
+
119
+
120
+ def cosine_similarity_matrix(text_feat: torch.Tensor, image_feats: torch.Tensor) -> np.ndarray:
121
+ """
122
+ Given text_feat (1 x dim) and image_feats (N x dim), compute cosine similarities in numpy.
123
+ Returns ndarray shape (N,)
124
+ """
125
+ # text_feat on CPU?
126
+ if isinstance(text_feat, torch.Tensor):
127
+ text_feat = text_feat.cpu()
128
+ sims = (image_feats @ text_feat.squeeze(0).cpu().T).numpy().squeeze()
129
+ # clamp tiny numerical issues
130
+ sims = np.clip(sims, -1.0, 1.0)
131
+ return sims
132
+
133
+
134
+ def normalize_scores_to_range(scores: np.ndarray, low=0.0, high=NORMALIZE_SCORE_TO) -> np.ndarray:
135
+ """
136
+ Maps scores from [-1,1] (cosine) to [low,high] (e.g., 0..100).
137
+ If all scores equal, map to mid-range to avoid divide-by-zero.
138
+ """
139
+ # if scores are already in [-1,1], map linearly
140
+ min_s, max_s = float(scores.min()), float(scores.max())
141
+ if math.isclose(min_s, max_s):
142
+ # degenerate case: all scores same — map all to midpoint
143
+ mid = (low + high) / 2.0
144
+ return np.full_like(scores, fill_value=mid, dtype=float)
145
+ # first ensure range is within [-1,1] - cosine outputs
146
+ scores_clipped = np.clip(scores, -1.0, 1.0)
147
+ # normalize to 0..1
148
+ norm01 = (scores_clipped - (-1.0)) / (2.0)
149
+ mapped = low + norm01 * (high - low)
150
+ return mapped
151
+
152
+
153
+ def make_visual_grid(images: List[Image.Image], scores: List[float], top_k: int = 12,
154
+ thumb_size: Tuple[int, int] = THUMB_SIZE, columns: int = 4,
155
+ font_path: Optional[str] = FONT_PATH) -> Image.Image:
156
+ """
157
+ Create a single PIL image that arranges top_k thumbnails in a grid with score captions.
158
+ """
159
+ top_k = min(top_k, len(images))
160
+ rows = math.ceil(top_k / columns)
161
+ w, h = thumb_size
162
+ caption_height = 28
163
+ grid_w = columns * w
164
+ grid_h = rows * (h + caption_height)
165
+
166
+ grid_img = Image.new("RGB", (grid_w, grid_h), color=(255, 255, 255))
167
+ draw = ImageDraw.Draw(grid_img)
168
+ try:
169
+ if font_path and os.path.exists(font_path):
170
+ font = ImageFont.truetype(font_path, 16)
171
+ else:
172
+ font = ImageFont.load_default()
173
+ except Exception:
174
+ font = ImageFont.load_default()
175
+
176
+ for idx in range(top_k):
177
+ img = images[idx].copy().resize(thumb_size, Image.Resampling.LANCZOS)
178
+ col = idx % columns
179
+ row = idx // columns
180
+ x = col * w
181
+ y = row * (h + caption_height)
182
+ grid_img.paste(img, (x, y))
183
+ # caption with background rectangle for readability
184
+ caption = f"{scores[idx]:.1f}"
185
+ # text_w, text_h = draw.textsize(caption, font=font)
186
+ # For Pillow >=10
187
+ bbox = draw.textbbox((0, 0), caption, font=font)
188
+ text_w, text_h = bbox[2] - bbox[0], bbox[3] - bbox[1]
189
+
190
+ rect_x0 = x
191
+ rect_y0 = y + h
192
+ rect_x1 = x + w
193
+ rect_y1 = rect_y0 + caption_height
194
+ draw.rectangle([rect_x0, rect_y0, rect_x1, rect_y1], fill=(255, 255, 255))
195
+ text_x = x + 6
196
+ text_y = rect_y0 + (caption_height - text_h) // 2
197
+ draw.text((text_x, text_y), caption, fill=(0, 0, 0), font=font)
198
+
199
+ return grid_img
200
+
201
+
202
+ # -------------------------
203
+ # Core pipeline
204
+ # -------------------------
205
+ def rank_images_by_text(query: str, files: List[gr.File], top_k: int = TOP_K_DEFAULT,
206
+ use_gpu: bool = (DEVICE == "cuda")) -> Tuple[pd.DataFrame, Image.Image]:
207
+ """
208
+ Main pipeline:
209
+ - load model (if not)
210
+ - read images from files
211
+ - encode text and images
212
+ - compute cosine similarity
213
+ - produce ranked DataFrame and visual grid image
214
+ Returns: (pandas.DataFrame with columns ['filename','score_cosine','score_normalized'], PIL.Image grid)
215
+ """
216
+ start_time = time.time()
217
+ if not query or (not files):
218
+ raise ValueError("Please provide both a text query and at least one image file.")
219
+
220
+ model, preprocess, tokenizer, dim = load_model(DEVICE if use_gpu else "cpu")
221
+ device = DEVICE if use_gpu else "cpu"
222
+
223
+ # Load images and remember filenames
224
+ images = []
225
+ filenames = []
226
+ for f in files:
227
+ # f is a tempfile-like object from gradio
228
+ try:
229
+ pil = load_pil_image(f)
230
+ images.append(pil)
231
+ # get filename attribute gracefully
232
+ name = getattr(f, "name", None)
233
+ if name:
234
+ fname = os.path.basename(name)
235
+ else:
236
+ # try to get filename from object dict
237
+ fname = getattr(f, "filename", "uploaded_image")
238
+ filenames.append(fname)
239
+ except Exception as e:
240
+ print(f"Skipping a file due to load error: {e}")
241
+
242
+ if len(images) == 0:
243
+ raise ValueError("No valid images could be loaded from uploads.")
244
+
245
+ # Encode text
246
+ text_feat = encode_text(query, model, tokenizer, device=device)
247
+
248
+ # Encode images (batched)
249
+ image_feats = encode_images(images, model, preprocess, device=device, batch_size=BATCH_SIZE)
250
+
251
+ # Compute cosine similarities
252
+ sims = cosine_similarity_matrix(text_feat, image_feats) # range [-1,1]
253
+ scores_norm = normalize_scores_to_range(sims, low=0.0, high=float(NORMALIZE_SCORE_TO))
254
+
255
+ # Rank results
256
+ order = np.argsort(-sims) # descending by raw cosine
257
+ sims_sorted = sims[order]
258
+ scores_sorted = scores_norm[order]
259
+ filenames_sorted = [filenames[i] for i in order]
260
+ images_sorted = [images[i] for i in order]
261
+
262
+ # Build DataFrame
263
+ df = pd.DataFrame({
264
+ "filename": filenames_sorted,
265
+ "score_cosine": sims_sorted,
266
+ f"score_{int(NORMALIZE_SCORE_TO)}": scores_sorted
267
+ })
268
+
269
+ # Create visual grid of top_k results
270
+ top_k = min(top_k, len(images_sorted))
271
+ top_images = images_sorted[:top_k]
272
+ top_scores = scores_sorted[:top_k].tolist()
273
+ grid_img = make_visual_grid(top_images, top_scores, top_k=top_k, thumb_size=THUMB_SIZE, columns=4)
274
+
275
+ elapsed = time.time() - start_time
276
+ print(f"Query processed in {elapsed:.2f}s. Images: {len(images)}. Top-K: {top_k}")
277
+ return df, grid_img
278
+
279
+
280
+ # -------------------------
281
+ # Gradio app UI
282
+ # -------------------------
283
+ def gradio_rank_fn(query: str, image_files: List[gr.File], top_k: int = TOP_K_DEFAULT, use_gpu: bool = (DEVICE == "cuda")):
284
+ """
285
+ Wrapper for Gradio. Returns (ranked table as CSV string / DataFrame, grid image as PIL, optionally downloadable CSV).
286
+ """
287
+ if not image_files:
288
+ return "No images uploaded.", None, None
289
+ try:
290
+ df, grid_img = rank_images_by_text(query, image_files, top_k=top_k, use_gpu=use_gpu)
291
+ except Exception as e:
292
+ return f"Error: {e}", None, None
293
+
294
+ # Save CSV to buffer so user can download
295
+ csv_buffer = io.StringIO()
296
+ df.to_csv(csv_buffer, index=False)
297
+ csv_bytes = csv_buffer.getvalue().encode("utf-8")
298
+ csv_buffer.close()
299
+
300
+ # Return textual summary, grid image, and CSV bytes for download component
301
+ summary = f"Ranked {len(df)} images for query: '{query}'. Top score: {df['score_cosine'].max():.4f}"
302
+ return summary, grid_img, ("rankings.csv", csv_bytes, "text/csv")
303
+
304
+
305
+ def build_interface():
306
+ title = "Text → Image Ranking (OpenCLIP) — Free & Open-source"
307
+ description = """
308
+ Enter any text query (e.g., "red chinos") and upload multiple product images (100+ supported).
309
+ The app uses an OpenCLIP model (open-source) to compute embeddings for text and images, then ranks images by cosine similarity.
310
+ You will get a visual grid of the top results annotated with normalized similarity scores (0–100) and a downloadable CSV of all rankings.
311
+ """
312
+ with gr.Blocks(title=title) as demo:
313
+ gr.Markdown(f"# {title}")
314
+ gr.Markdown(description)
315
+ with gr.Row():
316
+ with gr.Column(scale=3):
317
+ query = gr.Textbox(label="Text query", placeholder="e.g. 'red chinos' or 'floral kurta with pockets'", lines=1)
318
+ image_files = gr.File(label="Upload product images (multiple)", file_count="multiple",
319
+ file_types=["image"], interactive=True)
320
+ top_k = gr.Slider(minimum=1, maximum=64, value=TOP_K_DEFAULT, step=1, label="Top-K to visualize")
321
+ use_gpu = gr.Checkbox(label=f"Use GPU (detected device: {DEVICE}). Uncheck to force CPU.", value=(DEVICE == "cuda"))
322
+ run_btn = gr.Button("Rank images")
323
+ status_output = gr.Textbox(label="Status", interactive=False)
324
+ with gr.Column(scale=2):
325
+ gallery = gr.Image(type="pil", label="Top results grid (annotated)")
326
+ download = gr.File(label="Download CSV rankings")
327
+ summary = gr.Textbox(label="Summary", interactive=False)
328
+
329
+ # Hook up
330
+ def wrapped_run(q, files, topk, use_gpu_flag):
331
+ status = "Processing..."
332
+ # Gradio won't show intermediate states in this simple wrapper, so return at the end
333
+ try:
334
+ summary_text, grid_img, csv_tuple = gradio_rank_fn(q, files, topk, use_gpu_flag)
335
+ # for gr.File returning bytes tuple: (filename, bytes, mime)
336
+ # Save csv bytes to temp file for gr.File returning
337
+ if csv_tuple:
338
+ fname, content_bytes, mime = csv_tuple
339
+ # save to a BytesIO that gr.File can serve via memory? Gradio expects a path or a file-like?
340
+ # We'll save to disk in a temp file to make it simple:
341
+ tmp_path = os.path.join(os.getcwd(), fname)
342
+ with open(tmp_path, "wb") as f:
343
+ f.write(content_bytes)
344
+ csv_path = tmp_path
345
+ else:
346
+ csv_path = None
347
+ return summary_text, grid_img, csv_path
348
+ except Exception as e:
349
+ return f"Error: {e}", None, None
350
+
351
+ run_btn.click(fn=wrapped_run, inputs=[query, image_files, top_k, use_gpu], outputs=[summary, gallery, download])
352
+ gr.Markdown("## Notes")
353
+ gr.Markdown("- This uses an **open-source** OpenCLIP model. No paid API calls.")
354
+ gr.Markdown("- For best performance on large batches, run on a machine with a CUDA GPU. If you don't have a GPU, leave 'Use GPU' unchecked.")
355
+ gr.Markdown("- If you want to scale beyond thousands of images in a production setting, index the image embeddings with FAISS/Annoy and perform ANN search rather than computing full cosine in-memory.")
356
+ return demo
357
+
358
+
359
+ if __name__ == "__main__":
360
+ demo = build_interface()
361
+ # Start Gradio
362
+ demo.launch()