import gradio as gr import torch from PIL import Image from transformers import AutoProcessor, AutoModel import numpy as np from typing import List, Tuple import requests from io import BytesIO import pandas as pd import os # Initialize model and processor MODEL_NAME = "google/siglip2-so400m-patch16-naflex" device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Loading model on {device}...") processor = AutoProcessor.from_pretrained(MODEL_NAME) model = AutoModel.from_pretrained(MODEL_NAME).to(device) model.eval() # Global variables for image database and embeddings IMAGE_DATABASE = [] embeddings_cache = None # Cache for loaded images image_cache = {} # Load image URLs from Excel file def load_image_database_from_file(file_path: str) -> List[str]: """Load image URLs from Excel spreadsheet""" if not os.path.exists(file_path): raise FileNotFoundError( f"Image database file '{file_path}' not found. " f"Please upload an Excel file with a column named 'url' containing image URLs." ) df = pd.read_excel(file_path) # Look for a column named 'url', 'URL', 'image_url', or similar url_column = None for col in df.columns: if col.lower() in ['url', 'image_url', 'image_urls', 'urls', 'link', 'image']: url_column = col break if url_column is None: raise ValueError( f"Could not find URL column in Excel file. " f"Please use one of these column names: 'url', 'URL', 'image_url', 'urls', 'link', or 'image'. " f"Found columns: {list(df.columns)}" ) # Extract URLs and remove any NaN values urls = df[url_column].dropna().tolist() # Convert to strings and strip whitespace urls = [str(url).strip() for url in urls] print(f"Loaded {len(urls)} image URLs from {file_path}") return urls def load_image_from_url(url: str) -> Image.Image: """Load image from URL with caching""" if url not in image_cache: try: response = requests.get(url, timeout=10) response.raise_for_status() image = Image.open(BytesIO(response.content)).convert("RGB") image_cache[url] = image except Exception as e: print(f"Error loading image from {url}: {e}") # Create a placeholder image image_cache[url] = Image.new("RGB", (400, 400), color="gray") return image_cache[url] def compute_image_embeddings(urls: List[str]): """Compute embeddings for a list of image URLs""" print("Computing image embeddings...") images = [load_image_from_url(url) for url in urls] print(f"Loaded {len(images)} images") with torch.no_grad(): inputs = processor(images=images, return_tensors="pt", padding=True) inputs = {k: v.to(device) for k, v in inputs.items()} outputs = model.get_image_features(**inputs) # Extract the actual embeddings tensor from the output if hasattr(outputs, 'image_embeds'): image_embeddings = outputs.image_embeds elif hasattr(outputs, 'pooler_output'): image_embeddings = outputs.pooler_output else: # If it's already a tensor, use it directly image_embeddings = outputs print(f"Image embeddings shape: {image_embeddings.shape}") # Normalize the embeddings image_embeddings = image_embeddings / image_embeddings.norm(dim=-1, keepdim=True) embeddings_np = image_embeddings.cpu().numpy() print(f"Cached embeddings shape: {embeddings_np.shape}") print("Image embeddings computed!") return embeddings_np def search_images(query: str, urls: List[str], image_embeddings: np.ndarray, top_k: int = 5) -> List[Tuple[Image.Image, float]]: """Search for images matching the query""" if not query.strip(): return [] if len(urls) == 0: return [] print(f"Image embeddings shape in search: {image_embeddings.shape}") # Compute text embedding with torch.no_grad(): text_inputs = processor(text=[query], return_tensors="pt", padding=True) text_inputs = {k: v.to(device) for k, v in text_inputs.items()} outputs = model.get_text_features(**text_inputs) # Extract the actual embeddings tensor from the output if hasattr(outputs, 'text_embeds'): text_embedding = outputs.text_embeds elif hasattr(outputs, 'pooler_output'): text_embedding = outputs.pooler_output else: # If it's already a tensor, use it directly text_embedding = outputs # Normalize the embeddings text_embedding = text_embedding / text_embedding.norm(dim=-1, keepdim=True) text_embedding = text_embedding.cpu().numpy() print(f"Text embedding shape: {text_embedding.shape}") # Compute similarities similarities = np.dot(image_embeddings, text_embedding.T).squeeze() print(f"Similarities shape: {similarities.shape}") print(f"Similarities: {similarities}") # Handle the case where there's only one image (0-dimensional array) if similarities.ndim == 0: similarities = np.array([similarities]) # Get top-k results top_k = min(top_k, len(urls)) print(f"Requested top_k: {top_k}, Database size: {len(urls)}") top_indices = np.argsort(similarities)[::-1][:top_k] print(f"Top indices: {top_indices}") results = [] for idx in top_indices: image = load_image_from_url(urls[idx]) score = float(similarities[idx]) results.append((image, score)) print(f"Returning {len(results)} results") return results def load_database(file): """Load image database from uploaded Excel file""" global IMAGE_DATABASE, embeddings_cache if file is None: return "Please upload an Excel file.", None try: # Load URLs from the uploaded file IMAGE_DATABASE = load_image_database_from_file(file.name) if len(IMAGE_DATABASE) == 0: return "No valid URLs found in the uploaded file.", None # Clear embeddings cache embeddings_cache = None # Compute embeddings for the new database embeddings_cache = compute_image_embeddings(IMAGE_DATABASE) return f"✓ Successfully loaded {len(IMAGE_DATABASE)} images from database!", gr.update(interactive=True) except Exception as e: IMAGE_DATABASE = [] embeddings_cache = None return f"Error loading database: {str(e)}", gr.update(interactive=False) def gradio_search(query: str, top_k: float): """Gradio interface function""" # Convert top_k to int (Gradio sliders return floats) top_k = int(top_k) # Check if database is loaded if len(IMAGE_DATABASE) == 0 or embeddings_cache is None: return None results = search_images(query, IMAGE_DATABASE, embeddings_cache, top_k) if not results: return None # Format results for Gradio gallery gallery_data = [] for img, score in results: gallery_data.append((img, f"Score: {score:.4f}")) return gallery_data # Create Gradio interface with gr.Blocks(title="Image Search with SigLIP2") as demo: gr.Markdown( """ # 🔍 Image Search with SigLIP2 Search through a collection of images using natural language queries! The model used is **google/siglip2-so400m-patch16-naflex**. ## How to use: 1. Upload an Excel file (.xlsx) with a column named **'url'** containing image URLs 2. Wait for the images to be processed 3. Enter your search query 4. View the results! Try queries like: - "a cat" - "mountain landscape" - "city at night" - "food on a table" - "person doing sports" """ ) with gr.Row(): with gr.Column(): file_upload = gr.File( label="Upload Image Database (Excel file)", file_types=[".xlsx", ".xls"], type="filepath" ) load_button = gr.Button("Load Database", variant="primary") status_text = gr.Textbox( label="Status", value="Please upload an Excel file with image URLs.", interactive=False ) with gr.Row(): with gr.Column(scale=1): query_input = gr.Textbox( label="Search Query", placeholder="Enter your search term (e.g., 'sunset', 'dog', 'technology')", lines=2, interactive=False ) top_k_slider = gr.Slider( minimum=1, maximum=20, value=5, step=1, label="Number of Results", info="Select how many top results to display" ) search_button = gr.Button("Search", variant="primary") with gr.Column(scale=2): gallery_output = gr.Gallery( label="Search Results", columns=3, rows=2, height="auto", object_fit="contain" ) # Set up event handlers load_button.click( fn=load_database, inputs=[file_upload], outputs=[status_text, query_input] ) search_button.click( fn=gradio_search, inputs=[query_input, top_k_slider], outputs=gallery_output ) query_input.submit( fn=gradio_search, inputs=[query_input, top_k_slider], outputs=gallery_output ) gr.Markdown( """ --- **Excel File Format:** Your Excel file should have a column named `url` (or `URL`, `image_url`, `urls`, `link`, or `image`) containing the image URLs. Example: | url | |-----| | https://example.com/image1.jpg | | https://example.com/image2.jpg | **Note:** The SigLIP2 model computes similarity between your text query and the images to find the best matches. """ ) if __name__ == "__main__": demo.launch()