krislette's picture
Auto-deploy from GitHub: 7c591156b27da3e33cf2a35fbb1d3fdf593c7e3f
61f21af
raw
history blame
6.77 kB
import time
import joblib
import numpy as np
import torch
from src.preprocessing.preprocessor import single_preprocessing
from src.spectttra.spectttra_trainer import spectttra_train
from src.llm2vectrain.llm2vec_trainer import l2vec_train
from src.llm2vectrain.model import load_llm2vec_model
from src.models.mlp import build_mlp, load_config
from src.musiclime.print_utils import green_bold
class MusicLIMEPredictor:
"""
Batch prediction wrapper for MusicLIME explanations.
Integrates the complete Bach or Bot pipeline (SpecTTTra + LLM2Vec + MLP)
into a single callable for LIME perturbation processing. Optimized for
batch processing of multiple perturbed audio-lyrics pairs with detailed
timing analysis.
Attributes
----------
llm2vec_model : LLM2Vec
Pre-loaded LLM2Vec model for lyrics feature extraction
classifier : MLPClassifier
Lazy-loaded MLP classifier for final predictions
config : dict
Model configuration parameters
"""
def __init__(self):
"""
Initialize MusicLIME prediction wrapper with pre-trained models.
Loads LLM2Vec model and MLP configuration for batch processing
of perturbed audio-lyrics pairs during LIME explanation.
"""
print("[MusicLIME] Loading models for MusicLIME...")
self.llm2vec_model = load_llm2vec_model()
config = load_config("config/model_config.yml")
self.classifier = None
self.config = config
def __call__(self, texts, audios):
"""
Batch prediction function for MusicLIME perturbations.
Processes multiple perturbed audio-lyrics pairs through the complete
pipeline: preprocessing -> feature extraction -> scaling -> MLP prediction.
Optimized for batch processing of LIME perturbations.
Parameters
----------
texts : list of str
List of perturbed lyrics strings from LIME
audios : list of array-like
List of perturbed audio waveforms from LIME
Returns
-------
ndarray
Prediction probabilities in format [[P(AI), P(Human)], ...]
for each input pair, shape (n_samples, 2)
"""
print(f"[MusicLIME] Processing {len(texts)} samples with batch functions...")
# Step 1: Preprocess all samples (still needs to be individual)
start_time = time.time()
print("[MusicLIME] Preprocessing samples...")
processed_audios = []
processed_lyrics = []
for _, (text, audio) in enumerate(zip(texts, audios)):
processed_audio, processed_lyric = single_preprocessing(audio, text)
processed_audios.append(processed_audio)
processed_lyrics.append(processed_lyric)
preprocessing_time = time.time() - start_time
print(
green_bold(
f"[MusicLIME] Preprocessing completed in {preprocessing_time:.2f}s"
)
)
# Step 2: Batch feature extraction
start_time = time.time()
print("[MusicLIME] Extracting audio features (batch)...")
audio_features_batch = spectttra_train(processed_audios)
# Clear GPU cache after audio processing
if torch.cuda.is_available():
torch.cuda.empty_cache()
audio_time = time.time() - start_time
print(
green_bold(
f"[MusicLIME] Audio feature extraction completed in {audio_time:.2f}s"
)
)
start_time = time.time()
print("[MusicLIME] Extracting lyrics features (batch)...")
lyrics_features_batch = l2vec_train(
self.llm2vec_model, processed_lyrics
) # (batch, 2048)
# Clear GPU cache after lyrics processing
if torch.cuda.is_available():
torch.cuda.empty_cache()
lyrics_time = time.time() - start_time
print(
green_bold(
f"[MusicLIME] Lyrics feature extraction completed in {lyrics_time:.2f}s"
)
)
# Step 3: Scale and reduce in batch
start_time = time.time()
print("[MusicLIME] Scaling and reducing features (batch)...")
# Load the trained scalers
audio_scaler = joblib.load("models/fusion/audio_scaler.pkl")
lyric_scaler = joblib.load("models/fusion/lyrics_scaler.pkl")
# Then apply scaling to the batch
scaled_audio_batch = audio_scaler.transform(
audio_features_batch
) # (batch, 384)
scaled_lyrics_batch = lyric_scaler.transform(
lyrics_features_batch
) # (batch, 2048)
# Step 4: Apply PCA to lyrics batch
print("[MusicLIME] Applying PCA to lyrics (batch)")
pca_model = joblib.load("models/fusion/pca.pkl")
reduced_lyrics_batch = pca_model.transform(scaled_lyrics_batch) # (batch, 512)
# Step 5: Apply scaler to PCA-scaled lyrics batch
print("[MusicLIME] Reapplying scaler to PCA-scaled batch")
pca_scaler = joblib.load("models/fusion/pca_scaler.pkl")
reduced_lyrics_batch = pca_scaler.transform(
reduced_lyrics_batch
) # (batch, 512)
# Step 6: Concatenate features
combined_features_batch = np.concatenate(
[scaled_audio_batch, reduced_lyrics_batch], axis=1
) # (batch, sum of lyrics & audio vector dims)
scaling_time = time.time() - start_time
print(green_bold(f"[MusicLIME] Scaling completed in {scaling_time:.2f}s"))
# Step 7: Batch MLP prediction
start_time = time.time()
print("[MusicLIME] Running MLP predictions (batch)...")
if self.classifier is None:
self.classifier = build_mlp(
input_dim=combined_features_batch.shape[1], config=self.config
)
self.classifier.load_model("models/mlp/mlp_best.pth")
probabilities, predictions = self.classifier.predict(combined_features_batch)
# Convert to expected format
batch_results = [[1 - prob, prob] for prob in probabilities]
mlp_time = time.time() - start_time
print(green_bold(f"[MusicLIME] MLP prediction completed in {mlp_time:.2f}s"))
# Total time summary
total_time = (
preprocessing_time + audio_time + lyrics_time + scaling_time + mlp_time
)
print(f"[MusicLIME] Batch processing complete!")
print(
green_bold(
f"[MusicLIME] Total time: {total_time:.2f}s (Preprocessing: {preprocessing_time:.2f}s, Audio: {audio_time:.2f}s, Lyrics: {lyrics_time:.2f}s, Scaling: {scaling_time:.2f}s, MLP: {mlp_time:.2f}s)"
)
)
return np.array(batch_results)