harmonic-analysis / src /scorer.py
ohollo's picture
Make Radii needed be passed through
47de94c
from typing import Iterable
import faiss
import numpy as np
import pandas as pd
from src.methodology import CountBasedMethodology
from src.utils import indices_distances_gen
def _count_unique_neighbours(embeddings, radius, index, all_labels):
res = []
for indices_, _ in indices_distances_gen(embeddings, radius, index):
neighbours = np.unique(all_labels[indices_])
res.append(neighbours.shape[0])
return res
class EmbeddingsOriginalityScorer:
"""
Scores embeddings based on their originality. Specifically using counts of unique neighbours within certain radii.
:param index: FAISS index for similarity search.
:param labels: 1-d Numpy array of labels corresponding to index entries.
:param radii: List of radii to use for neighbour counting.
:param methodology: Methodology that takes dataframe where columns are the different radii, along with length of chord sequence. Each row represents an embedding to be scored.
"""
def __init__(self, index: faiss.Index, labels: np.ndarray, radii: Iterable[float], methodology: CountBasedMethodology):
self._index = index
self._labels = labels
self._radii = radii
self._methodology = methodology
def score(self, embeddings: np.ndarray, lengths: pd.Series) -> list[float]:
counts = {str(r): _count_unique_neighbours(embeddings, r, self._index, self._labels) for r in self._radii}
neighbours_df = pd.DataFrame(counts)
return self._methodology.execute(neighbours_df, lengths)