uleeberber commited on
Commit
90dc591
·
verified ·
1 Parent(s): 3ba80c9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -0
app.py CHANGED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ import numpy as np
4
+ import torch
5
+ from transformers import CLIPModel, CLIPProcessor
6
+
7
+ # -----------------------------
8
+ # 1. Load model & processor
9
+ # -----------------------------
10
+ model_name = "openai/clip-vit-base-patch32"
11
+ model = CLIPModel.from_pretrained(model_name)
12
+ processor = CLIPProcessor.from_pretrained(model_name)
13
+
14
+ device = "cuda" if torch.cuda.is_available() else "cpu"
15
+ model.to(device)
16
+ model.eval()
17
+
18
+ # -----------------------------
19
+ # 2. Load your saved embeddings
20
+ # -----------------------------
21
+ df = pd.read_parquet("animal_embeddings.parquet")
22
+
23
+ embeddings = df.drop(columns=["label", "index"]).values
24
+ labels = df["label"].tolist()
25
+ indices = df["index"].tolist()
26
+
27
+ # Load dataset to retrieve images
28
+ from datasets import load_dataset
29
+ dataset = load_dataset("mountassir/animals-10")["train"]
30
+ sampled_data = dataset.select(indices)
31
+
32
+ label_names = dataset.features["label"].names
33
+
34
+ # -----------------------------
35
+ # 3. Helper functions
36
+ # -----------------------------
37
+
38
+ def embed_image_query(pil_image):
39
+ with torch.no_grad():
40
+ inputs = processor(images=pil_image, return_tensors="pt").to(device)
41
+ feats = model.get_image_features(**inputs)
42
+ feats = feats / feats.norm(dim=-1, keepdim=True)
43
+ return feats.squeeze().cpu().numpy()
44
+
45
+ def embed_text_query(text):
46
+ with torch.no_grad():
47
+ inputs = processor(text=[text], return_tensors="pt").to(device)
48
+ feats = model.get_text_features(**inputs)
49
+ feats = feats / feats.norm(dim=-1, keepdim=True)
50
+ return feats.squeeze().cpu().numpy()
51
+
52
+ from sklearn.metrics.pairwise import cosine_similarity
53
+
54
+ def get_top_k(query_emb, k=3):
55
+ sims = cosine_similarity(query_emb.reshape(1, -1), embeddings)[0]
56
+ idxs = np.argsort(sims)[::-1][:k]
57
+ return idxs, sims[idxs]
58
+
59
+ # -----------------------------
60
+ # 4. Gradio functions
61
+ # -----------------------------
62
+
63
+ def gradio_image_search(image):
64
+ query_emb = embed_image_query(image)
65
+ idxs, scores = get_top_k(query_emb, 3)
66
+ results = [sampled_data[i]["image"] for i in idxs]
67
+ return results
68
+
69
+ def gradio_text_search(text):
70
+ query_emb = embed_text_query(text)
71
+ idxs, scores = get_top_k(query_emb, 3)
72
+ results = [sampled_data[i]["image"] for i in idxs]
73
+ return results
74
+
75
+ # -----------------------------
76
+ # 5. Build Gradio App
77
+ # -----------------------------
78
+ with gr.Blocks() as demo:
79
+ gr.Markdown("# 🐾 Animal Similarity Finder\nUpload an image or enter a text description.")
80
+
81
+ with gr.Tab("Image Search"):
82
+ img_in = gr.Image(type="pil")
83
+ img_out = gr.Gallery(label="Top 3 Results").columns(3)
84
+ btn1 = gr.Button("Search")
85
+ btn1.click(fn=gradio_image_search, inputs=img_in, outputs=img_out)
86
+
87
+ with gr.Tab("Text Search"):
88
+ txt_in = gr.Textbox(label="e.g. 'pet', 'bug', 'farm animal'")
89
+ txt_out = gr.Gallery(label="Top 3 Results").columns(3)
90
+ btn2 = gr.Button("Search")
91
+ btn2.click(fn=gradio_text_search, inputs=txt_in, outputs=txt_out)
92
+
93
+ demo.launch()