|
|
import hdbscan
|
|
|
import numpy as np
|
|
|
import os
|
|
|
import matplotlib.pyplot as plt
|
|
|
from PIL import Image
|
|
|
|
|
|
class ClusteringData:
|
|
|
def __init__(self,min_num_clusters=5,embeddings=None):
|
|
|
self.clusterer=hdbscan.HDBSCAN(min_cluster_size=min_num_clusters)
|
|
|
self.labels=None
|
|
|
self.probabilities=None
|
|
|
self.image_list=sorted(os.listdir(os.path.join('coco','val2017','val2017')))
|
|
|
self.embeddings=embeddings
|
|
|
|
|
|
def create_clusters(self,embeddings):
|
|
|
self.clusterer.fit(embeddings)
|
|
|
self.labels=self.clusterer.labels_
|
|
|
self.probabilities=self.clusterer.probabilities_
|
|
|
|
|
|
def save_model_data(self):
|
|
|
np.save(os.path.join("embeddings","labels.npy"),self.clusterer.labels_.astype(np.int32))
|
|
|
np.save(os.path.join("embeddings","probabilities.npy"),self.clusterer.probabilities_.astype(np.float32))
|
|
|
np.save(os.path.join("embeddings","image_embeddings.npy"),self.embeddings.astype(np.float32))
|
|
|
|
|
|
def load_model_data(self):
|
|
|
self.labels = np.load(os.path.join("embeddings", "labels.npy"))
|
|
|
self.probabilities = np.load(os.path.join("embeddings", "probabilities.npy"))
|
|
|
self.embeddings = np.load(os.path.join("embeddings", "image_embeddings.npy"))
|
|
|
|
|
|
def find_similar_records(self,embedding,k=10):
|
|
|
embedding=embedding/np.linalg.norm(embedding)
|
|
|
cosine_similarities=np.dot(self.embeddings,embedding)
|
|
|
best_match_idx=np.argmax(cosine_similarities)
|
|
|
most_similar_label=self.labels[best_match_idx]
|
|
|
|
|
|
if most_similar_label==-1:
|
|
|
candidates=np.arange(len(self.labels))
|
|
|
else:
|
|
|
candidates=np.where(self.labels== most_similar_label)[0]
|
|
|
final_scores=0.7*cosine_similarities[candidates]+0.3*self.probabilities[candidates]
|
|
|
final_indices=candidates[np.argsort(-final_scores)[:k]]
|
|
|
top_images=[self.image_list[i] for i in final_indices]
|
|
|
return top_images
|
|
|
|
|
|
def display_similar_records(self,embedding,k=10):
|
|
|
top_images=self.find_similar_records(embedding,k)
|
|
|
fig, axs = plt.subplots(1, len(top_images), figsize=(15, 5))
|
|
|
axs = np.atleast_1d(axs)
|
|
|
for ax, img_name in zip(axs, top_images):
|
|
|
img_path = os.path.join('coco', 'val2017', 'val2017', img_name)
|
|
|
img = Image.open(img_path).convert('RGB')
|
|
|
ax.imshow(img)
|
|
|
ax.axis("off")
|
|
|
plt.show()
|
|
|
|
|
|
|
|
|
|