Spaces:
Running
Running
File size: 2,513 Bytes
a7d861a 681b241 a7d861a 681b241 a7d861a b50a375 a7d861a b50a375 a7d861a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
from typing import NamedTuple
import faiss
import numpy as np
import pandas as pd
from src.utils import indices_distances_gen
_CLOSE_THRESHOLD_DEFAULT = 0.99
class Neighbour(NamedTuple):
distance: float
label: str
length: int
metadata: dict
class EmbeddingClosestNeighbours:
"""
Analyzes embeddings to find close neighbors based on a similarity threshold.
:param index: FAISS index for similarity search.
:param labels: 1-d Numpy array of labels corresponding to index entries.
:param metadata: Pandas DataFrame containing metadata for each indexed entry. Index should be aligned with labels.
:param close_threshold: Similarity threshold to consider embeddings as "close".
"""
def __init__(self, index: faiss.Index, labels: np.ndarray, lengths: np.ndarray,
metadata: pd.DataFrame, close_threshold: float = _CLOSE_THRESHOLD_DEFAULT):
self._index = index
self._labels = labels
self._lengths = lengths
self._metadata = metadata
self._close_threshold = close_threshold
def get(self, embeddings: np.ndarray, limit: int = None) -> list[list[Neighbour]]:
all_neighbours = []
for indices_, distances_ in indices_distances_gen(embeddings, self._close_threshold, self._index):
lengths_ = self._lengths[indices_]
labels, unique_indices = np.unique(self._labels[indices_], return_index=True)
distances = distances_[unique_indices]
lengths = lengths_[unique_indices]
sorted_indices = np.flip(np.argsort(distances))
sorted_labels = labels[sorted_indices]
sorted_distances = distances[sorted_indices]
sorted_lengths = lengths[sorted_indices]
seen_songs = set()
neighbours = []
for j in range(len(sorted_labels)):
meta = self._metadata.loc[sorted_labels[j]].to_dict()
key = (meta.get('title'), meta.get('artist'))
if key in seen_songs:
continue
seen_songs.add(key)
neighbours.append(Neighbour(
distance=float(sorted_distances[j]),
label=sorted_labels[j],
length=int(sorted_lengths[j]),
metadata=meta,
))
if limit is not None:
neighbours = neighbours[:limit]
all_neighbours.append(neighbours)
return all_neighbours |