File size: 1,676 Bytes
a7d861a
 
 
 
 
681b241
a7d861a
47de94c
a7d861a
 
c5184df
 
 
 
 
 
 
 
 
007017f
a7d861a
 
 
 
007017f
47de94c
a7d861a
 
681b241
 
a7d861a
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from dataclasses import dataclass

import pandas as pd
from src.methodology import SimpleMethodology
from src.neighbours import EmbeddingClosestNeighbours
from src.scorer import EmbeddingsOriginalityScorer

_FALLBACK_INDEX = 99

class EmbeddingsAnalysis:
    """
    Facade for analyzing embeddings, combining neighbor search and originality scoring.

    :param index: FAISS index for similarity search.
    :param all_labels: DataFrame containing 'track_id' and 'length' columns for indexed entries.
    :param lookup: Pandas DataFrame containing metadata for each indexed entry.
    :param scalers: Dictionary mapping length ranges to quantile transformers for score normalization.
    :param close_threshold: Similarity threshold for neighbor search.
    """
    def __init__(self, index, all_labels, lookup, scalers, radii, close_threshold=0.95, score_power=1.0):
        all_labels_np = all_labels['track_id'].to_numpy()
        all_lengths_np = all_labels['length'].to_numpy()
        self._ecn = EmbeddingClosestNeighbours(index, all_labels_np, all_lengths_np, lookup, close_threshold=close_threshold)
        specific_scalers = {i: scaler for (l, r), scaler in scalers.items() for i in range(l, r)}
        sm = SimpleMethodology(specific_scalers, specific_scalers[_FALLBACK_INDEX], score_power=score_power)
        self._scorer = EmbeddingsOriginalityScorer(index, all_labels_np, radii, sm)


    def get_scores(self, embeddings, lengths):
        score = self._scorer.score(embeddings, pd.Series(lengths))
        return score
    
    def get_neighbours(self, embeddings, limit=None):
        neighbours = self._ecn.get(embeddings, limit)
        return neighbours