File size: 3,973 Bytes
90dc591
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e3fb24d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90dc591
 
 
3c3a7a5
90dc591
 
 
 
 
3c3a7a5
90dc591
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import gradio as gr
import pandas as pd
import numpy as np
import torch
from transformers import CLIPModel, CLIPProcessor

# -----------------------------
# 1. Load model & processor
# -----------------------------
model_name = "openai/clip-vit-base-patch32"
model = CLIPModel.from_pretrained(model_name)
processor = CLIPProcessor.from_pretrained(model_name)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()

# -----------------------------
# 2. Load your saved embeddings
# -----------------------------
df = pd.read_parquet("animal_embeddings.parquet")

embeddings = df.drop(columns=["label", "index"]).values
labels = df["label"].tolist()
indices = df["index"].tolist()

# Load dataset to retrieve images
from datasets import load_dataset
dataset = load_dataset("mountassir/animals-10")["train"]
sampled_data = dataset.select(indices)

label_names = dataset.features["label"].names

# -----------------------------
# 3. Helper functions
# -----------------------------

def embed_image_query(pil_image):
    with torch.no_grad():
        inputs = processor(images=pil_image, return_tensors="pt").to(device)
        feats = model.get_image_features(**inputs)
        feats = feats / feats.norm(dim=-1, keepdim=True)
    return feats.squeeze().cpu().numpy()

def embed_text_query(text):
    with torch.no_grad():
        inputs = processor(text=[text], return_tensors="pt").to(device)
        feats = model.get_text_features(**inputs)
        feats = feats / feats.norm(dim=-1, keepdim=True)
    return feats.squeeze().cpu().numpy()

from sklearn.metrics.pairwise import cosine_similarity

def get_top_k(query_emb, k=3):
    sims = cosine_similarity(query_emb.reshape(1, -1), embeddings)[0]
    idxs = np.argsort(sims)[::-1][:k]
    return idxs, sims[idxs]

# -----------------------------
# 4. Gradio functions
# -----------------------------

def gradio_image_search(image):
    query_emb = embed_image_query(image)
    idxs, scores = get_top_k(query_emb, 3)
    results = [sampled_data[i]["image"] for i in idxs]
    return results

def gradio_text_search(text):
    query_emb = embed_text_query(text)
    idxs, scores = get_top_k(query_emb, 3)
    results = [sampled_data[i]["image"] for i in idxs]
    return results

# -----------------------------
# 5. Build Gradio App
# -----------------------------
with gr.Blocks() as demo:
    gr.Markdown("""
# 🐾 Animal Similarity Finder

Welcome! This app allows you to find animals that look visually similar using image and text embeddings.

 How it works
- The model uses **CLIP embeddings** to compare your input with a database of animal images.
- It returns the **Top 3 most similar images** from the Animals-10 dataset.

Image Search
Upload a picture of an animal (dog, cat, spider, butterfly, horse, etc.).  
The app will analyze the image and show you the 3 closest matches based on **visual similarity**.

Text Search
Type a description like:
- **"pet"** β†’ finds dogs & cats  
- **"bug"** β†’ finds spiders  
- **"farm animal"** β†’ finds sheep, cows, horses  
- **"bird"** β†’ finds chickens  

The model converts your text into an embedding and returns the 3 images most related to your description.

 Behind the scenes
- Embeddings generated with **CLIP (ViT-B/32)**  
- Similarity is computed using **cosine similarity**  
- All embeddings are precomputed for speed

Enjoy exploring the animal dataset! πŸΆπŸ±πŸ΄πŸ¦‹πŸ•·οΈ
""")


    with gr.Tab("Image Search"):
        img_in = gr.Image(type="pil")
        img_out = gr.Gallery(label="Top 3 Results", columns=3)
        btn1 = gr.Button("Search")
        btn1.click(fn=gradio_image_search, inputs=img_in, outputs=img_out)

    with gr.Tab("Text Search"):
        txt_in = gr.Textbox(label="e.g. 'pet', 'bug', 'farm animal'")
        txt_out = gr.Gallery(label="Top 3 Results", columns=3)
        btn2 = gr.Button("Search")
        btn2.click(fn=gradio_text_search, inputs=txt_in, outputs=txt_out)

demo.launch()