Spaces:
Sleeping
Sleeping
| import faulthandler | |
| import os | |
| import torch | |
| from tqdm import tqdm | |
| import argparse | |
| import faiss | |
| import PIL | |
| from src.modules.feature_extractor import FeatureExtractor | |
| from src.config.settings import FEATURE_EXTRACTOR_MODELS | |
| faulthandler.enable() | |
| def is_image_file(filename): | |
| valid_extensions = ('.png', '.jpg', '.jpeg', '.bmp', '.gif', '.tiff', '.webp') | |
| return filename.lower().endswith(valid_extensions) | |
| def main(args=None): | |
| # Set paths | |
| images_dir = "./data/images" | |
| model_dir = "./model" | |
| onnx_path = os.path.join(model_dir, f"{args.feat_extractor}_feature_extractor.onnx") | |
| # Create model directory if it doesn't exist | |
| os.makedirs(model_dir, exist_ok=True) | |
| # Initialize the feature extractor with ONNX support | |
| print(f"Initializing feature extractor with model: {args.feat_extractor}") | |
| feature_extractor = FeatureExtractor( | |
| base_model=args.feat_extractor, | |
| onnx_path=onnx_path | |
| ) | |
| # Initialize the vector database indexing | |
| print("Initializing FAISS index...") | |
| index = faiss.IndexFlatIP(feature_extractor.feat_dims) | |
| # Get the list of images in sorted order and filter out non-image files | |
| image_list = sorted([f for f in os.listdir(images_dir) if is_image_file(f)]) | |
| print(f"Found {len(image_list)} images to process") | |
| # Limit to first 10 images for initial testing | |
| print(f"Testing with first 10 images") | |
| # Process images and build index | |
| print("Processing images and building index...") | |
| with torch.no_grad(): | |
| for img_filename in tqdm(image_list, desc="Processing images"): | |
| try: | |
| # Load image | |
| img = PIL.Image.open(os.path.join(images_dir, img_filename)).convert("RGB") | |
| # Extract features using ONNX model | |
| output = feature_extractor.extract_features(img) | |
| # Prepare features for indexing | |
| output = output.view(output.size(0), -1) | |
| output = output / output.norm(p=2, dim=1, keepdim=True) | |
| output = output.cpu() | |
| # Add to the index | |
| index.add(output.numpy()) | |
| except Exception as e: | |
| print(f"Error processing image {img_filename}: {str(e)}") | |
| continue | |
| # Save the index | |
| index_filepath = os.path.join(model_dir, f"db_{args.feat_extractor}.index") | |
| print(f"Saving index to {index_filepath}") | |
| faiss.write_index(index, index_filepath) | |
| print("Index saved successfully!") | |
| if __name__ == "__main__": | |
| # Parse arguments | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--feat_extractor", | |
| type=str, | |
| default="vit_b_16", | |
| choices=FEATURE_EXTRACTOR_MODELS, | |
| help="Feature extractor model to use" | |
| ) | |
| args = parser.parse_args() | |
| # Run the main function | |
| main(args) |