harmonic-analysis / src /neighbours.py
ohollo's picture
Gradio bug workaround
b50a375
from typing import NamedTuple
import faiss
import numpy as np
import pandas as pd
from src.utils import indices_distances_gen
_CLOSE_THRESHOLD_DEFAULT = 0.99
class Neighbour(NamedTuple):
distance: float
label: str
length: int
metadata: dict
class EmbeddingClosestNeighbours:
"""
Analyzes embeddings to find close neighbors based on a similarity threshold.
:param index: FAISS index for similarity search.
:param labels: 1-d Numpy array of labels corresponding to index entries.
:param metadata: Pandas DataFrame containing metadata for each indexed entry. Index should be aligned with labels.
:param close_threshold: Similarity threshold to consider embeddings as "close".
"""
def __init__(self, index: faiss.Index, labels: np.ndarray, lengths: np.ndarray,
metadata: pd.DataFrame, close_threshold: float = _CLOSE_THRESHOLD_DEFAULT):
self._index = index
self._labels = labels
self._lengths = lengths
self._metadata = metadata
self._close_threshold = close_threshold
def get(self, embeddings: np.ndarray, limit: int = None) -> list[list[Neighbour]]:
all_neighbours = []
for indices_, distances_ in indices_distances_gen(embeddings, self._close_threshold, self._index):
lengths_ = self._lengths[indices_]
labels, unique_indices = np.unique(self._labels[indices_], return_index=True)
distances = distances_[unique_indices]
lengths = lengths_[unique_indices]
sorted_indices = np.flip(np.argsort(distances))
sorted_labels = labels[sorted_indices]
sorted_distances = distances[sorted_indices]
sorted_lengths = lengths[sorted_indices]
seen_songs = set()
neighbours = []
for j in range(len(sorted_labels)):
meta = self._metadata.loc[sorted_labels[j]].to_dict()
key = (meta.get('title'), meta.get('artist'))
if key in seen_songs:
continue
seen_songs.add(key)
neighbours.append(Neighbour(
distance=float(sorted_distances[j]),
label=sorted_labels[j],
length=int(sorted_lengths[j]),
metadata=meta,
))
if limit is not None:
neighbours = neighbours[:limit]
all_neighbours.append(neighbours)
return all_neighbours