MatanKriel's picture
Update app.py
bc50f93 verified
import gradio as gr
import torch
import pandas as pd
import numpy as np
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
from datasets import load_dataset
from torch.nn import functional as F
# --- 1. SETUP & CONFIG ---
MODEL_ID = "openai/clip-vit-base-patch32"
DATA_FILE = "food_embeddings_clip.parquet"
print("⏳ Starting App... Loading Model...")
# Load Model (CPU is fine for inference on single images)
model = CLIPModel.from_pretrained(MODEL_ID)
processor = CLIPProcessor.from_pretrained(MODEL_ID)
# --- 2. LOAD DATA (Must match Colab logic EXACTLY) ---
print("⏳ Loading Dataset (this takes a moment)...")
# We load the same 5000 images using the same seed so indices match the parquet file
dataset = load_dataset("ethz/food101", split="train").shuffle(seed=42).select(range(5000))
# --- 3. LOAD EMBEDDINGS ---
print("⏳ Loading Pre-computed Embeddings...")
df = pd.read_parquet(DATA_FILE)
# Convert the list of numbers in the parquet back to a Torch Tensor
db_features = torch.tensor(np.stack(df['embedding'].to_numpy()))
# Normalize once for speed
db_features = F.normalize(db_features, p=2, dim=1)
print("βœ… App Ready!")
# --- 4. CORE SEARCH LOGIC ---
def find_best_matches(query_features, top_k=3):
# Normalize query
query_features = F.normalize(query_features, p=2, dim=1)
# Calculate Similarity (Dot Product)
# Query (1x512) * DB (5000x512) = Scores (1x5000)
similarity = torch.mm(query_features, db_features.T)
# Get Top K
scores, indices = torch.topk(similarity, k=top_k)
results = []
for idx, score in zip(indices[0], scores[0]):
idx = idx.item()
# Grab image and info from the loaded dataset
img = dataset[idx]['image']
label = df.iloc[idx]['label_name'] # Get label from our dataframe
# Format output
results.append((img, f"{label} ({score:.2f})"))
return results
# --- 5. GRADIO FUNCTIONS ---
def search_by_image(input_image):
if input_image is None: return []
inputs = processor(images=input_image, return_tensors="pt")
with torch.no_grad():
features = model.get_image_features(**inputs)
return find_best_matches(features)
def search_by_text(input_text):
if not input_text: return []
inputs = processor(text=[input_text], return_tensors="pt", padding=True)
with torch.no_grad():
features = model.get_text_features(**inputs)
return find_best_matches(features)
# --- 6. BUILD UI ---
with gr.Blocks(title="Food Matcher AI") as demo:
gr.Markdown("# πŸ” Visual Dish Matcher")
gr.Markdown("Upload a photo of food (or describe it) to find similar dishes in our database.")
# --- VIDEO SECTION ---
# Using Accordion so it doesn't clutter the UI. Open=False means it starts closed.
with gr.Accordion("πŸ“Ί Watch Project Demo", open=False):
gr.HTML("""
<div style="display: flex; justify-content: center;">
<iframe width="560" height="315"
src="https://www.youtube.com/embed/IXeIxYHi0Es"
title="YouTube video player"
frameborder="0"
allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture"
allowfullscreen>
</iframe>
</div>
""")
# ----------------------------
with gr.Tab("Image Search"):
with gr.Row():
img_input = gr.Image(type="pil", label="Upload Food Image")
img_gallery = gr.Gallery(label="Top Matches")
btn_img = gr.Button("Find Similar Dishes")
btn_img.click(search_by_image, inputs=img_input, outputs=img_gallery)
with gr.Tab("Text Search"):
with gr.Row():
txt_input = gr.Textbox(label="Describe the food (e.g., 'Spicy Tacos')")
txt_gallery = gr.Gallery(label="Top Matches")
btn_txt = gr.Button("Search by Description")
btn_txt.click(search_by_text, inputs=txt_input, outputs=txt_gallery)
# Launch (Disable SSR for stability)
demo.launch()