uleeberber's picture
Update app.py
e3fb24d verified
import gradio as gr
import pandas as pd
import numpy as np
import torch
from transformers import CLIPModel, CLIPProcessor
# -----------------------------
# 1. Load model & processor
# -----------------------------
model_name = "openai/clip-vit-base-patch32"
model = CLIPModel.from_pretrained(model_name)
processor = CLIPProcessor.from_pretrained(model_name)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()
# -----------------------------
# 2. Load your saved embeddings
# -----------------------------
df = pd.read_parquet("animal_embeddings.parquet")
embeddings = df.drop(columns=["label", "index"]).values
labels = df["label"].tolist()
indices = df["index"].tolist()
# Load dataset to retrieve images
from datasets import load_dataset
dataset = load_dataset("mountassir/animals-10")["train"]
sampled_data = dataset.select(indices)
label_names = dataset.features["label"].names
# -----------------------------
# 3. Helper functions
# -----------------------------
def embed_image_query(pil_image):
with torch.no_grad():
inputs = processor(images=pil_image, return_tensors="pt").to(device)
feats = model.get_image_features(**inputs)
feats = feats / feats.norm(dim=-1, keepdim=True)
return feats.squeeze().cpu().numpy()
def embed_text_query(text):
with torch.no_grad():
inputs = processor(text=[text], return_tensors="pt").to(device)
feats = model.get_text_features(**inputs)
feats = feats / feats.norm(dim=-1, keepdim=True)
return feats.squeeze().cpu().numpy()
from sklearn.metrics.pairwise import cosine_similarity
def get_top_k(query_emb, k=3):
sims = cosine_similarity(query_emb.reshape(1, -1), embeddings)[0]
idxs = np.argsort(sims)[::-1][:k]
return idxs, sims[idxs]
# -----------------------------
# 4. Gradio functions
# -----------------------------
def gradio_image_search(image):
query_emb = embed_image_query(image)
idxs, scores = get_top_k(query_emb, 3)
results = [sampled_data[i]["image"] for i in idxs]
return results
def gradio_text_search(text):
query_emb = embed_text_query(text)
idxs, scores = get_top_k(query_emb, 3)
results = [sampled_data[i]["image"] for i in idxs]
return results
# -----------------------------
# 5. Build Gradio App
# -----------------------------
with gr.Blocks() as demo:
gr.Markdown("""
# 🐾 Animal Similarity Finder
Welcome! This app allows you to find animals that look visually similar using image and text embeddings.
How it works
- The model uses **CLIP embeddings** to compare your input with a database of animal images.
- It returns the **Top 3 most similar images** from the Animals-10 dataset.
Image Search
Upload a picture of an animal (dog, cat, spider, butterfly, horse, etc.).
The app will analyze the image and show you the 3 closest matches based on **visual similarity**.
Text Search
Type a description like:
- **"pet"** β†’ finds dogs & cats
- **"bug"** β†’ finds spiders
- **"farm animal"** β†’ finds sheep, cows, horses
- **"bird"** β†’ finds chickens
The model converts your text into an embedding and returns the 3 images most related to your description.
Behind the scenes
- Embeddings generated with **CLIP (ViT-B/32)**
- Similarity is computed using **cosine similarity**
- All embeddings are precomputed for speed
Enjoy exploring the animal dataset! πŸΆπŸ±πŸ΄πŸ¦‹πŸ•·οΈ
""")
with gr.Tab("Image Search"):
img_in = gr.Image(type="pil")
img_out = gr.Gallery(label="Top 3 Results", columns=3)
btn1 = gr.Button("Search")
btn1.click(fn=gradio_image_search, inputs=img_in, outputs=img_out)
with gr.Tab("Text Search"):
txt_in = gr.Textbox(label="e.g. 'pet', 'bug', 'farm animal'")
txt_out = gr.Gallery(label="Top 3 Results", columns=3)
btn2 = gr.Button("Search")
btn2.click(fn=gradio_text_search, inputs=txt_in, outputs=txt_out)
demo.launch()