Spaces:
Sleeping
Sleeping
File size: 4,146 Bytes
19567a7 b2aba87 19567a7 b2aba87 19567a7 36e2a11 1933b1d b2aba87 19567a7 1933b1d 0ffa00f 19567a7 b2aba87 19567a7 26e1d1d 19567a7 b2aba87 19567a7 b2aba87 19567a7 b2aba87 19567a7 26e1d1d 19567a7 bc50f93 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import gradio as gr
import torch
import pandas as pd
import numpy as np
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
from datasets import load_dataset
from torch.nn import functional as F
# --- 1. SETUP & CONFIG ---
MODEL_ID = "openai/clip-vit-base-patch32"
DATA_FILE = "food_embeddings_clip.parquet"
print("⏳ Starting App... Loading Model...")
# Load Model (CPU is fine for inference on single images)
model = CLIPModel.from_pretrained(MODEL_ID)
processor = CLIPProcessor.from_pretrained(MODEL_ID)
# --- 2. LOAD DATA (Must match Colab logic EXACTLY) ---
print("⏳ Loading Dataset (this takes a moment)...")
# We load the same 5000 images using the same seed so indices match the parquet file
dataset = load_dataset("ethz/food101", split="train").shuffle(seed=42).select(range(5000))
# --- 3. LOAD EMBEDDINGS ---
print("⏳ Loading Pre-computed Embeddings...")
df = pd.read_parquet(DATA_FILE)
# Convert the list of numbers in the parquet back to a Torch Tensor
db_features = torch.tensor(np.stack(df['embedding'].to_numpy()))
# Normalize once for speed
db_features = F.normalize(db_features, p=2, dim=1)
print("✅ App Ready!")
# --- 4. CORE SEARCH LOGIC ---
def find_best_matches(query_features, top_k=3):
# Normalize query
query_features = F.normalize(query_features, p=2, dim=1)
# Calculate Similarity (Dot Product)
# Query (1x512) * DB (5000x512) = Scores (1x5000)
similarity = torch.mm(query_features, db_features.T)
# Get Top K
scores, indices = torch.topk(similarity, k=top_k)
results = []
for idx, score in zip(indices[0], scores[0]):
idx = idx.item()
# Grab image and info from the loaded dataset
img = dataset[idx]['image']
label = df.iloc[idx]['label_name'] # Get label from our dataframe
# Format output
results.append((img, f"{label} ({score:.2f})"))
return results
# --- 5. GRADIO FUNCTIONS ---
def search_by_image(input_image):
if input_image is None: return []
inputs = processor(images=input_image, return_tensors="pt")
with torch.no_grad():
features = model.get_image_features(**inputs)
return find_best_matches(features)
def search_by_text(input_text):
if not input_text: return []
inputs = processor(text=[input_text], return_tensors="pt", padding=True)
with torch.no_grad():
features = model.get_text_features(**inputs)
return find_best_matches(features)
# --- 6. BUILD UI ---
with gr.Blocks(title="Food Matcher AI") as demo:
gr.Markdown("# 🍔 Visual Dish Matcher")
gr.Markdown("Upload a photo of food (or describe it) to find similar dishes in our database.")
# --- VIDEO SECTION ---
# Using Accordion so it doesn't clutter the UI. Open=False means it starts closed.
with gr.Accordion("📺 Watch Project Demo", open=False):
gr.HTML("""
<div style="display: flex; justify-content: center;">
<iframe width="560" height="315"
src="https://www.youtube.com/embed/IXeIxYHi0Es"
title="YouTube video player"
frameborder="0"
allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture"
allowfullscreen>
</iframe>
</div>
""")
# ----------------------------
with gr.Tab("Image Search"):
with gr.Row():
img_input = gr.Image(type="pil", label="Upload Food Image")
img_gallery = gr.Gallery(label="Top Matches")
btn_img = gr.Button("Find Similar Dishes")
btn_img.click(search_by_image, inputs=img_input, outputs=img_gallery)
with gr.Tab("Text Search"):
with gr.Row():
txt_input = gr.Textbox(label="Describe the food (e.g., 'Spicy Tacos')")
txt_gallery = gr.Gallery(label="Top Matches")
btn_txt = gr.Button("Search by Description")
btn_txt.click(search_by_text, inputs=txt_input, outputs=txt_gallery)
# Launch (Disable SSR for stability)
demo.launch()
|