# Description: # This script is used to query the index for similar images to a set of random images. # The script uses the FeatureExtractor class to extract the features from the images and the Faiss index to search for similar images. # # Usage: # # To use this script, you can run the following commands: (You MUST define a feat_extractor since indexings are different for each model) # python3 search_query.py --feat_extractor resnet50 # python3 search_query.py --feat_extractor resnet101 # python3 search_query.py --feat_extractor resnet50 --n 5 # python3 search_query.py --feat_extractor resnet50 --k 20 # python3 search_query.py --feat_extractor resnet50 --n 10 --k 12 # import matplotlib.pyplot as plt import numpy as np import argparse import torch import faiss import PIL import os from modules import FeatureExtractor from config import * def select_random_images(n, image_list): """Select n random images from the image list. Args: n (int): The number of images to select. image_list (list[str]): The list of image file names. Returns: list[PIL.Image]: The list of selected images. """ selected_indices = np.random.randint(len(image_list), size=n) img_filenames = [image_list[i] for i in selected_indices] images = [ PIL.Image.open(os.path.join(IMAGES_DIR, img_filename)) for img_filename in img_filenames ] return images def plot_query_results(query_img, similar_imgs, distances, out_filepath): """Plot the query image and the similar images side by side. Save the plot to the specified file path. Args: query_img (PIL.Image): The query image. similar_imgs (list[PIL.Image]): The list of similar images. distances (list[float]): The list of distances of the similar images. out_filepath (str): The file path to save the plot. Returns: None """ # initialize the figure fig, axes = plt.subplots(3, args.k // 2, figsize=(20, 10)) # plot the query image axes[0, 0].imshow(query_img) axes[0, 0].set_title("Query Image") axes[0, 0].axis("off") # do not draw the remaining pots in the first row for i in range(1, args.k // 2): axes[0, i].axis("off") # plot the similar images for i, (img, dist) in enumerate(zip(similar_imgs, distances)): axes[i // (args.k // 2) + 1, i % (args.k // 2)].imshow(img) axes[i // (args.k // 2) + 1, i % (args.k // 2)].set_title(f"{dist:.4f}") axes[i // (args.k // 2) + 1, i % (args.k // 2)].axis("off") # remove the remaining axes plt.tight_layout() # save the plot plt.savefig(out_filepath, bbox_inches="tight", dpi=200) def main(args=None): # set the random seed for reproducibility np.random.seed(args.seed) # load the vector database index index_filepath = os.path.join(DATA_DIR, f"db_{args.feat_extractor}.index") index = faiss.read_index(index_filepath) # initialize the feature extractor with the base model specified in the arguments feature_extractor = FeatureExtractor(base_model=args.feat_extractor) # get the list of images in sorted order since the index is built in the same order image_list = sorted(os.listdir(IMAGES_DIR)) # select n random images query_images = select_random_images(args.n, image_list) with torch.no_grad(): # iterate over the selected/query images for query_idx, img in enumerate(query_images, start=1): # output now has the features corresponding to input x output = feature_extractor.extract_features(img) # keep only batch dimension output = output.view(output.size(0), -1) # normalize output = output / output.norm(p=2, dim=1, keepdim=True) # search for similar images D, I = index.search(output.cpu().numpy(), args.k) # get the similar images similar_images = [ PIL.Image.open(os.path.join(IMAGES_DIR, image_list[index])) for index in I[0] ] # plot the query results and save the plot query_results_folderpath = f"{RESULTS_DIR}/results_{args.feat_extractor}" os.makedirs(query_results_folderpath, exist_ok=True) query_results_filepath = f"{query_results_folderpath}/query_{query_idx:03}.jpg" plot_query_results( img, similar_images, D[0], out_filepath=query_results_filepath ) if __name__ == "__main__": # parse arguments args = argparse.ArgumentParser() args.add_argument( "--feat_extractor", type=str, choices=FEATURE_EXTRACTOR_MODELS, required=True, ) args.add_argument( "--n", type=int, default=10, help="Number of random images to select", ) args.add_argument( "--k", type=int, default=12, help="Number of similar images to retrieve", ) args.add_argument( "--seed", type=int, default=777, help="Random seed for reproducibility", ) args = args.parse_args() # run the main function main(args)