giladbecher's picture
Update app.py
2d263ff verified
import gradio as gr
import torch
from transformers import CLIPProcessor, CLIPModel
from datasets import load_dataset
import pandas as pd
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from PIL import Image
# --- 1. SETUP & LOADING ---
# (חלק זה נשאר זהה לקודם - טעינת המודל והנתונים)
print("Loading model...")
device = "cpu"
model_id = "openai/clip-vit-base-patch32"
model = CLIPModel.from_pretrained(model_id).to(device)
processor = CLIPProcessor.from_pretrained(model_id)
print("Loading dataset...")
ds = load_dataset("sgtsaughter/pokemon-classification-images-151", split="train")
print("Loading embeddings...")
try:
df_emb = pd.read_parquet("pokemon_embeddings.parquet")
dataset_embeddings = np.stack(df_emb['embedding'].values)
except Exception as e:
print(f"Error loading embeddings: {e}")
# Fallback for testing if file doesn't exist
dataset_embeddings = np.zeros((100, 512))
# --- 2. CORE LOGIC ---
def get_embedding(input_data):
if isinstance(input_data, str):
inputs = processor(text=[input_data], return_tensors="pt", padding=True).to(device)
features = model.get_text_features(**inputs)
else:
inputs = processor(images=input_data, return_tensors="pt", padding=True).to(device)
features = model.get_image_features(**inputs)
features = features / features.norm(p=2, dim=-1, keepdim=True)
return features.detach().cpu().numpy().flatten()
def find_similar_pokemon(query, input_type="text"):
if query is None:
return []
try:
user_emb = get_embedding(query)
scores = cosine_similarity(user_emb.reshape(1, -1), dataset_embeddings).flatten()
top_indices = np.argsort(scores)[::-1][:3]
results = []
for idx in top_indices:
# שליפה מהדאטה-סט בזהירות
try:
item = ds[int(idx)]
img = item['image']
label = item['label']
# בדיקה אם השדה הוא מספר או מחרוזת
if hasattr(ds.features['label'], 'int2str'):
name = ds.features['label'].int2str(label)
else:
name = str(label)
score = scores[idx]
results.append((img, f"{name.capitalize()} ({score:.2f})"))
except:
continue
return results
except Exception as e:
print(f"Error in recommendation: {e}")
return []
# --- 3. ADVANCED UI ---
# הגדרת CSS מותאם אישית לכותרות
custom_css = """
.container {max-width: 1200px; margin: auto; padding-top: 20px;}
h1 {text-align: center; color: #4F46E5; font-size: 3em; margin-bottom: 10px;}
p {text-align: center; font-size: 1.2em; color: #555;}
.gallery-item {border-radius: 10px; overflow: hidden;}
"""
# שימוש בערכת נושא 'Soft' למראה נקי
with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, title="PokeMatch AI") as demo:
with gr.Column(elem_classes="container"):
gr.Markdown("# 🔍 Poké-Match AI")
gr.Markdown("### Discover Pokemon using Semantic Search powered by CLIP")
with gr.Tabs():
# --- TAB 1: TEXT SEARCH ---
with gr.TabItem("📝 Search by Text"):
with gr.Row():
# עמודה שמאלית - קלט
with gr.Column(scale=1):
gr.Markdown("### Describe your Pokemon")
text_input = gr.Textbox(
placeholder="E.g., 'A cute pink fairy' or 'Fire dragon'",
label="Your Description",
lines=2
)
text_button = gr.Button("✨ Find Matches", variant="primary")
# דוגמאות לחיצה מהירה - משדרג את חווית המשתמש
gr.Examples(
examples=["A giant blue water turtle", "A small yellow electric mouse", "Scary ghost in the shadows", "A pink singing balloon"],
inputs=[text_input]
)
# עמודה ימנית - פלט
with gr.Column(scale=2):
gr.Markdown("### Top Recommendations")
text_gallery = gr.Gallery(
label="Results",
columns=3,
height=350,
object_fit="contain"
)
text_button.click(find_similar_pokemon, inputs=[text_input], outputs=[text_gallery])
# --- TAB 2: IMAGE SEARCH ---
with gr.TabItem("🖼️ Search by Image"):
with gr.Row():
# עמודה שמאלית
with gr.Column(scale=1):
gr.Markdown("### Upload an Image")
image_input = gr.Image(type="pil", label="Upload Pokemon Image")
image_button = gr.Button("🔍 Analyze & Match", variant="primary")
# עמודה ימנית
with gr.Column(scale=2):
gr.Markdown("### Visual Matches")
image_gallery = gr.Gallery(
label="Similar Pokemon",
columns=3,
height=350,
object_fit="contain"
)
image_button.click(find_similar_pokemon, inputs=[image_input], outputs=[image_gallery])
gr.Markdown("---")
gr.Markdown("Created for Data Science Assignment • Powered by Hugging Face & OpenAI CLIP")
# --- PART 4: VIDEO PRESENTATION ---
gr.Markdown("---")
gr.Markdown("### 🎥 Project Presentation")
video_html = """
<div style="display: flex; justify-content: center;">
<iframe width="800" height="450"
src="https://www.youtube.com/embed/fr3Og1y7oeg"
title="YouTube video player"
frameborder="0"
allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture"
allowfullscreen>
</iframe>
</div>
"""
gr.HTML(video_html)
demo.launch()