import gradio as gr import pandas as pd import numpy as np import torch from transformers import CLIPModel, CLIPProcessor # ----------------------------- # 1. Load model & processor # ----------------------------- model_name = "openai/clip-vit-base-patch32" model = CLIPModel.from_pretrained(model_name) processor = CLIPProcessor.from_pretrained(model_name) device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) model.eval() # ----------------------------- # 2. Load your saved embeddings # ----------------------------- df = pd.read_parquet("animal_embeddings.parquet") embeddings = df.drop(columns=["label", "index"]).values labels = df["label"].tolist() indices = df["index"].tolist() # Load dataset to retrieve images from datasets import load_dataset dataset = load_dataset("mountassir/animals-10")["train"] sampled_data = dataset.select(indices) label_names = dataset.features["label"].names # ----------------------------- # 3. Helper functions # ----------------------------- def embed_image_query(pil_image): with torch.no_grad(): inputs = processor(images=pil_image, return_tensors="pt").to(device) feats = model.get_image_features(**inputs) feats = feats / feats.norm(dim=-1, keepdim=True) return feats.squeeze().cpu().numpy() def embed_text_query(text): with torch.no_grad(): inputs = processor(text=[text], return_tensors="pt").to(device) feats = model.get_text_features(**inputs) feats = feats / feats.norm(dim=-1, keepdim=True) return feats.squeeze().cpu().numpy() from sklearn.metrics.pairwise import cosine_similarity def get_top_k(query_emb, k=3): sims = cosine_similarity(query_emb.reshape(1, -1), embeddings)[0] idxs = np.argsort(sims)[::-1][:k] return idxs, sims[idxs] # ----------------------------- # 4. Gradio functions # ----------------------------- def gradio_image_search(image): query_emb = embed_image_query(image) idxs, scores = get_top_k(query_emb, 3) results = [sampled_data[i]["image"] for i in idxs] return results def gradio_text_search(text): query_emb = embed_text_query(text) idxs, scores = get_top_k(query_emb, 3) results = [sampled_data[i]["image"] for i in idxs] return results # ----------------------------- # 5. Build Gradio App # ----------------------------- with gr.Blocks() as demo: gr.Markdown(""" # 🐾 Animal Similarity Finder Welcome! This app allows you to find animals that look visually similar using image and text embeddings. How it works - The model uses **CLIP embeddings** to compare your input with a database of animal images. - It returns the **Top 3 most similar images** from the Animals-10 dataset. Image Search Upload a picture of an animal (dog, cat, spider, butterfly, horse, etc.). The app will analyze the image and show you the 3 closest matches based on **visual similarity**. Text Search Type a description like: - **"pet"** β†’ finds dogs & cats - **"bug"** β†’ finds spiders - **"farm animal"** β†’ finds sheep, cows, horses - **"bird"** β†’ finds chickens The model converts your text into an embedding and returns the 3 images most related to your description. Behind the scenes - Embeddings generated with **CLIP (ViT-B/32)** - Similarity is computed using **cosine similarity** - All embeddings are precomputed for speed Enjoy exploring the animal dataset! πŸΆπŸ±πŸ΄πŸ¦‹πŸ•·οΈ """) with gr.Tab("Image Search"): img_in = gr.Image(type="pil") img_out = gr.Gallery(label="Top 3 Results", columns=3) btn1 = gr.Button("Search") btn1.click(fn=gradio_image_search, inputs=img_in, outputs=img_out) with gr.Tab("Text Search"): txt_in = gr.Textbox(label="e.g. 'pet', 'bug', 'farm animal'") txt_out = gr.Gallery(label="Top 3 Results", columns=3) btn2 = gr.Button("Search") btn2.click(fn=gradio_text_search, inputs=txt_in, outputs=txt_out) demo.launch()