File size: 4,726 Bytes
4687d89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
# 
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import os
import pickle
from typing import List, Tuple

import faiss
import numpy as np
from tqdm import tqdm

class Indexer(object):

    def __init__(self, vector_sz, n_subquantizers=0, n_bits=16):
        # if n_subquantizers > 0:
        #     self.index = faiss.IndexPQ(vector_sz, n_subquantizers, n_bits, faiss.METRIC_INNER_PRODUCT)
        # else:
        self.vector_sz = vector_sz
        self.index = self._create_sharded_index()
        self.index_id_to_db_id = []
        self.label_dict = {}
        # self.index = faiss.IndexFlatIP(vector_sz)

        # self.index = faiss.index_cpu_to_all_gpus(self.index)
        # #self.index_id_to_db_id = np.empty((0), dtype=np.int64)
        # self.index_id_to_db_id = []
        # self.label_dict = {}

    def _create_sharded_index(self):
        # Determine the number of available GPUs
        ngpu = faiss.get_num_gpus()

        # If no GPUs available, use CPU index
        if ngpu == 0:
            print("No GPUs detected. Using CPU index.")
            return faiss.IndexFlatIP(self.vector_sz)
        
        
        # Create an IndexShards object with successive_ids=True to keep ids globally unique
        index = faiss.IndexShards(self.vector_sz, True, True)
        # Create a sub-index for each GPU and add it to the IndexShards container
        for i in range(ngpu):
            # Create a standard GPU resource object
            res = faiss.StandardGpuResources()
            # Configure the GPU index
            flat_config = faiss.GpuIndexFlatConfig()
            # flat_config.useFloat16 = True  # enable to reduce memory usage with half precision
            flat_config.device = i  # assign the GPU device id
            # Create the GPU index
            sub_index = faiss.GpuIndexFlatIP(res, self.vector_sz, flat_config)
            # Add the sub-index into the sharded index
            index.add_shard(sub_index)
        return index

    def index_data(self, ids, embeddings):
        self._update_id_mapping(ids)
        # embeddings = embeddings
        # if not self.index.is_trained:
        #     self.index.train(embeddings)
        self.index.add(embeddings)

        print(f'Total data indexed {self.index.ntotal}')

    def search_knn(self, query_vectors: np.array, top_docs: int, index_batch_size: int = 8) -> List[Tuple[List[object], List[float]]]:
        # query_vectors = query_vectors
        result = []
        nbatch = (len(query_vectors)-1) // index_batch_size + 1
        for k in tqdm(range(nbatch)):
            start_idx = k*index_batch_size
            end_idx = min((k+1)*index_batch_size, len(query_vectors))
            q = query_vectors[start_idx: end_idx]
            scores, indexes = self.index.search(q, top_docs)
            # convert to external ids
            db_ids = [[str(self.index_id_to_db_id[i]) for i in query_top_idxs] for query_top_idxs in indexes]
            db_labels = [[self.label_dict[self.index_id_to_db_id[i]] for i in query_top_idxs] for query_top_idxs in indexes]
            result.extend([(db_ids[i], scores[i],db_labels[i]) for i in range(len(db_ids))])
        return result

    def serialize(self, dir_path):
        index_file = os.path.join(dir_path, 'index.faiss')
        meta_file = os.path.join(dir_path, 'index_meta.faiss')
        print(f'Serializing index to {index_file}, meta data to {meta_file}')

        faiss.write_index(self.index, index_file)
        with open(meta_file, mode='wb') as f:
            pickle.dump(self.index_id_to_db_id, f)

    def deserialize_from(self, dir_path):
        index_file = os.path.join(dir_path, 'index.faiss')
        meta_file = os.path.join(dir_path, 'index_meta.faiss')
        print(f'Loading index from {index_file}, meta data from {meta_file}')

        self.index = faiss.read_index(index_file)
        print('Loaded index of type %s and size %d', type(self.index), self.index.ntotal)

        with open(meta_file, "rb") as reader:
            self.index_id_to_db_id = pickle.load(reader)
        assert len(
            self.index_id_to_db_id) == self.index.ntotal, 'Deserialized index_id_to_db_id should match faiss index size'

    def _update_id_mapping(self, db_ids: List):
        #new_ids = np.array(db_ids, dtype=np.int64)
        #self.index_id_to_db_id = np.concatenate((self.index_id_to_db_id, new_ids), axis=0)
        self.index_id_to_db_id.extend(db_ids)

    def reset(self):
        self.index.reset()
        self.index_id_to_db_id = []
        print(f'Index reset, total data indexed {self.index.ntotal}')