Spaces:
Running
Running
| from typing import NamedTuple | |
| import faiss | |
| import numpy as np | |
| import pandas as pd | |
| from src.utils import indices_distances_gen | |
| _CLOSE_THRESHOLD_DEFAULT = 0.99 | |
| class Neighbour(NamedTuple): | |
| distance: float | |
| label: str | |
| length: int | |
| metadata: dict | |
| class EmbeddingClosestNeighbours: | |
| """ | |
| Analyzes embeddings to find close neighbors based on a similarity threshold. | |
| :param index: FAISS index for similarity search. | |
| :param labels: 1-d Numpy array of labels corresponding to index entries. | |
| :param metadata: Pandas DataFrame containing metadata for each indexed entry. Index should be aligned with labels. | |
| :param close_threshold: Similarity threshold to consider embeddings as "close". | |
| """ | |
| def __init__(self, index: faiss.Index, labels: np.ndarray, lengths: np.ndarray, | |
| metadata: pd.DataFrame, close_threshold: float = _CLOSE_THRESHOLD_DEFAULT): | |
| self._index = index | |
| self._labels = labels | |
| self._lengths = lengths | |
| self._metadata = metadata | |
| self._close_threshold = close_threshold | |
| def get(self, embeddings: np.ndarray, limit: int = None) -> list[list[Neighbour]]: | |
| all_neighbours = [] | |
| for indices_, distances_ in indices_distances_gen(embeddings, self._close_threshold, self._index): | |
| lengths_ = self._lengths[indices_] | |
| labels, unique_indices = np.unique(self._labels[indices_], return_index=True) | |
| distances = distances_[unique_indices] | |
| lengths = lengths_[unique_indices] | |
| sorted_indices = np.flip(np.argsort(distances)) | |
| sorted_labels = labels[sorted_indices] | |
| sorted_distances = distances[sorted_indices] | |
| sorted_lengths = lengths[sorted_indices] | |
| seen_songs = set() | |
| neighbours = [] | |
| for j in range(len(sorted_labels)): | |
| meta = self._metadata.loc[sorted_labels[j]].to_dict() | |
| key = (meta.get('title'), meta.get('artist')) | |
| if key in seen_songs: | |
| continue | |
| seen_songs.add(key) | |
| neighbours.append(Neighbour( | |
| distance=float(sorted_distances[j]), | |
| label=sorted_labels[j], | |
| length=int(sorted_lengths[j]), | |
| metadata=meta, | |
| )) | |
| if limit is not None: | |
| neighbours = neighbours[:limit] | |
| all_neighbours.append(neighbours) | |
| return all_neighbours |