ashish-001 commited on
Commit
779c855
·
verified ·
1 Parent(s): 1cc5148

Upload 7 files

Browse files
Clustering.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hdbscan
2
+ import numpy as np
3
+ import os
4
+ import matplotlib.pyplot as plt
5
+ from PIL import Image
6
+
7
+ class ClusteringData:
8
+ def __init__(self,min_num_clusters=5,embeddings=None):
9
+ self.clusterer=hdbscan.HDBSCAN(min_cluster_size=min_num_clusters)
10
+ self.labels=None
11
+ self.probabilities=None
12
+ self.image_list=sorted(os.listdir(os.path.join('coco','val2017','val2017')))
13
+ self.embeddings=embeddings
14
+
15
+ def create_clusters(self,embeddings):
16
+ self.clusterer.fit(embeddings)
17
+ self.labels=self.clusterer.labels_
18
+ self.probabilities=self.clusterer.probabilities_
19
+
20
+ def save_model_data(self):
21
+ np.save(os.path.join("embeddings","labels.npy"),self.clusterer.labels_.astype(np.int32))
22
+ np.save(os.path.join("embeddings","probabilities.npy"),self.clusterer.probabilities_.astype(np.float32))
23
+ np.save(os.path.join("embeddings","image_embeddings.npy"),self.embeddings.astype(np.float32))
24
+
25
+ def load_model_data(self):
26
+ self.labels = np.load(os.path.join("embeddings", "labels.npy"))
27
+ self.probabilities = np.load(os.path.join("embeddings", "probabilities.npy"))
28
+ self.embeddings = np.load(os.path.join("embeddings", "image_embeddings.npy"))
29
+
30
+ def find_similar_records(self,embedding,k=10):
31
+ embedding=embedding/np.linalg.norm(embedding)
32
+ cosine_similarities=np.dot(self.embeddings,embedding)
33
+ best_match_idx=np.argmax(cosine_similarities)
34
+ most_similar_label=self.labels[best_match_idx]
35
+ # narrowing search with most_similar_label
36
+ if most_similar_label==-1:
37
+ candidates=np.arange(len(self.labels))
38
+ else:
39
+ candidates=np.where(self.labels== most_similar_label)[0]
40
+ final_scores=0.7*cosine_similarities[candidates]+0.3*self.probabilities[candidates]
41
+ final_indices=candidates[np.argsort(-final_scores)[:k]]
42
+ top_images=[self.image_list[i] for i in final_indices]
43
+ return top_images
44
+
45
+ def display_similar_records(self,embedding,k=10):
46
+ top_images=self.find_similar_records(embedding,k)
47
+ fig, axs = plt.subplots(1, len(top_images), figsize=(15, 5))
48
+ axs = np.atleast_1d(axs)
49
+ for ax, img_name in zip(axs, top_images):
50
+ img_path = os.path.join('coco', 'val2017', 'val2017', img_name)
51
+ img = Image.open(img_path).convert('RGB')
52
+ ax.imshow(img)
53
+ ax.axis("off")
54
+ plt.show()
55
+
56
+
Image.jpg ADDED
app.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from Clustering import ClusteringData
3
+ import numpy as np
4
+ from PIL import Image
5
+ import requests
6
+ import tempfile
7
+ import os
8
+ import logging
9
+ import json
10
+
11
+ logging.basicConfig(level=logging.INFO)
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ cd = ClusteringData()
16
+ cd.load_model_data()
17
+ logger.info("Clustering data loaded")
18
+
19
+
20
+ def search_images(text_query, uploaded_image, search_mode, top_k):
21
+ preview = None
22
+ results = []
23
+
24
+ if search_mode == "Text" and text_query.strip():
25
+ response = requests.get(
26
+ f"https://ashish-001-text-embedding-api.hf.space/embedding?text={text_query.strip()}")
27
+ if response.status_code == 200:
28
+ logger.info("Embedding returned successfully by text API")
29
+ data = json.loads(response.content)
30
+ embedding = data["embedding"]
31
+ results = cd.find_similar_records(embedding, k=top_k)
32
+ else:
33
+ logger.info(f"{response.status_code} returned by the text API")
34
+ results = []
35
+ results = [os.path.join("coco", "val2017", "val2017", fname)
36
+ for i, fname in enumerate(results)]
37
+ return None, results
38
+
39
+ elif search_mode == "Image":
40
+ if uploaded_image is not None:
41
+ preview = uploaded_image
42
+ tmp_path = uploaded_image
43
+ # with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp_file:
44
+ # uploaded_image.save(tmp_file.name)
45
+ # tmp_path = tmp_file.name
46
+ else:
47
+ preview = 'Image.jpg'
48
+ tmp_path = 'Image.jpg'
49
+ url = "https://ashish-001-clip-image-embedding-api.hf.space/clip/process"
50
+ files = {"file": open(tmp_path, "rb")}
51
+ response = requests.post(url, files=files)
52
+ if response.status_code == 200:
53
+ embedding = np.array(response.json()['embedding']).squeeze()
54
+ logger.info("Embedding returned successfully by image API")
55
+ results = cd.find_similar_records(embedding, k=top_k)
56
+ else:
57
+ logger.info(
58
+ f"{response.status_code} returned by the image API")
59
+ results = []
60
+ results = [os.path.join("coco", "val2017", "val2017", fname)
61
+ for i, fname in enumerate(results)]
62
+
63
+ return preview, results
64
+
65
+
66
+ with gr.Blocks() as demo:
67
+ gr.Markdown("## Multimodal Image Search with CLIP")
68
+ gr.Markdown("Search images using **text** or **image upload**.")
69
+
70
+ with gr.Row():
71
+ with gr.Column(scale=1):
72
+ # Inputs
73
+ search_mode = gr.Radio(
74
+ ["Text", "Image"], label="Search Mode", value="Text")
75
+ text_input = gr.Textbox(
76
+ label="Enter text query", placeholder="Type something...", visible=True, value='Empty street')
77
+ file_input = gr.Image(
78
+ type="filepath",
79
+ label="Upload image",
80
+ value="Image.jpg",
81
+ visible=False
82
+ )
83
+ top_k = gr.Slider(1, 20, value=6, step=1,
84
+ label="Number of results")
85
+ submit_btn = gr.Button("Search")
86
+
87
+ with gr.Column(scale=2):
88
+ preview_img = gr.Image(label="Uploaded / Default Image")
89
+ result_gallery = gr.Gallery(
90
+ label="Results", columns=3, height="auto")
91
+
92
+ def toggle_inputs(mode):
93
+ if mode == "Text":
94
+ return (
95
+ gr.update(visible=True),
96
+ gr.update(visible=False, value=None),
97
+ [],
98
+ None
99
+ )
100
+ else:
101
+ return (
102
+ gr.update(visible=False),
103
+ gr.update(visible=True, value=None),
104
+ [],
105
+ "Image.jpg"
106
+ )
107
+
108
+ search_mode.change(toggle_inputs, inputs=search_mode,
109
+ outputs=[text_input, file_input, result_gallery, preview_img])
110
+
111
+ submit_btn.click(fn=search_images,
112
+ inputs=[text_input,
113
+ file_input, search_mode, top_k],
114
+ outputs=[preview_img, result_gallery,])
115
+
116
+ demo.launch()
embeddings/image_embeddings.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:080499a88f33cfae389d37eae0d50d76ca3e11e444c31ab69f49f9f35930dc2e
3
+ size 15360128
embeddings/labels.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5f988d14d0ee40b57b7e69a64890a3b320e25c3d768f8f7b4275a1f846eba72b
3
+ size 20128
embeddings/probabilities.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:16931a633a03b1b5d6eba5eaf0d9c8af42aefa7a2835bf7a1396817fd2388b3a
3
+ size 20128
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch==2.8.0
2
+ torchvision==0.23.0
3
+ hdbscan==0.8.40
4
+ gradio==5.44.1
5
+ numpy==2.2.6
6
+ transformers==4.56.0
7
+ matplotlib==3.10.6