import hdbscan import numpy as np import os import matplotlib.pyplot as plt from PIL import Image class ClusteringData: def __init__(self,min_num_clusters=5,embeddings=None): self.clusterer=hdbscan.HDBSCAN(min_cluster_size=min_num_clusters) self.labels=None self.probabilities=None self.image_list=sorted(os.listdir(os.path.join('coco','val2017','val2017'))) self.embeddings=embeddings def create_clusters(self,embeddings): self.clusterer.fit(embeddings) self.labels=self.clusterer.labels_ self.probabilities=self.clusterer.probabilities_ def save_model_data(self): np.save(os.path.join("embeddings","labels.npy"),self.clusterer.labels_.astype(np.int32)) np.save(os.path.join("embeddings","probabilities.npy"),self.clusterer.probabilities_.astype(np.float32)) np.save(os.path.join("embeddings","image_embeddings.npy"),self.embeddings.astype(np.float32)) def load_model_data(self): self.labels = np.load(os.path.join("embeddings", "labels.npy")) self.probabilities = np.load(os.path.join("embeddings", "probabilities.npy")) self.embeddings = np.load(os.path.join("embeddings", "image_embeddings.npy")) def find_similar_records(self,embedding,k=10): embedding=embedding/np.linalg.norm(embedding) cosine_similarities=np.dot(self.embeddings,embedding) best_match_idx=np.argmax(cosine_similarities) most_similar_label=self.labels[best_match_idx] # narrowing search with most_similar_label if most_similar_label==-1: candidates=np.arange(len(self.labels)) else: candidates=np.where(self.labels== most_similar_label)[0] final_scores=0.7*cosine_similarities[candidates]+0.3*self.probabilities[candidates] final_indices=candidates[np.argsort(-final_scores)[:k]] top_images=[self.image_list[i] for i in final_indices] return top_images def display_similar_records(self,embedding,k=10): top_images=self.find_similar_records(embedding,k) fig, axs = plt.subplots(1, len(top_images), figsize=(15, 5)) axs = np.atleast_1d(axs) for ax, img_name in zip(axs, top_images): img_path = os.path.join('coco', 'val2017', 'val2017', img_name) img = Image.open(img_path).convert('RGB') ax.imshow(img) ax.axis("off") plt.show()