File size: 2,401 Bytes
7e50653 8ce87fc f67d590 8ce87fc ed729de 8ce87fc 7d383bb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 | #import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import cv2
# Helper class for visualising the Sim search
class neighbor_info:
def __init__(self,label,data,distance):
self.label = label
self.data = data
self.distance = distance
def GeM(x):
# Can make this hyper-param trainable but will not for now
p = 3
x = tf.math.maximum(x, 1e-6)
x = tf.pow(x, p)
x = tf.reduce_mean(x, axis=[1, 2], keepdims=False)
x = tf.pow(x, 1.0 / p)
return x
def from_path_to_image(path):
bgr_image = cv2.imread(path)
rgb_image = bgr_image[:, :, ::-1]
return rgb_image
def string_row_to_array(string):
float_list = string.strip('[ ]').split(',')
floats = [float(val) for val in float_list]
return np.array(floats)
def search(query_number, index, queries, relative_index_to_image_index, k_neighbours=5):
query = queries.embeddings.iloc[query_number]
query = query.reshape(1,-1)
faiss.normalize_L2(query)
distances, relative_index = index.search(query,k=k_neighbours)
absolute_indexes = [relative_index_to_image_index[rel_idx]
for rel_idx
in relative_index[0]]
return distances[0], absolute_indexes
def plot_nns(embeddings_df,absolute_indexes,queries,query_number,title=None):
nns = embeddings_df.loc[absolute_indexes]
# List of file paths to your images
query_image_path = [queries.iloc[query_number].path_to_image]
nns_image_path = nns.path_to_image.to_list()
image_paths = query_image_path + nns_image_path
# Plotting the images
fig, axes = plt.subplots(1, 6, figsize=(30, 6))
for i, ax in enumerate(axes.flatten()):
if i < len(image_paths):
img = mpimg.imread(image_paths[i])
ax.imshow(img)
if i == 0:
ax.set_title('Query Image')
else:
ax.set_title(f'{i} Closest Neighbour')
ax.axis('off')
if title:
plt.suptitle(title, fontsize=20)
fig.tight_layout()
plt.show()
return
def search_and_plot(query_number, embeddings_df, index, queries, relative_index_to_image_index,title=None):
distances, absolute_indexes = search(query_number, index, queries, relative_index_to_image_index)
plot_nns(embeddings_df, absolute_indexes, queries, query_number,title)
return |