Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import pandas as pd | |
| import numpy as np | |
| import torch | |
| from transformers import CLIPModel, CLIPProcessor | |
| # ----------------------------- | |
| # 1. Load model & processor | |
| # ----------------------------- | |
| model_name = "openai/clip-vit-base-patch32" | |
| model = CLIPModel.from_pretrained(model_name) | |
| processor = CLIPProcessor.from_pretrained(model_name) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model.to(device) | |
| model.eval() | |
| # ----------------------------- | |
| # 2. Load your saved embeddings | |
| # ----------------------------- | |
| df = pd.read_parquet("animal_embeddings.parquet") | |
| embeddings = df.drop(columns=["label", "index"]).values | |
| labels = df["label"].tolist() | |
| indices = df["index"].tolist() | |
| # Load dataset to retrieve images | |
| from datasets import load_dataset | |
| dataset = load_dataset("mountassir/animals-10")["train"] | |
| sampled_data = dataset.select(indices) | |
| label_names = dataset.features["label"].names | |
| # ----------------------------- | |
| # 3. Helper functions | |
| # ----------------------------- | |
| def embed_image_query(pil_image): | |
| with torch.no_grad(): | |
| inputs = processor(images=pil_image, return_tensors="pt").to(device) | |
| feats = model.get_image_features(**inputs) | |
| feats = feats / feats.norm(dim=-1, keepdim=True) | |
| return feats.squeeze().cpu().numpy() | |
| def embed_text_query(text): | |
| with torch.no_grad(): | |
| inputs = processor(text=[text], return_tensors="pt").to(device) | |
| feats = model.get_text_features(**inputs) | |
| feats = feats / feats.norm(dim=-1, keepdim=True) | |
| return feats.squeeze().cpu().numpy() | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| def get_top_k(query_emb, k=3): | |
| sims = cosine_similarity(query_emb.reshape(1, -1), embeddings)[0] | |
| idxs = np.argsort(sims)[::-1][:k] | |
| return idxs, sims[idxs] | |
| # ----------------------------- | |
| # 4. Gradio functions | |
| # ----------------------------- | |
| def gradio_image_search(image): | |
| query_emb = embed_image_query(image) | |
| idxs, scores = get_top_k(query_emb, 3) | |
| results = [sampled_data[i]["image"] for i in idxs] | |
| return results | |
| def gradio_text_search(text): | |
| query_emb = embed_text_query(text) | |
| idxs, scores = get_top_k(query_emb, 3) | |
| results = [sampled_data[i]["image"] for i in idxs] | |
| return results | |
| # ----------------------------- | |
| # 5. Build Gradio App | |
| # ----------------------------- | |
| with gr.Blocks() as demo: | |
| gr.Markdown(""" | |
| # πΎ Animal Similarity Finder | |
| Welcome! This app allows you to find animals that look visually similar using image and text embeddings. | |
| How it works | |
| - The model uses **CLIP embeddings** to compare your input with a database of animal images. | |
| - It returns the **Top 3 most similar images** from the Animals-10 dataset. | |
| Image Search | |
| Upload a picture of an animal (dog, cat, spider, butterfly, horse, etc.). | |
| The app will analyze the image and show you the 3 closest matches based on **visual similarity**. | |
| Text Search | |
| Type a description like: | |
| - **"pet"** β finds dogs & cats | |
| - **"bug"** β finds spiders | |
| - **"farm animal"** β finds sheep, cows, horses | |
| - **"bird"** β finds chickens | |
| The model converts your text into an embedding and returns the 3 images most related to your description. | |
| Behind the scenes | |
| - Embeddings generated with **CLIP (ViT-B/32)** | |
| - Similarity is computed using **cosine similarity** | |
| - All embeddings are precomputed for speed | |
| Enjoy exploring the animal dataset! πΆπ±π΄π¦π·οΈ | |
| """) | |
| with gr.Tab("Image Search"): | |
| img_in = gr.Image(type="pil") | |
| img_out = gr.Gallery(label="Top 3 Results", columns=3) | |
| btn1 = gr.Button("Search") | |
| btn1.click(fn=gradio_image_search, inputs=img_in, outputs=img_out) | |
| with gr.Tab("Text Search"): | |
| txt_in = gr.Textbox(label="e.g. 'pet', 'bug', 'farm animal'") | |
| txt_out = gr.Gallery(label="Top 3 Results", columns=3) | |
| btn2 = gr.Button("Search") | |
| btn2.click(fn=gradio_text_search, inputs=txt_in, outputs=txt_out) | |
| demo.launch() | |