Spaces:
Runtime error
Runtime error
| import faulthandler | |
| faulthandler.enable() | |
| import torch | |
| from tqdm import tqdm | |
| import argparse | |
| import faiss | |
| import PIL | |
| import os | |
| from modules import FeatureExtractor | |
| from config import * | |
| images_dir = "../data" | |
| model_dir = "../model" | |
| def is_image_file(filename): | |
| valid_extensions = ('.png', '.jpg', '.jpeg', '.bmp', '.gif', '.tiff', '.webp') | |
| return filename.lower().endswith(valid_extensions) | |
| def main(args=None): | |
| # initialize the feature extractor with the base model specified in the arguments | |
| feature_extractor = FeatureExtractor(base_model=args.feat_extractor) | |
| # initialize the vector database indexing | |
| index = faiss.IndexFlatIP(feature_extractor.feat_dims) | |
| # get the list of images in sorted order and filter out non-image files | |
| image_list = sorted([f for f in os.listdir(images_dir) if is_image_file(f)]) | |
| # print(image_list) | |
| with torch.no_grad(): | |
| # iterate over the images and add their extracted features to the index | |
| for img_filename in tqdm(image_list): | |
| # load image | |
| img = PIL.Image.open(os.path.join(images_dir, img_filename)).convert("RGB") | |
| output = feature_extractor.extract_features(img) | |
| # keep only batch dimension | |
| output = output.view(output.size(0), -1) | |
| # normalize the output since we are using the inner product as the similarity measure (cosine similarity) | |
| output = output / output.norm(p=2, dim=1, keepdim=True) | |
| output = output.cpu() | |
| # add to the index | |
| index.add(output.numpy()) | |
| # save the index | |
| index_filepath = os.path.join(model_dir, f"db_{args.feat_extractor}.index") | |
| faiss.write_index(index, index_filepath) | |
| if __name__ == "__main__": | |
| # parse arguments | |
| args = argparse.ArgumentParser() | |
| args.add_argument( | |
| "--feat_extractor", | |
| type=str, | |
| default="vit_b_16", | |
| choices=FEATURE_EXTRACTOR_MODELS, | |
| ) | |
| args = args.parse_args() | |
| # run the main function | |
| main(args) |