image-retrieval-full / src /utils /model_utils.py
ABAO77's picture
Upload 12 files
0ec5620 verified
import os
import faiss
import torch
from src.modules.feature_extractor import FeatureExtractor
def init_models(index_path: str, onnx_path: str) -> tuple[faiss.IndexFlatIP, FeatureExtractor]:
"""Initialize FAISS index and feature extractor.
Args:
index_path: Path to FAISS index file
onnx_path: Path to ONNX model file
Returns:
tuple: (FAISS index, Feature extractor)
Raises:
FileNotFoundError: If index file doesn't exist
RuntimeError: If model initialization fails
"""
# Check if index file exists
if not os.path.exists(index_path):
raise FileNotFoundError(f"Index file not found: {index_path}")
try:
# Load FAISS index
index = faiss.read_index(index_path)
print(f"Successfully loaded FAISS index from {index_path}")
# Initialize feature extractor with ONNX support
feature_extractor = FeatureExtractor(base_model="vit_b_16", onnx_path=onnx_path)
print("Successfully initialized feature extractor with ONNX support")
return index, feature_extractor
except Exception as e:
raise RuntimeError(f"Error initializing models: {str(e)}")
def search_similar_images(
index: faiss.IndexFlatIP,
features: torch.Tensor,
k: int = 1
) -> tuple[torch.Tensor, torch.Tensor]:
"""Search for similar images using FAISS index.
Args:
index: FAISS index
features: Image features to search for
k: Number of similar images to return
Returns:
tuple: (Distances, Indices)
"""
# Prepare features for FAISS search
features = features.view(features.size(0), -1)
features = features / features.norm(p=2, dim=1, keepdim=True)
# Search for similar images
D, I = index.search(features.cpu().numpy(), k)
return D, I