Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| from transformers import AutoProcessor, AutoModel | |
| import numpy as np | |
| from typing import List, Tuple | |
| import requests | |
| from io import BytesIO | |
| import pandas as pd | |
| import os | |
| # Initialize model and processor | |
| MODEL_NAME = "google/siglip2-so400m-patch16-naflex" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Loading model on {device}...") | |
| processor = AutoProcessor.from_pretrained(MODEL_NAME) | |
| model = AutoModel.from_pretrained(MODEL_NAME).to(device) | |
| model.eval() | |
| # Global variables for image database and embeddings | |
| IMAGE_DATABASE = [] | |
| embeddings_cache = None | |
| # Cache for loaded images | |
| image_cache = {} | |
| # Load image URLs from Excel file | |
| def load_image_database_from_file(file_path: str) -> List[str]: | |
| """Load image URLs from Excel spreadsheet""" | |
| if not os.path.exists(file_path): | |
| raise FileNotFoundError( | |
| f"Image database file '{file_path}' not found. " | |
| f"Please upload an Excel file with a column named 'url' containing image URLs." | |
| ) | |
| df = pd.read_excel(file_path) | |
| # Look for a column named 'url', 'URL', 'image_url', or similar | |
| url_column = None | |
| for col in df.columns: | |
| if col.lower() in ['url', 'image_url', 'image_urls', 'urls', 'link', 'image']: | |
| url_column = col | |
| break | |
| if url_column is None: | |
| raise ValueError( | |
| f"Could not find URL column in Excel file. " | |
| f"Please use one of these column names: 'url', 'URL', 'image_url', 'urls', 'link', or 'image'. " | |
| f"Found columns: {list(df.columns)}" | |
| ) | |
| # Extract URLs and remove any NaN values | |
| urls = df[url_column].dropna().tolist() | |
| # Convert to strings and strip whitespace | |
| urls = [str(url).strip() for url in urls] | |
| print(f"Loaded {len(urls)} image URLs from {file_path}") | |
| return urls | |
| def load_image_from_url(url: str) -> Image.Image: | |
| """Load image from URL with caching""" | |
| if url not in image_cache: | |
| try: | |
| response = requests.get(url, timeout=10) | |
| response.raise_for_status() | |
| image = Image.open(BytesIO(response.content)).convert("RGB") | |
| image_cache[url] = image | |
| except Exception as e: | |
| print(f"Error loading image from {url}: {e}") | |
| # Create a placeholder image | |
| image_cache[url] = Image.new("RGB", (400, 400), color="gray") | |
| return image_cache[url] | |
| def compute_image_embeddings(urls: List[str]): | |
| """Compute embeddings for a list of image URLs""" | |
| print("Computing image embeddings...") | |
| images = [load_image_from_url(url) for url in urls] | |
| print(f"Loaded {len(images)} images") | |
| with torch.no_grad(): | |
| inputs = processor(images=images, return_tensors="pt", padding=True) | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| outputs = model.get_image_features(**inputs) | |
| # Extract the actual embeddings tensor from the output | |
| if hasattr(outputs, 'image_embeds'): | |
| image_embeddings = outputs.image_embeds | |
| elif hasattr(outputs, 'pooler_output'): | |
| image_embeddings = outputs.pooler_output | |
| else: | |
| # If it's already a tensor, use it directly | |
| image_embeddings = outputs | |
| print(f"Image embeddings shape: {image_embeddings.shape}") | |
| # Normalize the embeddings | |
| image_embeddings = image_embeddings / image_embeddings.norm(dim=-1, keepdim=True) | |
| embeddings_np = image_embeddings.cpu().numpy() | |
| print(f"Cached embeddings shape: {embeddings_np.shape}") | |
| print("Image embeddings computed!") | |
| return embeddings_np | |
| def search_images(query: str, urls: List[str], image_embeddings: np.ndarray, top_k: int = 5) -> List[Tuple[Image.Image, float]]: | |
| """Search for images matching the query""" | |
| if not query.strip(): | |
| return [] | |
| if len(urls) == 0: | |
| return [] | |
| print(f"Image embeddings shape in search: {image_embeddings.shape}") | |
| # Compute text embedding | |
| with torch.no_grad(): | |
| text_inputs = processor(text=[query], return_tensors="pt", padding=True) | |
| text_inputs = {k: v.to(device) for k, v in text_inputs.items()} | |
| outputs = model.get_text_features(**text_inputs) | |
| # Extract the actual embeddings tensor from the output | |
| if hasattr(outputs, 'text_embeds'): | |
| text_embedding = outputs.text_embeds | |
| elif hasattr(outputs, 'pooler_output'): | |
| text_embedding = outputs.pooler_output | |
| else: | |
| # If it's already a tensor, use it directly | |
| text_embedding = outputs | |
| # Normalize the embeddings | |
| text_embedding = text_embedding / text_embedding.norm(dim=-1, keepdim=True) | |
| text_embedding = text_embedding.cpu().numpy() | |
| print(f"Text embedding shape: {text_embedding.shape}") | |
| # Compute similarities | |
| similarities = np.dot(image_embeddings, text_embedding.T).squeeze() | |
| print(f"Similarities shape: {similarities.shape}") | |
| print(f"Similarities: {similarities}") | |
| # Handle the case where there's only one image (0-dimensional array) | |
| if similarities.ndim == 0: | |
| similarities = np.array([similarities]) | |
| # Get top-k results | |
| top_k = min(top_k, len(urls)) | |
| print(f"Requested top_k: {top_k}, Database size: {len(urls)}") | |
| top_indices = np.argsort(similarities)[::-1][:top_k] | |
| print(f"Top indices: {top_indices}") | |
| results = [] | |
| for idx in top_indices: | |
| image = load_image_from_url(urls[idx]) | |
| score = float(similarities[idx]) | |
| results.append((image, score)) | |
| print(f"Returning {len(results)} results") | |
| return results | |
| def load_database(file): | |
| """Load image database from uploaded Excel file""" | |
| global IMAGE_DATABASE, embeddings_cache | |
| if file is None: | |
| return "Please upload an Excel file.", None | |
| try: | |
| # Load URLs from the uploaded file | |
| IMAGE_DATABASE = load_image_database_from_file(file.name) | |
| if len(IMAGE_DATABASE) == 0: | |
| return "No valid URLs found in the uploaded file.", None | |
| # Clear embeddings cache | |
| embeddings_cache = None | |
| # Compute embeddings for the new database | |
| embeddings_cache = compute_image_embeddings(IMAGE_DATABASE) | |
| return f"✓ Successfully loaded {len(IMAGE_DATABASE)} images from database!", gr.update(interactive=True) | |
| except Exception as e: | |
| IMAGE_DATABASE = [] | |
| embeddings_cache = None | |
| return f"Error loading database: {str(e)}", gr.update(interactive=False) | |
| def gradio_search(query: str, top_k: float): | |
| """Gradio interface function""" | |
| # Convert top_k to int (Gradio sliders return floats) | |
| top_k = int(top_k) | |
| # Check if database is loaded | |
| if len(IMAGE_DATABASE) == 0 or embeddings_cache is None: | |
| return None | |
| results = search_images(query, IMAGE_DATABASE, embeddings_cache, top_k) | |
| if not results: | |
| return None | |
| # Format results for Gradio gallery | |
| gallery_data = [] | |
| for img, score in results: | |
| gallery_data.append((img, f"Score: {score:.4f}")) | |
| return gallery_data | |
| # Create Gradio interface | |
| with gr.Blocks(title="Image Search with SigLIP2") as demo: | |
| gr.Markdown( | |
| """ | |
| # 🔍 Image Search with SigLIP2 | |
| Search through a collection of images using natural language queries! | |
| The model used is **google/siglip2-so400m-patch16-naflex**. | |
| ## How to use: | |
| 1. Upload an Excel file (.xlsx) with a column named **'url'** containing image URLs | |
| 2. Wait for the images to be processed | |
| 3. Enter your search query | |
| 4. View the results! | |
| Try queries like: | |
| - "a cat" | |
| - "mountain landscape" | |
| - "city at night" | |
| - "food on a table" | |
| - "person doing sports" | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| file_upload = gr.File( | |
| label="Upload Image Database (Excel file)", | |
| file_types=[".xlsx", ".xls"], | |
| type="filepath" | |
| ) | |
| load_button = gr.Button("Load Database", variant="primary") | |
| status_text = gr.Textbox( | |
| label="Status", | |
| value="Please upload an Excel file with image URLs.", | |
| interactive=False | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| query_input = gr.Textbox( | |
| label="Search Query", | |
| placeholder="Enter your search term (e.g., 'sunset', 'dog', 'technology')", | |
| lines=2, | |
| interactive=False | |
| ) | |
| top_k_slider = gr.Slider( | |
| minimum=1, | |
| maximum=20, | |
| value=5, | |
| step=1, | |
| label="Number of Results", | |
| info="Select how many top results to display" | |
| ) | |
| search_button = gr.Button("Search", variant="primary") | |
| with gr.Column(scale=2): | |
| gallery_output = gr.Gallery( | |
| label="Search Results", | |
| columns=3, | |
| rows=2, | |
| height="auto", | |
| object_fit="contain" | |
| ) | |
| # Set up event handlers | |
| load_button.click( | |
| fn=load_database, | |
| inputs=[file_upload], | |
| outputs=[status_text, query_input] | |
| ) | |
| search_button.click( | |
| fn=gradio_search, | |
| inputs=[query_input, top_k_slider], | |
| outputs=gallery_output | |
| ) | |
| query_input.submit( | |
| fn=gradio_search, | |
| inputs=[query_input, top_k_slider], | |
| outputs=gallery_output | |
| ) | |
| gr.Markdown( | |
| """ | |
| --- | |
| **Excel File Format:** | |
| Your Excel file should have a column named `url` (or `URL`, `image_url`, `urls`, `link`, or `image`) containing the image URLs. | |
| Example: | |
| | url | | |
| |-----| | |
| | https://example.com/image1.jpg | | |
| | https://example.com/image2.jpg | | |
| **Note:** The SigLIP2 model computes similarity between your text query and the images to find the best matches. | |
| """ | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |