image-retrieval / src /build_vector_database.py
ABAO77's picture
Upload 22 files
351bcee verified
import faulthandler
faulthandler.enable()
import torch
from tqdm import tqdm
import argparse
import faiss
import PIL
import os
from modules import FeatureExtractor
from config import *
images_dir = "../data"
model_dir = "../model"
def is_image_file(filename):
valid_extensions = ('.png', '.jpg', '.jpeg', '.bmp', '.gif', '.tiff', '.webp')
return filename.lower().endswith(valid_extensions)
def main(args=None):
# initialize the feature extractor with the base model specified in the arguments
feature_extractor = FeatureExtractor(base_model=args.feat_extractor)
# initialize the vector database indexing
index = faiss.IndexFlatIP(feature_extractor.feat_dims)
# get the list of images in sorted order and filter out non-image files
image_list = sorted([f for f in os.listdir(images_dir) if is_image_file(f)])
# print(image_list)
with torch.no_grad():
# iterate over the images and add their extracted features to the index
for img_filename in tqdm(image_list):
# load image
img = PIL.Image.open(os.path.join(images_dir, img_filename)).convert("RGB")
output = feature_extractor.extract_features(img)
# keep only batch dimension
output = output.view(output.size(0), -1)
# normalize the output since we are using the inner product as the similarity measure (cosine similarity)
output = output / output.norm(p=2, dim=1, keepdim=True)
output = output.cpu()
# add to the index
index.add(output.numpy())
# save the index
index_filepath = os.path.join(model_dir, f"db_{args.feat_extractor}.index")
faiss.write_index(index, index_filepath)
if __name__ == "__main__":
# parse arguments
args = argparse.ArgumentParser()
args.add_argument(
"--feat_extractor",
type=str,
default="vit_b_16",
choices=FEATURE_EXTRACTOR_MODELS,
)
args = args.parse_args()
# run the main function
main(args)