File size: 10,418 Bytes
3191302
 
 
 
 
 
 
 
1f8f3c7
 
3191302
 
 
 
 
 
 
 
 
 
af485ed
 
 
 
 
 
 
1f8f3c7
af485ed
1f8f3c7
af485ed
1f8f3c7
af485ed
 
1f8f3c7
 
af485ed
1f8f3c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af485ed
1f8f3c7
 
3191302
 
 
 
 
 
 
 
 
 
 
 
 
 
af485ed
 
3191302
af485ed
e19fc45
3191302
 
 
 
139aeb2
 
 
 
 
 
 
 
 
 
 
e19fc45
 
139aeb2
3191302
 
af485ed
 
3191302
af485ed
3191302
af485ed
3191302
 
 
 
af485ed
 
 
e19fc45
3191302
 
 
 
 
139aeb2
 
 
 
 
 
 
 
 
 
 
 
3191302
 
 
e19fc45
 
3191302
 
e19fc45
 
3191302
c870638
 
 
 
3191302
af485ed
 
e19fc45
3191302
e19fc45
3191302
 
 
af485ed
3191302
 
 
e19fc45
3191302
 
af485ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e19fc45
3191302
e19fc45
 
 
af485ed
 
 
 
 
3191302
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af485ed
 
 
 
 
1f8f3c7
3191302
 
 
 
 
 
 
 
 
af485ed
 
 
 
 
 
 
 
 
 
 
 
 
 
3191302
 
 
 
 
af485ed
 
3191302
 
 
e19fc45
3191302
 
e19fc45
 
3191302
 
 
 
 
 
 
 
 
 
 
 
 
af485ed
 
 
 
 
 
3191302
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af485ed
 
 
 
 
 
 
 
 
 
3191302
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
import gradio as gr
import torch
from PIL import Image
from transformers import AutoProcessor, AutoModel
import numpy as np
from typing import List, Tuple
import requests
from io import BytesIO
import pandas as pd
import os

# Initialize model and processor
MODEL_NAME = "google/siglip2-so400m-patch16-naflex"
device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Loading model on {device}...")
processor = AutoProcessor.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME).to(device)
model.eval()

# Global variables for image database and embeddings
IMAGE_DATABASE = []
embeddings_cache = None

# Cache for loaded images
image_cache = {}

# Load image URLs from Excel file
def load_image_database_from_file(file_path: str) -> List[str]:
    """Load image URLs from Excel spreadsheet"""
    if not os.path.exists(file_path):
        raise FileNotFoundError(
            f"Image database file '{file_path}' not found. "
            f"Please upload an Excel file with a column named 'url' containing image URLs."
        )
    
    df = pd.read_excel(file_path)
    
    # Look for a column named 'url', 'URL', 'image_url', or similar
    url_column = None
    for col in df.columns:
        if col.lower() in ['url', 'image_url', 'image_urls', 'urls', 'link', 'image']:
            url_column = col
            break
    
    if url_column is None:
        raise ValueError(
            f"Could not find URL column in Excel file. "
            f"Please use one of these column names: 'url', 'URL', 'image_url', 'urls', 'link', or 'image'. "
            f"Found columns: {list(df.columns)}"
        )
    
    # Extract URLs and remove any NaN values
    urls = df[url_column].dropna().tolist()
    
    # Convert to strings and strip whitespace
    urls = [str(url).strip() for url in urls]
    
    print(f"Loaded {len(urls)} image URLs from {file_path}")
    return urls

def load_image_from_url(url: str) -> Image.Image:
    """Load image from URL with caching"""
    if url not in image_cache:
        try:
            response = requests.get(url, timeout=10)
            response.raise_for_status()
            image = Image.open(BytesIO(response.content)).convert("RGB")
            image_cache[url] = image
        except Exception as e:
            print(f"Error loading image from {url}: {e}")
            # Create a placeholder image
            image_cache[url] = Image.new("RGB", (400, 400), color="gray")
    return image_cache[url]

def compute_image_embeddings(urls: List[str]):
    """Compute embeddings for a list of image URLs"""
    print("Computing image embeddings...")
    images = [load_image_from_url(url) for url in urls]
    print(f"Loaded {len(images)} images")
    
    with torch.no_grad():
        inputs = processor(images=images, return_tensors="pt", padding=True)
        inputs = {k: v.to(device) for k, v in inputs.items()}
        outputs = model.get_image_features(**inputs)
        
        # Extract the actual embeddings tensor from the output
        if hasattr(outputs, 'image_embeds'):
            image_embeddings = outputs.image_embeds
        elif hasattr(outputs, 'pooler_output'):
            image_embeddings = outputs.pooler_output
        else:
            # If it's already a tensor, use it directly
            image_embeddings = outputs
        
        print(f"Image embeddings shape: {image_embeddings.shape}")
        
        # Normalize the embeddings
        image_embeddings = image_embeddings / image_embeddings.norm(dim=-1, keepdim=True)
    
    embeddings_np = image_embeddings.cpu().numpy()
    print(f"Cached embeddings shape: {embeddings_np.shape}")
    print("Image embeddings computed!")
    return embeddings_np

def search_images(query: str, urls: List[str], image_embeddings: np.ndarray, top_k: int = 5) -> List[Tuple[Image.Image, float]]:
    """Search for images matching the query"""
    if not query.strip():
        return []
    
    if len(urls) == 0:
        return []
    
    print(f"Image embeddings shape in search: {image_embeddings.shape}")
    
    # Compute text embedding
    with torch.no_grad():
        text_inputs = processor(text=[query], return_tensors="pt", padding=True)
        text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
        outputs = model.get_text_features(**text_inputs)
        
        # Extract the actual embeddings tensor from the output
        if hasattr(outputs, 'text_embeds'):
            text_embedding = outputs.text_embeds
        elif hasattr(outputs, 'pooler_output'):
            text_embedding = outputs.pooler_output
        else:
            # If it's already a tensor, use it directly
            text_embedding = outputs
        
        # Normalize the embeddings
        text_embedding = text_embedding / text_embedding.norm(dim=-1, keepdim=True)
        text_embedding = text_embedding.cpu().numpy()
    
    print(f"Text embedding shape: {text_embedding.shape}")
    
    # Compute similarities
    similarities = np.dot(image_embeddings, text_embedding.T).squeeze()
    print(f"Similarities shape: {similarities.shape}")
    print(f"Similarities: {similarities}")
    
    # Handle the case where there's only one image (0-dimensional array)
    if similarities.ndim == 0:
        similarities = np.array([similarities])
    
    # Get top-k results
    top_k = min(top_k, len(urls))
    print(f"Requested top_k: {top_k}, Database size: {len(urls)}")
    
    top_indices = np.argsort(similarities)[::-1][:top_k]
    print(f"Top indices: {top_indices}")
    
    results = []
    for idx in top_indices:
        image = load_image_from_url(urls[idx])
        score = float(similarities[idx])
        results.append((image, score))
    
    print(f"Returning {len(results)} results")
    return results

def load_database(file):
    """Load image database from uploaded Excel file"""
    global IMAGE_DATABASE, embeddings_cache
    
    if file is None:
        return "Please upload an Excel file.", None
    
    try:
        # Load URLs from the uploaded file
        IMAGE_DATABASE = load_image_database_from_file(file.name)
        
        if len(IMAGE_DATABASE) == 0:
            return "No valid URLs found in the uploaded file.", None
        
        # Clear embeddings cache
        embeddings_cache = None
        
        # Compute embeddings for the new database
        embeddings_cache = compute_image_embeddings(IMAGE_DATABASE)
        
        return f"✓ Successfully loaded {len(IMAGE_DATABASE)} images from database!", gr.update(interactive=True)
    
    except Exception as e:
        IMAGE_DATABASE = []
        embeddings_cache = None
        return f"Error loading database: {str(e)}", gr.update(interactive=False)

def gradio_search(query: str, top_k: float):
    """Gradio interface function"""
    # Convert top_k to int (Gradio sliders return floats)
    top_k = int(top_k)
    
    # Check if database is loaded
    if len(IMAGE_DATABASE) == 0 or embeddings_cache is None:
        return None
    
    results = search_images(query, IMAGE_DATABASE, embeddings_cache, top_k)
    
    if not results:
        return None
    
    # Format results for Gradio gallery
    gallery_data = []
    for img, score in results:
        gallery_data.append((img, f"Score: {score:.4f}"))
    
    return gallery_data

# Create Gradio interface
with gr.Blocks(title="Image Search with SigLIP2") as demo:
    gr.Markdown(
        """
        # 🔍 Image Search with SigLIP2
        
        Search through a collection of images using natural language queries!
        The model used is **google/siglip2-so400m-patch16-naflex**.
        
        ## How to use:
        1. Upload an Excel file (.xlsx) with a column named **'url'** containing image URLs
        2. Wait for the images to be processed
        3. Enter your search query
        4. View the results!
        
        Try queries like:
        - "a cat"
        - "mountain landscape"
        - "city at night"
        - "food on a table"
        - "person doing sports"
        """
    )
    
    with gr.Row():
        with gr.Column():
            file_upload = gr.File(
                label="Upload Image Database (Excel file)",
                file_types=[".xlsx", ".xls"],
                type="filepath"
            )
            load_button = gr.Button("Load Database", variant="primary")
            status_text = gr.Textbox(
                label="Status",
                value="Please upload an Excel file with image URLs.",
                interactive=False
            )
    
    with gr.Row():
        with gr.Column(scale=1):
            query_input = gr.Textbox(
                label="Search Query",
                placeholder="Enter your search term (e.g., 'sunset', 'dog', 'technology')",
                lines=2,
                interactive=False
            )
            top_k_slider = gr.Slider(
                minimum=1,
                maximum=20,
                value=5,
                step=1,
                label="Number of Results",
                info="Select how many top results to display"
            )
            search_button = gr.Button("Search", variant="primary")
        
        with gr.Column(scale=2):
            gallery_output = gr.Gallery(
                label="Search Results",
                columns=3,
                rows=2,
                height="auto",
                object_fit="contain"
            )
    
    # Set up event handlers
    load_button.click(
        fn=load_database,
        inputs=[file_upload],
        outputs=[status_text, query_input]
    )
    
    search_button.click(
        fn=gradio_search,
        inputs=[query_input, top_k_slider],
        outputs=gallery_output
    )
    
    query_input.submit(
        fn=gradio_search,
        inputs=[query_input, top_k_slider],
        outputs=gallery_output
    )
    
    gr.Markdown(
        """
        ---
        **Excel File Format:**
        Your Excel file should have a column named `url` (or `URL`, `image_url`, `urls`, `link`, or `image`) containing the image URLs.
        
        Example:
        | url |
        |-----|
        | https://example.com/image1.jpg |
        | https://example.com/image2.jpg |
        
        **Note:** The SigLIP2 model computes similarity between your text query and the images to find the best matches.
        """
    )

if __name__ == "__main__":
    demo.launch()