Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from transformers import CLIPProcessor, CLIPModel | |
| from datasets import load_dataset | |
| import pandas as pd | |
| import numpy as np | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| from PIL import Image | |
| # --- 1. SETUP & LOADING --- | |
| # (חלק זה נשאר זהה לקודם - טעינת המודל והנתונים) | |
| print("Loading model...") | |
| device = "cpu" | |
| model_id = "openai/clip-vit-base-patch32" | |
| model = CLIPModel.from_pretrained(model_id).to(device) | |
| processor = CLIPProcessor.from_pretrained(model_id) | |
| print("Loading dataset...") | |
| ds = load_dataset("sgtsaughter/pokemon-classification-images-151", split="train") | |
| print("Loading embeddings...") | |
| try: | |
| df_emb = pd.read_parquet("pokemon_embeddings.parquet") | |
| dataset_embeddings = np.stack(df_emb['embedding'].values) | |
| except Exception as e: | |
| print(f"Error loading embeddings: {e}") | |
| # Fallback for testing if file doesn't exist | |
| dataset_embeddings = np.zeros((100, 512)) | |
| # --- 2. CORE LOGIC --- | |
| def get_embedding(input_data): | |
| if isinstance(input_data, str): | |
| inputs = processor(text=[input_data], return_tensors="pt", padding=True).to(device) | |
| features = model.get_text_features(**inputs) | |
| else: | |
| inputs = processor(images=input_data, return_tensors="pt", padding=True).to(device) | |
| features = model.get_image_features(**inputs) | |
| features = features / features.norm(p=2, dim=-1, keepdim=True) | |
| return features.detach().cpu().numpy().flatten() | |
| def find_similar_pokemon(query, input_type="text"): | |
| if query is None: | |
| return [] | |
| try: | |
| user_emb = get_embedding(query) | |
| scores = cosine_similarity(user_emb.reshape(1, -1), dataset_embeddings).flatten() | |
| top_indices = np.argsort(scores)[::-1][:3] | |
| results = [] | |
| for idx in top_indices: | |
| # שליפה מהדאטה-סט בזהירות | |
| try: | |
| item = ds[int(idx)] | |
| img = item['image'] | |
| label = item['label'] | |
| # בדיקה אם השדה הוא מספר או מחרוזת | |
| if hasattr(ds.features['label'], 'int2str'): | |
| name = ds.features['label'].int2str(label) | |
| else: | |
| name = str(label) | |
| score = scores[idx] | |
| results.append((img, f"{name.capitalize()} ({score:.2f})")) | |
| except: | |
| continue | |
| return results | |
| except Exception as e: | |
| print(f"Error in recommendation: {e}") | |
| return [] | |
| # --- 3. ADVANCED UI --- | |
| # הגדרת CSS מותאם אישית לכותרות | |
| custom_css = """ | |
| .container {max-width: 1200px; margin: auto; padding-top: 20px;} | |
| h1 {text-align: center; color: #4F46E5; font-size: 3em; margin-bottom: 10px;} | |
| p {text-align: center; font-size: 1.2em; color: #555;} | |
| .gallery-item {border-radius: 10px; overflow: hidden;} | |
| """ | |
| # שימוש בערכת נושא 'Soft' למראה נקי | |
| with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, title="PokeMatch AI") as demo: | |
| with gr.Column(elem_classes="container"): | |
| gr.Markdown("# 🔍 Poké-Match AI") | |
| gr.Markdown("### Discover Pokemon using Semantic Search powered by CLIP") | |
| with gr.Tabs(): | |
| # --- TAB 1: TEXT SEARCH --- | |
| with gr.TabItem("📝 Search by Text"): | |
| with gr.Row(): | |
| # עמודה שמאלית - קלט | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Describe your Pokemon") | |
| text_input = gr.Textbox( | |
| placeholder="E.g., 'A cute pink fairy' or 'Fire dragon'", | |
| label="Your Description", | |
| lines=2 | |
| ) | |
| text_button = gr.Button("✨ Find Matches", variant="primary") | |
| # דוגמאות לחיצה מהירה - משדרג את חווית המשתמש | |
| gr.Examples( | |
| examples=["A giant blue water turtle", "A small yellow electric mouse", "Scary ghost in the shadows", "A pink singing balloon"], | |
| inputs=[text_input] | |
| ) | |
| # עמודה ימנית - פלט | |
| with gr.Column(scale=2): | |
| gr.Markdown("### Top Recommendations") | |
| text_gallery = gr.Gallery( | |
| label="Results", | |
| columns=3, | |
| height=350, | |
| object_fit="contain" | |
| ) | |
| text_button.click(find_similar_pokemon, inputs=[text_input], outputs=[text_gallery]) | |
| # --- TAB 2: IMAGE SEARCH --- | |
| with gr.TabItem("🖼️ Search by Image"): | |
| with gr.Row(): | |
| # עמודה שמאלית | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Upload an Image") | |
| image_input = gr.Image(type="pil", label="Upload Pokemon Image") | |
| image_button = gr.Button("🔍 Analyze & Match", variant="primary") | |
| # עמודה ימנית | |
| with gr.Column(scale=2): | |
| gr.Markdown("### Visual Matches") | |
| image_gallery = gr.Gallery( | |
| label="Similar Pokemon", | |
| columns=3, | |
| height=350, | |
| object_fit="contain" | |
| ) | |
| image_button.click(find_similar_pokemon, inputs=[image_input], outputs=[image_gallery]) | |
| gr.Markdown("---") | |
| gr.Markdown("Created for Data Science Assignment • Powered by Hugging Face & OpenAI CLIP") | |
| # --- PART 4: VIDEO PRESENTATION --- | |
| gr.Markdown("---") | |
| gr.Markdown("### 🎥 Project Presentation") | |
| video_html = """ | |
| <div style="display: flex; justify-content: center;"> | |
| <iframe width="800" height="450" | |
| src="https://www.youtube.com/embed/fr3Og1y7oeg" | |
| title="YouTube video player" | |
| frameborder="0" | |
| allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture" | |
| allowfullscreen> | |
| </iframe> | |
| </div> | |
| """ | |
| gr.HTML(video_html) | |
| demo.launch() |