main-model / src /musiclime /wrapper.py
krislette's picture
Auto-deploy from GitHub: c58d63fae21b59bebcd6268e0b9ecb36714b289a
bb88b91
import time
import joblib
import numpy as np
import torch
from src.preprocessing.preprocessor import (
single_preprocessing,
single_audio_preprocessing,
)
from src.spectttra.spectttra_trainer import spectttra_train
from src.llm2vectrain.llm2vec_trainer import l2vec_train
from src.llm2vectrain.model import load_llm2vec_model
from src.models.mlp import build_mlp, load_config
from src.musiclime.print_utils import green_bold
class MusicLIMEPredictor:
"""
Batch prediction wrapper for MusicLIME explanations.
Integrates the complete Bach or Bot pipeline (SpecTTTra + LLM2Vec + MLP)
into a single callable for LIME perturbation processing. Optimized for
batch processing of multiple perturbed audio-lyrics pairs with detailed
timing analysis.
Attributes
----------
llm2vec_model : LLM2Vec
Pre-loaded LLM2Vec model for lyrics feature extraction
classifier : MLPClassifier
Lazy-loaded MLP classifier for final predictions
config : dict
Model configuration parameters
"""
def __init__(self):
"""
Initialize MusicLIME prediction wrapper with pre-trained models.
Loads LLM2Vec model and MLP configuration for batch processing
of perturbed audio-lyrics pairs during LIME explanation.
"""
print("[MusicLIME] Loading models for MusicLIME...")
self.llm2vec_model = load_llm2vec_model()
config = load_config("config/model_config.yml")
self.classifier = None
self.config = config
def __call__(self, texts, audios):
"""
Batch prediction function for MusicLIME perturbations.
Processes multiple perturbed audio-lyrics pairs through the complete
pipeline: preprocessing -> feature extraction -> scaling -> MLP prediction.
Optimized for batch processing of LIME perturbations.
Parameters
----------
texts : list of str
List of perturbed lyrics strings from LIME
audios : list of array-like
List of perturbed audio waveforms from LIME
Returns
-------
ndarray
Prediction probabilities in format [[P(AI), P(Human)], ...]
for each input pair, shape (n_samples, 2)
"""
print(f"[MusicLIME] Processing {len(texts)} samples with batch functions...")
# Step 1: Preprocess all samples (still needs to be individual)
start_time = time.time()
print("[MusicLIME] Preprocessing samples...")
processed_audios = []
processed_lyrics = []
for _, (text, audio) in enumerate(zip(texts, audios)):
processed_audio, processed_lyric = single_preprocessing(audio, text)
processed_audios.append(processed_audio)
processed_lyrics.append(processed_lyric)
preprocessing_time = time.time() - start_time
print(
green_bold(
f"[MusicLIME] Preprocessing completed in {preprocessing_time:.2f}s"
)
)
# Step 2: Batch feature extraction
start_time = time.time()
print("[MusicLIME] Extracting audio features (batch)...")
audio_features_batch = spectttra_train(processed_audios)
# Clear GPU cache after audio processing
if torch.cuda.is_available():
torch.cuda.empty_cache()
audio_time = time.time() - start_time
print(
green_bold(
f"[MusicLIME] Audio feature extraction completed in {audio_time:.2f}s"
)
)
start_time = time.time()
print("[MusicLIME] Extracting lyrics features (batch)...")
lyrics_features_batch = l2vec_train(
self.llm2vec_model, processed_lyrics
) # (batch, 2048)
# Clear GPU cache after lyrics processing
if torch.cuda.is_available():
torch.cuda.empty_cache()
lyrics_time = time.time() - start_time
print(
green_bold(
f"[MusicLIME] Lyrics feature extraction completed in {lyrics_time:.2f}s"
)
)
# Step 3: Scale and reduce in batch
start_time = time.time()
print("[MusicLIME] Scaling and reducing features (batch)...")
# Load the trained scalers
audio_scaler = joblib.load("models/fusion/audio_scaler.pkl")
lyric_scaler = joblib.load("models/fusion/lyrics_scaler.pkl")
# Then apply scaling to the batch
scaled_audio_batch = audio_scaler.transform(
audio_features_batch
) # (batch, 384)
scaled_lyrics_batch = lyric_scaler.transform(
lyrics_features_batch
) # (batch, 2048)
# Step 4: Apply PCA to lyrics batch
print("[MusicLIME] Applying PCA to lyrics (batch)")
pca_model = joblib.load("models/fusion/pca.pkl")
reduced_lyrics_batch = pca_model.transform(scaled_lyrics_batch) # (batch, 512)
# NOTE: Scaling after PCA produces underperforming models compared to non-scaling.
# One can toggle it on for experimentation/testing purposes.
# Step 5: Apply scaler to PCA-scaled lyrics batch
# print("[MusicLIME] Reapplying scaler to PCA-scaled batch")
# pca_scaler = joblib.load("models/fusion/pca_scaler.pkl")
# reduced_lyrics_batch = pca_scaler.transform(
# reduced_lyrics_batch
# ) # (batch, 512)
# Step 6: Concatenate features
combined_features_batch = np.concatenate(
[scaled_audio_batch, reduced_lyrics_batch], axis=1
) # (batch, sum of lyrics & audio vector dims)
scaling_time = time.time() - start_time
print(green_bold(f"[MusicLIME] Scaling completed in {scaling_time:.2f}s"))
# Step 7: Batch MLP prediction
start_time = time.time()
print("[MusicLIME] Running MLP predictions (batch)...")
if self.classifier is None:
self.classifier = build_mlp(
input_dim=combined_features_batch.shape[1], config=self.config
)
self.classifier.load_model("models/mlp/mlp_best_multimodal.pth")
probabilities, predictions = self.classifier.predict(combined_features_batch)
# Convert to expected format
batch_results = [[1 - prob, prob] for prob in probabilities]
mlp_time = time.time() - start_time
print(green_bold(f"[MusicLIME] MLP prediction completed in {mlp_time:.2f}s"))
# Total time summary
total_time = (
preprocessing_time + audio_time + lyrics_time + scaling_time + mlp_time
)
print("[MusicLIME] Batch processing complete!")
print(
green_bold(
f"[MusicLIME] Total time: {total_time:.2f}s (Preprocessing: {preprocessing_time:.2f}s, Audio: {audio_time:.2f}s, Lyrics: {lyrics_time:.2f}s, Scaling: {scaling_time:.2f}s, MLP: {mlp_time:.2f}s)"
)
)
return np.array(batch_results)
class AudioOnlyPredictor:
"""
Audio-only prediction wrapper for MusicLIME explanations.
Integrates the audio-only Bach or Bot pipeline (SpecTTTra + MLP) into a single
callable for LIME perturbation processing. Optimized for batch processing of
multiple perturbed audio samples while ignoring lyrics input. Mirrors the
multimodal MusicLIMEPredictor but processes only audio features.
This predictor is specifically designed for audio-only explainability where
lyrics are kept constant and only audio components are perturbed through
source separation techniques.
Attributes
----------
classifier : MLPClassifier or None
Lazy-loaded MLP classifier for audio-only predictions
config : dict
Model configuration parameters loaded from config files
"""
def __init__(self):
"""
Initialize audio-only prediction wrapper.
Loads model configuration for batch processing of perturbed audio samples
during LIME explanation. The MLP classifier is lazy-loaded on first use
to optimize memory usage.
"""
print("[MusicLIME] Loading models for Audio-Only MusicLIME...")
config = load_config("config/model_config.yml")
self.classifier = None
self.config = config
def __call__(self, texts, audios):
"""
Batch prediction function for audio-only MusicLIME perturbations.
Processes multiple perturbed audio samples through the audio-only pipeline:
preprocessing -> SpecTTTra feature extraction -> scaling -> MLP prediction.
Text inputs are ignored as this is audio-only mode. Optimized for batch
processing of LIME perturbations with detailed timing analysis.
Parameters
----------
texts : list of str
List of text strings (ignored in audio-only mode, kept for API compatibility)
audios : list of array-like
List of perturbed audio waveforms from LIME perturbations
Returns
-------
ndarray
Prediction probabilities in format [[P(AI), P(Human)], ...]
for each input audio sample, shape (n_samples, 2)
"""
print(
f"[MusicLIME] Processing {len(audios)} samples with batch functions (audio-only mode)..."
)
# Step 1: Preprocess all audio samples
start_time = time.time()
print("[MusicLIME] Preprocessing audio samples...")
processed_audios = []
for audio in audios:
processed_audio = single_audio_preprocessing(audio)
processed_audios.append(processed_audio)
preprocessing_time = time.time() - start_time
print(
green_bold(
f"[MusicLIME] Audio preprocessing completed in {preprocessing_time:.2f}s"
)
)
# Step 2: Batch audio feature extraction
start_time = time.time()
print("[MusicLIME] Extracting audio features (batch)...")
audio_features_batch = spectttra_train(processed_audios)
# Clear GPU cache after audio processing
if torch.cuda.is_available():
torch.cuda.empty_cache()
audio_time = time.time() - start_time
print(
green_bold(
f"[MusicLIME] Audio feature extraction completed in {audio_time:.2f}s"
)
)
# Step 3: Scale audio features in batch
start_time = time.time()
print("[MusicLIME] Scaling audio features (batch)...")
# Load the audio scaler
audio_scaler = joblib.load("models/fusion/audio_scaler.pkl")
scaled_audio_batch = audio_scaler.transform(audio_features_batch)
scaling_time = time.time() - start_time
print(green_bold(f"[MusicLIME] Audio scaling completed in {scaling_time:.2f}s"))
# Step 4: Audio-only MLP prediction
start_time = time.time()
print("[MusicLIME] Running audio-only MLP predictions (batch)...")
if self.classifier is None:
self.classifier = build_mlp(
input_dim=scaled_audio_batch.shape[1], config=self.config
)
self.classifier.load_model("models/mlp/mlp_best_unimodal.pth")
probabilities, predictions = self.classifier.predict(scaled_audio_batch)
# Clear GPU cache after MLP processing
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Convert to expected format
batch_results = [[1 - prob, prob] for prob in probabilities]
mlp_time = time.time() - start_time
print(
green_bold(
f"[MusicLIME] Audio-only MLP prediction completed in {mlp_time:.2f}s"
)
)
# Total time summary
total_time = preprocessing_time + audio_time + scaling_time + mlp_time
print("[MusicLIME] Audio-only batch processing complete!")
print(
green_bold(
f"[MusicLIME] Total time: {total_time:.2f}s (Preprocessing: {preprocessing_time:.2f}s, Audio: {audio_time:.2f}s, Scaling: {scaling_time:.2f}s, MLP: {mlp_time:.2f}s)"
)
)
return np.array(batch_results)