import time import joblib import numpy as np import torch from src.preprocessing.preprocessor import ( single_preprocessing, single_audio_preprocessing, ) from src.spectttra.spectttra_trainer import spectttra_train from src.llm2vectrain.llm2vec_trainer import l2vec_train from src.llm2vectrain.model import load_llm2vec_model from src.models.mlp import build_mlp, load_config from src.musiclime.print_utils import green_bold class MusicLIMEPredictor: """ Batch prediction wrapper for MusicLIME explanations. Integrates the complete Bach or Bot pipeline (SpecTTTra + LLM2Vec + MLP) into a single callable for LIME perturbation processing. Optimized for batch processing of multiple perturbed audio-lyrics pairs with detailed timing analysis. Attributes ---------- llm2vec_model : LLM2Vec Pre-loaded LLM2Vec model for lyrics feature extraction classifier : MLPClassifier Lazy-loaded MLP classifier for final predictions config : dict Model configuration parameters """ def __init__(self): """ Initialize MusicLIME prediction wrapper with pre-trained models. Loads LLM2Vec model and MLP configuration for batch processing of perturbed audio-lyrics pairs during LIME explanation. """ print("[MusicLIME] Loading models for MusicLIME...") self.llm2vec_model = load_llm2vec_model() config = load_config("config/model_config.yml") self.classifier = None self.config = config def __call__(self, texts, audios): """ Batch prediction function for MusicLIME perturbations. Processes multiple perturbed audio-lyrics pairs through the complete pipeline: preprocessing -> feature extraction -> scaling -> MLP prediction. Optimized for batch processing of LIME perturbations. Parameters ---------- texts : list of str List of perturbed lyrics strings from LIME audios : list of array-like List of perturbed audio waveforms from LIME Returns ------- ndarray Prediction probabilities in format [[P(AI), P(Human)], ...] for each input pair, shape (n_samples, 2) """ print(f"[MusicLIME] Processing {len(texts)} samples with batch functions...") # Step 1: Preprocess all samples (still needs to be individual) start_time = time.time() print("[MusicLIME] Preprocessing samples...") processed_audios = [] processed_lyrics = [] for _, (text, audio) in enumerate(zip(texts, audios)): processed_audio, processed_lyric = single_preprocessing(audio, text) processed_audios.append(processed_audio) processed_lyrics.append(processed_lyric) preprocessing_time = time.time() - start_time print( green_bold( f"[MusicLIME] Preprocessing completed in {preprocessing_time:.2f}s" ) ) # Step 2: Batch feature extraction start_time = time.time() print("[MusicLIME] Extracting audio features (batch)...") audio_features_batch = spectttra_train(processed_audios) # Clear GPU cache after audio processing if torch.cuda.is_available(): torch.cuda.empty_cache() audio_time = time.time() - start_time print( green_bold( f"[MusicLIME] Audio feature extraction completed in {audio_time:.2f}s" ) ) start_time = time.time() print("[MusicLIME] Extracting lyrics features (batch)...") lyrics_features_batch = l2vec_train( self.llm2vec_model, processed_lyrics ) # (batch, 2048) # Clear GPU cache after lyrics processing if torch.cuda.is_available(): torch.cuda.empty_cache() lyrics_time = time.time() - start_time print( green_bold( f"[MusicLIME] Lyrics feature extraction completed in {lyrics_time:.2f}s" ) ) # Step 3: Scale and reduce in batch start_time = time.time() print("[MusicLIME] Scaling and reducing features (batch)...") # Load the trained scalers audio_scaler = joblib.load("models/fusion/audio_scaler.pkl") lyric_scaler = joblib.load("models/fusion/lyrics_scaler.pkl") # Then apply scaling to the batch scaled_audio_batch = audio_scaler.transform( audio_features_batch ) # (batch, 384) scaled_lyrics_batch = lyric_scaler.transform( lyrics_features_batch ) # (batch, 2048) # Step 4: Apply PCA to lyrics batch print("[MusicLIME] Applying PCA to lyrics (batch)") pca_model = joblib.load("models/fusion/pca.pkl") reduced_lyrics_batch = pca_model.transform(scaled_lyrics_batch) # (batch, 512) # NOTE: Scaling after PCA produces underperforming models compared to non-scaling. # One can toggle it on for experimentation/testing purposes. # Step 5: Apply scaler to PCA-scaled lyrics batch # print("[MusicLIME] Reapplying scaler to PCA-scaled batch") # pca_scaler = joblib.load("models/fusion/pca_scaler.pkl") # reduced_lyrics_batch = pca_scaler.transform( # reduced_lyrics_batch # ) # (batch, 512) # Step 6: Concatenate features combined_features_batch = np.concatenate( [scaled_audio_batch, reduced_lyrics_batch], axis=1 ) # (batch, sum of lyrics & audio vector dims) scaling_time = time.time() - start_time print(green_bold(f"[MusicLIME] Scaling completed in {scaling_time:.2f}s")) # Step 7: Batch MLP prediction start_time = time.time() print("[MusicLIME] Running MLP predictions (batch)...") if self.classifier is None: self.classifier = build_mlp( input_dim=combined_features_batch.shape[1], config=self.config ) self.classifier.load_model("models/mlp/mlp_best_multimodal.pth") probabilities, predictions = self.classifier.predict(combined_features_batch) # Convert to expected format batch_results = [[1 - prob, prob] for prob in probabilities] mlp_time = time.time() - start_time print(green_bold(f"[MusicLIME] MLP prediction completed in {mlp_time:.2f}s")) # Total time summary total_time = ( preprocessing_time + audio_time + lyrics_time + scaling_time + mlp_time ) print("[MusicLIME] Batch processing complete!") print( green_bold( f"[MusicLIME] Total time: {total_time:.2f}s (Preprocessing: {preprocessing_time:.2f}s, Audio: {audio_time:.2f}s, Lyrics: {lyrics_time:.2f}s, Scaling: {scaling_time:.2f}s, MLP: {mlp_time:.2f}s)" ) ) return np.array(batch_results) class AudioOnlyPredictor: """ Audio-only prediction wrapper for MusicLIME explanations. Integrates the audio-only Bach or Bot pipeline (SpecTTTra + MLP) into a single callable for LIME perturbation processing. Optimized for batch processing of multiple perturbed audio samples while ignoring lyrics input. Mirrors the multimodal MusicLIMEPredictor but processes only audio features. This predictor is specifically designed for audio-only explainability where lyrics are kept constant and only audio components are perturbed through source separation techniques. Attributes ---------- classifier : MLPClassifier or None Lazy-loaded MLP classifier for audio-only predictions config : dict Model configuration parameters loaded from config files """ def __init__(self): """ Initialize audio-only prediction wrapper. Loads model configuration for batch processing of perturbed audio samples during LIME explanation. The MLP classifier is lazy-loaded on first use to optimize memory usage. """ print("[MusicLIME] Loading models for Audio-Only MusicLIME...") config = load_config("config/model_config.yml") self.classifier = None self.config = config def __call__(self, texts, audios): """ Batch prediction function for audio-only MusicLIME perturbations. Processes multiple perturbed audio samples through the audio-only pipeline: preprocessing -> SpecTTTra feature extraction -> scaling -> MLP prediction. Text inputs are ignored as this is audio-only mode. Optimized for batch processing of LIME perturbations with detailed timing analysis. Parameters ---------- texts : list of str List of text strings (ignored in audio-only mode, kept for API compatibility) audios : list of array-like List of perturbed audio waveforms from LIME perturbations Returns ------- ndarray Prediction probabilities in format [[P(AI), P(Human)], ...] for each input audio sample, shape (n_samples, 2) """ print( f"[MusicLIME] Processing {len(audios)} samples with batch functions (audio-only mode)..." ) # Step 1: Preprocess all audio samples start_time = time.time() print("[MusicLIME] Preprocessing audio samples...") processed_audios = [] for audio in audios: processed_audio = single_audio_preprocessing(audio) processed_audios.append(processed_audio) preprocessing_time = time.time() - start_time print( green_bold( f"[MusicLIME] Audio preprocessing completed in {preprocessing_time:.2f}s" ) ) # Step 2: Batch audio feature extraction start_time = time.time() print("[MusicLIME] Extracting audio features (batch)...") audio_features_batch = spectttra_train(processed_audios) # Clear GPU cache after audio processing if torch.cuda.is_available(): torch.cuda.empty_cache() audio_time = time.time() - start_time print( green_bold( f"[MusicLIME] Audio feature extraction completed in {audio_time:.2f}s" ) ) # Step 3: Scale audio features in batch start_time = time.time() print("[MusicLIME] Scaling audio features (batch)...") # Load the audio scaler audio_scaler = joblib.load("models/fusion/audio_scaler.pkl") scaled_audio_batch = audio_scaler.transform(audio_features_batch) scaling_time = time.time() - start_time print(green_bold(f"[MusicLIME] Audio scaling completed in {scaling_time:.2f}s")) # Step 4: Audio-only MLP prediction start_time = time.time() print("[MusicLIME] Running audio-only MLP predictions (batch)...") if self.classifier is None: self.classifier = build_mlp( input_dim=scaled_audio_batch.shape[1], config=self.config ) self.classifier.load_model("models/mlp/mlp_best_unimodal.pth") probabilities, predictions = self.classifier.predict(scaled_audio_batch) # Clear GPU cache after MLP processing if torch.cuda.is_available(): torch.cuda.empty_cache() # Convert to expected format batch_results = [[1 - prob, prob] for prob in probabilities] mlp_time = time.time() - start_time print( green_bold( f"[MusicLIME] Audio-only MLP prediction completed in {mlp_time:.2f}s" ) ) # Total time summary total_time = preprocessing_time + audio_time + scaling_time + mlp_time print("[MusicLIME] Audio-only batch processing complete!") print( green_bold( f"[MusicLIME] Total time: {total_time:.2f}s (Preprocessing: {preprocessing_time:.2f}s, Audio: {audio_time:.2f}s, Scaling: {scaling_time:.2f}s, MLP: {mlp_time:.2f}s)" ) ) return np.array(batch_results)