File size: 1,932 Bytes
5ce8318
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import os
import faiss
import numpy as np
from src.modules.feature_extractor import FeatureExtractor


def init_models(index_path: str, onnx_path: str) -> tuple[faiss.IndexFlatIP, FeatureExtractor]:
    """Initialize FAISS index and feature extractor.
    
    Args:
        index_path: Path to FAISS index file
        onnx_path: Path to ONNX model file
        
    Returns:
        tuple: (FAISS index, Feature extractor)
        
    Raises:
        FileNotFoundError: If index file doesn't exist
        RuntimeError: If model initialization fails
    """
    # Check if index file exists
    if not os.path.exists(index_path):
        raise FileNotFoundError(f"Index file not found: {index_path}")

    try:
        # Load FAISS index
        index = faiss.read_index(index_path)
        print(f"Successfully loaded FAISS index from {index_path}")
        
        # Initialize feature extractor with ONNX support
        feature_extractor = FeatureExtractor(base_model="vit_b_16", onnx_path=onnx_path)
        print("Successfully initialized feature extractor with ONNX support")
        
        return index, feature_extractor
        
    except Exception as e:
        raise RuntimeError(f"Error initializing models: {str(e)}")


def search_similar_images(
    index: faiss.IndexFlatIP,
    features: np.ndarray,
    k: int = 1
):
    """Search for similar images using FAISS index.
    
    Args:
        index: FAISS index
        features: Image features to search for (numpy array)
        k: Number of similar images to return
        
    Returns:
        tuple: (Distances, Indices)
    """
    # Reshape features if needed
    if len(features.shape) > 2:
        features = features.reshape(features.shape[0], -1)
    
    # Normalize features
    features = features / np.linalg.norm(features, axis=1, keepdims=True)
    
    # Search for similar images
    D, I = index.search(features, k)
    
    return D, I