File size: 841 Bytes
a7d861a
 
 
 
 
681b241
 
 
 
 
 
 
 
 
a7d861a
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import faiss
import numpy as np


def indices_distances_gen(embeddings: np.ndarray, radius: float, index: faiss.Index):
    """
    Generator that yields indices and distances of neighbors within a given radius for each embedding.
    :param embeddings: 2-d Numpy array where each row is an embedding to search neighbors for.
    :param radius: Similarity radius to search within.
    :param index: FAISS index for similarity search.
    """
    embeddings_copy = embeddings.copy().astype(np.float32)
    faiss.normalize_L2(embeddings_copy)
    lims, D, I = index.range_search(embeddings_copy, radius)
    # Iterate over lims and get indices per embedding
    for i in range(len(lims) - 1):
        start = lims[i]
        end = lims[i + 1]
        indices_ = I[start:end]
        distances_ = D[start:end]
        yield indices_, distances_