Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import pandas as pd | |
| import numpy as np | |
| from PIL import Image | |
| from transformers import CLIPProcessor, CLIPModel | |
| from datasets import load_dataset | |
| from torch.nn import functional as F | |
| # --- 1. SETUP & CONFIG --- | |
| MODEL_ID = "openai/clip-vit-base-patch32" | |
| DATA_FILE = "food_embeddings_clip.parquet" | |
| print("β³ Starting App... Loading Model...") | |
| # Load Model (CPU is fine for inference on single images) | |
| model = CLIPModel.from_pretrained(MODEL_ID) | |
| processor = CLIPProcessor.from_pretrained(MODEL_ID) | |
| # --- 2. LOAD DATA (Must match Colab logic EXACTLY) --- | |
| print("β³ Loading Dataset (this takes a moment)...") | |
| # We load the same 5000 images using the same seed so indices match the parquet file | |
| dataset = load_dataset("ethz/food101", split="train").shuffle(seed=42).select(range(5000)) | |
| # --- 3. LOAD EMBEDDINGS --- | |
| print("β³ Loading Pre-computed Embeddings...") | |
| df = pd.read_parquet(DATA_FILE) | |
| # Convert the list of numbers in the parquet back to a Torch Tensor | |
| db_features = torch.tensor(np.stack(df['embedding'].to_numpy())) | |
| # Normalize once for speed | |
| db_features = F.normalize(db_features, p=2, dim=1) | |
| print("β App Ready!") | |
| # --- 4. CORE SEARCH LOGIC --- | |
| def find_best_matches(query_features, top_k=3): | |
| # Normalize query | |
| query_features = F.normalize(query_features, p=2, dim=1) | |
| # Calculate Similarity (Dot Product) | |
| # Query (1x512) * DB (5000x512) = Scores (1x5000) | |
| similarity = torch.mm(query_features, db_features.T) | |
| # Get Top K | |
| scores, indices = torch.topk(similarity, k=top_k) | |
| results = [] | |
| for idx, score in zip(indices[0], scores[0]): | |
| idx = idx.item() | |
| # Grab image and info from the loaded dataset | |
| img = dataset[idx]['image'] | |
| label = df.iloc[idx]['label_name'] # Get label from our dataframe | |
| # Format output | |
| results.append((img, f"{label} ({score:.2f})")) | |
| return results | |
| # --- 5. GRADIO FUNCTIONS --- | |
| def search_by_image(input_image): | |
| if input_image is None: return [] | |
| inputs = processor(images=input_image, return_tensors="pt") | |
| with torch.no_grad(): | |
| features = model.get_image_features(**inputs) | |
| return find_best_matches(features) | |
| def search_by_text(input_text): | |
| if not input_text: return [] | |
| inputs = processor(text=[input_text], return_tensors="pt", padding=True) | |
| with torch.no_grad(): | |
| features = model.get_text_features(**inputs) | |
| return find_best_matches(features) | |
| # --- 6. BUILD UI --- | |
| with gr.Blocks(title="Food Matcher AI") as demo: | |
| gr.Markdown("# π Visual Dish Matcher") | |
| gr.Markdown("Upload a photo of food (or describe it) to find similar dishes in our database.") | |
| # --- VIDEO SECTION --- | |
| # Using Accordion so it doesn't clutter the UI. Open=False means it starts closed. | |
| with gr.Accordion("πΊ Watch Project Demo", open=False): | |
| gr.HTML(""" | |
| <div style="display: flex; justify-content: center;"> | |
| <iframe width="560" height="315" | |
| src="https://www.youtube.com/embed/IXeIxYHi0Es" | |
| title="YouTube video player" | |
| frameborder="0" | |
| allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture" | |
| allowfullscreen> | |
| </iframe> | |
| </div> | |
| """) | |
| # ---------------------------- | |
| with gr.Tab("Image Search"): | |
| with gr.Row(): | |
| img_input = gr.Image(type="pil", label="Upload Food Image") | |
| img_gallery = gr.Gallery(label="Top Matches") | |
| btn_img = gr.Button("Find Similar Dishes") | |
| btn_img.click(search_by_image, inputs=img_input, outputs=img_gallery) | |
| with gr.Tab("Text Search"): | |
| with gr.Row(): | |
| txt_input = gr.Textbox(label="Describe the food (e.g., 'Spicy Tacos')") | |
| txt_gallery = gr.Gallery(label="Top Matches") | |
| btn_txt = gr.Button("Search by Description") | |
| btn_txt.click(search_by_text, inputs=txt_input, outputs=txt_gallery) | |
| # Launch (Disable SSR for stability) | |
| demo.launch() | |