import faulthandler faulthandler.enable() import torch from tqdm import tqdm import argparse import faiss import PIL import os from modules import FeatureExtractor from config import * images_dir = "../data" model_dir = "../model" def is_image_file(filename): valid_extensions = ('.png', '.jpg', '.jpeg', '.bmp', '.gif', '.tiff', '.webp') return filename.lower().endswith(valid_extensions) def main(args=None): # initialize the feature extractor with the base model specified in the arguments feature_extractor = FeatureExtractor(base_model=args.feat_extractor) # initialize the vector database indexing index = faiss.IndexFlatIP(feature_extractor.feat_dims) # get the list of images in sorted order and filter out non-image files image_list = sorted([f for f in os.listdir(images_dir) if is_image_file(f)]) # print(image_list) with torch.no_grad(): # iterate over the images and add their extracted features to the index for img_filename in tqdm(image_list): # load image img = PIL.Image.open(os.path.join(images_dir, img_filename)).convert("RGB") output = feature_extractor.extract_features(img) # keep only batch dimension output = output.view(output.size(0), -1) # normalize the output since we are using the inner product as the similarity measure (cosine similarity) output = output / output.norm(p=2, dim=1, keepdim=True) output = output.cpu() # add to the index index.add(output.numpy()) # save the index index_filepath = os.path.join(model_dir, f"db_{args.feat_extractor}.index") faiss.write_index(index, index_filepath) if __name__ == "__main__": # parse arguments args = argparse.ArgumentParser() args.add_argument( "--feat_extractor", type=str, default="vit_b_16", choices=FEATURE_EXTRACTOR_MODELS, ) args = args.parse_args() # run the main function main(args)