File size: 1,178 Bytes
8807f0d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import hnswlib
import numpy as np
import os


class SearchEngine:
    def __init__(self, dim: int, max_elements: int, space="cosine"):
        self.index = hnswlib.Index(space=space, dim=dim)
        self.max_elements = max_elements
        self.is_initialized = False
        self.space = space

    def init_index(self):
        self.index.init_index(max_elements=self.max_elements, ef_construction=200, M=16)
        self.index.set_ef(50)
        self.is_initialized = True

    def add_embeddings(self, embeddings: np.ndarray):
        if not self.is_initialized:
            self.init_index()
        self.index.add_items(embeddings)

    def save_index(self, path="models/embeddings_index.bin"):
        os.makedirs(os.path.dirname(path), exist_ok=True)
        self.index.save_index(path)

    def load_index(self, path="models/embeddings_index.bin"):
        self.index.load_index(path)
        self.is_initialized = True

    def search(self, query_vector, top_k=5):
        labels, distances = self.index.knn_query(query_vector, k=top_k)

        # if self.space == "ip":
        #     return labels[0][::-1], distances[0][::-1]
        return labels[0], distances[0]