Spaces:
Sleeping
Sleeping
| import hdbscan | |
| import numpy as np | |
| import os | |
| import matplotlib.pyplot as plt | |
| from PIL import Image | |
| class ClusteringData: | |
| def __init__(self,min_num_clusters=5,embeddings=None): | |
| self.clusterer=hdbscan.HDBSCAN(min_cluster_size=min_num_clusters) | |
| self.labels=None | |
| self.probabilities=None | |
| self.image_list=sorted(os.listdir(os.path.join('coco','val2017','val2017'))) | |
| self.embeddings=embeddings | |
| def create_clusters(self,embeddings): | |
| self.clusterer.fit(embeddings) | |
| self.labels=self.clusterer.labels_ | |
| self.probabilities=self.clusterer.probabilities_ | |
| def save_model_data(self): | |
| np.save(os.path.join("embeddings","labels.npy"),self.clusterer.labels_.astype(np.int32)) | |
| np.save(os.path.join("embeddings","probabilities.npy"),self.clusterer.probabilities_.astype(np.float32)) | |
| np.save(os.path.join("embeddings","image_embeddings.npy"),self.embeddings.astype(np.float32)) | |
| def load_model_data(self): | |
| self.labels = np.load(os.path.join("embeddings", "labels.npy")) | |
| self.probabilities = np.load(os.path.join("embeddings", "probabilities.npy")) | |
| self.embeddings = np.load(os.path.join("embeddings", "image_embeddings.npy")) | |
| def find_similar_records(self,embedding,k=10): | |
| embedding=embedding/np.linalg.norm(embedding) | |
| cosine_similarities=np.dot(self.embeddings,embedding) | |
| best_match_idx=np.argmax(cosine_similarities) | |
| most_similar_label=self.labels[best_match_idx] | |
| # narrowing search with most_similar_label | |
| if most_similar_label==-1: | |
| candidates=np.arange(len(self.labels)) | |
| else: | |
| candidates=np.where(self.labels== most_similar_label)[0] | |
| final_scores=0.7*cosine_similarities[candidates]+0.3*self.probabilities[candidates] | |
| final_indices=candidates[np.argsort(-final_scores)[:k]] | |
| top_images=[self.image_list[i] for i in final_indices] | |
| return top_images | |
| def display_similar_records(self,embedding,k=10): | |
| top_images=self.find_similar_records(embedding,k) | |
| fig, axs = plt.subplots(1, len(top_images), figsize=(15, 5)) | |
| axs = np.atleast_1d(axs) | |
| for ax, img_name in zip(axs, top_images): | |
| img_path = os.path.join('coco', 'val2017', 'val2017', img_name) | |
| img = Image.open(img_path).convert('RGB') | |
| ax.imshow(img) | |
| ax.axis("off") | |
| plt.show() | |