import faulthandler import os import torch from tqdm import tqdm import argparse import faiss import PIL from src.modules.feature_extractor import FeatureExtractor from src.config.settings import FEATURE_EXTRACTOR_MODELS faulthandler.enable() def is_image_file(filename): valid_extensions = ('.png', '.jpg', '.jpeg', '.bmp', '.gif', '.tiff', '.webp') return filename.lower().endswith(valid_extensions) def main(args=None): # Set paths images_dir = "./data/images" model_dir = "./model" onnx_path = os.path.join(model_dir, f"{args.feat_extractor}_feature_extractor.onnx") # Create model directory if it doesn't exist os.makedirs(model_dir, exist_ok=True) # Initialize the feature extractor with ONNX support print(f"Initializing feature extractor with model: {args.feat_extractor}") feature_extractor = FeatureExtractor( base_model=args.feat_extractor, onnx_path=onnx_path ) # Initialize the vector database indexing print("Initializing FAISS index...") index = faiss.IndexFlatIP(feature_extractor.feat_dims) # Get the list of images in sorted order and filter out non-image files image_list = sorted([f for f in os.listdir(images_dir) if is_image_file(f)]) print(f"Found {len(image_list)} images to process") # Limit to first 10 images for initial testing print(f"Testing with first 10 images") # Process images and build index print("Processing images and building index...") with torch.no_grad(): for img_filename in tqdm(image_list, desc="Processing images"): try: # Load image img = PIL.Image.open(os.path.join(images_dir, img_filename)).convert("RGB") # Extract features using ONNX model output = feature_extractor.extract_features(img) # Prepare features for indexing output = output.view(output.size(0), -1) output = output / output.norm(p=2, dim=1, keepdim=True) output = output.cpu() # Add to the index index.add(output.numpy()) except Exception as e: print(f"Error processing image {img_filename}: {str(e)}") continue # Save the index index_filepath = os.path.join(model_dir, f"db_{args.feat_extractor}.index") print(f"Saving index to {index_filepath}") faiss.write_index(index, index_filepath) print("Index saved successfully!") if __name__ == "__main__": # Parse arguments parser = argparse.ArgumentParser() parser.add_argument( "--feat_extractor", type=str, default="vit_b_16", choices=FEATURE_EXTRACTOR_MODELS, help="Feature extractor model to use" ) args = parser.parse_args() # Run the main function main(args)