Spaces:
Build error
Build error
| import gradio as gr | |
| import os | |
| import torch | |
| import numpy as np | |
| import pandas as pd | |
| from PIL import Image | |
| from transformers import CLIPProcessor, CLIPModel | |
| from datasets import load_dataset, concatenate_datasets | |
| # ============================================================ | |
| # LOAD EVERYTHING ON STARTUP (runs once when the Space boots) | |
| # ============================================================ | |
| print("Loading CLIP model...") | |
| MODEL_NAME = "openai/clip-vit-base-patch32" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| clip_model = CLIPModel.from_pretrained(MODEL_NAME).to(device) | |
| clip_model.eval() | |
| processor = CLIPProcessor.from_pretrained(MODEL_NAME) | |
| print("Loading dataset...") | |
| ds = load_dataset("chavajaz/wonders_dataset") | |
| full_ds = concatenate_datasets([ds["train"], ds["validation"], ds["test"]]) | |
| class_names = full_ds.features["label"].names | |
| print("Loading precomputed embeddings...") | |
| embeddings_df = pd.read_parquet("wonders_embeddings.parquet") | |
| image_embeddings = np.array(embeddings_df["embedding"].tolist(), dtype=np.float32) | |
| EMBEDDINGS_TENSOR = torch.tensor(image_embeddings, device=device, dtype=torch.float32) | |
| print(f"Ready. {len(full_ds)} images, embeddings {image_embeddings.shape}, on {device}") | |
| # ============================================================ | |
| # CORE FUNCTIONS | |
| # ============================================================ | |
| def embed_image(pil_image): | |
| img = pil_image.convert("RGB") | |
| inputs = processor(images=img, return_tensors="pt").to(device) | |
| feats = clip_model.get_image_features(**inputs) | |
| if not isinstance(feats, torch.Tensor): | |
| if hasattr(feats, "image_embeds") and feats.image_embeds is not None: | |
| feats = feats.image_embeds | |
| elif hasattr(feats, "pooler_output") and feats.pooler_output is not None: | |
| feats = feats.pooler_output | |
| else: | |
| feats = feats[0] | |
| feats = feats / feats.norm(dim=-1, keepdim=True) | |
| return feats | |
| def embed_text(text): | |
| inputs = processor(text=[text], return_tensors="pt", padding=True, truncation=True).to(device) | |
| feats = clip_model.get_text_features(**inputs) | |
| if not isinstance(feats, torch.Tensor): | |
| if hasattr(feats, "text_embeds") and feats.text_embeds is not None: | |
| feats = feats.text_embeds | |
| elif hasattr(feats, "pooler_output") and feats.pooler_output is not None: | |
| feats = feats.pooler_output | |
| else: | |
| feats = feats[0] | |
| feats = feats / feats.norm(dim=-1, keepdim=True) | |
| return feats | |
| def recommend(query_embedding, top_k=3, diversity_threshold=0.98): | |
| sims = (query_embedding @ EMBEDDINGS_TENSOR.T).squeeze(0) | |
| top_scores, top_indices = sims.topk(min(top_k * 20, len(sims))) | |
| top_scores = top_scores.cpu().tolist() | |
| top_indices = top_indices.cpu().tolist() | |
| results = [] | |
| chosen_embeddings = [] | |
| for score, idx in zip(top_scores, top_indices): | |
| candidate_emb = EMBEDDINGS_TENSOR[idx] | |
| too_similar = any( | |
| (candidate_emb @ prev_emb).item() > diversity_threshold | |
| for prev_emb in chosen_embeddings | |
| ) | |
| if too_similar: | |
| continue | |
| item = full_ds[idx] | |
| results.append({ | |
| "index": idx, | |
| "score": score, | |
| "image": item["image"], | |
| "label_name": class_names[item["label"]], | |
| }) | |
| chosen_embeddings.append(candidate_emb) | |
| if len(results) >= top_k: | |
| break | |
| return results | |
| def recommend_from_image(input_image): | |
| if input_image is None: | |
| return [], "✋ Please upload an image to find matching wonders." | |
| query_emb = embed_image(input_image) | |
| results = recommend(query_emb, top_k=3) | |
| gallery_items = [ | |
| (r["image"], f"{r['label_name'].replace('_', ' ').title()} • match {r['score']*100:.1f}%") | |
| for r in results | |
| ] | |
| medals = ["🥇", "🥈", "🥉"] | |
| summary = "Your top 3 wonder matches:\n\n" + "\n".join( | |
| f"{medals[i]} {r['label_name'].replace('_', ' ').title():<22} similarity {r['score']:.3f}" | |
| for i, r in enumerate(results) | |
| ) | |
| return gallery_items, summary | |
| def recommend_from_text(text_query): | |
| if not text_query or not text_query.strip(): | |
| return [], "✋ Please describe what you're looking for." | |
| query_emb = embed_text(text_query) | |
| results = recommend(query_emb, top_k=3) | |
| gallery_items = [ | |
| (r["image"], f"{r['label_name'].replace('_', ' ').title()} • match {r['score']*100:.1f}%") | |
| for r in results | |
| ] | |
| medals = ["🥇", "🥈", "🥉"] | |
| summary = f'Best matches for "{text_query}":\n\n' + "\n".join( | |
| f"{medals[i]} {r['label_name'].replace('_', ' ').title():<22} similarity {r['score']:.3f}" | |
| for i, r in enumerate(results) | |
| ) | |
| return gallery_items, summary | |
| # ============================================================ | |
| # UI (paste your CUSTOM_CSS + Blocks here, unchanged) | |
| # ============================================================ | |
| CUSTOM_CSS = """ | |
| @import url('https://fonts.googleapis.com/css2?family=Quicksand:wght@400;500;600;700&family=Nunito:wght@400;600;700;800&display=swap'); | |
| .gradio-container { | |
| background: linear-gradient(135deg, #F5EBDD 0%, #EDE0CC 100%) !important; | |
| font-family: 'Nunito', 'Quicksand', -apple-system, sans-serif !important; | |
| } | |
| /* Headings get the rounder, friendlier Quicksand */ | |
| h1, h2, h3, h4 { | |
| font-family: 'Quicksand', sans-serif !important; | |
| letter-spacing: 0.3px !important; | |
| } | |
| /* ---------- HEADER ---------- */ | |
| #header-block { | |
| background: linear-gradient(135deg, #8B4513 0%, #A0522D 50%, #CD853F 100%); | |
| padding: 36px 28px; | |
| border-radius: 20px; | |
| margin-bottom: 28px; | |
| box-shadow: 0 8px 24px rgba(139, 69, 19, 0.25); | |
| text-align: center; | |
| } | |
| #header-block h1 { | |
| color: #FFF8E7 !important; | |
| font-size: 2.8em !important; | |
| font-weight: 700 !important; | |
| margin: 0 !important; | |
| text-shadow: 2px 2px 4px rgba(0,0,0,0.2); | |
| } | |
| #header-block h3 { | |
| color: #FFE4B5 !important; | |
| font-weight: 500 !important; | |
| margin: 10px 0 0 0 !important; | |
| } | |
| #header-block p { | |
| color: #FFF8E7 !important; | |
| margin-top: 14px !important; | |
| font-size: 1.05em !important; | |
| opacity: 0.95; | |
| } | |
| /* ---------- TABS (the big upgrade) ---------- */ | |
| .tab-nav { | |
| background: transparent !important; | |
| border-bottom: none !important; | |
| gap: 12px !important; | |
| padding: 0 4px !important; | |
| margin-bottom: 8px !important; | |
| } | |
| .tab-nav button { | |
| background: #FFF8E7 !important; | |
| border: 2px solid #D2B48C !important; | |
| color: #8B4513 !important; | |
| font-family: 'Nunito', sans-serif !important; | |
| font-size: 1.15em !important; | |
| font-weight: 700 !important; | |
| padding: 14px 32px !important; | |
| border-radius: 14px !important; | |
| margin: 0 !important; | |
| box-shadow: 0 2px 6px rgba(139, 69, 19, 0.12) !important; | |
| transition: all 0.25s ease !important; | |
| cursor: pointer !important; | |
| } | |
| .tab-nav button:hover { | |
| background: #FFE8C8 !important; | |
| border-color: #A0522D !important; | |
| transform: translateY(-2px); | |
| box-shadow: 0 4px 12px rgba(139, 69, 19, 0.25) !important; | |
| } | |
| .tab-nav button.selected { | |
| background: linear-gradient(135deg, #8B4513 0%, #A0522D 100%) !important; | |
| border-color: #8B4513 !important; | |
| color: #FFF8E7 !important; | |
| box-shadow: 0 6px 16px rgba(139, 69, 19, 0.4) !important; | |
| transform: translateY(-2px); | |
| } | |
| /* ---------- BUTTONS ---------- */ | |
| button.primary, .gr-button-primary { | |
| background: linear-gradient(135deg, #8B4513 0%, #A0522D 100%) !important; | |
| border: none !important; | |
| color: #FFF8E7 !important; | |
| font-family: 'Nunito', sans-serif !important; | |
| font-weight: 700 !important; | |
| font-size: 1.08em !important; | |
| padding: 14px 30px !important; | |
| border-radius: 12px !important; | |
| box-shadow: 0 4px 12px rgba(139, 69, 19, 0.3) !important; | |
| transition: all 0.2s ease !important; | |
| } | |
| button.primary:hover, .gr-button-primary:hover { | |
| transform: translateY(-2px); | |
| box-shadow: 0 6px 16px rgba(139, 69, 19, 0.45) !important; | |
| } | |
| /* ---------- INPUTS / PANELS ---------- */ | |
| .gr-box, .gr-form, .gr-panel { | |
| background: #FFF8E7 !important; | |
| border: 2px solid #D2B48C !important; | |
| border-radius: 14px !important; | |
| } | |
| label, .gr-input-label { | |
| color: #5C4033 !important; | |
| font-family: 'Nunito', sans-serif !important; | |
| font-weight: 700 !important; | |
| font-size: 1em !important; | |
| } | |
| textarea, input[type="text"] { | |
| background: #FFFAF0 !important; | |
| border: 2px solid #D2B48C !important; | |
| color: #3E2723 !important; | |
| font-family: 'Nunito', sans-serif !important; | |
| font-size: 1.02em !important; | |
| border-radius: 10px !important; | |
| padding: 12px !important; | |
| } | |
| textarea:focus, input[type="text"]:focus { | |
| border-color: #8B4513 !important; | |
| outline: none !important; | |
| box-shadow: 0 0 0 3px rgba(139, 69, 19, 0.15) !important; | |
| } | |
| .gr-gallery { | |
| background: #FFF8E7 !important; | |
| border: 2px solid #D2B48C !important; | |
| border-radius: 14px !important; | |
| padding: 10px !important; | |
| } | |
| /* ---------- FOOTER ---------- */ | |
| #footer-block { | |
| margin-top: 28px; | |
| padding: 22px 24px; | |
| background: rgba(139, 69, 19, 0.08); | |
| border-radius: 14px; | |
| border-left: 5px solid #8B4513; | |
| color: #5C4033 !important; | |
| font-family: 'Nunito', sans-serif !important; | |
| line-height: 1.7; | |
| } | |
| #footer-block a { | |
| color: #8B4513 !important; | |
| font-weight: 700; | |
| text-decoration: none; | |
| border-bottom: 1px dashed #8B4513; | |
| } | |
| #footer-block a:hover { | |
| color: #A0522D !important; | |
| } | |
| """ | |
| with gr.Blocks(css=CUSTOM_CSS, theme=gr.themes.Soft( | |
| primary_hue="orange", secondary_hue="amber", neutral_hue="stone", | |
| ), title="Wonder Finder") as demo: | |
| gr.HTML(""" | |
| <div id="header-block"> | |
| <h1>🌍 Wonder Finder</h1> | |
| <h3>Discover the World's 12 Wonders Through AI Vision</h3> | |
| <p>Upload a travel photo or describe a place — get the closest matches from 11,544 images.<br> | |
| Powered by CLIP's joint image–text embedding space.</p> | |
| </div> | |
| """) | |
| with gr.Tabs(): | |
| with gr.Tab("📷 Search by Image"): | |
| gr.Markdown("### Upload your travel photo, and we'll find the wonders that look most like it.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| img_input = gr.Image(type="pil", label="Drop your photo here", height=320) | |
| img_btn = gr.Button("✨ Find Similar Wonders", variant="primary", size="lg") | |
| with gr.Column(scale=2): | |
| img_gallery = gr.Gallery(label="Top 3 Matches", columns=3, rows=1, height=320, object_fit="cover") | |
| img_summary = gr.Textbox(label="📊 Match Details", lines=6, show_copy_button=True) | |
| gr.Examples( | |
| examples=[[full_ds[i]["image"]] for i in [50, 2000, 5000, 7500, 10000]], | |
| inputs=img_input, | |
| label="✨ Or try these sample images:", | |
| ) | |
| img_btn.click(recommend_from_image, inputs=img_input, outputs=[img_gallery, img_summary]) | |
| with gr.Tab("💬 Search by Description"): | |
| gr.Markdown("### Describe a place in your own words — CLIP translates language into visual matches.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| text_input = gr.Textbox( | |
| label="Describe a wonder", | |
| placeholder='e.g. "an ancient stone temple in the jungle" or "a tall tower at sunset"', | |
| lines=3, | |
| ) | |
| text_btn = gr.Button("✨ Find Matching Wonders", variant="primary", size="lg") | |
| with gr.Column(scale=2): | |
| text_gallery = gr.Gallery(label="Top 3 Matches", columns=3, rows=1, height=320, object_fit="cover") | |
| text_summary = gr.Textbox(label="📊 Match Details", lines=6, show_copy_button=True) | |
| gr.Examples( | |
| examples=[ | |
| ["ancient stone pyramid in the desert"], | |
| ["tall modern skyscraper at night"], | |
| ["waterfall in the tropical jungle"], | |
| ["ancient Roman amphitheater"], | |
| ["statue of a religious figure with outstretched arms"], | |
| ["a misty stone monument at sunrise"], | |
| ["white marble palace with a dome"], | |
| ], | |
| inputs=text_input, | |
| label="✨ Or try these example queries:", | |
| ) | |
| text_btn.click(recommend_from_text, inputs=text_input, outputs=[text_gallery, text_summary]) | |
| gr.HTML(""" | |
| <div id="footer-block"> | |
| <strong>About this app</strong><br> | |
| <strong>Dataset:</strong> <a href="https://huggingface.co/datasets/chavajaz/wonders_dataset">chavajaz/wonders_dataset</a> — 11,544 images across 12 wonder classes (CC0).<br> | |
| <strong>Model:</strong> <a href="https://huggingface.co/openai/clip-vit-base-patch32">CLIP ViT-B/32</a> — embeds images and text into the same 512-D space for cross-modal retrieval.<br> | |
| <strong>Method:</strong> L2-normalized cosine similarity over precomputed embeddings, with a diversity filter (threshold 0.98) to suppress near-duplicate results. | |
| </div> | |
| """) | |
| demo = gr.Interface(fn=greet, inputs="text", outputs="text") | |
| demo.launch() | |