MatanKriel commited on
Commit
9bb2979
Β·
verified Β·
1 Parent(s): 0ffa00f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -124
app.py CHANGED
@@ -2,149 +2,96 @@ import gradio as gr
2
  import torch
3
  import pandas as pd
4
  import numpy as np
5
- import os
6
  from PIL import Image
7
- from transformers import AutoProcessor, AutoModel
8
- from datasets import load_dataset
9
- from torch.nn import functional as F
10
 
11
- # --- 1. SETUP & CONFIG ---
12
- MODEL_ID = "google/siglip-base-patch16-224"
13
- DATA_FILE = "food_embeddings_siglip.parquet"
 
14
 
15
- print(f"⏳ Starting App... Loading Model: {MODEL_ID}...")
16
- try:
17
- model = AutoModel.from_pretrained(MODEL_ID)
18
- processor = AutoProcessor.from_pretrained(MODEL_ID)
19
- except Exception as e:
20
- print(f"❌ Model Error: {e}")
21
 
22
- # --- 2. LOAD DATA ---
23
- # --- 2. LOAD DATA (SMART MATCHING) ---
24
- print("⏳ Loading Dataset...")
25
 
26
- # 1. Load the Embeddings File FIRST
27
  df = pd.read_parquet(DATA_FILE)
28
- valid_indices = df.index.tolist() # Assuming you preserved the original indices in the dataframe index
29
- # OR if you reset the index in the notebook, we just check the length:
30
- num_embeddings = len(df)
31
 
32
- print(f" πŸ‘‰ Embeddings file has {num_embeddings} rows.")
33
-
34
- # 2. Load the Dataset
35
- dataset_full = load_dataset("ethz/food101", split="train").shuffle(seed=42).select(range(5000))
36
-
37
- # 3. CRITICAL FIX: If lengths don't match, we assume the parquet is a subset.
38
- # (This is a guess - if you didn't save the original indices, this might still be slightly off,
39
- # but it prevents the 'IndexError' crash).
40
- if len(dataset_full) > num_embeddings:
41
- print(f"⚠️ DATA MISMATCH DETECTED: Dataset has {len(dataset_full)} but Parquet has {num_embeddings}.")
42
- print(" βœ‚οΈ Truncating dataset to match Parquet length...")
43
- dataset = dataset_full.select(range(num_embeddings))
44
- else:
45
- dataset = dataset_full
46
-
47
- print(f"βœ… Final Dataset Size: {len(dataset)}")
48
-
49
- # --- 3. LOAD EMBEDDINGS ---
50
- print(f"⏳ Loading Embeddings from {DATA_FILE}...")
51
- try:
52
- df = pd.read_parquet(DATA_FILE)
53
- db_features = torch.tensor(np.stack(df['embedding'].to_numpy()))
54
- db_features = F.normalize(db_features, p=2, dim=1)
55
- print("βœ… System Ready!")
56
- except Exception as e:
57
- print(f"❌ Error loading parquet file: {e}")
58
- db_features = None
59
 
60
- # --- 4. CORE SEARCH LOGIC (SAFE MODE) ---
61
- def find_best_matches(query_features, top_k=3):
62
- if db_features is None:
63
- return []
64
 
65
- # Normalize query
66
- query_features = F.normalize(query_features, p=2, dim=1)
67
 
68
- # Similarity Search
69
- similarity = torch.mm(query_features, db_features.T)
70
- scores, indices = torch.topk(similarity, k=top_k)
71
 
 
72
  results = []
73
- for idx, score in zip(indices[0], scores[0]):
74
  idx = idx.item()
 
75
 
76
- # 1. Get the raw image
77
- img_data = dataset[idx]['image']
78
-
79
- # 2. Resize it to be small & fast (300x300 max)
80
- img_data.thumbnail((300, 300))
81
-
82
- # 3. Save to a temporary path (prevents the "Too much data" crash)
83
- save_path = f"/tmp/temp_result_{idx}.jpg"
84
- img_data.save(save_path)
85
-
86
- label = df.iloc[idx]['label_name']
87
-
88
- # 4. Return the PATH (string), NOT the image object
89
- results.append((save_path, f"{label} ({score:.2f})"))
90
 
91
  return results
92
 
93
- # --- 5. GRADIO FUNCTIONS ---
94
- def search_by_image(input_image):
95
- if input_image is None: return []
96
- inputs = processor(images=input_image, return_tensors="pt")
97
- with torch.no_grad():
98
- features = model.get_image_features(**inputs)
99
- return find_best_matches(features)
100
-
101
- def search_by_text(input_text):
102
- if not input_text: return []
103
- inputs = processor(text=[input_text], return_tensors="pt", padding="max_length")
104
- with torch.no_grad():
105
- features = model.get_text_features(**inputs)
106
- return find_best_matches(features)
107
-
108
- # --- 6. BUILD UI (Clean & Centered) ---
109
- custom_css = """
110
- .gradio-container { width: 100%; max-width: 1000px; margin: 0 auto !important; }
111
- h1 { text-align: center; color: #E67E22; }
112
- """
113
-
114
- with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, title="Food Matcher AI") as demo:
115
 
116
  with gr.Row():
117
- gr.Markdown("# πŸ” Visual Dish Matcher (SigLIP)")
118
-
119
- gr.Markdown("Upload a food photo or describe a craving. We'll find the closest matches.", elem_classes=["center-text"])
120
-
121
- with gr.Accordion("πŸ“Ί Watch Demo Video", open=False):
122
- gr.HTML('<div style="display:flex; justify-content:center;"><iframe width="560" height="315" src="https://www.youtube.com/embed/IXeIxYHi0Es" frameborder="0" allowfullscreen></iframe></div>')
123
-
124
- with gr.Tab("πŸ–ΌοΈ Search by Image"):
125
- with gr.Row():
126
- with gr.Column(scale=1):
127
- img_input = gr.Image(type="pil", label="Your Photo", height=300)
128
- btn_img = gr.Button("πŸ” Find Matches", variant="primary", size="lg")
129
 
130
- with gr.Column(scale=2):
131
- img_gallery = gr.Gallery(label="Similar Dishes", columns=3, height=350, object_fit="contain")
132
-
133
- btn_img.click(search_by_image, inputs=img_input, outputs=img_gallery)
134
-
135
- with gr.Tab("πŸ“ Search by Text"):
136
- with gr.Row():
137
- with gr.Column(scale=1):
138
- txt_input = gr.Textbox(label="Describe it", placeholder="e.g. 'Spicy Tacos'", lines=4)
139
- btn_txt = gr.Button("πŸ” Search", variant="primary", size="lg")
140
-
141
- with gr.Column(scale=2):
142
- txt_gallery = gr.Gallery(label="Similar Dishes", columns=3, height=350, object_fit="contain")
143
-
144
- btn_txt.click(search_by_text, inputs=txt_input, outputs=txt_gallery)
145
 
146
- gr.Markdown("---")
147
- gr.Markdown("By Matan Kriel & Odeya Shmuel | Powered by Google SigLIP")
148
 
149
- # Launch
150
- demo.launch()
 
2
  import torch
3
  import pandas as pd
4
  import numpy as np
5
+ from transformers import AutoModel, AutoProcessor
6
  from PIL import Image
7
+ import io
 
 
8
 
9
+ # --- CONFIGURATION ---
10
+ # ⚠️ IMPORTANT: Change this if 'MetaCLIP' or 'OpenAI CLIP' won your notebook battle!
11
+ MODEL_ID = "google/siglip-base-patch16-224"
12
+ DATA_FILE = "food_embeddings_best.parquet"
13
 
14
+ print("⏳ Loading Model & Data...")
 
 
 
 
 
15
 
16
+ # 1. Load Model (Only once)
17
+ model = AutoModel.from_pretrained(MODEL_ID)
18
+ processor = AutoProcessor.from_pretrained(MODEL_ID)
19
 
20
+ # 2. Load the "Memory" (Parquet file)
21
  df = pd.read_parquet(DATA_FILE)
 
 
 
22
 
23
+ # 3. Prepare the Database Vectors
24
+ # Convert the dataframe column into a PyTorch Tensor
25
+ all_vectors = np.stack(df['embedding'].to_numpy())
26
+ db_features = torch.tensor(all_vectors)
27
+
28
+ # (Optional: If your notebook didn't normalize, uncomment this.
29
+ # But your notebook code already did, so we skip it to be fast!)
30
+ # db_features = db_features / db_features.norm(p=2, dim=-1, keepdim=True)
31
+
32
+ print("βœ… System Ready!")
33
+
34
+ def search(text_query, image_query):
35
+ # A. Decide: Is this a Text search or Image search?
36
+ if image_query:
37
+ # Process Image
38
+ inputs = processor(images=image_query, return_tensors="pt")
39
+ get_feat_func = model.get_image_features
40
+ elif text_query:
41
+ # Process Text
42
+ inputs = processor(text=[text_query], return_tensors="pt", padding=True)
43
+ get_feat_func = model.get_text_features
44
+ else:
45
+ return None
46
+
47
+ # B. Run Model (Inference)
48
+ with torch.no_grad():
49
+ query_vec = get_feat_func(**inputs)
50
 
51
+ # C. Search Logic (Pure Math)
52
+ # 1. Normalize Query (Math requirement: Vector / Magnitude)
53
+ query_vec = query_vec / query_vec.norm(p=2, dim=-1, keepdim=True)
 
54
 
55
+ # 2. Dot Product (Similarity)
56
+ scores = torch.mm(query_vec, db_features.T)
57
 
58
+ # 3. Get Top 5
59
+ top_scores, top_indices = torch.topk(scores, k=5)
 
60
 
61
+ # D. Fetch Results
62
  results = []
63
+ for idx, score in zip(top_indices[0], top_scores[0]):
64
  idx = idx.item()
65
+ row = df.iloc[idx]
66
 
67
+ # Handle Image Loading (Parquet saves images as binary/dict)
68
+ img_data = row['image']
69
+ if isinstance(img_data, dict) and 'bytes' in img_data:
70
+ img = Image.open(io.BytesIO(img_data['bytes']))
71
+ else:
72
+ img = img_data # It might already be a PIL object
73
+
74
+ results.append((img, f"{row['label_name']} ({score.item():.2f})"))
 
 
 
 
 
 
75
 
76
  return results
77
 
78
+ # --- INTERFACE ---
79
+ with gr.Blocks(title="AI Food Search") as demo:
80
+ gr.Markdown("# πŸ” AI Food Search")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
  with gr.Row():
83
+ # Left: Inputs
84
+ with gr.Column():
85
+ txt_input = gr.Textbox(label="Search by Text", placeholder="e.g. 'spicy pepperoni pizza'")
86
+ img_input = gr.Image(type="pil", label="Or Search by Image")
87
+ btn = gr.Button("Search", variant="primary")
 
 
 
 
 
 
 
88
 
89
+ # Right: Output Gallery
90
+ with gr.Column():
91
+ gallery = gr.Gallery(label="Top Matches")
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
+ # Connect buttons
94
+ btn.click(fn=search, inputs=[txt_input, img_input], outputs=gallery)
95
 
96
+ # Force bind to 0.0.0.0 for Spaces
97
+ demo.launch(server_name="0.0.0.0", server_port=7860)