YoungjaeDev commited on
Commit
2e15a8b
·
verified ·
1 Parent(s): 150e5f4

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ data/embeddings/image_index.faiss filter=lfs diff=lfs merge=lfs -text
37
+ data/embeddings/image_index.json filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,12 +1,43 @@
1
  ---
2
  title: Multimodal Search
3
- emoji: 🐨
4
- colorFrom: indigo
5
- colorTo: purple
6
  sdk: gradio
7
- sdk_version: 6.2.0
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: Multimodal Search
3
+ emoji: 🔍
4
+ colorFrom: green
5
+ colorTo: gray
6
  sdk: gradio
7
+ sdk_version: 5.50.0
8
  app_file: app.py
9
  pinned: false
10
+ license: mit
11
  ---
12
 
13
+ # Multimodal Search
14
+
15
+ Search Flickr30k images using **Text**, **Image**, or **Composed queries (CIR)**.
16
+
17
+ ## Features
18
+
19
+ - **Text Search**: Find images matching text descriptions
20
+ - **Image Search**: Find similar images using a reference image
21
+ - **Composed Image Retrieval (CIR)**: Combine reference image with text modification
22
+ - Formula: `q = normalize(image_embedding + lambda * text_embedding)`
23
+
24
+ ## Tech Stack
25
+
26
+ - **Model**: SigLIP 2 (`google/siglip2-so400m-patch14-384`)
27
+ - **Index**: FAISS (IndexFlatIP for cosine similarity)
28
+ - **Dataset**: Flickr30k (31,014 images, 155,070 captions)
29
+ - **UI**: Gradio with custom theme
30
+
31
+ ## Usage
32
+
33
+ 1. **Text Search**: Enter text in the query box
34
+ 2. **Image Search**: Upload or click an image from results
35
+ 3. **CIR**: Combine text + image for composed search
36
+
37
+ Adjust **Lambda** weight to balance image vs text influence in CIR.
38
+
39
+ ## Links
40
+
41
+ - [GitHub Repository](https://github.com/YoungjaeDev/multimodal-search-mvp)
42
+ - [SigLIP 2 Model](https://huggingface.co/google/siglip2-so400m-patch14-384)
43
+ - [Flickr30k Dataset](https://huggingface.co/datasets/nlphuji/flickr30k)
app.py ADDED
@@ -0,0 +1,1018 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Gradio application for Multimodal Search MVP - HF Spaces Edition."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ from pathlib import Path
7
+ from typing import TYPE_CHECKING, Iterable
8
+
9
+ import gradio as gr
10
+ import spaces
11
+ from gradio.themes import Soft
12
+ from gradio.themes.utils import colors, fonts, sizes
13
+ from PIL import Image as PILImage
14
+
15
+ if TYPE_CHECKING:
16
+ from core.search import MultimodalSearch
17
+
18
+ # Global search engine (lazy loaded)
19
+ _search_engine: MultimodalSearch | None = None
20
+
21
+ # Global dataset (lazy loaded)
22
+ _flickr30k_dataset = None
23
+
24
+ # Data paths (HF Spaces uses HF Hub for data)
25
+ DATA_DIR = Path(__file__).parent / "data"
26
+ EMBEDDINGS_DIR = DATA_DIR / "embeddings"
27
+
28
+
29
+ def get_flickr30k_dataset():
30
+ """Get or load the Flickr30k dataset (lazy loading).
31
+
32
+ Returns:
33
+ Flickr30k dataset with images.
34
+ """
35
+ global _flickr30k_dataset
36
+
37
+ if _flickr30k_dataset is None:
38
+ from datasets import load_dataset
39
+
40
+ _flickr30k_dataset = load_dataset(
41
+ "nlphuji/flickr30k",
42
+ split="test",
43
+ )
44
+
45
+ return _flickr30k_dataset
46
+
47
+
48
+ class RefinedTheme(Soft):
49
+ """Editorial/Documentation style theme.
50
+
51
+ Features:
52
+ - No gradients, solid colors only
53
+ - Single accent color (Emerald)
54
+ - High contrast, professional look
55
+ - Pretendard font (Korean support)
56
+ """
57
+
58
+ def __init__(
59
+ self,
60
+ *,
61
+ primary_hue: colors.Color | str = colors.zinc,
62
+ secondary_hue: colors.Color | str = colors.emerald,
63
+ neutral_hue: colors.Color | str = colors.zinc,
64
+ text_size: sizes.Size | str = sizes.text_md,
65
+ font: fonts.Font | str | Iterable[fonts.Font | str] = (
66
+ fonts.GoogleFont("Pretendard"),
67
+ "Pretendard",
68
+ "-apple-system",
69
+ "BlinkMacSystemFont",
70
+ "system-ui",
71
+ "sans-serif",
72
+ ),
73
+ font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
74
+ fonts.GoogleFont("JetBrains Mono"),
75
+ "ui-monospace",
76
+ "monospace",
77
+ ),
78
+ ):
79
+ super().__init__(
80
+ primary_hue=primary_hue,
81
+ secondary_hue=secondary_hue,
82
+ neutral_hue=neutral_hue,
83
+ text_size=text_size,
84
+ font=font,
85
+ font_mono=font_mono,
86
+ )
87
+ super().set(
88
+ # Background - Light mode (solid colors, no gradients)
89
+ body_background_fill="#fafafa",
90
+ background_fill_primary="#ffffff",
91
+ background_fill_secondary="#f4f4f5",
92
+ # Background - Dark mode
93
+ body_background_fill_dark="#18181b",
94
+ background_fill_primary_dark="#27272a",
95
+ background_fill_secondary_dark="#3f3f46",
96
+ # Text colors
97
+ body_text_color="*neutral_800",
98
+ body_text_color_dark="#fafafa",
99
+ # Buttons - Light mode (solid colors, no gradients)
100
+ button_primary_background_fill="*secondary_600",
101
+ button_primary_background_fill_hover="*secondary_700",
102
+ button_primary_text_color="white",
103
+ button_secondary_background_fill="*neutral_100",
104
+ button_secondary_background_fill_hover="*neutral_200",
105
+ # Buttons - Dark mode
106
+ button_primary_background_fill_dark="*secondary_500",
107
+ button_primary_background_fill_hover_dark="*secondary_600",
108
+ button_secondary_background_fill_dark="*neutral_700",
109
+ button_secondary_background_fill_hover_dark="*neutral_600",
110
+ # Minimal styling
111
+ block_border_width="1px",
112
+ block_border_color="*neutral_200",
113
+ block_border_color_dark="*neutral_700",
114
+ block_shadow="none",
115
+ button_primary_shadow="none",
116
+ button_secondary_shadow="none",
117
+ # Title styling
118
+ block_title_text_weight="600",
119
+ block_title_text_size="*text_md",
120
+ # Input fields - Light mode
121
+ input_background_fill="*neutral_50",
122
+ input_border_color="*neutral_300",
123
+ input_border_width="1px",
124
+ # Input fields - Dark mode
125
+ input_background_fill_dark="*neutral_800",
126
+ input_border_color_dark="*neutral_600",
127
+ # Accent colors - for tabs, links, and interactive elements
128
+ # Use secondary (emerald) instead of primary (zinc) for visibility
129
+ color_accent="*secondary_500",
130
+ color_accent_soft="*secondary_100",
131
+ color_accent_soft_dark="*secondary_800",
132
+ border_color_accent="*secondary_400",
133
+ border_color_accent_dark="*secondary_600",
134
+ )
135
+
136
+
137
+ css = """
138
+ /* Container */
139
+ #col-container {
140
+ margin: 0 auto;
141
+ max-width: 1400px;
142
+ }
143
+
144
+ /* Header row - 3 column grid layout */
145
+ #header-row {
146
+ display: grid !important;
147
+ grid-template-columns: auto 1fr auto;
148
+ align-items: center;
149
+ gap: 16px;
150
+ }
151
+
152
+ /* Header logo - left column */
153
+ #header-logo {
154
+ background: transparent !important;
155
+ border: none !important;
156
+ min-width: 120px;
157
+ max-width: 120px;
158
+ }
159
+
160
+ /* Title - center column (screen-centered) */
161
+ #main-title {
162
+ text-align: center;
163
+ justify-self: center;
164
+ }
165
+ #main-title h1 {
166
+ font-size: 1.75rem;
167
+ font-weight: 600;
168
+ color: var(--body-text-color);
169
+ margin: 0;
170
+ }
171
+ #main-title p {
172
+ color: var(--body-text-color-subdued);
173
+ font-size: 0.95rem;
174
+ margin: 0;
175
+ }
176
+
177
+ /* Header controls - right column (inner div for horizontal layout) */
178
+ #header-controls-inner {
179
+ display: flex;
180
+ flex-direction: row;
181
+ align-items: center;
182
+ gap: 8px;
183
+ }
184
+
185
+ /* Theme transition */
186
+ body, .gradio-container {
187
+ transition: background-color 0.2s ease, color 0.2s ease;
188
+ }
189
+
190
+ /* Theme toggle button (native HTML button) */
191
+ .theme-toggle-btn {
192
+ min-width: 40px;
193
+ height: 40px;
194
+ padding: 8px;
195
+ border: 1px solid var(--border-color-primary);
196
+ border-radius: 8px;
197
+ background-color: var(--background-fill-primary);
198
+ color: var(--body-text-color);
199
+ cursor: pointer;
200
+ display: inline-flex;
201
+ align-items: center;
202
+ justify-content: center;
203
+ transition: border-color 0.2s ease, background-color 0.2s ease;
204
+ }
205
+ .theme-toggle-btn:hover {
206
+ border-color: var(--color-accent);
207
+ background-color: var(--background-fill-secondary);
208
+ }
209
+ .theme-toggle-btn:focus {
210
+ outline: none;
211
+ border-color: var(--color-accent);
212
+ box-shadow: 0 0 0 2px rgba(var(--color-accent-rgb), 0.2);
213
+ }
214
+
215
+ /* Theme toggle icons - show moon in light, sun in dark */
216
+ #theme-toggle .icon-moon { display: inline-flex; }
217
+ #theme-toggle .icon-sun { display: none; }
218
+ .dark #theme-toggle .icon-moon { display: none; }
219
+ .dark #theme-toggle .icon-sun { display: inline-flex; }
220
+
221
+ /* Language selector (native select) */
222
+ #lang-selector {
223
+ min-width: 100px;
224
+ padding: 8px 12px;
225
+ font-size: 14px;
226
+ font-family: inherit;
227
+ border: 1px solid var(--border-color-primary);
228
+ border-radius: 8px;
229
+ background-color: var(--background-fill-primary);
230
+ color: var(--body-text-color);
231
+ cursor: pointer;
232
+ outline: none;
233
+ appearance: none;
234
+ background-image: url("data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' width='12' height='12' viewBox='0 0 24 24' fill='none' stroke='%23666' stroke-width='2'%3E%3Cpath d='M6 9l6 6 6-6'/%3E%3C/svg%3E");
235
+ background-repeat: no-repeat;
236
+ background-position: right 8px center;
237
+ padding-right: 28px;
238
+ }
239
+ #lang-selector:hover {
240
+ border-color: var(--color-accent);
241
+ }
242
+ #lang-selector:focus {
243
+ border-color: var(--color-accent);
244
+ box-shadow: 0 0 0 2px rgba(var(--color-accent-rgb), 0.2);
245
+ }
246
+ .dark #lang-selector {
247
+ background-image: url("data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' width='12' height='12' viewBox='0 0 24 24' fill='none' stroke='%23aaa' stroke-width='2'%3E%3Cpath d='M6 9l6 6 6-6'/%3E%3C/svg%3E");
248
+ }
249
+
250
+ /* Buttons */
251
+ .submit-btn {
252
+ font-weight: 500 !important;
253
+ }
254
+
255
+ /* Text areas */
256
+ textarea {
257
+ font-size: 0.9rem !important;
258
+ }
259
+
260
+ /* Labels */
261
+ .label-wrap {
262
+ font-weight: 500 !important;
263
+ }
264
+
265
+ /* Gallery styling - unified interface */
266
+ .gallery-container {
267
+ min-height: 500px;
268
+ }
269
+
270
+ /* Main content layout - unified interface */
271
+ #main-content {
272
+ display: flex;
273
+ gap: 24px;
274
+ }
275
+
276
+ /* Input panel styling */
277
+ #input-panel {
278
+ min-width: 280px;
279
+ max-width: 320px;
280
+ }
281
+
282
+ /* Results panel styling */
283
+ #results-panel {
284
+ flex: 1;
285
+ }
286
+
287
+ /* Reference image container */
288
+ #ref-image-container {
289
+ border: 2px dashed var(--border-color-primary);
290
+ border-radius: 12px;
291
+ padding: 8px;
292
+ background: var(--background-fill-secondary);
293
+ transition: border-color 0.2s ease;
294
+ }
295
+ #ref-image-container:hover {
296
+ border-color: var(--color-accent);
297
+ }
298
+
299
+ /* Search mode indicator */
300
+ #search-mode-indicator {
301
+ padding: 8px 16px;
302
+ border-radius: 8px;
303
+ font-size: 0.875rem;
304
+ font-weight: 500;
305
+ text-align: center;
306
+ }
307
+ .mode-text {
308
+ background: var(--secondary-100);
309
+ color: var(--secondary-700);
310
+ }
311
+ .mode-image {
312
+ background: var(--secondary-200);
313
+ color: var(--secondary-800);
314
+ }
315
+ .mode-composed {
316
+ background: var(--secondary-300);
317
+ color: var(--secondary-900);
318
+ }
319
+ .mode-none {
320
+ background: var(--background-fill-secondary);
321
+ color: var(--body-text-color-subdued);
322
+ }
323
+
324
+ /* Click hint text */
325
+ .click-hint {
326
+ font-size: 0.8rem;
327
+ color: var(--body-text-color-subdued);
328
+ text-align: center;
329
+ margin-top: 8px;
330
+ }
331
+
332
+ /* Clear button */
333
+ #clear-image-btn {
334
+ margin-top: 8px;
335
+ }
336
+
337
+ /* Slider group */
338
+ .slider-group {
339
+ margin-top: 16px;
340
+ }
341
+ .slider-group p {
342
+ white-space: nowrap;
343
+ }
344
+
345
+ /* CIR info box */
346
+ .cir-info {
347
+ background: var(--background-fill-secondary);
348
+ border-radius: 8px;
349
+ padding: 12px;
350
+ margin-top: 16px;
351
+ font-size: 0.85rem;
352
+ }
353
+ .cir-info code {
354
+ background: var(--background-fill-primary);
355
+ padding: 2px 6px;
356
+ border-radius: 4px;
357
+ font-family: var(--font-mono);
358
+ }
359
+ """
360
+
361
+ # Header controls HTML (combined for horizontal layout - avoids gr.Group column issue)
362
+ HEADER_CONTROLS_HTML = """
363
+ <div id="header-controls-inner">
364
+ <select id="lang-selector" class="lang-select" aria-label="Select language">
365
+ <option value="en">English</option>
366
+ <option value="ko">한국어</option>
367
+ </select>
368
+ <button id="theme-toggle" class="theme-toggle-btn" type="button" aria-label="Toggle theme">
369
+ <span class="icon-moon"><svg xmlns="http://www.w3.org/2000/svg" width="18" height="18" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M12 3a6 6 0 0 0 9 9 9 9 0 1 1-9-9Z"/></svg></span>
370
+ <span class="icon-sun"><svg xmlns="http://www.w3.org/2000/svg" width="18" height="18" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><circle cx="12" cy="12" r="4"/><path d="M12 2v2"/><path d="M12 20v2"/><path d="m4.93 4.93 1.41 1.41"/><path d="m17.66 17.66 1.41 1.41"/><path d="M2 12h2"/><path d="M20 12h2"/><path d="m6.34 17.66-1.41 1.41"/><path d="m19.07 4.93-1.41 1.41"/></svg></span>
371
+ </button>
372
+ </div>
373
+ """
374
+
375
+ # Theme initialization JavaScript (runs on page load, includes click handler)
376
+ INIT_THEME_JS = """
377
+ () => {
378
+ // Initialize theme from localStorage or system preference
379
+ const saved = localStorage.getItem('theme');
380
+ const prefersDark = window.matchMedia('(prefers-color-scheme: dark)').matches;
381
+ const shouldBeDark = saved === 'dark' || (!saved && prefersDark);
382
+ if (shouldBeDark) {
383
+ document.documentElement.classList.add('dark');
384
+ }
385
+
386
+ // Theme toggle click handler
387
+ const toggleBtn = document.getElementById('theme-toggle');
388
+ if (toggleBtn) {
389
+ toggleBtn.addEventListener('click', (e) => {
390
+ e.preventDefault();
391
+ document.documentElement.classList.toggle('dark');
392
+ const isDark = document.documentElement.classList.contains('dark');
393
+ localStorage.setItem('theme', isDark ? 'dark' : 'light');
394
+ });
395
+ }
396
+ }
397
+ """
398
+
399
+ # Initial language detection and selector initialization JavaScript
400
+ INIT_LANG_JS = """
401
+ () => {
402
+ const params = new URLSearchParams(window.location.search);
403
+ const lang = params.get('lang') || 'en';
404
+
405
+ // Set language selector value based on URL parameter
406
+ const langSelector = document.querySelector('#lang-selector');
407
+ if (langSelector) {
408
+ langSelector.value = lang;
409
+
410
+ // Add change event listener for language switching
411
+ langSelector.addEventListener('change', (e) => {
412
+ const targetLang = e.target.value;
413
+ const currentUrl = new URL(window.location);
414
+ currentUrl.searchParams.set('lang', targetLang);
415
+ window.location.href = currentUrl.toString();
416
+ });
417
+ }
418
+
419
+ // Update UI text based on language
420
+ if (lang === 'ko') {
421
+ const titleContainer = document.querySelector('#main-title');
422
+ if (titleContainer) {
423
+ const h1 = titleContainer.querySelector('h1');
424
+ if (h1) h1.textContent = '멀티모달 검색';
425
+ const p = titleContainer.querySelector('p');
426
+ if (p) p.textContent = '텍스트, 이미지 또는 합성 쿼리(CIR)를 사용하여 Flickr30k 이미지를 검색합니다.';
427
+ }
428
+ }
429
+ // English is the default, no need to update text
430
+ }
431
+ """
432
+
433
+ # Internationalization labels
434
+ LABELS = {
435
+ "en": {
436
+ "title": "Multimodal Search",
437
+ "subtitle": "Search Flickr30k images using text, images, or composed queries (CIR).",
438
+ "text_search": "Text Search",
439
+ "image_search": "Image Search",
440
+ "composed_search": "Composed Search",
441
+ "top_k": "Top-K Results",
442
+ "top_k_info": "Number of results to return",
443
+ "search_query": "Search Query",
444
+ "search_query_placeholder": "e.g., a dog running on the beach",
445
+ "text_search_desc": "Enter a text query to find matching images.",
446
+ "search_results": "Search Results",
447
+ "query_image": "Query Image",
448
+ "image_search_desc": "Upload an image to find similar images.",
449
+ "similar_images": "Similar Images",
450
+ "reference_image": "Reference Image",
451
+ "modification_text": "Modification Text",
452
+ "modification_placeholder": "e.g., make it red, on the beach, with a dog",
453
+ "lambda_weight": "Lambda Weight",
454
+ "lambda_info": "< 1: Image | 1: Balanced | > 1: Text",
455
+ "composed_results": "Composed Search Results",
456
+ "cir_desc": "**Composed Image Retrieval (CIR)** combines a reference image with text modification.",
457
+ "cir_formula": "Formula: `q = normalize(image_embedding + lambda * text_embedding)`",
458
+ "footer": "Multimodal Search MVP | SigLIP 2 + FAISS + Flickr30k",
459
+ # Unified interface labels
460
+ "text_query_label": "Text Query",
461
+ "text_query_info": "Text only: text search | Text + Image: composed search (CIR)",
462
+ "click_to_select": "Click any result image to use as reference",
463
+ "search_mode_text": "Mode: Text Search",
464
+ "search_mode_image": "Mode: Image Search",
465
+ "search_mode_composed": "Mode: Composed Search (CIR)",
466
+ "search_mode_none": "Enter text or upload an image to search",
467
+ "clear_image": "Clear Image",
468
+ },
469
+ "ko": {
470
+ "title": "멀티모달 검색",
471
+ "subtitle": "텍스트, 이미지 또는 합성 쿼리(CIR)를 사용하여 Flickr30k 이미지를 검색합니다.",
472
+ "text_search": "텍스트 검색",
473
+ "image_search": "이미지 검색",
474
+ "composed_search": "합성 검색",
475
+ "top_k": "결과 개수",
476
+ "top_k_info": "반환할 결과 수",
477
+ "search_query": "검색어",
478
+ "search_query_placeholder": "예: 해변을 달리는 강아지",
479
+ "text_search_desc": "텍스트 쿼리를 입력하여 일치하는 이미지를 찾습니다.",
480
+ "search_results": "검색 결과",
481
+ "query_image": "쿼리 이미지",
482
+ "image_search_desc": "이미지를 업로드하여 유사한 이미지를 찾습니다.",
483
+ "similar_images": "유사 이미지",
484
+ "reference_image": "참조 이미지",
485
+ "modification_text": "수정 텍스트",
486
+ "modification_placeholder": "예: 빨간색으로, 해변에서, 강아지와 함께",
487
+ "lambda_weight": "람다 가중치",
488
+ "lambda_info": "< 1: 이미지 | 1: 균형 | > 1: 텍스트",
489
+ "composed_results": "합성 검색 결과",
490
+ "cir_desc": "**합성 이미지 검색 (CIR)**은 참조 이미지와 텍스트 수정을 결합합니다.",
491
+ "cir_formula": "공식: `q = normalize(이미지_임베딩 + lambda * 텍스트_임베딩)`",
492
+ "footer": "멀티모달 검색 MVP | SigLIP 2 + FAISS + Flickr30k",
493
+ # Unified interface labels
494
+ "text_query_label": "텍스트 쿼리",
495
+ "text_query_info": "텍스트만: 텍스트 검색 | 텍스트 + 이미지: 합성 검색 (CIR)",
496
+ "click_to_select": "결과 이미지를 클릭하면 참조 이미지로 설정됩니다",
497
+ "search_mode_text": "모드: 텍스트 검색",
498
+ "search_mode_image": "모드: 이미지 검색",
499
+ "search_mode_composed": "모드: 합성 검색 (CIR)",
500
+ "search_mode_none": "텍스트를 입력하거나 이미지를 업로드하세요",
501
+ "clear_image": "이미지 초기화",
502
+ },
503
+ }
504
+
505
+
506
+ def L(key: str, lang: str = "en") -> str:
507
+ """Get localized label.
508
+
509
+ Args:
510
+ key: Label key to look up.
511
+ lang: Language code ('en' or 'ko').
512
+
513
+ Returns:
514
+ Localized string, or key if not found.
515
+ """
516
+ return LABELS.get(lang, LABELS["en"]).get(key, key)
517
+
518
+
519
+ @spaces.GPU
520
+ def get_search_engine() -> MultimodalSearch:
521
+ """Get or create the search engine (lazy loading with GPU).
522
+
523
+ Returns:
524
+ MultimodalSearch instance.
525
+ """
526
+ global _search_engine
527
+
528
+ if _search_engine is None:
529
+ from core.embeddings import EmbeddingModel
530
+ from core.index import FaissIndex
531
+ from core.search import MultimodalSearch
532
+
533
+ # Initialize embedding model with GPU
534
+ device = "cuda"
535
+ embedding_model = EmbeddingModel(device=device)
536
+
537
+ # Load FAISS index
538
+ index = FaissIndex(device=device)
539
+ index_path = EMBEDDINGS_DIR / "image_index"
540
+ index.load(index_path)
541
+
542
+ # Create search engine
543
+ _search_engine = MultimodalSearch(
544
+ embedding_model=embedding_model,
545
+ index=index,
546
+ default_lambda=1.0,
547
+ )
548
+
549
+ return _search_engine
550
+
551
+
552
+ def get_image_path(filename: str) -> str | None:
553
+ """Get the full path to a Flickr30k image.
554
+
555
+ Note: This function is kept for backwards compatibility.
556
+ The actual image retrieval is done via get_image_by_index().
557
+
558
+ Args:
559
+ filename: Image filename (e.g., "1000092795.jpg").
560
+
561
+ Returns:
562
+ The filename itself (used as a key for image lookup).
563
+ """
564
+ return filename
565
+
566
+
567
+ def get_image_by_index(index: int) -> PILImage.Image | None:
568
+ """Get a Flickr30k image by its index.
569
+
570
+ Args:
571
+ index: Index in the dataset (0-31013).
572
+
573
+ Returns:
574
+ PIL Image or None if not found.
575
+ """
576
+ try:
577
+ dataset = get_flickr30k_dataset()
578
+ if 0 <= index < len(dataset):
579
+ return dataset[index]["image"]
580
+ return None
581
+ except Exception:
582
+ return None
583
+
584
+
585
+ def format_results(results: list[dict]) -> list[tuple[PILImage.Image | str, str]]:
586
+ """Format search results for Gradio Gallery.
587
+
588
+ Args:
589
+ results: List of result dicts from MultimodalSearch.
590
+
591
+ Returns:
592
+ List of (image, caption) tuples for gr.Gallery.
593
+ Images are PIL Image objects or paths.
594
+ """
595
+ if not results:
596
+ return []
597
+
598
+ formatted = []
599
+ for result in results:
600
+ index = result.get("index", -1)
601
+ image = get_image_by_index(index)
602
+
603
+ if image is None:
604
+ continue
605
+
606
+ # Get first caption and score
607
+ captions = result.get("captions", [])
608
+ score = result.get("score", 0.0)
609
+
610
+ # Format caption with score
611
+ caption = captions[0] if captions else "No caption"
612
+ caption_with_score = f"[{score:.3f}] {caption}"
613
+
614
+ formatted.append((image, caption_with_score))
615
+
616
+ return formatted
617
+
618
+
619
+ @spaces.GPU
620
+ def search_by_text_handler(query: str, top_k: int) -> list[tuple[str, str]]:
621
+ """Handle text search requests with GPU acceleration.
622
+
623
+ Args:
624
+ query: Text query string.
625
+ top_k: Number of results to return.
626
+
627
+ Returns:
628
+ List of (image_path, caption) tuples.
629
+ """
630
+ if not query or not query.strip():
631
+ return []
632
+
633
+ try:
634
+ engine = get_search_engine()
635
+ results = engine.search_by_text(query.strip(), k=int(top_k))
636
+ return format_results(results)
637
+ except Exception as e:
638
+ raise gr.Error(f"Search failed: {e}")
639
+
640
+
641
+ @spaces.GPU
642
+ def search_by_image_handler(
643
+ image: PILImage.Image | None, top_k: int
644
+ ) -> list[tuple[str, str]]:
645
+ """Handle image search requests with GPU acceleration.
646
+
647
+ Args:
648
+ image: Query image (PIL Image).
649
+ top_k: Number of results to return.
650
+
651
+ Returns:
652
+ List of (image_path, caption) tuples.
653
+ """
654
+ if image is None:
655
+ return []
656
+
657
+ try:
658
+ engine = get_search_engine()
659
+ results = engine.search_by_image(image, k=int(top_k))
660
+ return format_results(results)
661
+ except Exception as e:
662
+ raise gr.Error(f"Search failed: {e}")
663
+
664
+
665
+ @spaces.GPU
666
+ def search_composed_handler(
667
+ image: PILImage.Image | None,
668
+ modification_text: str,
669
+ top_k: int,
670
+ lambda_weight: float,
671
+ ) -> list[tuple[str, str]]:
672
+ """Handle composed image retrieval requests with GPU acceleration.
673
+
674
+ Args:
675
+ image: Reference image (PIL Image).
676
+ modification_text: Text describing desired modification.
677
+ top_k: Number of results to return.
678
+ lambda_weight: Weight for text embedding in CIR.
679
+
680
+ Returns:
681
+ List of (image_path, caption) tuples.
682
+ """
683
+ if image is None:
684
+ return []
685
+
686
+ if not modification_text or not modification_text.strip():
687
+ return []
688
+
689
+ try:
690
+ engine = get_search_engine()
691
+ results = engine.search_composed(
692
+ image,
693
+ modification_text.strip(),
694
+ k=int(top_k),
695
+ lambda_weight=float(lambda_weight),
696
+ )
697
+ return format_results(results)
698
+ except Exception as e:
699
+ raise gr.Error(f"Search failed: {e}")
700
+
701
+
702
+ def get_mode_label(mode: str, lang: str = "en") -> str:
703
+ """Get localized label for search mode."""
704
+ mode_labels = {
705
+ "text": "search_mode_text",
706
+ "image": "search_mode_image",
707
+ "composed": "search_mode_composed",
708
+ "none": "search_mode_none",
709
+ }
710
+ label_key = mode_labels.get(mode, "search_mode_none")
711
+ return L(label_key, lang)
712
+
713
+
714
+ def get_random_samples(top_k: int = 10) -> list[tuple[PILImage.Image | str, str]]:
715
+ """Get random sample images for initial display.
716
+
717
+ Args:
718
+ top_k: Number of random samples to return.
719
+
720
+ Returns:
721
+ List of (image, caption) tuples.
722
+ """
723
+ import random
724
+
725
+ try:
726
+ dataset = get_flickr30k_dataset()
727
+ total = len(dataset)
728
+ indices = random.sample(range(total), min(top_k, total))
729
+
730
+ samples = []
731
+ for idx in indices:
732
+ try:
733
+ item = dataset[idx]
734
+ image = item["image"]
735
+ captions = item.get("captions", item.get("caption", []))
736
+ caption = captions[0] if captions else "No caption"
737
+ samples.append((image, caption))
738
+ except Exception:
739
+ continue
740
+
741
+ return samples
742
+ except Exception as e:
743
+ print(f"Error loading samples: {e}")
744
+ return []
745
+
746
+
747
+ @spaces.GPU
748
+ def unified_search_handler(
749
+ text_query: str,
750
+ image: PILImage.Image | None,
751
+ top_k: int,
752
+ lambda_weight: float,
753
+ ) -> tuple[list[tuple[PILImage.Image | str, str]], str]:
754
+ """Handle unified search requests with GPU acceleration.
755
+
756
+ Automatically determines search mode based on inputs:
757
+ - Text only: text_search
758
+ - Image only: image_search
759
+ - Text + Image: composed_search (CIR)
760
+
761
+ Args:
762
+ text_query: Text query string.
763
+ image: Reference image (PIL Image) or None.
764
+ top_k: Number of results to return.
765
+ lambda_weight: Weight for text embedding in CIR.
766
+
767
+ Returns:
768
+ Tuple of (gallery_results, search_mode_indicator).
769
+ """
770
+ has_text = text_query and text_query.strip()
771
+ has_image = image is not None
772
+
773
+ # Determine search mode
774
+ if has_text and has_image:
775
+ mode = "composed"
776
+ elif has_text:
777
+ mode = "text"
778
+ elif has_image:
779
+ mode = "image"
780
+ else:
781
+ mode = "none"
782
+ return [], mode
783
+
784
+ # Execute search with unified error handling
785
+ results: list = []
786
+ try:
787
+ engine = get_search_engine()
788
+ if mode == "composed":
789
+ results = engine.search_composed(
790
+ image,
791
+ text_query.strip(),
792
+ k=int(top_k),
793
+ lambda_weight=float(lambda_weight),
794
+ )
795
+ elif mode == "text":
796
+ results = engine.search_by_text(text_query.strip(), k=int(top_k))
797
+ elif mode == "image":
798
+ results = engine.search_by_image(image, k=int(top_k))
799
+ except Exception as e:
800
+ raise gr.Error(f"Search failed: {e}")
801
+
802
+ gallery = format_results(results)
803
+ return gallery, mode
804
+
805
+
806
+ def on_gallery_select(
807
+ evt: gr.SelectData,
808
+ gallery_data: list[tuple[PILImage.Image | str, str]],
809
+ ) -> PILImage.Image | None:
810
+ """Handle gallery selection to set reference image.
811
+
812
+ Args:
813
+ evt: Gradio SelectData event containing selection index.
814
+ gallery_data: Current gallery data.
815
+
816
+ Returns:
817
+ Selected PIL Image or None.
818
+ """
819
+ if gallery_data is None or len(gallery_data) == 0:
820
+ return None
821
+
822
+ try:
823
+ idx = evt.index
824
+ if 0 <= idx < len(gallery_data):
825
+ image_data = gallery_data[idx]
826
+ # Gallery data is (image, caption) tuple
827
+ if isinstance(image_data, tuple):
828
+ image = image_data[0]
829
+ if isinstance(image, PILImage.Image):
830
+ return image
831
+ return None
832
+ except Exception:
833
+ return None
834
+
835
+
836
+ def create_app(lang: str = "en") -> gr.Blocks:
837
+ """Create and configure the Gradio application.
838
+
839
+ Args:
840
+ lang: Language code ('en' or 'ko').
841
+
842
+ Returns:
843
+ Gradio Blocks application.
844
+ """
845
+ theme = RefinedTheme()
846
+ with gr.Blocks(title=L("title", lang), theme=theme, css=css) as app:
847
+ # State for gallery data (used by gallery select handler)
848
+ gallery_state = gr.State([])
849
+
850
+ with gr.Column(elem_id="col-container"):
851
+ # Header: 3-column grid [logo] [title - centered] [controls - right]
852
+ with gr.Row(elem_id="header-row"):
853
+ # Left: Logo
854
+ gr.Image(
855
+ value="assets/logo.png",
856
+ show_label=False,
857
+ show_download_button=False,
858
+ show_fullscreen_button=False,
859
+ interactive=False,
860
+ height=80,
861
+ width=80,
862
+ elem_id="header-logo",
863
+ )
864
+ # Center: Title (screen-centered via CSS Grid)
865
+ gr.Markdown(
866
+ f"# {L('title', lang)}\n{L('subtitle', lang)}",
867
+ elem_id="main-title",
868
+ )
869
+ # Right: Controls (single HTML for horizontal layout)
870
+ gr.HTML(value=HEADER_CONTROLS_HTML)
871
+
872
+ # Main content area - unified interface
873
+ with gr.Row(elem_id="main-content"):
874
+ # Left panel: Input controls
875
+ with gr.Column(scale=1, elem_id="input-panel"):
876
+ # Reference Image
877
+ ref_image = gr.Image(
878
+ label=L("reference_image", lang),
879
+ type="pil",
880
+ height=200,
881
+ elem_id="ref-image-container",
882
+ )
883
+
884
+ # Clear image button
885
+ clear_btn = gr.Button(
886
+ L("clear_image", lang),
887
+ variant="secondary",
888
+ size="sm",
889
+ elem_id="clear-image-btn",
890
+ )
891
+
892
+ # Text query
893
+ text_query = gr.Textbox(
894
+ label=L("text_query_label", lang),
895
+ placeholder=L("search_query_placeholder", lang),
896
+ lines=2,
897
+ )
898
+
899
+ # Sliders group
900
+ with gr.Group(elem_classes=["slider-group"]):
901
+ lambda_slider = gr.Slider(
902
+ minimum=0.3,
903
+ maximum=2.0,
904
+ value=1.0,
905
+ step=0.1,
906
+ label=L("lambda_weight", lang),
907
+ info=L("lambda_info", lang),
908
+ )
909
+
910
+ top_k_slider = gr.Slider(
911
+ minimum=1,
912
+ maximum=50,
913
+ value=10,
914
+ step=1,
915
+ label=L("top_k", lang),
916
+ info=L("top_k_info", lang),
917
+ )
918
+
919
+ # Right panel: Results gallery
920
+ with gr.Column(scale=3, elem_id="results-panel"):
921
+ # Search mode indicator
922
+ mode_indicator = gr.Markdown(
923
+ f"<div id='search-mode-indicator' class='mode-none'>"
924
+ f"{L('search_mode_none', lang)}</div>",
925
+ elem_id="mode-indicator-container",
926
+ )
927
+
928
+ # Click hint
929
+ gr.Markdown(
930
+ f"<p class='click-hint'>{L('click_to_select', lang)}</p>",
931
+ )
932
+
933
+ # Results gallery
934
+ results_gallery = gr.Gallery(
935
+ label=L("search_results", lang),
936
+ show_label=True,
937
+ columns=5,
938
+ rows=4,
939
+ height="auto",
940
+ object_fit="cover",
941
+ elem_classes=["gallery-container"],
942
+ allow_preview=True,
943
+ )
944
+
945
+ # Helper function to run unified search and update state
946
+ def search_and_update_state(
947
+ text_query: str,
948
+ image: PILImage.Image | None,
949
+ top_k: int,
950
+ lambda_weight: float,
951
+ ) -> tuple[list, list, str]:
952
+ """Run search and return gallery data, state, and mode indicator."""
953
+ gallery, mode = unified_search_handler(
954
+ text_query, image, top_k, lambda_weight
955
+ )
956
+ mode_html = (
957
+ f"<div id='search-mode-indicator' class='mode-{mode}'>"
958
+ f"{get_mode_label(mode, lang)}</div>"
959
+ )
960
+ return gallery, gallery, mode_html
961
+
962
+ # Real-time search on any input change
963
+ search_inputs = [text_query, ref_image, top_k_slider, lambda_slider]
964
+ search_outputs = [results_gallery, gallery_state, mode_indicator]
965
+
966
+ for input_component in [text_query, ref_image, lambda_slider, top_k_slider]:
967
+ input_component.change(
968
+ fn=search_and_update_state,
969
+ inputs=search_inputs,
970
+ outputs=search_outputs,
971
+ show_progress="hidden",
972
+ )
973
+
974
+ # Clear image button
975
+ clear_btn.click(
976
+ fn=lambda: None,
977
+ inputs=[],
978
+ outputs=[ref_image],
979
+ )
980
+
981
+ # Load initial samples on app start
982
+ def load_initial_samples(top_k: int) -> tuple[list, list]:
983
+ """Load random samples for initial display."""
984
+ samples = get_random_samples(int(top_k))
985
+ return samples, samples
986
+
987
+ app.load(
988
+ fn=load_initial_samples,
989
+ inputs=[top_k_slider],
990
+ outputs=[results_gallery, gallery_state],
991
+ )
992
+
993
+ # Initialize theme on page load
994
+ app.load(fn=None, js=INIT_THEME_JS)
995
+
996
+ # Initialize language from URL query parameter
997
+ app.load(fn=None, js=INIT_LANG_JS)
998
+
999
+ return app
1000
+
1001
+
1002
+ def get_default_lang() -> str:
1003
+ """Get default language from environment or URL."""
1004
+ return os.environ.get("APP_LANG", "en")
1005
+
1006
+
1007
+ # Create the demo instance for module-level access
1008
+ demo = create_app(get_default_lang())
1009
+
1010
+
1011
+ if __name__ == "__main__":
1012
+ lang = get_default_lang()
1013
+ app = create_app(lang)
1014
+ app.queue(max_size=30).launch(
1015
+ server_name="0.0.0.0",
1016
+ server_port=7860,
1017
+ show_error=True,
1018
+ )
assets/logo.png ADDED
core/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ """Core module for multimodal search MVP."""
2
+
3
+ from core.embeddings import EmbeddingModel
4
+ from core.index import FaissIndex
5
+ from core.search import MultimodalSearch
6
+
7
+ __all__ = ["EmbeddingModel", "FaissIndex", "MultimodalSearch"]
core/embeddings.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SigLIP 2 embedding model wrapper."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from tqdm import tqdm
11
+ from transformers import AutoModel, AutoProcessor
12
+
13
+ if TYPE_CHECKING:
14
+ from PIL import Image
15
+
16
+
17
+ class EmbeddingModel:
18
+ """SigLIP 2 embedding model for text and image encoding.
19
+
20
+ Model: google/siglip2-so400m-patch14-384
21
+ Dimension: 1152
22
+ """
23
+
24
+ MODEL_ID = "google/siglip2-so400m-patch14-384"
25
+ EMBEDDING_DIM = 1152
26
+
27
+ def __init__(self, device: str = "cpu") -> None:
28
+ """Initialize the embedding model.
29
+
30
+ Args:
31
+ device: Device to run the model on ('cpu' or 'cuda').
32
+ """
33
+ self.device = device
34
+ self.model = None
35
+ self.processor = None
36
+
37
+ def load(self) -> None:
38
+ """Load the model and processor."""
39
+ self.processor = AutoProcessor.from_pretrained(self.MODEL_ID)
40
+ self.model = AutoModel.from_pretrained(self.MODEL_ID)
41
+ self.model.to(self.device)
42
+ # Set model to evaluation mode (disable dropout, etc.)
43
+ self.model.train(False)
44
+
45
+ def _ensure_loaded(self) -> None:
46
+ """Ensure model is loaded before inference."""
47
+ if self.model is None or self.processor is None:
48
+ self.load()
49
+
50
+ def encode_image(self, image: Image.Image) -> np.ndarray:
51
+ """Encode a single image to embedding vector.
52
+
53
+ Args:
54
+ image: PIL Image to encode.
55
+
56
+ Returns:
57
+ Normalized embedding vector of shape (1152,).
58
+ """
59
+ self._ensure_loaded()
60
+
61
+ inputs = self.processor(images=image, return_tensors="pt")
62
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
63
+
64
+ with torch.no_grad():
65
+ features = self.model.get_image_features(**inputs)
66
+ features = F.normalize(features, dim=-1)
67
+
68
+ return features.cpu().numpy().squeeze(0)
69
+
70
+ def encode_images(
71
+ self,
72
+ images: list[Image.Image],
73
+ batch_size: int = 32,
74
+ show_progress: bool = True,
75
+ ) -> np.ndarray:
76
+ """Encode multiple images to embedding vectors.
77
+
78
+ Args:
79
+ images: List of PIL Images to encode.
80
+ batch_size: Batch size for processing.
81
+ show_progress: Show progress bar.
82
+
83
+ Returns:
84
+ Normalized embedding vectors of shape (N, 1152).
85
+ """
86
+ if not images:
87
+ return np.empty((0, self.EMBEDDING_DIM), dtype=np.float32)
88
+
89
+ self._ensure_loaded()
90
+
91
+ all_embeddings = []
92
+ iterator = range(0, len(images), batch_size)
93
+ if show_progress:
94
+ iterator = tqdm(iterator, desc="Encoding images", unit="batch")
95
+
96
+ for i in iterator:
97
+ batch_images = images[i : i + batch_size]
98
+ inputs = self.processor(images=batch_images, return_tensors="pt")
99
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
100
+
101
+ with torch.no_grad():
102
+ features = self.model.get_image_features(**inputs)
103
+ features = F.normalize(features, dim=-1)
104
+
105
+ all_embeddings.append(features.cpu().numpy())
106
+
107
+ return np.concatenate(all_embeddings, axis=0)
108
+
109
+ def encode_text(self, text: str) -> np.ndarray:
110
+ """Encode a single text to embedding vector.
111
+
112
+ Args:
113
+ text: Text string to encode.
114
+
115
+ Returns:
116
+ Normalized embedding vector of shape (1152,).
117
+ """
118
+ self._ensure_loaded()
119
+
120
+ # SigLIP requires padding="max_length" as trained
121
+ inputs = self.processor(
122
+ text=text,
123
+ padding="max_length",
124
+ truncation=True,
125
+ return_tensors="pt",
126
+ )
127
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
128
+
129
+ with torch.no_grad():
130
+ features = self.model.get_text_features(**inputs)
131
+ features = F.normalize(features, dim=-1)
132
+
133
+ return features.cpu().numpy().squeeze(0)
134
+
135
+ def encode_texts(
136
+ self,
137
+ texts: list[str],
138
+ batch_size: int = 32,
139
+ show_progress: bool = True,
140
+ ) -> np.ndarray:
141
+ """Encode multiple texts to embedding vectors.
142
+
143
+ Args:
144
+ texts: List of text strings to encode.
145
+ batch_size: Batch size for processing.
146
+ show_progress: Show progress bar.
147
+
148
+ Returns:
149
+ Normalized embedding vectors of shape (N, 1152).
150
+ """
151
+ if not texts:
152
+ return np.empty((0, self.EMBEDDING_DIM), dtype=np.float32)
153
+
154
+ self._ensure_loaded()
155
+
156
+ all_embeddings = []
157
+ iterator = range(0, len(texts), batch_size)
158
+ if show_progress:
159
+ iterator = tqdm(iterator, desc="Encoding texts", unit="batch")
160
+
161
+ for i in iterator:
162
+ batch_texts = texts[i : i + batch_size]
163
+ inputs = self.processor(
164
+ text=batch_texts,
165
+ padding="max_length",
166
+ truncation=True,
167
+ return_tensors="pt",
168
+ )
169
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
170
+
171
+ with torch.no_grad():
172
+ features = self.model.get_text_features(**inputs)
173
+ features = F.normalize(features, dim=-1)
174
+
175
+ all_embeddings.append(features.cpu().numpy())
176
+
177
+ return np.concatenate(all_embeddings, axis=0)
core/index.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FAISS index wrapper for efficient similarity search."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ from pathlib import Path
7
+ from typing import TYPE_CHECKING
8
+
9
+ import faiss
10
+ import numpy as np
11
+
12
+ if TYPE_CHECKING:
13
+ pass
14
+
15
+
16
+ class FaissIndex:
17
+ """FAISS IndexFlatIP wrapper for normalized vector search.
18
+
19
+ Uses Inner Product (IP) for cosine similarity with pre-normalized vectors.
20
+ """
21
+
22
+ def __init__(self, dimension: int = 1152, device: str = "cpu") -> None:
23
+ """Initialize FAISS index.
24
+
25
+ Args:
26
+ dimension: Embedding dimension (1152 for SigLIP 2).
27
+ device: Device to use ('cpu' or 'cuda').
28
+ """
29
+ self.dimension = dimension
30
+ self.device = device
31
+ self.index: faiss.Index | None = None
32
+ self.metadata: list[dict] = []
33
+ self._gpu_resources: faiss.StandardGpuResources | None = None
34
+
35
+ def _cleanup_gpu_resources(self) -> None:
36
+ """Release GPU resources to prevent memory leaks."""
37
+ if self._gpu_resources is not None:
38
+ self._gpu_resources = None
39
+ self.index = None
40
+ import gc
41
+ gc.collect()
42
+
43
+ def __del__(self) -> None:
44
+ """Destructor to clean up GPU resources."""
45
+ self._cleanup_gpu_resources()
46
+
47
+ def build(self, embeddings: np.ndarray, metadata: list[dict] | None = None) -> None:
48
+ """Build the index from embeddings.
49
+
50
+ Args:
51
+ embeddings: Normalized embedding vectors of shape (N, dimension).
52
+ metadata: Optional list of metadata dicts for each embedding.
53
+
54
+ Raises:
55
+ ValueError: If embeddings are empty, wrong dimension, or metadata mismatch.
56
+ """
57
+ if embeddings.size == 0:
58
+ raise ValueError("Cannot build index from empty embeddings")
59
+
60
+ if embeddings.shape[1] != self.dimension:
61
+ raise ValueError(
62
+ f"Embedding dimension {embeddings.shape[1]} does not match "
63
+ f"index dimension {self.dimension}"
64
+ )
65
+
66
+ if metadata is not None and len(metadata) != len(embeddings):
67
+ raise ValueError(
68
+ f"Metadata length {len(metadata)} does not match "
69
+ f"embeddings count {len(embeddings)}"
70
+ )
71
+
72
+ # Convert to float32 if needed (FAISS requirement)
73
+ if embeddings.dtype != np.float32:
74
+ embeddings = embeddings.astype(np.float32)
75
+
76
+ # Clean up existing GPU resources before rebuilding
77
+ self._cleanup_gpu_resources()
78
+
79
+ # Create IndexFlatIP for inner product (cosine similarity with normalized vectors)
80
+ self.index = faiss.IndexFlatIP(self.dimension)
81
+
82
+ # Move to GPU if requested
83
+ if self.device == "cuda":
84
+ self._gpu_resources = faiss.StandardGpuResources()
85
+ self.index = faiss.index_cpu_to_gpu(self._gpu_resources, 0, self.index)
86
+
87
+ # Add embeddings to index
88
+ self.index.add(embeddings)
89
+
90
+ # Store metadata
91
+ self.metadata = metadata if metadata is not None else [{} for _ in range(len(embeddings))]
92
+
93
+ def search(
94
+ self, query: np.ndarray, k: int = 10
95
+ ) -> tuple[np.ndarray, np.ndarray, list[dict]]:
96
+ """Search for k nearest neighbors.
97
+
98
+ Args:
99
+ query: Query embedding of shape (1, dimension) or (dimension,).
100
+ k: Number of results to return.
101
+
102
+ Returns:
103
+ Tuple of (scores, indices, metadata_list).
104
+
105
+ Raises:
106
+ ValueError: If index is not built.
107
+ """
108
+ if self.index is None:
109
+ raise ValueError("Index not built. Call build() first.")
110
+
111
+ # Reshape 1D query to 2D
112
+ if query.ndim == 1:
113
+ query = query.reshape(1, -1)
114
+
115
+ # Convert to float32 if needed
116
+ if query.dtype != np.float32:
117
+ query = query.astype(np.float32)
118
+
119
+ # Limit k to index size
120
+ k = min(k, self.index.ntotal)
121
+
122
+ # Perform search
123
+ scores, indices = self.index.search(query, k)
124
+
125
+ # Flatten results (single query)
126
+ scores = scores[0]
127
+ indices = indices[0]
128
+
129
+ # Get metadata for results
130
+ result_metadata = [self.metadata[idx] for idx in indices]
131
+
132
+ return scores, indices, result_metadata
133
+
134
+ def save(self, path: str | Path) -> None:
135
+ """Save index and metadata to disk.
136
+
137
+ Args:
138
+ path: Path to save the index (without extension).
139
+ Creates {path}.faiss and {path}.json files.
140
+
141
+ Raises:
142
+ ValueError: If index is not built.
143
+ """
144
+ if self.index is None:
145
+ raise ValueError("Index not built. Call build() first.")
146
+
147
+ path = Path(path)
148
+ path.parent.mkdir(parents=True, exist_ok=True)
149
+
150
+ # Convert GPU index to CPU before saving
151
+ index_to_save = self.index
152
+ if self.device == "cuda":
153
+ index_to_save = faiss.index_gpu_to_cpu(self.index)
154
+
155
+ # Save FAISS index
156
+ faiss.write_index(index_to_save, str(path.with_suffix(".faiss")))
157
+
158
+ # Save metadata as JSON
159
+ with open(path.with_suffix(".json"), "w", encoding="utf-8") as f:
160
+ json.dump(
161
+ {"dimension": self.dimension, "metadata": self.metadata},
162
+ f,
163
+ ensure_ascii=False,
164
+ )
165
+
166
+ def load(self, path: str | Path) -> None:
167
+ """Load index and metadata from disk.
168
+
169
+ Args:
170
+ path: Path to load the index from (without extension).
171
+ Expects {path}.faiss and {path}.json files.
172
+ """
173
+ path = Path(path)
174
+
175
+ # Clean up existing GPU resources before loading
176
+ self._cleanup_gpu_resources()
177
+
178
+ # Load FAISS index
179
+ self.index = faiss.read_index(str(path.with_suffix(".faiss")))
180
+
181
+ # Move to GPU if requested
182
+ if self.device == "cuda":
183
+ self._gpu_resources = faiss.StandardGpuResources()
184
+ self.index = faiss.index_cpu_to_gpu(self._gpu_resources, 0, self.index)
185
+
186
+ # Load metadata
187
+ with open(path.with_suffix(".json"), encoding="utf-8") as f:
188
+ data = json.load(f)
189
+ self.dimension = data["dimension"]
190
+ self.metadata = data["metadata"]
core/search.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Multimodal search logic for Text, Image, and Composed queries."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING
6
+
7
+ import numpy as np
8
+
9
+ if TYPE_CHECKING:
10
+ from PIL import Image
11
+
12
+ from core.embeddings import EmbeddingModel
13
+ from core.index import FaissIndex
14
+
15
+
16
+ class MultimodalSearch:
17
+ """Unified search interface for Text, Image, and CIR (Composed Image Retrieval).
18
+
19
+ CIR Zero-shot formula: q = normalize(img_emb + lambda * text_emb)
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ embedding_model: EmbeddingModel,
25
+ index: FaissIndex,
26
+ default_lambda: float = 1.0,
27
+ ) -> None:
28
+ """Initialize multimodal search.
29
+
30
+ Args:
31
+ embedding_model: Loaded EmbeddingModel instance.
32
+ index: Loaded FaissIndex instance.
33
+ default_lambda: Default weight for text embedding in CIR.
34
+ """
35
+ self.embedding_model = embedding_model
36
+ self.index = index
37
+ self.default_lambda = default_lambda
38
+
39
+ def _format_results(
40
+ self,
41
+ scores: np.ndarray,
42
+ indices: np.ndarray,
43
+ metadata_list: list[dict],
44
+ ) -> list[dict]:
45
+ """Format FAISS search results into list of result dicts.
46
+
47
+ Args:
48
+ scores: Score array from FAISS.
49
+ indices: Index array from FAISS.
50
+ metadata_list: Metadata list from FAISS.
51
+
52
+ Returns:
53
+ List of result dicts with 'score', 'index', and metadata fields.
54
+ """
55
+ if len(scores) == 0:
56
+ return []
57
+
58
+ results = []
59
+ for score, idx, meta in zip(scores, indices, metadata_list):
60
+ result = {
61
+ "score": float(score),
62
+ "index": int(idx),
63
+ }
64
+ result.update(meta)
65
+ results.append(result)
66
+ return results
67
+
68
+ def search_by_text(self, query: str, k: int = 10) -> list[dict]:
69
+ """Search images by text query.
70
+
71
+ Args:
72
+ query: Text query string.
73
+ k: Number of results to return.
74
+
75
+ Returns:
76
+ List of result dicts with 'score', 'index', and metadata.
77
+ """
78
+ text_embedding = self.embedding_model.encode_text(query)
79
+ scores, indices, metadata_list = self.index.search(text_embedding, k=k)
80
+ return self._format_results(scores, indices, metadata_list)
81
+
82
+ def search_by_image(self, image: Image.Image, k: int = 10) -> list[dict]:
83
+ """Search similar images by reference image.
84
+
85
+ Args:
86
+ image: Query image.
87
+ k: Number of results to return.
88
+
89
+ Returns:
90
+ List of result dicts with 'score', 'index', and metadata.
91
+ """
92
+ image_embedding = self.embedding_model.encode_image(image)
93
+ scores, indices, metadata_list = self.index.search(image_embedding, k=k)
94
+ return self._format_results(scores, indices, metadata_list)
95
+
96
+ def search_composed(
97
+ self,
98
+ image: Image.Image,
99
+ modification_text: str,
100
+ k: int = 10,
101
+ lambda_weight: float | None = None,
102
+ ) -> list[dict]:
103
+ """Composed Image Retrieval: find images matching (image + text modification).
104
+
105
+ Uses zero-shot CIR: q = normalize(img_emb + lambda * text_emb)
106
+
107
+ Args:
108
+ image: Reference image.
109
+ modification_text: Text describing desired modification.
110
+ k: Number of results to return.
111
+ lambda_weight: Weight for text embedding (uses default if None).
112
+
113
+ Returns:
114
+ List of result dicts with 'score', 'index', and metadata.
115
+ """
116
+ if lambda_weight is None:
117
+ lambda_weight = self.default_lambda
118
+
119
+ image_embedding = self.embedding_model.encode_image(image)
120
+ text_embedding = self.embedding_model.encode_text(modification_text)
121
+ composed_query = self._compose_query(image_embedding, text_embedding, lambda_weight)
122
+
123
+ scores, indices, metadata_list = self.index.search(composed_query, k=k)
124
+ return self._format_results(scores, indices, metadata_list)
125
+
126
+ def _compose_query(
127
+ self,
128
+ image_emb: np.ndarray,
129
+ text_emb: np.ndarray,
130
+ lambda_weight: float,
131
+ ) -> np.ndarray:
132
+ """Compose query embedding from image and text.
133
+
134
+ Args:
135
+ image_emb: Image embedding vector.
136
+ text_emb: Text embedding vector.
137
+ lambda_weight: Weight for text embedding.
138
+
139
+ Returns:
140
+ Normalized composed query embedding.
141
+ """
142
+ composed = image_emb + lambda_weight * text_emb
143
+ norm = np.linalg.norm(composed)
144
+ if norm > 0:
145
+ composed = composed / norm
146
+ return composed
data/embeddings/image_index.faiss ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:00fffaa18facc712fd49346af90954c47afa8ad7ec3453a782e0afe463e602e8
3
+ size 142912557
data/embeddings/image_index.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:48615896a09252765a532b70153cec8b35d2c537a16abdfa9a8b71e2a73079aa
3
+ size 12842559
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core dependencies for HF Spaces
2
+ transformers>=4.49
3
+ torch>=2.0
4
+ faiss-cpu>=1.7.4
5
+ gradio>=5.0,<6.0
6
+ pillow>=10.0
7
+ numpy>=1.24
8
+ datasets>=2.14,<4.0
9
+ huggingface-hub>=0.20
10
+ tqdm>=4.65
11
+
12
+ # HF Spaces ZeroGPU
13
+ spaces