from typing import NamedTuple import faiss import numpy as np import pandas as pd from src.utils import indices_distances_gen _CLOSE_THRESHOLD_DEFAULT = 0.99 class Neighbour(NamedTuple): distance: float label: str length: int metadata: dict class EmbeddingClosestNeighbours: """ Analyzes embeddings to find close neighbors based on a similarity threshold. :param index: FAISS index for similarity search. :param labels: 1-d Numpy array of labels corresponding to index entries. :param metadata: Pandas DataFrame containing metadata for each indexed entry. Index should be aligned with labels. :param close_threshold: Similarity threshold to consider embeddings as "close". """ def __init__(self, index: faiss.Index, labels: np.ndarray, lengths: np.ndarray, metadata: pd.DataFrame, close_threshold: float = _CLOSE_THRESHOLD_DEFAULT): self._index = index self._labels = labels self._lengths = lengths self._metadata = metadata self._close_threshold = close_threshold def get(self, embeddings: np.ndarray, limit: int = None) -> list[list[Neighbour]]: all_neighbours = [] for indices_, distances_ in indices_distances_gen(embeddings, self._close_threshold, self._index): lengths_ = self._lengths[indices_] labels, unique_indices = np.unique(self._labels[indices_], return_index=True) distances = distances_[unique_indices] lengths = lengths_[unique_indices] sorted_indices = np.flip(np.argsort(distances)) sorted_labels = labels[sorted_indices] sorted_distances = distances[sorted_indices] sorted_lengths = lengths[sorted_indices] seen_songs = set() neighbours = [] for j in range(len(sorted_labels)): meta = self._metadata.loc[sorted_labels[j]].to_dict() key = (meta.get('title'), meta.get('artist')) if key in seen_songs: continue seen_songs.add(key) neighbours.append(Neighbour( distance=float(sorted_distances[j]), label=sorted_labels[j], length=int(sorted_lengths[j]), metadata=meta, )) if limit is not None: neighbours = neighbours[:limit] all_neighbours.append(neighbours) return all_neighbours