Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,134 +1,114 @@
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import torch
|
| 3 |
import pandas as pd
|
| 4 |
import numpy as np
|
| 5 |
-
from transformers import CLIPModel, CLIPProcessor
|
| 6 |
from PIL import Image
|
| 7 |
-
import
|
|
|
|
|
|
|
| 8 |
|
| 9 |
-
# ---
|
| 10 |
MODEL_ID = "openai/clip-vit-base-patch32"
|
| 11 |
DATA_FILE = "food_embeddings_clip.parquet"
|
| 12 |
|
| 13 |
-
|
| 14 |
-
#
|
| 15 |
-
YOUTUBE_ID = "IXeIxYHi0Es"
|
| 16 |
-
|
| 17 |
-
print(f"⏳ Loading {MODEL_ID} and Data...")
|
| 18 |
-
|
| 19 |
-
# 1. Load Model
|
| 20 |
model = CLIPModel.from_pretrained(MODEL_ID)
|
| 21 |
processor = CLIPProcessor.from_pretrained(MODEL_ID)
|
| 22 |
|
| 23 |
-
# 2.
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
#
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
top_scores, top_indices = torch.topk(scores, k=5)
|
| 51 |
-
|
| 52 |
-
# C. Format Results
|
| 53 |
results = []
|
| 54 |
-
for idx, score in zip(
|
| 55 |
-
|
| 56 |
|
| 57 |
-
#
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
results.append((img, f"{row['label_name']} ({score.item():.2f})"))
|
| 65 |
return results
|
| 66 |
|
| 67 |
-
# ---
|
| 68 |
-
|
| 69 |
-
|
| 70 |
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
Search through 5,000 food images using natural language or reference images.
|
| 77 |
-
"""
|
| 78 |
-
)
|
| 79 |
-
|
| 80 |
-
# 2. YouTube Demo Section (Embedded Player)
|
| 81 |
-
if YOUTUBE_ID and YOUTUBE_ID != "YOUR_YOUTUBE_ID_HERE":
|
| 82 |
-
gr.HTML(
|
| 83 |
-
f"""
|
| 84 |
-
<div style="display: flex; justify-content: center; margin-bottom: 20px;">
|
| 85 |
-
<iframe width="560" height="315"
|
| 86 |
-
src="https://www.youtube.com/embed/{YOUTUBE_ID}"
|
| 87 |
-
title="YouTube video player" frameborder="0"
|
| 88 |
-
allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture"
|
| 89 |
-
allowfullscreen></iframe>
|
| 90 |
-
</div>
|
| 91 |
-
"""
|
| 92 |
-
)
|
| 93 |
-
else:
|
| 94 |
-
gr.Info("ℹ️ Add your YouTube ID in the code to display the video here.")
|
| 95 |
-
|
| 96 |
-
# 3. Main Search Interface
|
| 97 |
-
with gr.Row():
|
| 98 |
-
# Left Column: Inputs
|
| 99 |
-
with gr.Column(scale=1):
|
| 100 |
-
gr.Markdown("### 🔍 Your Query")
|
| 101 |
-
txt_input = gr.Textbox(
|
| 102 |
-
label="Search by Text",
|
| 103 |
-
placeholder="e.g. 'spicy tacos with lime'",
|
| 104 |
-
show_label=True
|
| 105 |
-
)
|
| 106 |
-
gr.Markdown("**OR**")
|
| 107 |
-
img_input = gr.Image(
|
| 108 |
-
type="pil",
|
| 109 |
-
label="Search by Image",
|
| 110 |
-
height=300
|
| 111 |
-
)
|
| 112 |
-
|
| 113 |
-
search_btn = gr.Button("🚀 Find Food", variant="primary", size="lg")
|
| 114 |
-
|
| 115 |
-
# Right Column: Results
|
| 116 |
-
with gr.Column(scale=2):
|
| 117 |
-
gr.Markdown("### 🍕 Top Matches")
|
| 118 |
-
gallery = gr.Gallery(
|
| 119 |
-
label="Results",
|
| 120 |
-
columns=3,
|
| 121 |
-
height="auto",
|
| 122 |
-
object_fit="cover"
|
| 123 |
-
)
|
| 124 |
-
|
| 125 |
-
# 4. Footer / Credits
|
| 126 |
-
gr.Markdown("---")
|
| 127 |
-
gr.Markdown(f"*Model: {MODEL_ID} | Dataset: Food-101 (Subset)*")
|
| 128 |
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
|
| 133 |
-
#
|
| 134 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
import gradio as gr
|
| 3 |
import torch
|
| 4 |
import pandas as pd
|
| 5 |
import numpy as np
|
|
|
|
| 6 |
from PIL import Image
|
| 7 |
+
from transformers import CLIPProcessor, CLIPModel
|
| 8 |
+
from datasets import load_dataset
|
| 9 |
+
from torch.nn import functional as F
|
| 10 |
|
| 11 |
+
# --- 1. SETUP & CONFIG ---
|
| 12 |
MODEL_ID = "openai/clip-vit-base-patch32"
|
| 13 |
DATA_FILE = "food_embeddings_clip.parquet"
|
| 14 |
|
| 15 |
+
print("⏳ Starting App... Loading Model...")
|
| 16 |
+
# Load Model (CPU is fine for inference on single images)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
model = CLIPModel.from_pretrained(MODEL_ID)
|
| 18 |
processor = CLIPProcessor.from_pretrained(MODEL_ID)
|
| 19 |
|
| 20 |
+
# --- 2. LOAD DATA (Must match Colab logic EXACTLY) ---
|
| 21 |
+
print("⏳ Loading Dataset (this takes a moment)...")
|
| 22 |
+
# We load the same 5000 images using the same seed so indices match the parquet file
|
| 23 |
+
dataset = load_dataset("ethz/food101", split="train").shuffle(seed=42).select(range(5000))
|
| 24 |
+
|
| 25 |
+
# --- 3. LOAD EMBEDDINGS ---
|
| 26 |
+
print("⏳ Loading Pre-computed Embeddings...")
|
| 27 |
+
df = pd.read_parquet(DATA_FILE)
|
| 28 |
+
# Convert the list of numbers in the parquet back to a Torch Tensor
|
| 29 |
+
db_features = torch.tensor(np.stack(df['embedding'].to_numpy()))
|
| 30 |
+
# Normalize once for speed
|
| 31 |
+
db_features = F.normalize(db_features, p=2, dim=1)
|
| 32 |
+
|
| 33 |
+
print("✅ App Ready!")
|
| 34 |
+
|
| 35 |
+
# --- 4. CORE SEARCH LOGIC ---
|
| 36 |
+
def find_best_matches(query_features, top_k=3):
|
| 37 |
+
# Normalize query
|
| 38 |
+
query_features = F.normalize(query_features, p=2, dim=1)
|
| 39 |
+
|
| 40 |
+
# Calculate Similarity (Dot Product)
|
| 41 |
+
# Query (1x512) * DB (5000x512) = Scores (1x5000)
|
| 42 |
+
similarity = torch.mm(query_features, db_features.T)
|
| 43 |
+
|
| 44 |
+
# Get Top K
|
| 45 |
+
scores, indices = torch.topk(similarity, k=top_k)
|
| 46 |
+
|
|
|
|
|
|
|
|
|
|
| 47 |
results = []
|
| 48 |
+
for idx, score in zip(indices[0], scores[0]):
|
| 49 |
+
idx = idx.item()
|
| 50 |
|
| 51 |
+
# Grab image and info from the loaded dataset
|
| 52 |
+
img = dataset[idx]['image']
|
| 53 |
+
label = df.iloc[idx]['label_name'] # Get label from our dataframe
|
| 54 |
+
|
| 55 |
+
# Format output
|
| 56 |
+
results.append((img, f"{label} ({score:.2f})"))
|
|
|
|
|
|
|
| 57 |
return results
|
| 58 |
|
| 59 |
+
# --- 5. GRADIO FUNCTIONS ---
|
| 60 |
+
def search_by_image(input_image):
|
| 61 |
+
if input_image is None: return []
|
| 62 |
|
| 63 |
+
inputs = processor(images=input_image, return_tensors="pt")
|
| 64 |
+
with torch.no_grad():
|
| 65 |
+
features = model.get_image_features(**inputs)
|
| 66 |
+
|
| 67 |
+
return find_best_matches(features)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
|
| 69 |
+
def search_by_text(input_text):
|
| 70 |
+
if not input_text: return []
|
| 71 |
+
|
| 72 |
+
inputs = processor(text=[input_text], return_tensors="pt", padding=True)
|
| 73 |
+
with torch.no_grad():
|
| 74 |
+
features = model.get_text_features(**inputs)
|
| 75 |
+
|
| 76 |
+
return find_best_matches(features)
|
| 77 |
|
| 78 |
+
# --- 6. BUILD UI ---
|
| 79 |
+
with gr.Blocks(title="Food Matcher AI") as demo:
|
| 80 |
+
gr.Markdown("# 🍔 Visual Dish Matcher")
|
| 81 |
+
gr.Markdown("Upload a photo of food (or describe it) to find similar dishes in our database.")
|
| 82 |
+
|
| 83 |
+
# --- VIDEO SECTION ---
|
| 84 |
+
# Using Accordion so it doesn't clutter the UI. Open=False means it starts closed.
|
| 85 |
+
with gr.Accordion("📺 Watch Project Demo", open=False):
|
| 86 |
+
gr.HTML("""
|
| 87 |
+
<div style="display: flex; justify-content: center;">
|
| 88 |
+
<iframe width="560" height="315"
|
| 89 |
+
src="https://www.youtube.com/embed/IXeIxYHi0Es"
|
| 90 |
+
title="YouTube video player"
|
| 91 |
+
frameborder="0"
|
| 92 |
+
allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture"
|
| 93 |
+
allowfullscreen>
|
| 94 |
+
</iframe>
|
| 95 |
+
</div>
|
| 96 |
+
""")
|
| 97 |
+
# ----------------------------
|
| 98 |
+
|
| 99 |
+
with gr.Tab("Image Search"):
|
| 100 |
+
with gr.Row():
|
| 101 |
+
img_input = gr.Image(type="pil", label="Upload Food Image")
|
| 102 |
+
img_gallery = gr.Gallery(label="Top Matches")
|
| 103 |
+
btn_img = gr.Button("Find Similar Dishes")
|
| 104 |
+
btn_img.click(search_by_image, inputs=img_input, outputs=img_gallery)
|
| 105 |
+
|
| 106 |
+
with gr.Tab("Text Search"):
|
| 107 |
+
with gr.Row():
|
| 108 |
+
txt_input = gr.Textbox(label="Describe the food (e.g., 'Spicy Tacos')")
|
| 109 |
+
txt_gallery = gr.Gallery(label="Top Matches")
|
| 110 |
+
btn_txt = gr.Button("Search by Description")
|
| 111 |
+
btn_txt.click(search_by_text, inputs=txt_input, outputs=txt_gallery)
|
| 112 |
+
|
| 113 |
+
# Launch (Disable SSR for stability)
|
| 114 |
+
demo.launch(ssr_mode=False)
|