File size: 10,697 Bytes
3f8c153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
import gradio as gr
import torch
from safetensors.torch import load_file
from pathlib import Path
import json
from functools import lru_cache

# Import utilities
from models.model_loader import load_embed_model, load_rerank_model
from utils.search import search_embeddings, rerank_results
from utils.visualization import plot_results

# ============================================================================
# Data Loading Functions
# ============================================================================

@lru_cache(maxsize=1)
def load_metadata(metadata_path: str = "data/recipes/metadata.json") -> dict:
    """Load metadata.json (cached)."""
    with open(metadata_path, "r", encoding="utf-8") as f:
        return json.load(f)


def get_recipes_and_paths(metadata_path: str = "data/recipes/metadata.json"):
    """Load recipes, image paths, and full data."""
    print("📋 Loading metadata...")
    metadata = load_metadata(metadata_path)
    recipes = metadata["recipes"]
    
    all_recipes_markdown = [r["markdown"] for r in recipes]
    all_image_paths = [r["image_path"] for r in recipes]
    
    print(f"✅ Loaded {len(recipes)} recipes")
    return all_recipes_markdown, all_image_paths, recipes


# ============================================================================
# Load Models & Data (once at startup)
# ============================================================================

print("🔄 Loading models...")
embed_model, embed_device = load_embed_model()
rerank_model, rerank_processor, rerank_device = load_rerank_model()

print("📦 Loading embeddings...")

# Load Image embeddings
print("   - Loading image embeddings...")
image_embeddings = load_file("data/embeddings/image_embedding_500.safetensors")["image_embeddings_500"]
print(f"     ✅ Image embeddings shape: {image_embeddings.shape}")

# Load Text embeddings
print("   - Loading text embeddings...")
text_embeddings = load_file("data/embeddings/text_embeddings_500.safetensors")["text_embeddings_500"]
print(f"     ✅ Text embeddings shape: {text_embeddings.shape}")

# Load Image+Text embeddings
print("   - Loading image+text embeddings...")
image_text_embeddings = load_file("data/embeddings/image_text_embeddings_500.safetensors")["image_text_embeddings"]
print(f"     ✅ Image+Text embeddings shape: {image_text_embeddings.shape}")

print("📋 Loading recipe data...")
all_recipes, all_image_paths, full_recipes = get_recipes_and_paths()

print("\n✅ System ready!")
print(f"   📚 Recipes loaded: {len(all_recipes)}")
print(f"   🖼️  Images loaded: {len(all_image_paths)}")
print(f"   📊 Embeddings:")
print(f"      - Image: {image_embeddings.shape}")
print(f"      - Text: {text_embeddings.shape}")
print(f"      - Image+Text: {image_text_embeddings.shape}")

# ============================================================================
# Gradio Interface
# ============================================================================

def search_and_display(
    query: str,
    modality: str,
    use_rerank: bool,
    num_results: int = 3
):
    """Main search function."""
    
    if not query.strip():
        return None, {"error": "Please enter a search query"}
    
    # Select embeddings and documents based on modality
    if modality == "Image":
        embeddings = image_embeddings
        documents = [""] * len(all_image_paths)  # Empty for image-only
        search_modality = "image"
    elif modality == "Text":
        embeddings = text_embeddings
        documents = all_recipes
        search_modality = "text"
    else:  # Image+Text
        embeddings = image_text_embeddings
        documents = all_recipes
        search_modality = "image_text"
    
    print(f"\n🔍 Searching with modality: {modality}")
    print(f"   Embeddings shape: {embeddings.shape}")
    print(f"   Query: {query[:50]}...")
    
    # Initial search
    results = search_embeddings(
        query=query,
        query_image=None,
        model=embed_model,
        embeddings=embeddings,
        documents=documents,
        image_paths=all_image_paths,
        top_k=20,
        modality=search_modality
    )
    
    print(f"   ✅ Found {len(results)} results")
    
    # Optional reranking (only for image/image_text modalities)
    if use_rerank and modality != "Text":
        print(f"   🎯 Reranking top {min(10, len(results))} results...")
        results = rerank_results(
            query=query,
            results=results,
            rerank_model=rerank_model,
            rerank_processor=rerank_processor,
            device=rerank_device,
            top_k=min(10, len(results))
        )
        print(f"   ✅ Reranking complete")
    
    # Add full recipe details to results
    for i, result in enumerate(results[:num_results]):
        # Find matching recipe index
        img_path = result.get("image_path")
        if img_path:
            try:
                idx = all_image_paths.index(img_path)
                result["recipe_details"] = {
                    "name": full_recipes[idx]["name"],
                    "description": full_recipes[idx]["description"][:200] + "...",
                    "tags": full_recipes[idx]["tags"][:5],
                    "ingredients_count": len(full_recipes[idx]["ingredients"]),
                    "steps_count": len(full_recipes[idx]["steps"]),
                }
            except (ValueError, IndexError):
                pass
    
    # Visualize
    if modality != "Text":
        output_img = plot_results(results, query, num_images=num_results)
        return output_img, results[:num_results]
    else:
        # For text-only, return formatted text
        text_output = "\n\n".join([
            f"**Rank {r['rank']}** (Score: {r['score']:.4f})\n{r.get('text', '')[:500]}..."
            for r in results[:num_results]
        ])
        return None, results[:num_results]


# ============================================================================
# Gradio UI
# ============================================================================

with gr.Blocks(title="🍳 Multimodal Recipe Search", theme=gr.themes.Soft()) as demo:
    gr.Markdown(
        """
        # 🍳 Multimodal Recipe RAG System
        Search **500 recipes** using text queries across images and documents.
        
        **Three Search Modes:**
        - 🖼️ **Image**: Visual similarity search (find similar-looking dishes)
        - 📝 **Text**: Semantic text search (find by ingredients, instructions, reviews)
        - 🎨 **Image+Text**: Combined multimodal search (best of both worlds)
        
        Powered by **NVIDIA Nemotron Embed-VL** and **Rerank-VL** models.
        """
    )
    
    with gr.Row():
        with gr.Column(scale=1):
            query_input = gr.Textbox(
                label="🔍 Search Query",
                placeholder="e.g., 'chocolate cake recipe', 'healthy breakfast', 'pasta with tomatoes'",
                lines=2
            )
            
            modality_radio = gr.Radio(
                choices=["Image", "Text", "Image+Text"],
                value="Image+Text",  # ✅ Default to Image+Text
                label="📊 Search Modality",
                info="Choose how to search: visual, text, or combined"
            )
            
            rerank_check = gr.Checkbox(
                label="🎯 Use Reranking (Cross-Encoder)",
                value=True,
                info="Rerank top results for better accuracy (adds ~1-2s)"
            )
            
            num_results_slider = gr.Slider(
                minimum=1,
                maximum=5,
                value=3,
                step=1,
                label="📈 Number of Results to Display"
            )
            
            search_btn = gr.Button("🚀 Search", variant="primary", size="lg")
            
            # Add info box
            gr.Markdown(
                """
                **💡 Tips:**
                - Use **Image** for visual similarity (e.g., "desserts with chocolate")
                - Use **Text** for ingredient/instruction search (e.g., "vegetarian pasta")
                - Use **Image+Text** for best overall results
                """
            )
        
        with gr.Column(scale=2):
            output_image = gr.Image(label="🖼️ Top Recipe Results", type="pil")
            output_json = gr.JSON(label="📋 Detailed Results")
    
    # Examples for each modality
    with gr.Accordion("💡 Example Queries", open=False):
        gr.Examples(
            examples=[
                # Image searches
                ["recipes with steak", "Image", True, 3],
                ["chocolate desserts", "Image", True, 3],
                ["colorful salads", "Image", False, 4],
                
                # Text searches
                ["healthy breakfast ideas", "Text", False, 5],
                ["vegetarian meals with pasta", "Text", False, 3],
                ["quick dinner under 30 minutes", "Text", False, 4],
                
                # Image+Text searches (best results)
                ["creamy pasta dishes", "Image+Text", True, 3],
                ["spicy chicken recipes", "Image+Text", True, 3],
                ["fresh summer salads", "Image+Text", True, 4],
            ],
            inputs=[query_input, modality_radio, rerank_check, num_results_slider],
            label="Try these examples"
        )
    
    # Footer
    gr.Markdown(
        """
        ---
        ### 🔧 Technical Details
        
        **Models:**
        - 🔍 Embedding: [nvidia/llama-nemotron-embed-vl-1b-v2](https://huggingface.co/nvidia/llama-nemotron-embed-vl-1b-v2)
        - 🎯 Reranking: [nvidia/llama-nemotron-rerank-vl-1b-v2](https://huggingface.co/nvidia/llama-nemotron-rerank-vl-1b-v2)
        
        **Dataset:** [TurkishCodeMan/recipe-synthetic-images-10k](https://huggingface.co/datasets/TurkishCodeMan/recipe-synthetic-images-10k) (500 recipes)
        
        **Embeddings:**
        - Image: 2048-dim visual embeddings
        - Text: 2048-dim semantic embeddings  
        - Image+Text: 2048-dim multimodal embeddings
        """
    )
    
    # Connect button
    search_btn.click(
        fn=search_and_display,
        inputs=[query_input, modality_radio, rerank_check, num_results_slider],
        outputs=[output_image, output_json]
    )

# ============================================================================
# Launch
# ============================================================================

if __name__ == "__main__":
    demo.launch(
        server_name="0.0.0.0",
        server_port=7860,
        share=False,
        show_error=True
    )