File size: 12,216 Bytes
fc7b4a9
 
 
646d85a
fc7b4a9
50aaa2a
 
 
 
fc7b4a9
 
 
 
 
 
 
 
7633e2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc7b4a9
7633e2f
 
 
 
 
 
fc7b4a9
 
 
 
 
 
 
 
7633e2f
 
 
 
 
fc7b4a9
7633e2f
 
 
 
 
 
fc7b4a9
7633e2f
 
 
 
 
fc7b4a9
 
 
 
 
 
 
 
 
91f3c16
fc7b4a9
 
 
 
 
 
 
 
 
 
 
 
 
 
646d85a
 
 
 
 
 
fc7b4a9
 
 
 
 
 
 
 
 
 
 
 
646d85a
 
 
 
 
fc7b4a9
 
 
 
 
 
 
75d43d2
fc7b4a9
75d43d2
fc7b4a9
75d43d2
fc7b4a9
75d43d2
fc7b4a9
75d43d2
fc7b4a9
 
 
 
75d43d2
 
 
 
 
 
 
 
bb88b91
 
e1ee8d1
bb88b91
 
 
 
 
e1ee8d1
 
fc7b4a9
75d43d2
 
fc7b4a9
 
 
e1ee8d1
fc7b4a9
 
 
 
 
 
50aaa2a
fc7b4a9
 
 
 
 
 
 
 
 
 
75d43d2
fc7b4a9
50aaa2a
fc7b4a9
 
75d43d2
fc7b4a9
 
 
 
50aaa2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
import time
import joblib
import numpy as np
import torch

from src.preprocessing.preprocessor import (
    single_preprocessing,
    single_audio_preprocessing,
)
from src.spectttra.spectttra_trainer import spectttra_train
from src.llm2vectrain.llm2vec_trainer import l2vec_train
from src.llm2vectrain.model import load_llm2vec_model
from src.models.mlp import build_mlp, load_config
from src.musiclime.print_utils import green_bold


class MusicLIMEPredictor:
    """
    Batch prediction wrapper for MusicLIME explanations.

    Integrates the complete Bach or Bot pipeline (SpecTTTra + LLM2Vec + MLP)
    into a single callable for LIME perturbation processing. Optimized for
    batch processing of multiple perturbed audio-lyrics pairs with detailed
    timing analysis.

    Attributes
    ----------
    llm2vec_model : LLM2Vec
        Pre-loaded LLM2Vec model for lyrics feature extraction
    classifier : MLPClassifier
        Lazy-loaded MLP classifier for final predictions
    config : dict
        Model configuration parameters
    """

    def __init__(self):
        """
        Initialize MusicLIME prediction wrapper with pre-trained models.

        Loads LLM2Vec model and MLP configuration for batch processing
        of perturbed audio-lyrics pairs during LIME explanation.
        """
        print("[MusicLIME] Loading models for MusicLIME...")
        self.llm2vec_model = load_llm2vec_model()
        config = load_config("config/model_config.yml")
        self.classifier = None
        self.config = config

    def __call__(self, texts, audios):
        """
        Batch prediction function for MusicLIME perturbations.

        Processes multiple perturbed audio-lyrics pairs through the complete
        pipeline: preprocessing -> feature extraction -> scaling -> MLP prediction.
        Optimized for batch processing of LIME perturbations.

        Parameters
        ----------
        texts : list of str
            List of perturbed lyrics strings from LIME
        audios : list of array-like
            List of perturbed audio waveforms from LIME

        Returns
        -------
        ndarray
            Prediction probabilities in format [[P(AI), P(Human)], ...]
            for each input pair, shape (n_samples, 2)
        """
        print(f"[MusicLIME] Processing {len(texts)} samples with batch functions...")

        # Step 1: Preprocess all samples (still needs to be individual)
        start_time = time.time()
        print("[MusicLIME] Preprocessing samples...")
        processed_audios = []
        processed_lyrics = []

        for _, (text, audio) in enumerate(zip(texts, audios)):
            processed_audio, processed_lyric = single_preprocessing(audio, text)
            processed_audios.append(processed_audio)
            processed_lyrics.append(processed_lyric)

        preprocessing_time = time.time() - start_time
        print(
            green_bold(
                f"[MusicLIME] Preprocessing completed in {preprocessing_time:.2f}s"
            )
        )

        # Step 2: Batch feature extraction
        start_time = time.time()
        print("[MusicLIME] Extracting audio features (batch)...")
        audio_features_batch = spectttra_train(processed_audios)

        # Clear GPU cache after audio processing
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        audio_time = time.time() - start_time
        print(
            green_bold(
                f"[MusicLIME] Audio feature extraction completed in {audio_time:.2f}s"
            )
        )

        start_time = time.time()
        print("[MusicLIME] Extracting lyrics features (batch)...")
        lyrics_features_batch = l2vec_train(
            self.llm2vec_model, processed_lyrics
        )  # (batch, 2048)

        # Clear GPU cache after lyrics processing
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        lyrics_time = time.time() - start_time
        print(
            green_bold(
                f"[MusicLIME] Lyrics feature extraction completed in {lyrics_time:.2f}s"
            )
        )

        # Step 3: Scale and reduce in batch
        start_time = time.time()
        print("[MusicLIME] Scaling and reducing features (batch)...")

        # Load the trained scalers
        audio_scaler = joblib.load("models/fusion/audio_scaler.pkl")
        lyric_scaler = joblib.load("models/fusion/lyrics_scaler.pkl")

        # Then apply scaling to the batch
        scaled_audio_batch = audio_scaler.transform(
            audio_features_batch
        )  # (batch, 384)
        scaled_lyrics_batch = lyric_scaler.transform(
            lyrics_features_batch
        )  # (batch, 2048)

        # Step 4: Apply PCA to lyrics batch
        print("[MusicLIME] Applying PCA to lyrics (batch)")
        pca_model = joblib.load("models/fusion/pca.pkl")
        reduced_lyrics_batch = pca_model.transform(scaled_lyrics_batch)  # (batch, 512)

        # NOTE: Scaling after PCA produces underperforming models compared to non-scaling.
        # One can toggle it on for experimentation/testing purposes.
        # Step 5: Apply scaler to PCA-scaled lyrics batch
        # print("[MusicLIME] Reapplying scaler to PCA-scaled batch")
        # pca_scaler = joblib.load("models/fusion/pca_scaler.pkl")
        # reduced_lyrics_batch = pca_scaler.transform(
        #     reduced_lyrics_batch
        # )  # (batch, 512)

        # Step 6: Concatenate features
        combined_features_batch = np.concatenate(
            [scaled_audio_batch, reduced_lyrics_batch], axis=1
        )  # (batch, sum of lyrics & audio vector dims)
        scaling_time = time.time() - start_time
        print(green_bold(f"[MusicLIME] Scaling completed in {scaling_time:.2f}s"))

        # Step 7: Batch MLP prediction
        start_time = time.time()
        print("[MusicLIME] Running MLP predictions (batch)...")
        if self.classifier is None:
            self.classifier = build_mlp(
                input_dim=combined_features_batch.shape[1], config=self.config
            )
            self.classifier.load_model("models/mlp/mlp_best_multimodal.pth")

        probabilities, predictions = self.classifier.predict(combined_features_batch)

        # Convert to expected format
        batch_results = [[1 - prob, prob] for prob in probabilities]
        mlp_time = time.time() - start_time
        print(green_bold(f"[MusicLIME] MLP prediction completed in {mlp_time:.2f}s"))

        # Total time summary
        total_time = (
            preprocessing_time + audio_time + lyrics_time + scaling_time + mlp_time
        )
        print("[MusicLIME] Batch processing complete!")
        print(
            green_bold(
                f"[MusicLIME] Total time: {total_time:.2f}s (Preprocessing: {preprocessing_time:.2f}s, Audio: {audio_time:.2f}s, Lyrics: {lyrics_time:.2f}s, Scaling: {scaling_time:.2f}s, MLP: {mlp_time:.2f}s)"
            )
        )

        return np.array(batch_results)


class AudioOnlyPredictor:
    """
    Audio-only prediction wrapper for MusicLIME explanations.

    Integrates the audio-only Bach or Bot pipeline (SpecTTTra + MLP) into a single
    callable for LIME perturbation processing. Optimized for batch processing of
    multiple perturbed audio samples while ignoring lyrics input. Mirrors the
    multimodal MusicLIMEPredictor but processes only audio features.

    This predictor is specifically designed for audio-only explainability where
    lyrics are kept constant and only audio components are perturbed through
    source separation techniques.

    Attributes
    ----------
    classifier : MLPClassifier or None
        Lazy-loaded MLP classifier for audio-only predictions
    config : dict
        Model configuration parameters loaded from config files
    """

    def __init__(self):
        """
        Initialize audio-only prediction wrapper.

        Loads model configuration for batch processing of perturbed audio samples
        during LIME explanation. The MLP classifier is lazy-loaded on first use
        to optimize memory usage.
        """
        print("[MusicLIME] Loading models for Audio-Only MusicLIME...")
        config = load_config("config/model_config.yml")
        self.classifier = None
        self.config = config

    def __call__(self, texts, audios):
        """
        Batch prediction function for audio-only MusicLIME perturbations.

        Processes multiple perturbed audio samples through the audio-only pipeline:
        preprocessing -> SpecTTTra feature extraction -> scaling -> MLP prediction.
        Text inputs are ignored as this is audio-only mode. Optimized for batch
        processing of LIME perturbations with detailed timing analysis.

        Parameters
        ----------
        texts : list of str
            List of text strings (ignored in audio-only mode, kept for API compatibility)
        audios : list of array-like
            List of perturbed audio waveforms from LIME perturbations

        Returns
        -------
        ndarray
            Prediction probabilities in format [[P(AI), P(Human)], ...]
            for each input audio sample, shape (n_samples, 2)
        """
        print(
            f"[MusicLIME] Processing {len(audios)} samples with batch functions (audio-only mode)..."
        )

        # Step 1: Preprocess all audio samples
        start_time = time.time()
        print("[MusicLIME] Preprocessing audio samples...")
        processed_audios = []

        for audio in audios:
            processed_audio = single_audio_preprocessing(audio)
            processed_audios.append(processed_audio)

        preprocessing_time = time.time() - start_time
        print(
            green_bold(
                f"[MusicLIME] Audio preprocessing completed in {preprocessing_time:.2f}s"
            )
        )

        # Step 2: Batch audio feature extraction
        start_time = time.time()
        print("[MusicLIME] Extracting audio features (batch)...")
        audio_features_batch = spectttra_train(processed_audios)

        # Clear GPU cache after audio processing
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        audio_time = time.time() - start_time
        print(
            green_bold(
                f"[MusicLIME] Audio feature extraction completed in {audio_time:.2f}s"
            )
        )

        # Step 3: Scale audio features in batch
        start_time = time.time()
        print("[MusicLIME] Scaling audio features (batch)...")

        # Load the audio scaler
        audio_scaler = joblib.load("models/fusion/audio_scaler.pkl")
        scaled_audio_batch = audio_scaler.transform(audio_features_batch)

        scaling_time = time.time() - start_time
        print(green_bold(f"[MusicLIME] Audio scaling completed in {scaling_time:.2f}s"))

        # Step 4: Audio-only MLP prediction
        start_time = time.time()
        print("[MusicLIME] Running audio-only MLP predictions (batch)...")

        if self.classifier is None:
            self.classifier = build_mlp(
                input_dim=scaled_audio_batch.shape[1], config=self.config
            )
            self.classifier.load_model("models/mlp/mlp_best_unimodal.pth")

        probabilities, predictions = self.classifier.predict(scaled_audio_batch)

        # Clear GPU cache after MLP processing
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        # Convert to expected format
        batch_results = [[1 - prob, prob] for prob in probabilities]
        mlp_time = time.time() - start_time
        print(
            green_bold(
                f"[MusicLIME] Audio-only MLP prediction completed in {mlp_time:.2f}s"
            )
        )

        # Total time summary
        total_time = preprocessing_time + audio_time + scaling_time + mlp_time
        print("[MusicLIME] Audio-only batch processing complete!")
        print(
            green_bold(
                f"[MusicLIME] Total time: {total_time:.2f}s (Preprocessing: {preprocessing_time:.2f}s, Audio: {audio_time:.2f}s, Scaling: {scaling_time:.2f}s, MLP: {mlp_time:.2f}s)"
            )
        )

        return np.array(batch_results)