Spaces:
Runtime error
Runtime error
File size: 5,244 Bytes
982b011 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
# Description:
# This script is used to query the index for similar images to a set of random images.
# The script uses the FeatureExtractor class to extract the features from the images and the Faiss index to search for similar images.
#
# Usage:
#
# To use this script, you can run the following commands: (You MUST define a feat_extractor since indexings are different for each model)
# python3 search_query.py --feat_extractor resnet50
# python3 search_query.py --feat_extractor resnet101
# python3 search_query.py --feat_extractor resnet50 --n 5
# python3 search_query.py --feat_extractor resnet50 --k 20
# python3 search_query.py --feat_extractor resnet50 --n 10 --k 12
#
import matplotlib.pyplot as plt
import numpy as np
import argparse
import torch
import faiss
import PIL
import os
from modules import FeatureExtractor
from config import *
def select_random_images(n, image_list):
"""Select n random images from the image list.
Args:
n (int): The number of images to select.
image_list (list[str]): The list of image file names.
Returns:
list[PIL.Image]: The list of selected images.
"""
selected_indices = np.random.randint(len(image_list), size=n)
img_filenames = [image_list[i] for i in selected_indices]
images = [
PIL.Image.open(os.path.join(IMAGES_DIR, img_filename))
for img_filename in img_filenames
]
return images
def plot_query_results(query_img, similar_imgs, distances, out_filepath):
"""Plot the query image and the similar images side by side. Save the plot to the specified file path.
Args:
query_img (PIL.Image): The query image.
similar_imgs (list[PIL.Image]): The list of similar images.
distances (list[float]): The list of distances of the similar images.
out_filepath (str): The file path to save the plot.
Returns:
None
"""
# initialize the figure
fig, axes = plt.subplots(3, args.k // 2, figsize=(20, 10))
# plot the query image
axes[0, 0].imshow(query_img)
axes[0, 0].set_title("Query Image")
axes[0, 0].axis("off")
# do not draw the remaining pots in the first row
for i in range(1, args.k // 2):
axes[0, i].axis("off")
# plot the similar images
for i, (img, dist) in enumerate(zip(similar_imgs, distances)):
axes[i // (args.k // 2) + 1, i % (args.k // 2)].imshow(img)
axes[i // (args.k // 2) + 1, i % (args.k // 2)].set_title(f"{dist:.4f}")
axes[i // (args.k // 2) + 1, i % (args.k // 2)].axis("off")
# remove the remaining axes
plt.tight_layout()
# save the plot
plt.savefig(out_filepath, bbox_inches="tight", dpi=200)
def main(args=None):
# set the random seed for reproducibility
np.random.seed(args.seed)
# load the vector database index
index_filepath = os.path.join(DATA_DIR, f"db_{args.feat_extractor}.index")
index = faiss.read_index(index_filepath)
# initialize the feature extractor with the base model specified in the arguments
feature_extractor = FeatureExtractor(base_model=args.feat_extractor)
# get the list of images in sorted order since the index is built in the same order
image_list = sorted(os.listdir(IMAGES_DIR))
# select n random images
query_images = select_random_images(args.n, image_list)
with torch.no_grad():
# iterate over the selected/query images
for query_idx, img in enumerate(query_images, start=1):
# output now has the features corresponding to input x
output = feature_extractor.extract_features(img)
# keep only batch dimension
output = output.view(output.size(0), -1)
# normalize
output = output / output.norm(p=2, dim=1, keepdim=True)
# search for similar images
D, I = index.search(output.cpu().numpy(), args.k)
# get the similar images
similar_images = [
PIL.Image.open(os.path.join(IMAGES_DIR, image_list[index]))
for index in I[0]
]
# plot the query results and save the plot
query_results_folderpath = f"{RESULTS_DIR}/results_{args.feat_extractor}"
os.makedirs(query_results_folderpath, exist_ok=True)
query_results_filepath = f"{query_results_folderpath}/query_{query_idx:03}.jpg"
plot_query_results(
img, similar_images, D[0], out_filepath=query_results_filepath
)
if __name__ == "__main__":
# parse arguments
args = argparse.ArgumentParser()
args.add_argument(
"--feat_extractor",
type=str,
choices=FEATURE_EXTRACTOR_MODELS,
required=True,
)
args.add_argument(
"--n",
type=int,
default=10,
help="Number of random images to select",
)
args.add_argument(
"--k",
type=int,
default=12,
help="Number of similar images to retrieve",
)
args.add_argument(
"--seed",
type=int,
default=777,
help="Random seed for reproducibility",
)
args = args.parse_args()
# run the main function
main(args)
|