Spaces:
Sleeping
Sleeping
File size: 6,774 Bytes
fc7b4a9 61f21af fc7b4a9 7633e2f fc7b4a9 7633e2f fc7b4a9 7633e2f fc7b4a9 7633e2f fc7b4a9 7633e2f fc7b4a9 91f3c16 fc7b4a9 61f21af fc7b4a9 61f21af fc7b4a9 75d43d2 fc7b4a9 75d43d2 fc7b4a9 75d43d2 fc7b4a9 75d43d2 fc7b4a9 75d43d2 fc7b4a9 75d43d2 e1ee8d1 fc7b4a9 75d43d2 fc7b4a9 e1ee8d1 fc7b4a9 75d43d2 fc7b4a9 75d43d2 fc7b4a9 75d43d2 fc7b4a9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
import time
import joblib
import numpy as np
import torch
from src.preprocessing.preprocessor import single_preprocessing
from src.spectttra.spectttra_trainer import spectttra_train
from src.llm2vectrain.llm2vec_trainer import l2vec_train
from src.llm2vectrain.model import load_llm2vec_model
from src.models.mlp import build_mlp, load_config
from src.musiclime.print_utils import green_bold
class MusicLIMEPredictor:
"""
Batch prediction wrapper for MusicLIME explanations.
Integrates the complete Bach or Bot pipeline (SpecTTTra + LLM2Vec + MLP)
into a single callable for LIME perturbation processing. Optimized for
batch processing of multiple perturbed audio-lyrics pairs with detailed
timing analysis.
Attributes
----------
llm2vec_model : LLM2Vec
Pre-loaded LLM2Vec model for lyrics feature extraction
classifier : MLPClassifier
Lazy-loaded MLP classifier for final predictions
config : dict
Model configuration parameters
"""
def __init__(self):
"""
Initialize MusicLIME prediction wrapper with pre-trained models.
Loads LLM2Vec model and MLP configuration for batch processing
of perturbed audio-lyrics pairs during LIME explanation.
"""
print("[MusicLIME] Loading models for MusicLIME...")
self.llm2vec_model = load_llm2vec_model()
config = load_config("config/model_config.yml")
self.classifier = None
self.config = config
def __call__(self, texts, audios):
"""
Batch prediction function for MusicLIME perturbations.
Processes multiple perturbed audio-lyrics pairs through the complete
pipeline: preprocessing -> feature extraction -> scaling -> MLP prediction.
Optimized for batch processing of LIME perturbations.
Parameters
----------
texts : list of str
List of perturbed lyrics strings from LIME
audios : list of array-like
List of perturbed audio waveforms from LIME
Returns
-------
ndarray
Prediction probabilities in format [[P(AI), P(Human)], ...]
for each input pair, shape (n_samples, 2)
"""
print(f"[MusicLIME] Processing {len(texts)} samples with batch functions...")
# Step 1: Preprocess all samples (still needs to be individual)
start_time = time.time()
print("[MusicLIME] Preprocessing samples...")
processed_audios = []
processed_lyrics = []
for _, (text, audio) in enumerate(zip(texts, audios)):
processed_audio, processed_lyric = single_preprocessing(audio, text)
processed_audios.append(processed_audio)
processed_lyrics.append(processed_lyric)
preprocessing_time = time.time() - start_time
print(
green_bold(
f"[MusicLIME] Preprocessing completed in {preprocessing_time:.2f}s"
)
)
# Step 2: Batch feature extraction
start_time = time.time()
print("[MusicLIME] Extracting audio features (batch)...")
audio_features_batch = spectttra_train(processed_audios)
# Clear GPU cache after audio processing
if torch.cuda.is_available():
torch.cuda.empty_cache()
audio_time = time.time() - start_time
print(
green_bold(
f"[MusicLIME] Audio feature extraction completed in {audio_time:.2f}s"
)
)
start_time = time.time()
print("[MusicLIME] Extracting lyrics features (batch)...")
lyrics_features_batch = l2vec_train(
self.llm2vec_model, processed_lyrics
) # (batch, 2048)
# Clear GPU cache after lyrics processing
if torch.cuda.is_available():
torch.cuda.empty_cache()
lyrics_time = time.time() - start_time
print(
green_bold(
f"[MusicLIME] Lyrics feature extraction completed in {lyrics_time:.2f}s"
)
)
# Step 3: Scale and reduce in batch
start_time = time.time()
print("[MusicLIME] Scaling and reducing features (batch)...")
# Load the trained scalers
audio_scaler = joblib.load("models/fusion/audio_scaler.pkl")
lyric_scaler = joblib.load("models/fusion/lyrics_scaler.pkl")
# Then apply scaling to the batch
scaled_audio_batch = audio_scaler.transform(
audio_features_batch
) # (batch, 384)
scaled_lyrics_batch = lyric_scaler.transform(
lyrics_features_batch
) # (batch, 2048)
# Step 4: Apply PCA to lyrics batch
print("[MusicLIME] Applying PCA to lyrics (batch)")
pca_model = joblib.load("models/fusion/pca.pkl")
reduced_lyrics_batch = pca_model.transform(scaled_lyrics_batch) # (batch, 512)
# Step 5: Apply scaler to PCA-scaled lyrics batch
print("[MusicLIME] Reapplying scaler to PCA-scaled batch")
pca_scaler = joblib.load("models/fusion/pca_scaler.pkl")
reduced_lyrics_batch = pca_scaler.transform(
reduced_lyrics_batch
) # (batch, 512)
# Step 6: Concatenate features
combined_features_batch = np.concatenate(
[scaled_audio_batch, reduced_lyrics_batch], axis=1
) # (batch, sum of lyrics & audio vector dims)
scaling_time = time.time() - start_time
print(green_bold(f"[MusicLIME] Scaling completed in {scaling_time:.2f}s"))
# Step 7: Batch MLP prediction
start_time = time.time()
print("[MusicLIME] Running MLP predictions (batch)...")
if self.classifier is None:
self.classifier = build_mlp(
input_dim=combined_features_batch.shape[1], config=self.config
)
self.classifier.load_model("models/mlp/mlp_best.pth")
probabilities, predictions = self.classifier.predict(combined_features_batch)
# Convert to expected format
batch_results = [[1 - prob, prob] for prob in probabilities]
mlp_time = time.time() - start_time
print(green_bold(f"[MusicLIME] MLP prediction completed in {mlp_time:.2f}s"))
# Total time summary
total_time = (
preprocessing_time + audio_time + lyrics_time + scaling_time + mlp_time
)
print(f"[MusicLIME] Batch processing complete!")
print(
green_bold(
f"[MusicLIME] Total time: {total_time:.2f}s (Preprocessing: {preprocessing_time:.2f}s, Audio: {audio_time:.2f}s, Lyrics: {lyrics_time:.2f}s, Scaling: {scaling_time:.2f}s, MLP: {mlp_time:.2f}s)"
)
)
return np.array(batch_results)
|