AdrianHR's picture
feat: Remove unused imports and app2.py
f67d590
#import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import cv2
# Helper class for visualising the Sim search
class neighbor_info:
def __init__(self,label,data,distance):
self.label = label
self.data = data
self.distance = distance
def GeM(x):
# Can make this hyper-param trainable but will not for now
p = 3
x = tf.math.maximum(x, 1e-6)
x = tf.pow(x, p)
x = tf.reduce_mean(x, axis=[1, 2], keepdims=False)
x = tf.pow(x, 1.0 / p)
return x
def from_path_to_image(path):
bgr_image = cv2.imread(path)
rgb_image = bgr_image[:, :, ::-1]
return rgb_image
def string_row_to_array(string):
float_list = string.strip('[ ]').split(',')
floats = [float(val) for val in float_list]
return np.array(floats)
def search(query_number, index, queries, relative_index_to_image_index, k_neighbours=5):
query = queries.embeddings.iloc[query_number]
query = query.reshape(1,-1)
faiss.normalize_L2(query)
distances, relative_index = index.search(query,k=k_neighbours)
absolute_indexes = [relative_index_to_image_index[rel_idx]
for rel_idx
in relative_index[0]]
return distances[0], absolute_indexes
def plot_nns(embeddings_df,absolute_indexes,queries,query_number,title=None):
nns = embeddings_df.loc[absolute_indexes]
# List of file paths to your images
query_image_path = [queries.iloc[query_number].path_to_image]
nns_image_path = nns.path_to_image.to_list()
image_paths = query_image_path + nns_image_path
# Plotting the images
fig, axes = plt.subplots(1, 6, figsize=(30, 6))
for i, ax in enumerate(axes.flatten()):
if i < len(image_paths):
img = mpimg.imread(image_paths[i])
ax.imshow(img)
if i == 0:
ax.set_title('Query Image')
else:
ax.set_title(f'{i} Closest Neighbour')
ax.axis('off')
if title:
plt.suptitle(title, fontsize=20)
fig.tight_layout()
plt.show()
return
def search_and_plot(query_number, embeddings_df, index, queries, relative_index_to_image_index,title=None):
distances, absolute_indexes = search(query_number, index, queries, relative_index_to_image_index)
plot_nns(embeddings_df, absolute_indexes, queries, query_number,title)
return