MatanKriel commited on
Commit
19567a7
·
verified ·
1 Parent(s): 36e2a11

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -115
app.py CHANGED
@@ -1,134 +1,114 @@
 
1
  import gradio as gr
2
  import torch
3
  import pandas as pd
4
  import numpy as np
5
- from transformers import CLIPModel, CLIPProcessor
6
  from PIL import Image
7
- import io
 
 
8
 
9
- # --- CONFIGURATION ---
10
  MODEL_ID = "openai/clip-vit-base-patch32"
11
  DATA_FILE = "food_embeddings_clip.parquet"
12
 
13
- # 🎥 PASTE YOUR YOUTUBE VIDEO ID HERE
14
- # (e.g. if link is https://www.youtube.com/watch?v=dQw4w9WgXcQ, the ID is dQw4w9WgXcQ)
15
- YOUTUBE_ID = "IXeIxYHi0Es"
16
-
17
- print(f"⏳ Loading {MODEL_ID} and Data...")
18
-
19
- # 1. Load Model
20
  model = CLIPModel.from_pretrained(MODEL_ID)
21
  processor = CLIPProcessor.from_pretrained(MODEL_ID)
22
 
23
- # 2. Load Data
24
- try:
25
- df = pd.read_parquet(DATA_FILE)
26
- # Prepare Vectors
27
- all_vectors = np.stack(df['embedding'].to_numpy())
28
- db_features = torch.tensor(all_vectors)
29
- except FileNotFoundError:
30
- raise RuntimeError(f"❌ ERROR: Could not find {DATA_FILE}. Did you upload it?")
31
-
32
- print("✅ System Ready!")
33
-
34
- # --- SEARCH LOGIC ---
35
- def search(text_query, image_query):
36
- if not text_query and not image_query:
37
- return []
38
-
39
- # A. Determine Input
40
- if image_query:
41
- inputs = processor(images=image_query, return_tensors="pt", padding=True)
42
- get_feat = model.get_image_features
43
- else:
44
- inputs = processor(text=[text_query], return_tensors="pt", padding=True)
45
- get_feat = model.get_text_features
46
-
47
- # B. Inference & Search
48
- with torch.no_grad():
49
- query_vec = get_feat(**inputs)
50
- top_scores, top_indices = torch.topk(scores, k=5)
51
-
52
- # C. Format Results
53
  results = []
54
- for idx, score in zip(top_indices[0], top_scores[0]):
55
- row = df.iloc[idx.item()]
56
 
57
- # Load Image
58
- img_data = row['image']
59
- if isinstance(img_data, dict) and 'bytes' in img_data:
60
- img = Image.open(io.BytesIO(img_data['bytes']))
61
- else:
62
- img = img_data
63
-
64
- results.append((img, f"{row['label_name']} ({score.item():.2f})"))
65
  return results
66
 
67
- # --- APP INTERFACE (The Original Design) ---
68
- # We use a 'Soft' theme for a professional look
69
- with gr.Blocks(theme=gr.themes.Soft(), title="AI Food Search") as demo:
70
 
71
- # 1. Header Section
72
- gr.Markdown(
73
- """
74
- # 🍔 AI Food Search Engine
75
- ### Powered by OpenAI CLIP & Hugging Face
76
- Search through 5,000 food images using natural language or reference images.
77
- """
78
- )
79
-
80
- # 2. YouTube Demo Section (Embedded Player)
81
- if YOUTUBE_ID and YOUTUBE_ID != "YOUR_YOUTUBE_ID_HERE":
82
- gr.HTML(
83
- f"""
84
- <div style="display: flex; justify-content: center; margin-bottom: 20px;">
85
- <iframe width="560" height="315"
86
- src="https://www.youtube.com/embed/{YOUTUBE_ID}"
87
- title="YouTube video player" frameborder="0"
88
- allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture"
89
- allowfullscreen></iframe>
90
- </div>
91
- """
92
- )
93
- else:
94
- gr.Info("ℹ️ Add your YouTube ID in the code to display the video here.")
95
-
96
- # 3. Main Search Interface
97
- with gr.Row():
98
- # Left Column: Inputs
99
- with gr.Column(scale=1):
100
- gr.Markdown("### 🔍 Your Query")
101
- txt_input = gr.Textbox(
102
- label="Search by Text",
103
- placeholder="e.g. 'spicy tacos with lime'",
104
- show_label=True
105
- )
106
- gr.Markdown("**OR**")
107
- img_input = gr.Image(
108
- type="pil",
109
- label="Search by Image",
110
- height=300
111
- )
112
-
113
- search_btn = gr.Button("🚀 Find Food", variant="primary", size="lg")
114
-
115
- # Right Column: Results
116
- with gr.Column(scale=2):
117
- gr.Markdown("### 🍕 Top Matches")
118
- gallery = gr.Gallery(
119
- label="Results",
120
- columns=3,
121
- height="auto",
122
- object_fit="cover"
123
- )
124
-
125
- # 4. Footer / Credits
126
- gr.Markdown("---")
127
- gr.Markdown(f"*Model: {MODEL_ID} | Dataset: Food-101 (Subset)*")
128
 
129
- # Event Listeners (Enter Key + Button Click)
130
- txt_input.submit(search, inputs=[txt_input, img_input], outputs=gallery)
131
- search_btn.click(search, inputs=[txt_input, img_input], outputs=gallery)
 
 
 
 
 
132
 
133
- # Launch
134
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
  import gradio as gr
3
  import torch
4
  import pandas as pd
5
  import numpy as np
 
6
  from PIL import Image
7
+ from transformers import CLIPProcessor, CLIPModel
8
+ from datasets import load_dataset
9
+ from torch.nn import functional as F
10
 
11
+ # --- 1. SETUP & CONFIG ---
12
  MODEL_ID = "openai/clip-vit-base-patch32"
13
  DATA_FILE = "food_embeddings_clip.parquet"
14
 
15
+ print("⏳ Starting App... Loading Model...")
16
+ # Load Model (CPU is fine for inference on single images)
 
 
 
 
 
17
  model = CLIPModel.from_pretrained(MODEL_ID)
18
  processor = CLIPProcessor.from_pretrained(MODEL_ID)
19
 
20
+ # --- 2. LOAD DATA (Must match Colab logic EXACTLY) ---
21
+ print("⏳ Loading Dataset (this takes a moment)...")
22
+ # We load the same 5000 images using the same seed so indices match the parquet file
23
+ dataset = load_dataset("ethz/food101", split="train").shuffle(seed=42).select(range(5000))
24
+
25
+ # --- 3. LOAD EMBEDDINGS ---
26
+ print("⏳ Loading Pre-computed Embeddings...")
27
+ df = pd.read_parquet(DATA_FILE)
28
+ # Convert the list of numbers in the parquet back to a Torch Tensor
29
+ db_features = torch.tensor(np.stack(df['embedding'].to_numpy()))
30
+ # Normalize once for speed
31
+ db_features = F.normalize(db_features, p=2, dim=1)
32
+
33
+ print("✅ App Ready!")
34
+
35
+ # --- 4. CORE SEARCH LOGIC ---
36
+ def find_best_matches(query_features, top_k=3):
37
+ # Normalize query
38
+ query_features = F.normalize(query_features, p=2, dim=1)
39
+
40
+ # Calculate Similarity (Dot Product)
41
+ # Query (1x512) * DB (5000x512) = Scores (1x5000)
42
+ similarity = torch.mm(query_features, db_features.T)
43
+
44
+ # Get Top K
45
+ scores, indices = torch.topk(similarity, k=top_k)
46
+
 
 
 
47
  results = []
48
+ for idx, score in zip(indices[0], scores[0]):
49
+ idx = idx.item()
50
 
51
+ # Grab image and info from the loaded dataset
52
+ img = dataset[idx]['image']
53
+ label = df.iloc[idx]['label_name'] # Get label from our dataframe
54
+
55
+ # Format output
56
+ results.append((img, f"{label} ({score:.2f})"))
 
 
57
  return results
58
 
59
+ # --- 5. GRADIO FUNCTIONS ---
60
+ def search_by_image(input_image):
61
+ if input_image is None: return []
62
 
63
+ inputs = processor(images=input_image, return_tensors="pt")
64
+ with torch.no_grad():
65
+ features = model.get_image_features(**inputs)
66
+
67
+ return find_best_matches(features)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
+ def search_by_text(input_text):
70
+ if not input_text: return []
71
+
72
+ inputs = processor(text=[input_text], return_tensors="pt", padding=True)
73
+ with torch.no_grad():
74
+ features = model.get_text_features(**inputs)
75
+
76
+ return find_best_matches(features)
77
 
78
+ # --- 6. BUILD UI ---
79
+ with gr.Blocks(title="Food Matcher AI") as demo:
80
+ gr.Markdown("# 🍔 Visual Dish Matcher")
81
+ gr.Markdown("Upload a photo of food (or describe it) to find similar dishes in our database.")
82
+
83
+ # --- VIDEO SECTION ---
84
+ # Using Accordion so it doesn't clutter the UI. Open=False means it starts closed.
85
+ with gr.Accordion("📺 Watch Project Demo", open=False):
86
+ gr.HTML("""
87
+ <div style="display: flex; justify-content: center;">
88
+ <iframe width="560" height="315"
89
+ src="https://www.youtube.com/embed/IXeIxYHi0Es"
90
+ title="YouTube video player"
91
+ frameborder="0"
92
+ allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture"
93
+ allowfullscreen>
94
+ </iframe>
95
+ </div>
96
+ """)
97
+ # ----------------------------
98
+
99
+ with gr.Tab("Image Search"):
100
+ with gr.Row():
101
+ img_input = gr.Image(type="pil", label="Upload Food Image")
102
+ img_gallery = gr.Gallery(label="Top Matches")
103
+ btn_img = gr.Button("Find Similar Dishes")
104
+ btn_img.click(search_by_image, inputs=img_input, outputs=img_gallery)
105
+
106
+ with gr.Tab("Text Search"):
107
+ with gr.Row():
108
+ txt_input = gr.Textbox(label="Describe the food (e.g., 'Spicy Tacos')")
109
+ txt_gallery = gr.Gallery(label="Top Matches")
110
+ btn_txt = gr.Button("Search by Description")
111
+ btn_txt.click(search_by_text, inputs=txt_input, outputs=txt_gallery)
112
+
113
+ # Launch (Disable SSR for stability)
114
+ demo.launch(ssr_mode=False)