CLIP-based-Image-Search / Clustering.py
ashish-001's picture
Upload 7 files
779c855 verified
import hdbscan
import numpy as np
import os
import matplotlib.pyplot as plt
from PIL import Image
class ClusteringData:
def __init__(self,min_num_clusters=5,embeddings=None):
self.clusterer=hdbscan.HDBSCAN(min_cluster_size=min_num_clusters)
self.labels=None
self.probabilities=None
self.image_list=sorted(os.listdir(os.path.join('coco','val2017','val2017')))
self.embeddings=embeddings
def create_clusters(self,embeddings):
self.clusterer.fit(embeddings)
self.labels=self.clusterer.labels_
self.probabilities=self.clusterer.probabilities_
def save_model_data(self):
np.save(os.path.join("embeddings","labels.npy"),self.clusterer.labels_.astype(np.int32))
np.save(os.path.join("embeddings","probabilities.npy"),self.clusterer.probabilities_.astype(np.float32))
np.save(os.path.join("embeddings","image_embeddings.npy"),self.embeddings.astype(np.float32))
def load_model_data(self):
self.labels = np.load(os.path.join("embeddings", "labels.npy"))
self.probabilities = np.load(os.path.join("embeddings", "probabilities.npy"))
self.embeddings = np.load(os.path.join("embeddings", "image_embeddings.npy"))
def find_similar_records(self,embedding,k=10):
embedding=embedding/np.linalg.norm(embedding)
cosine_similarities=np.dot(self.embeddings,embedding)
best_match_idx=np.argmax(cosine_similarities)
most_similar_label=self.labels[best_match_idx]
# narrowing search with most_similar_label
if most_similar_label==-1:
candidates=np.arange(len(self.labels))
else:
candidates=np.where(self.labels== most_similar_label)[0]
final_scores=0.7*cosine_similarities[candidates]+0.3*self.probabilities[candidates]
final_indices=candidates[np.argsort(-final_scores)[:k]]
top_images=[self.image_list[i] for i in final_indices]
return top_images
def display_similar_records(self,embedding,k=10):
top_images=self.find_similar_records(embedding,k)
fig, axs = plt.subplots(1, len(top_images), figsize=(15, 5))
axs = np.atleast_1d(axs)
for ax, img_name in zip(axs, top_images):
img_path = os.path.join('coco', 'val2017', 'val2017', img_name)
img = Image.open(img_path).convert('RGB')
ax.imshow(img)
ax.axis("off")
plt.show()