File size: 1,542 Bytes
47de94c
 
a7d861a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
681b241
a7d861a
 
 
 
 
47de94c
 
 
 
a7d861a
 
47de94c
a7d861a
47de94c
a7d861a
47de94c
a7d861a
47de94c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from typing import Iterable

import faiss
import numpy as np
import pandas as pd

from src.methodology import CountBasedMethodology
from src.utils import indices_distances_gen


def _count_unique_neighbours(embeddings, radius, index, all_labels):
    res = []
    for indices_, _ in indices_distances_gen(embeddings, radius, index):
        neighbours = np.unique(all_labels[indices_])
        res.append(neighbours.shape[0])
    return res


class EmbeddingsOriginalityScorer:
    """
    Scores embeddings based on their originality. Specifically using counts of unique neighbours within certain radii.

    :param index: FAISS index for similarity search.
    :param labels: 1-d Numpy array of labels corresponding to index entries.
    :param radii: List of radii to use for neighbour counting.
    :param methodology: Methodology that takes dataframe where columns are the different radii, along with length of chord sequence. Each row represents an embedding to be scored.
    """
    def __init__(self, index: faiss.Index, labels: np.ndarray, radii: Iterable[float], methodology: CountBasedMethodology):
        self._index = index
        self._labels = labels
        self._radii = radii
        self._methodology = methodology

    def score(self, embeddings: np.ndarray, lengths: pd.Series) -> list[float]:
        counts = {str(r): _count_unique_neighbours(embeddings, r, self._index, self._labels) for r in self._radii}
        neighbours_df = pd.DataFrame(counts)
        return self._methodology.execute(neighbours_df, lengths)