File size: 1,894 Bytes
0ec5620
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import os
import faiss
import torch
from src.modules.feature_extractor import FeatureExtractor


def init_models(index_path: str, onnx_path: str) -> tuple[faiss.IndexFlatIP, FeatureExtractor]:
    """Initialize FAISS index and feature extractor.
    
    Args:
        index_path: Path to FAISS index file
        onnx_path: Path to ONNX model file
        
    Returns:
        tuple: (FAISS index, Feature extractor)
        
    Raises:
        FileNotFoundError: If index file doesn't exist
        RuntimeError: If model initialization fails
    """
    # Check if index file exists
    if not os.path.exists(index_path):
        raise FileNotFoundError(f"Index file not found: {index_path}")

    try:
        # Load FAISS index
        index = faiss.read_index(index_path)
        print(f"Successfully loaded FAISS index from {index_path}")
        
        # Initialize feature extractor with ONNX support
        feature_extractor = FeatureExtractor(base_model="vit_b_16", onnx_path=onnx_path)
        print("Successfully initialized feature extractor with ONNX support")
        
        return index, feature_extractor
        
    except Exception as e:
        raise RuntimeError(f"Error initializing models: {str(e)}")


def search_similar_images(
    index: faiss.IndexFlatIP,
    features: torch.Tensor,
    k: int = 1
) -> tuple[torch.Tensor, torch.Tensor]:
    """Search for similar images using FAISS index.
    
    Args:
        index: FAISS index
        features: Image features to search for
        k: Number of similar images to return
        
    Returns:
        tuple: (Distances, Indices)
    """
    # Prepare features for FAISS search
    features = features.view(features.size(0), -1)
    features = features / features.norm(p=2, dim=1, keepdim=True)
    
    # Search for similar images
    D, I = index.search(features.cpu().numpy(), k)
    
    return D, I