File size: 2,401 Bytes
7e50653
8ce87fc
f67d590
8ce87fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed729de
8ce87fc
7d383bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
#import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import cv2

# Helper class for visualising the Sim search
class neighbor_info:
    def __init__(self,label,data,distance):
        self.label = label
        self.data = data
        self.distance = distance
        
def GeM(x):
    # Can make this hyper-param trainable but will not for now
    p = 3

    x = tf.math.maximum(x, 1e-6)
    x = tf.pow(x, p)
    x = tf.reduce_mean(x, axis=[1, 2], keepdims=False)
    x = tf.pow(x, 1.0 / p)
    return x

def from_path_to_image(path):
    bgr_image = cv2.imread(path)
    rgb_image = bgr_image[:, :, ::-1]
    return rgb_image
  
def string_row_to_array(string):
    float_list = string.strip('[ ]').split(',')
    floats = [float(val) for val in float_list]
    return np.array(floats)

def search(query_number, index, queries, relative_index_to_image_index, k_neighbours=5):
    query = queries.embeddings.iloc[query_number]
    query = query.reshape(1,-1)
    faiss.normalize_L2(query)

    distances, relative_index = index.search(query,k=k_neighbours)
    absolute_indexes = [relative_index_to_image_index[rel_idx]
                        for rel_idx
                        in relative_index[0]]
    return distances[0], absolute_indexes

def plot_nns(embeddings_df,absolute_indexes,queries,query_number,title=None):
    nns = embeddings_df.loc[absolute_indexes]
    # List of file paths to your images
    query_image_path = [queries.iloc[query_number].path_to_image]
    nns_image_path = nns.path_to_image.to_list()
    image_paths = query_image_path + nns_image_path

    # Plotting the images
    fig, axes = plt.subplots(1, 6, figsize=(30, 6))

    for i, ax in enumerate(axes.flatten()):
        if i < len(image_paths):
            img = mpimg.imread(image_paths[i])
            ax.imshow(img)
            if i == 0:
                ax.set_title('Query Image')
            else:
                ax.set_title(f'{i} Closest Neighbour')
            ax.axis('off')

    if title:
        plt.suptitle(title, fontsize=20)

    fig.tight_layout()

    plt.show()
    return

def search_and_plot(query_number, embeddings_df, index, queries, relative_index_to_image_index,title=None):
    distances, absolute_indexes = search(query_number, index, queries, relative_index_to_image_index)
    plot_nns(embeddings_df, absolute_indexes, queries, query_number,title)
    return