Spaces:
Sleeping
Sleeping
| """ | |
| Gradio app: Text-to-Image ranking using OpenCLIP (open-source) | |
| Features: | |
| - Accepts a text query and multiple images (100+). | |
| - Encodes text and images with OpenCLIP (ViT-B-32 by default). | |
| - Computes cosine similarity, normalizes scores to 0-100. | |
| - Returns a ranked CSV and a visual grid image annotated with scores. | |
| - GPU optional (will use CUDA if available). | |
| """ | |
| import os | |
| import io | |
| import math | |
| import time | |
| from typing import List, Tuple, Optional | |
| import torch | |
| import open_clip | |
| from PIL import Image, ImageDraw, ImageFont | |
| import numpy as np | |
| import pandas as pd | |
| import gradio as gr | |
| # ------------------------- | |
| # Configuration / Globals | |
| # ------------------------- | |
| MODEL_NAME = "ViT-B-32" | |
| # MODEL_PRETRAIN = "laion2b_s32b_b79k" | |
| MODEL_PRETRAIN = "openai" | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| BATCH_SIZE = 64 | |
| TOP_K_DEFAULT = 20 | |
| THUMB_SIZE = (256, 256) | |
| FONT_PATH = None | |
| NORMALIZE_SCORE_TO = 100 | |
| # ------------------------- | |
| _model_data = {"loaded": False} | |
| def load_model(device: str = DEVICE): | |
| """ | |
| Loads OpenCLIP model and transforms. Cached on first call. | |
| Returns model, preprocess function, tokenizer, and embedding dimension. | |
| """ | |
| if _model_data.get("loaded", False): | |
| return _model_data["model"], _model_data["preprocess"], _model_data["tokenizer"], _model_data["dim"] | |
| print(f"Loading OpenCLIP {MODEL_NAME} ({MODEL_PRETRAIN}) to {device} ...") | |
| model, _, preprocess = open_clip.create_model_and_transforms(MODEL_NAME, MODEL_PRETRAIN) | |
| tokenizer = open_clip.get_tokenizer(MODEL_NAME) | |
| model.to(device) | |
| model.eval() | |
| dim = model.text_projection.shape[1] if hasattr(model, "text_projection") else model.projection.shape[1] | |
| _model_data.update({ | |
| "loaded": True, | |
| "model": model, | |
| "preprocess": preprocess, | |
| "tokenizer": tokenizer, | |
| "dim": dim | |
| }) | |
| print("Model loaded.") | |
| return model, preprocess, tokenizer, dim | |
| # ------------------------- | |
| # Utilities | |
| # ------------------------- | |
| def load_pil_image(file_obj) -> Image.Image: | |
| """ | |
| Given a file-like object from Gradio (or path), return a PIL image in RGB. | |
| """ | |
| if isinstance(file_obj, str): | |
| img = Image.open(file_obj) | |
| else: | |
| file_obj.seek(0) | |
| img = Image.open(io.BytesIO(file_obj.read())) | |
| return img.convert("RGB") | |
| def batchify(iterable, batch_size): | |
| """Yield successive batches from iterable""" | |
| it = list(iterable) | |
| for i in range(0, len(it), batch_size): | |
| yield it[i:i + batch_size] | |
| def encode_text(text: str, model, tokenizer, device: str = DEVICE) -> torch.Tensor: | |
| """ | |
| Encode text to a normalized embedding tensor (1 x dim) | |
| """ | |
| texts_tokenized = tokenizer([text]) | |
| with torch.no_grad(): | |
| text_tokens = texts_tokenized.to(device) | |
| text_feats = model.encode_text(text_tokens) # (1, dim) | |
| text_feats = text_feats / text_feats.norm(dim=-1, keepdim=True) | |
| return text_feats | |
| def encode_images(images: List[Image.Image], model, preprocess, device: str = DEVICE, batch_size: int = BATCH_SIZE) -> torch.Tensor: | |
| """ | |
| Encode a list of PIL images into normalized embeddings (N x dim). | |
| Uses batching to avoid memory blowups. Returns CPU tensor. | |
| """ | |
| all_feats = [] | |
| model_device = next(model.parameters()).device | |
| for batch in batchify(images, batch_size): | |
| batch_tensors = torch.stack([preprocess(img) for img in batch]).to(device) | |
| with torch.no_grad(): | |
| feats = model.encode_image(batch_tensors) | |
| feats = feats / feats.norm(dim=-1, keepdim=True) | |
| all_feats.append(feats.cpu()) | |
| all_feats = torch.cat(all_feats, dim=0) | |
| return all_feats | |
| def cosine_similarity_matrix(text_feat: torch.Tensor, image_feats: torch.Tensor) -> np.ndarray: | |
| """ | |
| Given text_feat (1 x dim) and image_feats (N x dim), compute cosine similarities in numpy. | |
| Returns ndarray shape (N,) | |
| """ | |
| if isinstance(text_feat, torch.Tensor): | |
| text_feat = text_feat.cpu() | |
| sims = (image_feats @ text_feat.squeeze(0).cpu().T).numpy().squeeze() | |
| sims = np.clip(sims, -1.0, 1.0) | |
| return sims | |
| def normalize_scores_to_range(scores: np.ndarray, low=0.0, high=NORMALIZE_SCORE_TO) -> np.ndarray: | |
| """ | |
| Maps scores from [-1,1] (cosine) to [low,high] (e.g., 0..100). | |
| If all scores equal, map to mid-range to avoid divide-by-zero. | |
| """ | |
| min_s, max_s = float(scores.min()), float(scores.max()) | |
| if math.isclose(min_s, max_s): | |
| mid = (low + high) / 2.0 | |
| return np.full_like(scores, fill_value=mid, dtype=float) | |
| scores_clipped = np.clip(scores, -1.0, 1.0) | |
| norm01 = (scores_clipped - (-1.0)) / (2.0) | |
| mapped = low + norm01 * (high - low) | |
| return mapped | |
| def make_visual_grid(images: List[Image.Image], scores: List[float], top_k: int = 12, | |
| thumb_size: Tuple[int, int] = THUMB_SIZE, columns: int = 4, | |
| font_path: Optional[str] = FONT_PATH) -> Image.Image: | |
| """ | |
| Create a single PIL image that arranges top_k thumbnails in a grid with score captions. | |
| """ | |
| top_k = min(top_k, len(images)) | |
| rows = math.ceil(top_k / columns) | |
| w, h = thumb_size | |
| caption_height = 28 | |
| grid_w = columns * w | |
| grid_h = rows * (h + caption_height) | |
| grid_img = Image.new("RGB", (grid_w, grid_h), color=(255, 255, 255)) | |
| draw = ImageDraw.Draw(grid_img) | |
| try: | |
| if font_path and os.path.exists(font_path): | |
| font = ImageFont.truetype(font_path, 16) | |
| else: | |
| font = ImageFont.load_default() | |
| except Exception: | |
| font = ImageFont.load_default() | |
| for idx in range(top_k): | |
| img = images[idx].copy().resize(thumb_size, Image.Resampling.LANCZOS) | |
| col = idx % columns | |
| row = idx // columns | |
| x = col * w | |
| y = row * (h + caption_height) | |
| grid_img.paste(img, (x, y)) | |
| caption = f"{scores[idx]:.1f}" | |
| bbox = draw.textbbox((0, 0), caption, font=font) | |
| text_w, text_h = bbox[2] - bbox[0], bbox[3] - bbox[1] | |
| rect_x0 = x | |
| rect_y0 = y + h | |
| rect_x1 = x + w | |
| rect_y1 = rect_y0 + caption_height | |
| draw.rectangle([rect_x0, rect_y0, rect_x1, rect_y1], fill=(255, 255, 255)) | |
| text_x = x + 6 | |
| text_y = rect_y0 + (caption_height - text_h) // 2 | |
| draw.text((text_x, text_y), caption, fill=(0, 0, 0), font=font) | |
| return grid_img | |
| def rank_images_by_text(query: str, files: List[gr.File], top_k: int = TOP_K_DEFAULT, | |
| use_gpu: bool = (DEVICE == "cuda")) -> Tuple[pd.DataFrame, Image.Image]: | |
| """ | |
| Main pipeline: | |
| - load model (if not) | |
| - read images from files | |
| - encode text and images | |
| - compute cosine similarity | |
| - produce ranked DataFrame and visual grid image | |
| Returns: (pandas.DataFrame with columns ['filename','score_cosine','score_normalized'], PIL.Image grid) | |
| """ | |
| start_time = time.time() | |
| if not query or (not files): | |
| raise ValueError("Please provide both a text query and at least one image file.") | |
| model, preprocess, tokenizer, dim = load_model(DEVICE if use_gpu else "cpu") | |
| device = DEVICE if use_gpu else "cpu" | |
| images = [] | |
| filenames = [] | |
| for f in files: | |
| try: | |
| pil = load_pil_image(f) | |
| images.append(pil) | |
| name = getattr(f, "name", None) | |
| if name: | |
| fname = os.path.basename(name) | |
| else: | |
| fname = getattr(f, "filename", "uploaded_image") | |
| filenames.append(fname) | |
| except Exception as e: | |
| print(f"Skipping a file due to load error: {e}") | |
| if len(images) == 0: | |
| raise ValueError("No valid images could be loaded from uploads.") | |
| text_feat = encode_text(query, model, tokenizer, device=device) | |
| image_feats = encode_images(images, model, preprocess, device=device, batch_size=BATCH_SIZE) | |
| sims = cosine_similarity_matrix(text_feat, image_feats) # range [-1,1] | |
| scores_norm = normalize_scores_to_range(sims, low=0.0, high=float(NORMALIZE_SCORE_TO)) | |
| # Rank results | |
| order = np.argsort(-sims) | |
| sims_sorted = sims[order] | |
| scores_sorted = scores_norm[order] | |
| filenames_sorted = [filenames[i] for i in order] | |
| images_sorted = [images[i] for i in order] | |
| df = pd.DataFrame({ | |
| "filename": filenames_sorted, | |
| "score_cosine": sims_sorted, | |
| f"score_{int(NORMALIZE_SCORE_TO)}": scores_sorted | |
| }) | |
| top_k = min(top_k, len(images_sorted)) | |
| top_images = images_sorted[:top_k] | |
| top_scores = scores_sorted[:top_k].tolist() | |
| grid_img = make_visual_grid(top_images, top_scores, top_k=top_k, thumb_size=THUMB_SIZE, columns=4) | |
| elapsed = time.time() - start_time | |
| print(f"Query processed in {elapsed:.2f}s. Images: {len(images)}. Top-K: {top_k}") | |
| return df, grid_img | |
| # ------------------------- | |
| # Gradio app UI | |
| # ------------------------- | |
| def gradio_rank_fn(query: str, image_files: List[gr.File], top_k: int = TOP_K_DEFAULT, use_gpu: bool = (DEVICE == "cuda")): | |
| """ | |
| Wrapper for Gradio. Returns (ranked table as CSV string / DataFrame, grid image as PIL, optionally downloadable CSV). | |
| """ | |
| if not image_files: | |
| return "No images uploaded.", None, None | |
| try: | |
| df, grid_img = rank_images_by_text(query, image_files, top_k=top_k, use_gpu=use_gpu) | |
| except Exception as e: | |
| return f"Error: {e}", None, None | |
| csv_buffer = io.StringIO() | |
| df.to_csv(csv_buffer, index=False) | |
| csv_bytes = csv_buffer.getvalue().encode("utf-8") | |
| csv_buffer.close() | |
| summary = f"Ranked {len(df)} images for query: '{query}'. Top score: {df['score_cosine'].max():.4f}" | |
| return summary, grid_img, ("rankings.csv", csv_bytes, "text/csv") | |
| def build_interface(): | |
| title = "Text → Image Ranking" | |
| description = """ | |
| Enter any text query (e.g., "red chinos") and upload multiple product images (100+ supported). | |
| The app uses an OpenCLIP model (open-source) to compute embeddings for text and images, then ranks images by cosine similarity. | |
| """ | |
| with gr.Blocks(title=title) as demo: | |
| gr.Markdown(f"# {title}") | |
| gr.Markdown(description) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| query = gr.Textbox(label="Text query", placeholder="e.g. 'red chinos' or 'floral kurta with pockets'", lines=1) | |
| image_files = gr.File(label="Upload product images (multiple)", file_count="multiple", | |
| file_types=["image"], interactive=True) | |
| top_k = gr.Slider(minimum=1, maximum=64, value=TOP_K_DEFAULT, step=1, label="Top-K to visualize") | |
| use_gpu = gr.Checkbox(label=f"Use GPU (detected device: {DEVICE}). Uncheck to force CPU.", value=(DEVICE == "cuda")) | |
| run_btn = gr.Button("Rank images") | |
| status_output = gr.Textbox(label="Status", interactive=False) | |
| with gr.Column(scale=2): | |
| gallery = gr.Image(type="pil", label="Top results grid (annotated)") | |
| download = gr.File(label="Download CSV rankings") | |
| summary = gr.Textbox(label="Summary", interactive=False) | |
| def wrapped_run(q, files, topk, use_gpu_flag): | |
| status = "Processing..." | |
| try: | |
| summary_text, grid_img, csv_tuple = gradio_rank_fn(q, files, topk, use_gpu_flag) | |
| if csv_tuple: | |
| fname, content_bytes, mime = csv_tuple | |
| tmp_path = os.path.join(os.getcwd(), fname) | |
| with open(tmp_path, "wb") as f: | |
| f.write(content_bytes) | |
| csv_path = tmp_path | |
| else: | |
| csv_path = None | |
| return summary_text, grid_img, csv_path | |
| except Exception as e: | |
| return f"Error: {e}", None, None | |
| run_btn.click(fn=wrapped_run, inputs=[query, image_files, top_k, use_gpu], outputs=[summary, gallery, download]) | |
| gr.Markdown("## Notes") | |
| gr.Markdown("- This uses an **open-source** OpenCLIP model. No paid API calls.") | |
| gr.Markdown("- The app is slow because every time it runs it creates embeddings of the text and the images . The speed of the app can be increased if we use already stored images so we don't have to create embeddings everytime.") | |
| gr.Markdown("The accuracy of this app can be increased if we used different models of open clip , but for computational efficiency i have utilized one of the efficient models . Also if we finetune this model , the accuracy of the model can be hugely increased, But since this is just a asssignment , i have created a demo prototype only.") | |
| return demo | |
| if __name__ == "__main__": | |
| demo = build_interface() | |
| # Start Gradio | |
| demo.launch() | |