File size: 5,244 Bytes
982b011
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
# Description:
#   This script is used to query the index for similar images to a set of random images.
#   The script uses the FeatureExtractor class to extract the features from the images and the Faiss index to search for similar images.
#
# Usage:
#
#   To use this script, you can run the following commands: (You MUST define a feat_extractor since indexings are different for each model)
#       python3 search_query.py --feat_extractor resnet50
#       python3 search_query.py --feat_extractor resnet101
#       python3 search_query.py --feat_extractor resnet50 --n 5
#       python3 search_query.py --feat_extractor resnet50 --k 20
#       python3 search_query.py --feat_extractor resnet50 --n 10 --k 12
#
import matplotlib.pyplot as plt
import numpy as np
import argparse
import torch
import faiss
import PIL
import os

from modules import FeatureExtractor
from config import *


def select_random_images(n, image_list):
    """Select n random images from the image list.

    Args:
        n (int): The number of images to select.
        image_list (list[str]): The list of image file names.

    Returns:
        list[PIL.Image]: The list of selected images.
    """
    selected_indices = np.random.randint(len(image_list), size=n)
    img_filenames = [image_list[i] for i in selected_indices]
    images = [
        PIL.Image.open(os.path.join(IMAGES_DIR, img_filename))
        for img_filename in img_filenames
    ]
    return images


def plot_query_results(query_img, similar_imgs, distances, out_filepath):
    """Plot the query image and the similar images side by side. Save the plot to the specified file path.

    Args:
        query_img (PIL.Image): The query image.
        similar_imgs (list[PIL.Image]): The list of similar images.
        distances (list[float]): The list of distances of the similar images.
        out_filepath (str): The file path to save the plot.

    Returns:
        None
    """
    # initialize the figure
    fig, axes = plt.subplots(3, args.k // 2, figsize=(20, 10))
    # plot the query image
    axes[0, 0].imshow(query_img)
    axes[0, 0].set_title("Query Image")
    axes[0, 0].axis("off")
    # do not draw the remaining pots in the first row
    for i in range(1, args.k // 2):
        axes[0, i].axis("off")
    # plot the similar images
    for i, (img, dist) in enumerate(zip(similar_imgs, distances)):
        axes[i // (args.k // 2) + 1, i % (args.k // 2)].imshow(img)
        axes[i // (args.k // 2) + 1, i % (args.k // 2)].set_title(f"{dist:.4f}")
        axes[i // (args.k // 2) + 1, i % (args.k // 2)].axis("off")
    # remove the remaining axes
    plt.tight_layout()
    # save the plot
    plt.savefig(out_filepath, bbox_inches="tight", dpi=200)


def main(args=None):

    # set the random seed for reproducibility
    np.random.seed(args.seed)

    # load the vector database index
    index_filepath = os.path.join(DATA_DIR, f"db_{args.feat_extractor}.index")
    index = faiss.read_index(index_filepath)

    # initialize the feature extractor with the base model specified in the arguments
    feature_extractor = FeatureExtractor(base_model=args.feat_extractor)

    # get the list of images in sorted order since the index is built in the same order
    image_list = sorted(os.listdir(IMAGES_DIR))
    # select n random images
    query_images = select_random_images(args.n, image_list)

    with torch.no_grad():
        # iterate over the selected/query images
        for query_idx, img in enumerate(query_images, start=1):
            # output now has the features corresponding to input x
            output = feature_extractor.extract_features(img)
            # keep only batch dimension
            output = output.view(output.size(0), -1)
            # normalize
            output = output / output.norm(p=2, dim=1, keepdim=True)
            # search for similar images
            D, I = index.search(output.cpu().numpy(), args.k)

            # get the similar images
            similar_images = [
                PIL.Image.open(os.path.join(IMAGES_DIR, image_list[index]))
                for index in I[0]
            ]
            # plot the query results and save the plot
            query_results_folderpath = f"{RESULTS_DIR}/results_{args.feat_extractor}"
            os.makedirs(query_results_folderpath, exist_ok=True)
            query_results_filepath = f"{query_results_folderpath}/query_{query_idx:03}.jpg"
            plot_query_results(
                img, similar_images, D[0], out_filepath=query_results_filepath
            )


if __name__ == "__main__":
    # parse arguments
    args = argparse.ArgumentParser()
    args.add_argument(
        "--feat_extractor",
        type=str,
        choices=FEATURE_EXTRACTOR_MODELS,
        required=True,
    )
    args.add_argument(
        "--n",
        type=int,
        default=10,
        help="Number of random images to select",
    )
    args.add_argument(
        "--k",
        type=int,
        default=12,
        help="Number of similar images to retrieve",
    )
    args.add_argument(
        "--seed",
        type=int,
        default=777,
        help="Random seed for reproducibility",
    )

    args = args.parse_args()

    # run the main function
    main(args)