|
|
|
|
|
import numpy as np |
|
|
import matplotlib.pyplot as plt |
|
|
import cv2 |
|
|
|
|
|
|
|
|
class neighbor_info: |
|
|
def __init__(self,label,data,distance): |
|
|
self.label = label |
|
|
self.data = data |
|
|
self.distance = distance |
|
|
|
|
|
def GeM(x): |
|
|
|
|
|
p = 3 |
|
|
|
|
|
x = tf.math.maximum(x, 1e-6) |
|
|
x = tf.pow(x, p) |
|
|
x = tf.reduce_mean(x, axis=[1, 2], keepdims=False) |
|
|
x = tf.pow(x, 1.0 / p) |
|
|
return x |
|
|
|
|
|
def from_path_to_image(path): |
|
|
bgr_image = cv2.imread(path) |
|
|
rgb_image = bgr_image[:, :, ::-1] |
|
|
return rgb_image |
|
|
|
|
|
def string_row_to_array(string): |
|
|
float_list = string.strip('[ ]').split(',') |
|
|
floats = [float(val) for val in float_list] |
|
|
return np.array(floats) |
|
|
|
|
|
def search(query_number, index, queries, relative_index_to_image_index, k_neighbours=5): |
|
|
query = queries.embeddings.iloc[query_number] |
|
|
query = query.reshape(1,-1) |
|
|
faiss.normalize_L2(query) |
|
|
|
|
|
distances, relative_index = index.search(query,k=k_neighbours) |
|
|
absolute_indexes = [relative_index_to_image_index[rel_idx] |
|
|
for rel_idx |
|
|
in relative_index[0]] |
|
|
return distances[0], absolute_indexes |
|
|
|
|
|
def plot_nns(embeddings_df,absolute_indexes,queries,query_number,title=None): |
|
|
nns = embeddings_df.loc[absolute_indexes] |
|
|
|
|
|
query_image_path = [queries.iloc[query_number].path_to_image] |
|
|
nns_image_path = nns.path_to_image.to_list() |
|
|
image_paths = query_image_path + nns_image_path |
|
|
|
|
|
|
|
|
fig, axes = plt.subplots(1, 6, figsize=(30, 6)) |
|
|
|
|
|
for i, ax in enumerate(axes.flatten()): |
|
|
if i < len(image_paths): |
|
|
img = mpimg.imread(image_paths[i]) |
|
|
ax.imshow(img) |
|
|
if i == 0: |
|
|
ax.set_title('Query Image') |
|
|
else: |
|
|
ax.set_title(f'{i} Closest Neighbour') |
|
|
ax.axis('off') |
|
|
|
|
|
if title: |
|
|
plt.suptitle(title, fontsize=20) |
|
|
|
|
|
fig.tight_layout() |
|
|
|
|
|
plt.show() |
|
|
return |
|
|
|
|
|
def search_and_plot(query_number, embeddings_df, index, queries, relative_index_to_image_index,title=None): |
|
|
distances, absolute_indexes = search(query_number, index, queries, relative_index_to_image_index) |
|
|
plot_nns(embeddings_df, absolute_indexes, queries, query_number,title) |
|
|
return |