File size: 6,774 Bytes
fc7b4a9
 
 
61f21af
fc7b4a9
 
 
 
 
 
 
 
 
 
7633e2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc7b4a9
7633e2f
 
 
 
 
 
fc7b4a9
 
 
 
 
 
 
 
7633e2f
 
 
 
 
fc7b4a9
7633e2f
 
 
 
 
 
fc7b4a9
7633e2f
 
 
 
 
fc7b4a9
 
 
 
 
 
 
 
 
91f3c16
fc7b4a9
 
 
 
 
 
 
 
 
 
 
 
 
 
61f21af
 
 
 
 
 
fc7b4a9
 
 
 
 
 
 
 
 
 
 
 
61f21af
 
 
 
 
fc7b4a9
 
 
 
 
 
 
75d43d2
fc7b4a9
75d43d2
fc7b4a9
75d43d2
fc7b4a9
75d43d2
fc7b4a9
75d43d2
fc7b4a9
 
 
 
75d43d2
 
 
 
 
 
 
 
e1ee8d1
 
 
 
 
 
 
 
fc7b4a9
75d43d2
 
fc7b4a9
 
 
e1ee8d1
fc7b4a9
 
 
 
 
 
75d43d2
fc7b4a9
 
 
 
 
 
 
 
 
 
75d43d2
fc7b4a9
 
 
 
75d43d2
fc7b4a9
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import time
import joblib
import numpy as np
import torch

from src.preprocessing.preprocessor import single_preprocessing
from src.spectttra.spectttra_trainer import spectttra_train
from src.llm2vectrain.llm2vec_trainer import l2vec_train
from src.llm2vectrain.model import load_llm2vec_model
from src.models.mlp import build_mlp, load_config
from src.musiclime.print_utils import green_bold


class MusicLIMEPredictor:
    """
    Batch prediction wrapper for MusicLIME explanations.

    Integrates the complete Bach or Bot pipeline (SpecTTTra + LLM2Vec + MLP)
    into a single callable for LIME perturbation processing. Optimized for
    batch processing of multiple perturbed audio-lyrics pairs with detailed
    timing analysis.

    Attributes
    ----------
    llm2vec_model : LLM2Vec
        Pre-loaded LLM2Vec model for lyrics feature extraction
    classifier : MLPClassifier
        Lazy-loaded MLP classifier for final predictions
    config : dict
        Model configuration parameters
    """

    def __init__(self):
        """
        Initialize MusicLIME prediction wrapper with pre-trained models.

        Loads LLM2Vec model and MLP configuration for batch processing
        of perturbed audio-lyrics pairs during LIME explanation.
        """
        print("[MusicLIME] Loading models for MusicLIME...")
        self.llm2vec_model = load_llm2vec_model()
        config = load_config("config/model_config.yml")
        self.classifier = None
        self.config = config

    def __call__(self, texts, audios):
        """
        Batch prediction function for MusicLIME perturbations.

        Processes multiple perturbed audio-lyrics pairs through the complete
        pipeline: preprocessing -> feature extraction -> scaling -> MLP prediction.
        Optimized for batch processing of LIME perturbations.

        Parameters
        ----------
        texts : list of str
            List of perturbed lyrics strings from LIME
        audios : list of array-like
            List of perturbed audio waveforms from LIME

        Returns
        -------
        ndarray
            Prediction probabilities in format [[P(AI), P(Human)], ...]
            for each input pair, shape (n_samples, 2)
        """
        print(f"[MusicLIME] Processing {len(texts)} samples with batch functions...")

        # Step 1: Preprocess all samples (still needs to be individual)
        start_time = time.time()
        print("[MusicLIME] Preprocessing samples...")
        processed_audios = []
        processed_lyrics = []

        for _, (text, audio) in enumerate(zip(texts, audios)):
            processed_audio, processed_lyric = single_preprocessing(audio, text)
            processed_audios.append(processed_audio)
            processed_lyrics.append(processed_lyric)

        preprocessing_time = time.time() - start_time
        print(
            green_bold(
                f"[MusicLIME] Preprocessing completed in {preprocessing_time:.2f}s"
            )
        )

        # Step 2: Batch feature extraction
        start_time = time.time()
        print("[MusicLIME] Extracting audio features (batch)...")
        audio_features_batch = spectttra_train(processed_audios)

        # Clear GPU cache after audio processing
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        audio_time = time.time() - start_time
        print(
            green_bold(
                f"[MusicLIME] Audio feature extraction completed in {audio_time:.2f}s"
            )
        )

        start_time = time.time()
        print("[MusicLIME] Extracting lyrics features (batch)...")
        lyrics_features_batch = l2vec_train(
            self.llm2vec_model, processed_lyrics
        )  # (batch, 2048)

        # Clear GPU cache after lyrics processing
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        lyrics_time = time.time() - start_time
        print(
            green_bold(
                f"[MusicLIME] Lyrics feature extraction completed in {lyrics_time:.2f}s"
            )
        )

        # Step 3: Scale and reduce in batch
        start_time = time.time()
        print("[MusicLIME] Scaling and reducing features (batch)...")

        # Load the trained scalers
        audio_scaler = joblib.load("models/fusion/audio_scaler.pkl")
        lyric_scaler = joblib.load("models/fusion/lyrics_scaler.pkl")

        # Then apply scaling to the batch
        scaled_audio_batch = audio_scaler.transform(
            audio_features_batch
        )  # (batch, 384)
        scaled_lyrics_batch = lyric_scaler.transform(
            lyrics_features_batch
        )  # (batch, 2048)

        # Step 4: Apply PCA to lyrics batch
        print("[MusicLIME] Applying PCA to lyrics (batch)")
        pca_model = joblib.load("models/fusion/pca.pkl")
        reduced_lyrics_batch = pca_model.transform(scaled_lyrics_batch)  # (batch, 512)

        # Step 5: Apply scaler to PCA-scaled lyrics batch
        print("[MusicLIME] Reapplying scaler to PCA-scaled batch")
        pca_scaler = joblib.load("models/fusion/pca_scaler.pkl")
        reduced_lyrics_batch = pca_scaler.transform(
            reduced_lyrics_batch
        )  # (batch, 512)

        # Step 6: Concatenate features
        combined_features_batch = np.concatenate(
            [scaled_audio_batch, reduced_lyrics_batch], axis=1
        )  # (batch, sum of lyrics & audio vector dims)
        scaling_time = time.time() - start_time
        print(green_bold(f"[MusicLIME] Scaling completed in {scaling_time:.2f}s"))

        # Step 7: Batch MLP prediction
        start_time = time.time()
        print("[MusicLIME] Running MLP predictions (batch)...")
        if self.classifier is None:
            self.classifier = build_mlp(
                input_dim=combined_features_batch.shape[1], config=self.config
            )
            self.classifier.load_model("models/mlp/mlp_best.pth")

        probabilities, predictions = self.classifier.predict(combined_features_batch)

        # Convert to expected format
        batch_results = [[1 - prob, prob] for prob in probabilities]
        mlp_time = time.time() - start_time
        print(green_bold(f"[MusicLIME] MLP prediction completed in {mlp_time:.2f}s"))

        # Total time summary
        total_time = (
            preprocessing_time + audio_time + lyrics_time + scaling_time + mlp_time
        )
        print(f"[MusicLIME] Batch processing complete!")
        print(
            green_bold(
                f"[MusicLIME] Total time: {total_time:.2f}s (Preprocessing: {preprocessing_time:.2f}s, Audio: {audio_time:.2f}s, Lyrics: {lyrics_time:.2f}s, Scaling: {scaling_time:.2f}s, MLP: {mlp_time:.2f}s)"
            )
        )

        return np.array(batch_results)