| | |
| | |
| | |
| | |
| |
|
| | from multiprocessing.pool import ThreadPool |
| | import faiss |
| | from typing import List, Tuple |
| |
|
| | from . import rpc |
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | class SearchServer(rpc.Server): |
| | """ Assign version that can be exposed via RPC """ |
| |
|
| | def __init__(self, s: int, index: faiss.Index): |
| | rpc.Server.__init__(self, s) |
| | self.index = index |
| | self.index_ivf = faiss.extract_index_ivf(index) |
| |
|
| | def set_nprobe(self, nprobe: int) -> int: |
| | """ set nprobe field """ |
| | self.index_ivf.nprobe = nprobe |
| |
|
| | def get_ntotal(self) -> int: |
| | return self.index.ntotal |
| |
|
| | def __getattr__(self, f): |
| | |
| | return getattr(self.index, f) |
| |
|
| |
|
| | def run_index_server(index: faiss.Index, port: int, v6: bool = False): |
| | """ serve requests for that index forerver """ |
| | rpc.run_server( |
| | lambda s: SearchServer(s, index), |
| | port, v6=v6) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class ClientIndex: |
| | """manages a set of distance sub-indexes. The sub_indexes search a |
| | subset of the inverted lists. Searches are merged afterwards |
| | """ |
| |
|
| | def __init__(self, machine_ports: List[Tuple[str, int]], v6: bool = False): |
| | """ connect to a series of (host, port) pairs """ |
| | self.sub_indexes = [] |
| | for machine, port in machine_ports: |
| | self.sub_indexes.append(rpc.Client(machine, port, v6)) |
| |
|
| | self.ni = len(self.sub_indexes) |
| | |
| | self.pool = ThreadPool(self.ni) |
| | |
| | self.ntotal = self.get_ntotal() |
| | self.verbose = False |
| |
|
| | def set_nprobe(self, nprobe: int) -> None: |
| | self.pool.map( |
| | lambda idx: idx.set_nprobe(nprobe), |
| | self.sub_indexes |
| | ) |
| |
|
| | def set_omp_num_threads(self, nt: int) -> None: |
| | self.pool.map( |
| | lambda idx: idx.set_omp_num_threads(nt), |
| | self.sub_indexes |
| | ) |
| |
|
| | def get_ntotal(self) -> None: |
| | return sum(self.pool.map( |
| | lambda idx: idx.get_ntotal(), |
| | self.sub_indexes |
| | )) |
| |
|
| | def search(self, x, k: int): |
| |
|
| | rh = faiss.ResultHeap(x.shape[0], k) |
| |
|
| | for Di, Ii in self.pool.imap(lambda idx: idx.search(x, k), self.sub_indexes): |
| | rh.add_result(Di, Ii) |
| | rh.finalize() |
| | return rh.D, rh.I |
| |
|