File size: 4,146 Bytes
19567a7
b2aba87
 
 
 
 
19567a7
 
 
b2aba87
19567a7
36e2a11
1933b1d
b2aba87
19567a7
 
1933b1d
 
0ffa00f
19567a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2aba87
19567a7
 
26e1d1d
19567a7
 
 
 
 
 
b2aba87
 
19567a7
 
 
b2aba87
19567a7
 
 
 
 
b2aba87
19567a7
 
 
 
 
 
 
 
26e1d1d
19567a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc50f93
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115

import gradio as gr
import torch
import pandas as pd
import numpy as np
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
from datasets import load_dataset
from torch.nn import functional as F

# --- 1. SETUP & CONFIG ---
MODEL_ID = "openai/clip-vit-base-patch32"
DATA_FILE = "food_embeddings_clip.parquet"

print("⏳ Starting App... Loading Model...")
# Load Model (CPU is fine for inference on single images)
model = CLIPModel.from_pretrained(MODEL_ID)
processor = CLIPProcessor.from_pretrained(MODEL_ID)

# --- 2. LOAD DATA (Must match Colab logic EXACTLY) ---
print("⏳ Loading Dataset (this takes a moment)...")
# We load the same 5000 images using the same seed so indices match the parquet file
dataset = load_dataset("ethz/food101", split="train").shuffle(seed=42).select(range(5000))

# --- 3. LOAD EMBEDDINGS ---
print("⏳ Loading Pre-computed Embeddings...")
df = pd.read_parquet(DATA_FILE)
# Convert the list of numbers in the parquet back to a Torch Tensor
db_features = torch.tensor(np.stack(df['embedding'].to_numpy()))
# Normalize once for speed
db_features = F.normalize(db_features, p=2, dim=1)

print("✅ App Ready!")

# --- 4. CORE SEARCH LOGIC ---
def find_best_matches(query_features, top_k=3):
    # Normalize query
    query_features = F.normalize(query_features, p=2, dim=1)
    
    # Calculate Similarity (Dot Product)
    # Query (1x512) * DB (5000x512) = Scores (1x5000)
    similarity = torch.mm(query_features, db_features.T)
    
    # Get Top K
    scores, indices = torch.topk(similarity, k=top_k)
    
    results = []
    for idx, score in zip(indices[0], scores[0]):
        idx = idx.item()
        
        # Grab image and info from the loaded dataset
        img = dataset[idx]['image']
        label = df.iloc[idx]['label_name'] # Get label from our dataframe
        
        # Format output
        results.append((img, f"{label} ({score:.2f})"))
    return results

# --- 5. GRADIO FUNCTIONS ---
def search_by_image(input_image):
    if input_image is None: return []
    
    inputs = processor(images=input_image, return_tensors="pt")
    with torch.no_grad():
        features = model.get_image_features(**inputs)
        
    return find_best_matches(features)

def search_by_text(input_text):
    if not input_text: return []
    
    inputs = processor(text=[input_text], return_tensors="pt", padding=True)
    with torch.no_grad():
        features = model.get_text_features(**inputs)
        
    return find_best_matches(features)

# --- 6. BUILD UI ---
with gr.Blocks(title="Food Matcher AI") as demo:
    gr.Markdown("# 🍔 Visual Dish Matcher")
    gr.Markdown("Upload a photo of food (or describe it) to find similar dishes in our database.")
    
    # --- VIDEO SECTION ---
    # Using Accordion so it doesn't clutter the UI. Open=False means it starts closed.
    with gr.Accordion("📺 Watch Project Demo", open=False):
        gr.HTML("""
            <div style="display: flex; justify-content: center;">
                <iframe width="560" height="315"
                    src="https://www.youtube.com/embed/IXeIxYHi0Es"
                    title="YouTube video player"
                    frameborder="0"
                    allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture"
                    allowfullscreen>
                </iframe>
            </div>
        """)
    # ----------------------------
    
    with gr.Tab("Image Search"):
        with gr.Row():
            img_input = gr.Image(type="pil", label="Upload Food Image")
            img_gallery = gr.Gallery(label="Top Matches")
        btn_img = gr.Button("Find Similar Dishes")
        btn_img.click(search_by_image, inputs=img_input, outputs=img_gallery)

    with gr.Tab("Text Search"):
        with gr.Row():
            txt_input = gr.Textbox(label="Describe the food (e.g., 'Spicy Tacos')")
            txt_gallery = gr.Gallery(label="Top Matches")
        btn_txt = gr.Button("Search by Description")
        btn_txt.click(search_by_text, inputs=txt_input, outputs=txt_gallery)

# Launch (Disable SSR for stability)
demo.launch()