Fail200 / app.py
Jonathandav's picture
Upload 4 files
911d411 verified
Raw
History Blame Contribute Delete
13.4 kB
import gradio as gr
import os
import torch
import numpy as np
import pandas as pd
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
from datasets import load_dataset, concatenate_datasets
# ============================================================
# LOAD EVERYTHING ON STARTUP (runs once when the Space boots)
# ============================================================
print("Loading CLIP model...")
MODEL_NAME = "openai/clip-vit-base-patch32"
device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model = CLIPModel.from_pretrained(MODEL_NAME).to(device)
clip_model.eval()
processor = CLIPProcessor.from_pretrained(MODEL_NAME)
print("Loading dataset...")
ds = load_dataset("chavajaz/wonders_dataset")
full_ds = concatenate_datasets([ds["train"], ds["validation"], ds["test"]])
class_names = full_ds.features["label"].names
print("Loading precomputed embeddings...")
embeddings_df = pd.read_parquet("wonders_embeddings.parquet")
image_embeddings = np.array(embeddings_df["embedding"].tolist(), dtype=np.float32)
EMBEDDINGS_TENSOR = torch.tensor(image_embeddings, device=device, dtype=torch.float32)
print(f"Ready. {len(full_ds)} images, embeddings {image_embeddings.shape}, on {device}")
# ============================================================
# CORE FUNCTIONS
# ============================================================
@torch.no_grad()
def embed_image(pil_image):
img = pil_image.convert("RGB")
inputs = processor(images=img, return_tensors="pt").to(device)
feats = clip_model.get_image_features(**inputs)
if not isinstance(feats, torch.Tensor):
if hasattr(feats, "image_embeds") and feats.image_embeds is not None:
feats = feats.image_embeds
elif hasattr(feats, "pooler_output") and feats.pooler_output is not None:
feats = feats.pooler_output
else:
feats = feats[0]
feats = feats / feats.norm(dim=-1, keepdim=True)
return feats
@torch.no_grad()
def embed_text(text):
inputs = processor(text=[text], return_tensors="pt", padding=True, truncation=True).to(device)
feats = clip_model.get_text_features(**inputs)
if not isinstance(feats, torch.Tensor):
if hasattr(feats, "text_embeds") and feats.text_embeds is not None:
feats = feats.text_embeds
elif hasattr(feats, "pooler_output") and feats.pooler_output is not None:
feats = feats.pooler_output
else:
feats = feats[0]
feats = feats / feats.norm(dim=-1, keepdim=True)
return feats
def recommend(query_embedding, top_k=3, diversity_threshold=0.98):
sims = (query_embedding @ EMBEDDINGS_TENSOR.T).squeeze(0)
top_scores, top_indices = sims.topk(min(top_k * 20, len(sims)))
top_scores = top_scores.cpu().tolist()
top_indices = top_indices.cpu().tolist()
results = []
chosen_embeddings = []
for score, idx in zip(top_scores, top_indices):
candidate_emb = EMBEDDINGS_TENSOR[idx]
too_similar = any(
(candidate_emb @ prev_emb).item() > diversity_threshold
for prev_emb in chosen_embeddings
)
if too_similar:
continue
item = full_ds[idx]
results.append({
"index": idx,
"score": score,
"image": item["image"],
"label_name": class_names[item["label"]],
})
chosen_embeddings.append(candidate_emb)
if len(results) >= top_k:
break
return results
def recommend_from_image(input_image):
if input_image is None:
return [], "✋ Please upload an image to find matching wonders."
query_emb = embed_image(input_image)
results = recommend(query_emb, top_k=3)
gallery_items = [
(r["image"], f"{r['label_name'].replace('_', ' ').title()} • match {r['score']*100:.1f}%")
for r in results
]
medals = ["🥇", "🥈", "🥉"]
summary = "Your top 3 wonder matches:\n\n" + "\n".join(
f"{medals[i]} {r['label_name'].replace('_', ' ').title():<22} similarity {r['score']:.3f}"
for i, r in enumerate(results)
)
return gallery_items, summary
def recommend_from_text(text_query):
if not text_query or not text_query.strip():
return [], "✋ Please describe what you're looking for."
query_emb = embed_text(text_query)
results = recommend(query_emb, top_k=3)
gallery_items = [
(r["image"], f"{r['label_name'].replace('_', ' ').title()} • match {r['score']*100:.1f}%")
for r in results
]
medals = ["🥇", "🥈", "🥉"]
summary = f'Best matches for "{text_query}":\n\n' + "\n".join(
f"{medals[i]} {r['label_name'].replace('_', ' ').title():<22} similarity {r['score']:.3f}"
for i, r in enumerate(results)
)
return gallery_items, summary
# ============================================================
# UI (paste your CUSTOM_CSS + Blocks here, unchanged)
# ============================================================
CUSTOM_CSS = """
@import url('https://fonts.googleapis.com/css2?family=Quicksand:wght@400;500;600;700&family=Nunito:wght@400;600;700;800&display=swap');
.gradio-container {
background: linear-gradient(135deg, #F5EBDD 0%, #EDE0CC 100%) !important;
font-family: 'Nunito', 'Quicksand', -apple-system, sans-serif !important;
}
/* Headings get the rounder, friendlier Quicksand */
h1, h2, h3, h4 {
font-family: 'Quicksand', sans-serif !important;
letter-spacing: 0.3px !important;
}
/* ---------- HEADER ---------- */
#header-block {
background: linear-gradient(135deg, #8B4513 0%, #A0522D 50%, #CD853F 100%);
padding: 36px 28px;
border-radius: 20px;
margin-bottom: 28px;
box-shadow: 0 8px 24px rgba(139, 69, 19, 0.25);
text-align: center;
}
#header-block h1 {
color: #FFF8E7 !important;
font-size: 2.8em !important;
font-weight: 700 !important;
margin: 0 !important;
text-shadow: 2px 2px 4px rgba(0,0,0,0.2);
}
#header-block h3 {
color: #FFE4B5 !important;
font-weight: 500 !important;
margin: 10px 0 0 0 !important;
}
#header-block p {
color: #FFF8E7 !important;
margin-top: 14px !important;
font-size: 1.05em !important;
opacity: 0.95;
}
/* ---------- TABS (the big upgrade) ---------- */
.tab-nav {
background: transparent !important;
border-bottom: none !important;
gap: 12px !important;
padding: 0 4px !important;
margin-bottom: 8px !important;
}
.tab-nav button {
background: #FFF8E7 !important;
border: 2px solid #D2B48C !important;
color: #8B4513 !important;
font-family: 'Nunito', sans-serif !important;
font-size: 1.15em !important;
font-weight: 700 !important;
padding: 14px 32px !important;
border-radius: 14px !important;
margin: 0 !important;
box-shadow: 0 2px 6px rgba(139, 69, 19, 0.12) !important;
transition: all 0.25s ease !important;
cursor: pointer !important;
}
.tab-nav button:hover {
background: #FFE8C8 !important;
border-color: #A0522D !important;
transform: translateY(-2px);
box-shadow: 0 4px 12px rgba(139, 69, 19, 0.25) !important;
}
.tab-nav button.selected {
background: linear-gradient(135deg, #8B4513 0%, #A0522D 100%) !important;
border-color: #8B4513 !important;
color: #FFF8E7 !important;
box-shadow: 0 6px 16px rgba(139, 69, 19, 0.4) !important;
transform: translateY(-2px);
}
/* ---------- BUTTONS ---------- */
button.primary, .gr-button-primary {
background: linear-gradient(135deg, #8B4513 0%, #A0522D 100%) !important;
border: none !important;
color: #FFF8E7 !important;
font-family: 'Nunito', sans-serif !important;
font-weight: 700 !important;
font-size: 1.08em !important;
padding: 14px 30px !important;
border-radius: 12px !important;
box-shadow: 0 4px 12px rgba(139, 69, 19, 0.3) !important;
transition: all 0.2s ease !important;
}
button.primary:hover, .gr-button-primary:hover {
transform: translateY(-2px);
box-shadow: 0 6px 16px rgba(139, 69, 19, 0.45) !important;
}
/* ---------- INPUTS / PANELS ---------- */
.gr-box, .gr-form, .gr-panel {
background: #FFF8E7 !important;
border: 2px solid #D2B48C !important;
border-radius: 14px !important;
}
label, .gr-input-label {
color: #5C4033 !important;
font-family: 'Nunito', sans-serif !important;
font-weight: 700 !important;
font-size: 1em !important;
}
textarea, input[type="text"] {
background: #FFFAF0 !important;
border: 2px solid #D2B48C !important;
color: #3E2723 !important;
font-family: 'Nunito', sans-serif !important;
font-size: 1.02em !important;
border-radius: 10px !important;
padding: 12px !important;
}
textarea:focus, input[type="text"]:focus {
border-color: #8B4513 !important;
outline: none !important;
box-shadow: 0 0 0 3px rgba(139, 69, 19, 0.15) !important;
}
.gr-gallery {
background: #FFF8E7 !important;
border: 2px solid #D2B48C !important;
border-radius: 14px !important;
padding: 10px !important;
}
/* ---------- FOOTER ---------- */
#footer-block {
margin-top: 28px;
padding: 22px 24px;
background: rgba(139, 69, 19, 0.08);
border-radius: 14px;
border-left: 5px solid #8B4513;
color: #5C4033 !important;
font-family: 'Nunito', sans-serif !important;
line-height: 1.7;
}
#footer-block a {
color: #8B4513 !important;
font-weight: 700;
text-decoration: none;
border-bottom: 1px dashed #8B4513;
}
#footer-block a:hover {
color: #A0522D !important;
}
"""
with gr.Blocks(css=CUSTOM_CSS, theme=gr.themes.Soft(
primary_hue="orange", secondary_hue="amber", neutral_hue="stone",
), title="Wonder Finder") as demo:
gr.HTML("""
<div id="header-block">
<h1>🌍 Wonder Finder</h1>
<h3>Discover the World's 12 Wonders Through AI Vision</h3>
<p>Upload a travel photo or describe a place — get the closest matches from 11,544 images.<br>
Powered by CLIP's joint image–text embedding space.</p>
</div>
""")
with gr.Tabs():
with gr.Tab("📷 Search by Image"):
gr.Markdown("### Upload your travel photo, and we'll find the wonders that look most like it.")
with gr.Row():
with gr.Column(scale=1):
img_input = gr.Image(type="pil", label="Drop your photo here", height=320)
img_btn = gr.Button("✨ Find Similar Wonders", variant="primary", size="lg")
with gr.Column(scale=2):
img_gallery = gr.Gallery(label="Top 3 Matches", columns=3, rows=1, height=320, object_fit="cover")
img_summary = gr.Textbox(label="📊 Match Details", lines=6, show_copy_button=True)
gr.Examples(
examples=[[full_ds[i]["image"]] for i in [50, 2000, 5000, 7500, 10000]],
inputs=img_input,
label="✨ Or try these sample images:",
)
img_btn.click(recommend_from_image, inputs=img_input, outputs=[img_gallery, img_summary])
with gr.Tab("💬 Search by Description"):
gr.Markdown("### Describe a place in your own words — CLIP translates language into visual matches.")
with gr.Row():
with gr.Column(scale=1):
text_input = gr.Textbox(
label="Describe a wonder",
placeholder='e.g. "an ancient stone temple in the jungle" or "a tall tower at sunset"',
lines=3,
)
text_btn = gr.Button("✨ Find Matching Wonders", variant="primary", size="lg")
with gr.Column(scale=2):
text_gallery = gr.Gallery(label="Top 3 Matches", columns=3, rows=1, height=320, object_fit="cover")
text_summary = gr.Textbox(label="📊 Match Details", lines=6, show_copy_button=True)
gr.Examples(
examples=[
["ancient stone pyramid in the desert"],
["tall modern skyscraper at night"],
["waterfall in the tropical jungle"],
["ancient Roman amphitheater"],
["statue of a religious figure with outstretched arms"],
["a misty stone monument at sunrise"],
["white marble palace with a dome"],
],
inputs=text_input,
label="✨ Or try these example queries:",
)
text_btn.click(recommend_from_text, inputs=text_input, outputs=[text_gallery, text_summary])
gr.HTML("""
<div id="footer-block">
<strong>About this app</strong><br>
<strong>Dataset:</strong> <a href="https://huggingface.co/datasets/chavajaz/wonders_dataset">chavajaz/wonders_dataset</a> — 11,544 images across 12 wonder classes (CC0).<br>
<strong>Model:</strong> <a href="https://huggingface.co/openai/clip-vit-base-patch32">CLIP ViT-B/32</a> — embeds images and text into the same 512-D space for cross-modal retrieval.<br>
<strong>Method:</strong> L2-normalized cosine similarity over precomputed embeddings, with a diversity filter (threshold 0.98) to suppress near-duplicate results.
</div>
""")
demo = gr.Interface(fn=greet, inputs="text", outputs="text")
demo.launch()