Triventure-AI / src /utils /model_utils.py
ABAO77's picture
Upload 37 files
5ce8318 verified
import os
import faiss
import numpy as np
from src.modules.feature_extractor import FeatureExtractor
def init_models(index_path: str, onnx_path: str) -> tuple[faiss.IndexFlatIP, FeatureExtractor]:
"""Initialize FAISS index and feature extractor.
Args:
index_path: Path to FAISS index file
onnx_path: Path to ONNX model file
Returns:
tuple: (FAISS index, Feature extractor)
Raises:
FileNotFoundError: If index file doesn't exist
RuntimeError: If model initialization fails
"""
# Check if index file exists
if not os.path.exists(index_path):
raise FileNotFoundError(f"Index file not found: {index_path}")
try:
# Load FAISS index
index = faiss.read_index(index_path)
print(f"Successfully loaded FAISS index from {index_path}")
# Initialize feature extractor with ONNX support
feature_extractor = FeatureExtractor(base_model="vit_b_16", onnx_path=onnx_path)
print("Successfully initialized feature extractor with ONNX support")
return index, feature_extractor
except Exception as e:
raise RuntimeError(f"Error initializing models: {str(e)}")
def search_similar_images(
index: faiss.IndexFlatIP,
features: np.ndarray,
k: int = 1
):
"""Search for similar images using FAISS index.
Args:
index: FAISS index
features: Image features to search for (numpy array)
k: Number of similar images to return
Returns:
tuple: (Distances, Indices)
"""
# Reshape features if needed
if len(features.shape) > 2:
features = features.reshape(features.shape[0], -1)
# Normalize features
features = features / np.linalg.norm(features, axis=1, keepdims=True)
# Search for similar images
D, I = index.search(features, k)
return D, I