from typing import Iterable import faiss import numpy as np import pandas as pd from src.methodology import CountBasedMethodology from src.utils import indices_distances_gen def _count_unique_neighbours(embeddings, radius, index, all_labels): res = [] for indices_, _ in indices_distances_gen(embeddings, radius, index): neighbours = np.unique(all_labels[indices_]) res.append(neighbours.shape[0]) return res class EmbeddingsOriginalityScorer: """ Scores embeddings based on their originality. Specifically using counts of unique neighbours within certain radii. :param index: FAISS index for similarity search. :param labels: 1-d Numpy array of labels corresponding to index entries. :param radii: List of radii to use for neighbour counting. :param methodology: Methodology that takes dataframe where columns are the different radii, along with length of chord sequence. Each row represents an embedding to be scored. """ def __init__(self, index: faiss.Index, labels: np.ndarray, radii: Iterable[float], methodology: CountBasedMethodology): self._index = index self._labels = labels self._radii = radii self._methodology = methodology def score(self, embeddings: np.ndarray, lengths: pd.Series) -> list[float]: counts = {str(r): _count_unique_neighbours(embeddings, r, self._index, self._labels) for r in self._radii} neighbours_df = pd.DataFrame(counts) return self._methodology.execute(neighbours_df, lengths)