Spaces:
Sleeping
Sleeping
File size: 3,973 Bytes
90dc591 e3fb24d 90dc591 3c3a7a5 90dc591 3c3a7a5 90dc591 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 | import gradio as gr
import pandas as pd
import numpy as np
import torch
from transformers import CLIPModel, CLIPProcessor
# -----------------------------
# 1. Load model & processor
# -----------------------------
model_name = "openai/clip-vit-base-patch32"
model = CLIPModel.from_pretrained(model_name)
processor = CLIPProcessor.from_pretrained(model_name)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()
# -----------------------------
# 2. Load your saved embeddings
# -----------------------------
df = pd.read_parquet("animal_embeddings.parquet")
embeddings = df.drop(columns=["label", "index"]).values
labels = df["label"].tolist()
indices = df["index"].tolist()
# Load dataset to retrieve images
from datasets import load_dataset
dataset = load_dataset("mountassir/animals-10")["train"]
sampled_data = dataset.select(indices)
label_names = dataset.features["label"].names
# -----------------------------
# 3. Helper functions
# -----------------------------
def embed_image_query(pil_image):
with torch.no_grad():
inputs = processor(images=pil_image, return_tensors="pt").to(device)
feats = model.get_image_features(**inputs)
feats = feats / feats.norm(dim=-1, keepdim=True)
return feats.squeeze().cpu().numpy()
def embed_text_query(text):
with torch.no_grad():
inputs = processor(text=[text], return_tensors="pt").to(device)
feats = model.get_text_features(**inputs)
feats = feats / feats.norm(dim=-1, keepdim=True)
return feats.squeeze().cpu().numpy()
from sklearn.metrics.pairwise import cosine_similarity
def get_top_k(query_emb, k=3):
sims = cosine_similarity(query_emb.reshape(1, -1), embeddings)[0]
idxs = np.argsort(sims)[::-1][:k]
return idxs, sims[idxs]
# -----------------------------
# 4. Gradio functions
# -----------------------------
def gradio_image_search(image):
query_emb = embed_image_query(image)
idxs, scores = get_top_k(query_emb, 3)
results = [sampled_data[i]["image"] for i in idxs]
return results
def gradio_text_search(text):
query_emb = embed_text_query(text)
idxs, scores = get_top_k(query_emb, 3)
results = [sampled_data[i]["image"] for i in idxs]
return results
# -----------------------------
# 5. Build Gradio App
# -----------------------------
with gr.Blocks() as demo:
gr.Markdown("""
# πΎ Animal Similarity Finder
Welcome! This app allows you to find animals that look visually similar using image and text embeddings.
How it works
- The model uses **CLIP embeddings** to compare your input with a database of animal images.
- It returns the **Top 3 most similar images** from the Animals-10 dataset.
Image Search
Upload a picture of an animal (dog, cat, spider, butterfly, horse, etc.).
The app will analyze the image and show you the 3 closest matches based on **visual similarity**.
Text Search
Type a description like:
- **"pet"** β finds dogs & cats
- **"bug"** β finds spiders
- **"farm animal"** β finds sheep, cows, horses
- **"bird"** β finds chickens
The model converts your text into an embedding and returns the 3 images most related to your description.
Behind the scenes
- Embeddings generated with **CLIP (ViT-B/32)**
- Similarity is computed using **cosine similarity**
- All embeddings are precomputed for speed
Enjoy exploring the animal dataset! πΆπ±π΄π¦π·οΈ
""")
with gr.Tab("Image Search"):
img_in = gr.Image(type="pil")
img_out = gr.Gallery(label="Top 3 Results", columns=3)
btn1 = gr.Button("Search")
btn1.click(fn=gradio_image_search, inputs=img_in, outputs=img_out)
with gr.Tab("Text Search"):
txt_in = gr.Textbox(label="e.g. 'pet', 'bug', 'farm animal'")
txt_out = gr.Gallery(label="Top 3 Results", columns=3)
btn2 = gr.Button("Search")
btn2.click(fn=gradio_text_search, inputs=txt_in, outputs=txt_out)
demo.launch()
|