File size: 2,513 Bytes
a7d861a
 
 
 
 
 
 
 
 
681b241
 
a7d861a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
681b241
 
a7d861a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b50a375
 
 
 
 
 
 
 
 
a7d861a
 
 
b50a375
 
a7d861a
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65

from typing import NamedTuple

import faiss
import numpy as np
import pandas as pd

from src.utils import indices_distances_gen

_CLOSE_THRESHOLD_DEFAULT = 0.99


class Neighbour(NamedTuple):
    distance: float
    label: str
    length: int
    metadata: dict


class EmbeddingClosestNeighbours:
    """
    Analyzes embeddings to find close neighbors based on a similarity threshold.
    
    :param index: FAISS index for similarity search.
    :param labels: 1-d Numpy array of labels corresponding to index entries.
    :param metadata: Pandas DataFrame containing metadata for each indexed entry. Index should be aligned with labels.
    :param close_threshold: Similarity threshold to consider embeddings as "close".
    """
    def __init__(self, index: faiss.Index, labels: np.ndarray, lengths: np.ndarray, 
                 metadata: pd.DataFrame, close_threshold: float = _CLOSE_THRESHOLD_DEFAULT):
        self._index = index
        self._labels = labels
        self._lengths = lengths
        self._metadata = metadata
        self._close_threshold = close_threshold

    def get(self, embeddings: np.ndarray, limit: int = None) -> list[list[Neighbour]]:
        all_neighbours = []
        for indices_, distances_ in indices_distances_gen(embeddings, self._close_threshold, self._index):
            lengths_ = self._lengths[indices_]
            labels, unique_indices = np.unique(self._labels[indices_], return_index=True)
            distances = distances_[unique_indices]
            lengths = lengths_[unique_indices]
            sorted_indices = np.flip(np.argsort(distances))
            sorted_labels = labels[sorted_indices]
            sorted_distances = distances[sorted_indices]
            sorted_lengths = lengths[sorted_indices]
            seen_songs = set()
            neighbours = []
            for j in range(len(sorted_labels)):
                meta = self._metadata.loc[sorted_labels[j]].to_dict()
                key = (meta.get('title'), meta.get('artist'))
                if key in seen_songs:
                    continue
                seen_songs.add(key)
                neighbours.append(Neighbour(
                    distance=float(sorted_distances[j]),
                    label=sorted_labels[j],
                    length=int(sorted_lengths[j]),
                    metadata=meta,
                ))
            if limit is not None:
                neighbours = neighbours[:limit]
            all_neighbours.append(neighbours)
        return all_neighbours