Spaces:
Runtime error
Runtime error
File size: 2,030 Bytes
982b011 351bcee 982b011 351bcee 982b011 351bcee 982b011 95c2d79 982b011 351bcee 982b011 351bcee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
import faulthandler
faulthandler.enable()
import torch
from tqdm import tqdm
import argparse
import faiss
import PIL
import os
from modules import FeatureExtractor
from config import *
images_dir = "../data"
model_dir = "../model"
def is_image_file(filename):
valid_extensions = ('.png', '.jpg', '.jpeg', '.bmp', '.gif', '.tiff', '.webp')
return filename.lower().endswith(valid_extensions)
def main(args=None):
# initialize the feature extractor with the base model specified in the arguments
feature_extractor = FeatureExtractor(base_model=args.feat_extractor)
# initialize the vector database indexing
index = faiss.IndexFlatIP(feature_extractor.feat_dims)
# get the list of images in sorted order and filter out non-image files
image_list = sorted([f for f in os.listdir(images_dir) if is_image_file(f)])
# print(image_list)
with torch.no_grad():
# iterate over the images and add their extracted features to the index
for img_filename in tqdm(image_list):
# load image
img = PIL.Image.open(os.path.join(images_dir, img_filename)).convert("RGB")
output = feature_extractor.extract_features(img)
# keep only batch dimension
output = output.view(output.size(0), -1)
# normalize the output since we are using the inner product as the similarity measure (cosine similarity)
output = output / output.norm(p=2, dim=1, keepdim=True)
output = output.cpu()
# add to the index
index.add(output.numpy())
# save the index
index_filepath = os.path.join(model_dir, f"db_{args.feat_extractor}.index")
faiss.write_index(index, index_filepath)
if __name__ == "__main__":
# parse arguments
args = argparse.ArgumentParser()
args.add_argument(
"--feat_extractor",
type=str,
default="vit_b_16",
choices=FEATURE_EXTRACTOR_MODELS,
)
args = args.parse_args()
# run the main function
main(args) |