Spaces:
Running
Running
Auto-deploy from GitHub: bb659763110ffbe4c2a85e186bebb84edb7010de
Browse files- app/schemas.py +18 -0
- app/server.py +69 -2
- scripts/explain.py +184 -1
- scripts/explain_combined_runner.py +92 -0
- scripts/predict.py +126 -0
- scripts/predict_combined_runner.py +65 -0
- src/musiclime/explainer.py +85 -0
app/schemas.py
CHANGED
|
@@ -54,3 +54,21 @@ class AudioOnlyPredictionXAIResponse(BaseModel):
|
|
| 54 |
audio_content_type: str
|
| 55 |
audio_file_size: int
|
| 56 |
results: dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
audio_content_type: str
|
| 55 |
audio_file_size: int
|
| 56 |
results: dict
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class CombinedExplanationResponse(BaseModel):
|
| 60 |
+
status: str
|
| 61 |
+
lyrics: str
|
| 62 |
+
audio_file_name: str
|
| 63 |
+
audio_content_type: str
|
| 64 |
+
audio_file_size: int
|
| 65 |
+
results: dict # Contains both multimodal and audio_only results
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class CombinedPredictionResponse(BaseModel):
|
| 69 |
+
status: str
|
| 70 |
+
lyrics: str
|
| 71 |
+
audio_file_name: str
|
| 72 |
+
audio_content_type: str
|
| 73 |
+
audio_file_size: int
|
| 74 |
+
results: dict # Contains both multimodal and audio_only predictions
|
app/server.py
CHANGED
|
@@ -10,13 +10,15 @@ from app.schemas import (
|
|
| 10 |
AudioOnlyPredictionResponse,
|
| 11 |
AudioOnlyPredictionXAIResponse,
|
| 12 |
WelcomeResponse,
|
|
|
|
|
|
|
| 13 |
)
|
| 14 |
from app.utils import load_server_config, load_model_config
|
| 15 |
from app.validators import validate_lyrics, validate_audio_source, validate_audio_only
|
| 16 |
|
| 17 |
# Model/XAI-related imports
|
| 18 |
-
from scripts.explain import musiclime_multimodal, musiclime_unimodal
|
| 19 |
-
from scripts.predict import predict_multimodal, predict_unimodal
|
| 20 |
|
| 21 |
# Other imports
|
| 22 |
import io
|
|
@@ -64,6 +66,8 @@ def root():
|
|
| 64 |
"/api/v1/explain/multimodal": "POST endpoint for multimodal explainability",
|
| 65 |
"/api/v1/predict/audio": "POST endpoint for audio-only prediction",
|
| 66 |
"/api/v1/explain/audio": "POST endpoint for audio-only explainability",
|
|
|
|
|
|
|
| 67 |
},
|
| 68 |
)
|
| 69 |
|
|
@@ -217,6 +221,69 @@ async def explain_audio_only_endpoint(
|
|
| 217 |
raise HTTPException(status_code=500, detail=str(e))
|
| 218 |
|
| 219 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
@app.get("/api/v1/model/info", response_model=ModelInfoResponse, tags=["Model"])
|
| 221 |
async def get_model_info():
|
| 222 |
"""
|
|
|
|
| 10 |
AudioOnlyPredictionResponse,
|
| 11 |
AudioOnlyPredictionXAIResponse,
|
| 12 |
WelcomeResponse,
|
| 13 |
+
CombinedExplanationResponse,
|
| 14 |
+
CombinedPredictionResponse,
|
| 15 |
)
|
| 16 |
from app.utils import load_server_config, load_model_config
|
| 17 |
from app.validators import validate_lyrics, validate_audio_source, validate_audio_only
|
| 18 |
|
| 19 |
# Model/XAI-related imports
|
| 20 |
+
from scripts.explain import musiclime_multimodal, musiclime_unimodal, musiclime_combined
|
| 21 |
+
from scripts.predict import predict_multimodal, predict_unimodal, predict_combined
|
| 22 |
|
| 23 |
# Other imports
|
| 24 |
import io
|
|
|
|
| 66 |
"/api/v1/explain/multimodal": "POST endpoint for multimodal explainability",
|
| 67 |
"/api/v1/predict/audio": "POST endpoint for audio-only prediction",
|
| 68 |
"/api/v1/explain/audio": "POST endpoint for audio-only explainability",
|
| 69 |
+
"/api/v1/predict/combined": "POST endpoint for BOTH predictions",
|
| 70 |
+
"/api/v1/explain/combined": "POST endpoint for BOTH explanations",
|
| 71 |
},
|
| 72 |
)
|
| 73 |
|
|
|
|
| 221 |
raise HTTPException(status_code=500, detail=str(e))
|
| 222 |
|
| 223 |
|
| 224 |
+
# New combined endpoints (multimodal and audio-only)
|
| 225 |
+
@app.post("/api/v1/predict/combined", response_model=CombinedPredictionResponse)
|
| 226 |
+
async def predict_combined_endpoint(
|
| 227 |
+
lyrics: str = Depends(validate_lyrics),
|
| 228 |
+
audio_data_tuple: Tuple = Depends(validate_audio_source),
|
| 229 |
+
):
|
| 230 |
+
"""Combined multimodal and audio-only prediction endpoint (optimized)."""
|
| 231 |
+
try:
|
| 232 |
+
audio_content, audio_file_name, audio_content_type = audio_data_tuple
|
| 233 |
+
|
| 234 |
+
try:
|
| 235 |
+
audio_data, sr = librosa.load(io.BytesIO(audio_content))
|
| 236 |
+
except Exception as e:
|
| 237 |
+
raise HTTPException(status_code=400, detail=f"Invalid audio file: {str(e)}")
|
| 238 |
+
|
| 239 |
+
# Generate both predictions with shared audio processing
|
| 240 |
+
results = predict_combined(audio_data, lyrics)
|
| 241 |
+
|
| 242 |
+
return CombinedPredictionResponse(
|
| 243 |
+
status="success",
|
| 244 |
+
lyrics=lyrics,
|
| 245 |
+
audio_file_name=audio_file_name,
|
| 246 |
+
audio_content_type=audio_content_type,
|
| 247 |
+
audio_file_size=len(audio_content),
|
| 248 |
+
results=results,
|
| 249 |
+
)
|
| 250 |
+
except HTTPException:
|
| 251 |
+
raise
|
| 252 |
+
except Exception as e:
|
| 253 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
@app.post("/api/v1/explain/combined", response_model=CombinedExplanationResponse)
|
| 257 |
+
async def explain_combined_endpoint(
|
| 258 |
+
lyrics: str = Depends(validate_lyrics),
|
| 259 |
+
audio_data_tuple: Tuple = Depends(validate_audio_source),
|
| 260 |
+
):
|
| 261 |
+
"""Combined multimodal and audio-only explanation endpoint (optimized)."""
|
| 262 |
+
try:
|
| 263 |
+
audio_content, audio_file_name, audio_content_type = audio_data_tuple
|
| 264 |
+
|
| 265 |
+
try:
|
| 266 |
+
audio_data, sr = librosa.load(io.BytesIO(audio_content))
|
| 267 |
+
except Exception as e:
|
| 268 |
+
raise HTTPException(status_code=400, detail=f"Invalid audio file: {str(e)}")
|
| 269 |
+
|
| 270 |
+
# Generate both explanations with single source separation
|
| 271 |
+
results = musiclime_combined(audio_data, lyrics)
|
| 272 |
+
|
| 273 |
+
return CombinedExplanationResponse(
|
| 274 |
+
status="success",
|
| 275 |
+
lyrics=lyrics,
|
| 276 |
+
audio_file_name=audio_file_name,
|
| 277 |
+
audio_content_type=audio_content_type,
|
| 278 |
+
audio_file_size=len(audio_content),
|
| 279 |
+
results=results,
|
| 280 |
+
)
|
| 281 |
+
except HTTPException:
|
| 282 |
+
raise
|
| 283 |
+
except Exception as e:
|
| 284 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 285 |
+
|
| 286 |
+
|
| 287 |
@app.get("/api/v1/model/info", response_model=ModelInfoResponse, tags=["Model"])
|
| 288 |
async def get_model_info():
|
| 289 |
"""
|
scripts/explain.py
CHANGED
|
@@ -3,6 +3,7 @@ import numpy as np
|
|
| 3 |
from datetime import datetime
|
| 4 |
from src.musiclime.explainer import MusicLIMEExplainer
|
| 5 |
from src.musiclime.wrapper import MusicLIMEPredictor, AudioOnlyPredictor
|
|
|
|
| 6 |
|
| 7 |
|
| 8 |
def musiclime_multimodal(audio_data, lyrics_text):
|
|
@@ -31,7 +32,7 @@ def musiclime_multimodal(audio_data, lyrics_text):
|
|
| 31 |
print(f"[MusicLIME] Using num_samples={num_samples}, num_features={num_features}")
|
| 32 |
|
| 33 |
# Create musiclime instances
|
| 34 |
-
explainer = MusicLIMEExplainer()
|
| 35 |
predictor = MusicLIMEPredictor()
|
| 36 |
|
| 37 |
# Then generate explanations
|
|
@@ -179,3 +180,185 @@ def musiclime_unimodal(audio_data, modality="audio"):
|
|
| 179 |
"timestamp": start_time.isoformat(),
|
| 180 |
},
|
| 181 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
from datetime import datetime
|
| 4 |
from src.musiclime.explainer import MusicLIMEExplainer
|
| 5 |
from src.musiclime.wrapper import MusicLIMEPredictor, AudioOnlyPredictor
|
| 6 |
+
from src.musiclime.print_utils import green_bold
|
| 7 |
|
| 8 |
|
| 9 |
def musiclime_multimodal(audio_data, lyrics_text):
|
|
|
|
| 32 |
print(f"[MusicLIME] Using num_samples={num_samples}, num_features={num_features}")
|
| 33 |
|
| 34 |
# Create musiclime instances
|
| 35 |
+
explainer = MusicLIMEExplainer(random_state=42)
|
| 36 |
predictor = MusicLIMEPredictor()
|
| 37 |
|
| 38 |
# Then generate explanations
|
|
|
|
| 180 |
"timestamp": start_time.isoformat(),
|
| 181 |
},
|
| 182 |
}
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def musiclime_combined(audio_data, lyrics_text):
|
| 186 |
+
"""
|
| 187 |
+
Generate both multimodal and audio-only MusicLIME explanations efficiently.
|
| 188 |
+
|
| 189 |
+
Performs source separation once and generates both explanation types
|
| 190 |
+
to reduce total processing time by ~50% compared to separate calls.
|
| 191 |
+
|
| 192 |
+
Parameters
|
| 193 |
+
----------
|
| 194 |
+
audio_data : array-like
|
| 195 |
+
Audio waveform data from librosa.load or similar
|
| 196 |
+
lyrics_text : str
|
| 197 |
+
String containing song lyrics
|
| 198 |
+
|
| 199 |
+
Returns
|
| 200 |
+
-------
|
| 201 |
+
dict
|
| 202 |
+
Combined results containing both multimodal and audio-only explanations
|
| 203 |
+
"""
|
| 204 |
+
from src.musiclime.factorization import OpenUnmixFactorization
|
| 205 |
+
from src.musiclime.text_utils import LineIndexedString
|
| 206 |
+
|
| 207 |
+
start_time = datetime.now()
|
| 208 |
+
|
| 209 |
+
# Get configuration
|
| 210 |
+
num_samples = int(os.getenv("MUSICLIME_NUM_SAMPLES", "1000"))
|
| 211 |
+
num_features = int(os.getenv("MUSICLIME_NUM_FEATURES", "10"))
|
| 212 |
+
|
| 213 |
+
print(
|
| 214 |
+
"[MusicLIME] Combined mode: generating both multimodal and audio-only explanations"
|
| 215 |
+
)
|
| 216 |
+
print(f"[MusicLIME] Using num_samples={num_samples}, num_features={num_features}")
|
| 217 |
+
|
| 218 |
+
# Create factorizations once
|
| 219 |
+
print("[MusicLIME] Creating factorizations once for both explanations...")
|
| 220 |
+
factorization_start = datetime.now()
|
| 221 |
+
|
| 222 |
+
audio_factorization = OpenUnmixFactorization(
|
| 223 |
+
audio_data, temporal_segmentation_params=10
|
| 224 |
+
)
|
| 225 |
+
text_factorization = LineIndexedString(lyrics_text)
|
| 226 |
+
|
| 227 |
+
factorization_time = (datetime.now() - factorization_start).total_seconds()
|
| 228 |
+
print(
|
| 229 |
+
green_bold(f"[MusicLIME] Factorization completed in {factorization_time:.2f}s")
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
# Create explainer and predictors
|
| 233 |
+
explainer = MusicLIMEExplainer(random_state=42)
|
| 234 |
+
multimodal_predictor = MusicLIMEPredictor()
|
| 235 |
+
audio_predictor = AudioOnlyPredictor()
|
| 236 |
+
|
| 237 |
+
# Generate multimodal explanation (reusing factorizations)
|
| 238 |
+
print("[MusicLIME] Generating multimodal explanation...")
|
| 239 |
+
multimodal_start = datetime.now()
|
| 240 |
+
|
| 241 |
+
multimodal_explanation = explainer.explain_instance_with_factorization(
|
| 242 |
+
audio_factorization,
|
| 243 |
+
text_factorization,
|
| 244 |
+
multimodal_predictor,
|
| 245 |
+
num_samples=num_samples,
|
| 246 |
+
labels=(1,),
|
| 247 |
+
modality="both",
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
multimodal_time = (datetime.now() - multimodal_start).total_seconds()
|
| 251 |
+
print(
|
| 252 |
+
green_bold(
|
| 253 |
+
f"[MusicLIME] Multimodal explanation completed in {multimodal_time:.2f}s"
|
| 254 |
+
)
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
# Generate audio-only explanation (reusing the same factorization)
|
| 258 |
+
print("[MusicLIME] Generating audio-only explanation (reusing factorizations)...")
|
| 259 |
+
audio_start = datetime.now()
|
| 260 |
+
|
| 261 |
+
audio_explanation = explainer.explain_instance_with_factorization(
|
| 262 |
+
audio_factorization,
|
| 263 |
+
text_factorization,
|
| 264 |
+
audio_predictor,
|
| 265 |
+
num_samples=num_samples,
|
| 266 |
+
labels=(1,),
|
| 267 |
+
modality="audio",
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
audio_time = (datetime.now() - audio_start).total_seconds()
|
| 271 |
+
print(
|
| 272 |
+
green_bold(f"[MusicLIME] Audio-only explanation completed in {audio_time:.2f}s")
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
# Process multimodal results
|
| 276 |
+
multimodal_prediction = multimodal_explanation.predictions[0]
|
| 277 |
+
multimodal_class = np.argmax(multimodal_prediction)
|
| 278 |
+
multimodal_confidence = float(np.max(multimodal_prediction))
|
| 279 |
+
multimodal_features = multimodal_explanation.get_explanation(
|
| 280 |
+
label=1, num_features=num_features
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
# Process audio-only results
|
| 284 |
+
audio_prediction = audio_explanation.predictions[0]
|
| 285 |
+
audio_class = np.argmax(audio_prediction)
|
| 286 |
+
audio_confidence = float(np.max(audio_prediction))
|
| 287 |
+
audio_features = audio_explanation.get_explanation(
|
| 288 |
+
label=1, num_features=num_features
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
# Calculate total runtime
|
| 292 |
+
end_time = datetime.now()
|
| 293 |
+
total_runtime = (end_time - start_time).total_seconds()
|
| 294 |
+
|
| 295 |
+
print(green_bold("[MusicLIME] Combined explanation completed!"))
|
| 296 |
+
print(f"[MusicLIME] Factorization: {factorization_time:.2f}s (done once)")
|
| 297 |
+
print(f"[MusicLIME] Multimodal: {multimodal_time:.2f}s")
|
| 298 |
+
print(f"[MusicLIME] Audio-only: {audio_time:.2f}s")
|
| 299 |
+
print(f"[MusicLIME] Total: {total_runtime:.2f}s")
|
| 300 |
+
|
| 301 |
+
return {
|
| 302 |
+
"multimodal": {
|
| 303 |
+
"prediction": {
|
| 304 |
+
"class": int(multimodal_class),
|
| 305 |
+
"class_name": (
|
| 306 |
+
"Human-Composed" if multimodal_class == 1 else "AI-Generated"
|
| 307 |
+
),
|
| 308 |
+
"confidence": multimodal_confidence,
|
| 309 |
+
"probabilities": multimodal_prediction.tolist(),
|
| 310 |
+
},
|
| 311 |
+
"explanations": [
|
| 312 |
+
{
|
| 313 |
+
"rank": i + 1,
|
| 314 |
+
"modality": item["type"],
|
| 315 |
+
"feature_text": item["feature"],
|
| 316 |
+
"weight": float(item["weight"]),
|
| 317 |
+
"importance": abs(float(item["weight"])),
|
| 318 |
+
}
|
| 319 |
+
for i, item in enumerate(multimodal_features)
|
| 320 |
+
],
|
| 321 |
+
"summary": {
|
| 322 |
+
"total_features_analyzed": len(multimodal_features),
|
| 323 |
+
"audio_features_count": len(
|
| 324 |
+
[f for f in multimodal_features if f["type"] == "audio"]
|
| 325 |
+
),
|
| 326 |
+
"lyrics_features_count": len(
|
| 327 |
+
[f for f in multimodal_features if f["type"] == "lyrics"]
|
| 328 |
+
),
|
| 329 |
+
"runtime_seconds": multimodal_time,
|
| 330 |
+
"samples_generated": num_samples,
|
| 331 |
+
},
|
| 332 |
+
},
|
| 333 |
+
"audio_only": {
|
| 334 |
+
"prediction": {
|
| 335 |
+
"class": int(audio_class),
|
| 336 |
+
"class_name": "Human-Composed" if audio_class == 1 else "AI-Generated",
|
| 337 |
+
"confidence": audio_confidence,
|
| 338 |
+
"probabilities": audio_prediction.tolist(),
|
| 339 |
+
},
|
| 340 |
+
"explanations": [
|
| 341 |
+
{
|
| 342 |
+
"rank": i + 1,
|
| 343 |
+
"modality": item["type"],
|
| 344 |
+
"feature_text": item["feature"],
|
| 345 |
+
"weight": float(item["weight"]),
|
| 346 |
+
"importance": abs(float(item["weight"])),
|
| 347 |
+
}
|
| 348 |
+
for i, item in enumerate(audio_features)
|
| 349 |
+
],
|
| 350 |
+
"summary": {
|
| 351 |
+
"total_features_analyzed": len(audio_features),
|
| 352 |
+
"audio_features_count": len(audio_features),
|
| 353 |
+
"lyrics_features_count": 0,
|
| 354 |
+
"runtime_seconds": audio_time,
|
| 355 |
+
"samples_generated": num_samples,
|
| 356 |
+
},
|
| 357 |
+
},
|
| 358 |
+
"combined_summary": {
|
| 359 |
+
"total_runtime_seconds": total_runtime,
|
| 360 |
+
"factorization_time_seconds": factorization_time,
|
| 361 |
+
"source_separation_reused": True,
|
| 362 |
+
"timestamp": start_time.isoformat(),
|
| 363 |
+
},
|
| 364 |
+
}
|
scripts/explain_combined_runner.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import librosa
|
| 2 |
+
from scripts.explain import musiclime_combined
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def explain_combined_runner(sample: str):
|
| 6 |
+
# Load test audio and lyrics
|
| 7 |
+
audio_path = f"data/external/{sample}.mp3"
|
| 8 |
+
lyrics_path = f"data/external/{sample}.txt"
|
| 9 |
+
|
| 10 |
+
# Load audio
|
| 11 |
+
audio_data, sr = librosa.load(audio_path)
|
| 12 |
+
|
| 13 |
+
# Load lyrics
|
| 14 |
+
with open(lyrics_path, "r", encoding="utf-8") as f:
|
| 15 |
+
lyrics_text = f.read()
|
| 16 |
+
|
| 17 |
+
print("Running combined MusicLIME explanation (optimized)...")
|
| 18 |
+
result = musiclime_combined(audio_data, lyrics_text)
|
| 19 |
+
|
| 20 |
+
# Display multimodal results
|
| 21 |
+
print(f"\n{'='*60}")
|
| 22 |
+
print("=== MULTIMODAL EXPLANATION RESULTS ===")
|
| 23 |
+
print(f"{'='*60}")
|
| 24 |
+
multimodal = result["multimodal"]
|
| 25 |
+
print(
|
| 26 |
+
f"Prediction: {multimodal['prediction']['class_name']} ({multimodal['prediction']['confidence']:.3f})"
|
| 27 |
+
)
|
| 28 |
+
print(f"Runtime: {multimodal['summary']['runtime_seconds']:.2f}s")
|
| 29 |
+
|
| 30 |
+
print("\n=== TOP MULTIMODAL FEATURES ===")
|
| 31 |
+
for feature in multimodal["explanations"]:
|
| 32 |
+
print(
|
| 33 |
+
f"Rank {feature['rank']}: {feature['modality']} | Weight: {feature['weight']:.4f} | Importance: {feature['importance']:.4f}"
|
| 34 |
+
)
|
| 35 |
+
print(f" Feature: {feature['feature_text'][:80]}...")
|
| 36 |
+
print()
|
| 37 |
+
|
| 38 |
+
# Display audio-only results
|
| 39 |
+
print(f"\n{'='*60}")
|
| 40 |
+
print("=== AUDIO-ONLY EXPLANATION RESULTS ===")
|
| 41 |
+
print(f"{'='*60}")
|
| 42 |
+
audio_only = result["audio_only"]
|
| 43 |
+
print(
|
| 44 |
+
f"Prediction: {audio_only['prediction']['class_name']} ({audio_only['prediction']['confidence']:.3f})"
|
| 45 |
+
)
|
| 46 |
+
print(f"Runtime: {audio_only['summary']['runtime_seconds']:.2f}s")
|
| 47 |
+
|
| 48 |
+
print("\n=== TOP AUDIO-ONLY FEATURES ===")
|
| 49 |
+
for feature in audio_only["explanations"]:
|
| 50 |
+
print(
|
| 51 |
+
f"Rank {feature['rank']}: {feature['modality']} | Weight: {feature['weight']:.4f} | Importance: {feature['importance']:.4f}"
|
| 52 |
+
)
|
| 53 |
+
print(f" Feature: {feature['feature_text'][:80]}...")
|
| 54 |
+
print()
|
| 55 |
+
|
| 56 |
+
# Display performance summary
|
| 57 |
+
print(f"\n{'='*60}")
|
| 58 |
+
print("=== PERFORMANCE SUMMARY ===")
|
| 59 |
+
print(f"{'='*60}")
|
| 60 |
+
summary = result["combined_summary"]
|
| 61 |
+
print(
|
| 62 |
+
f"Factorization time: {summary['factorization_time_seconds']:.2f}s (done once)"
|
| 63 |
+
)
|
| 64 |
+
print(f"Multimodal explanation: {multimodal['summary']['runtime_seconds']:.2f}s")
|
| 65 |
+
print(f"Audio-only explanation: {audio_only['summary']['runtime_seconds']:.2f}s")
|
| 66 |
+
print(f"Total runtime: {summary['total_runtime_seconds']:.2f}s")
|
| 67 |
+
print(f"Source separation reused: {summary['source_separation_reused']}")
|
| 68 |
+
|
| 69 |
+
# Comparison
|
| 70 |
+
print("\n=== PREDICTION COMPARISON ===")
|
| 71 |
+
print(
|
| 72 |
+
f"Multimodal: {multimodal['prediction']['class_name']} ({multimodal['prediction']['confidence']:.3f})"
|
| 73 |
+
)
|
| 74 |
+
print(
|
| 75 |
+
f"Audio-only: {audio_only['prediction']['class_name']} ({audio_only['prediction']['confidence']:.3f})"
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
if multimodal["prediction"]["class"] == audio_only["prediction"]["class"]:
|
| 79 |
+
print("Both modalities agree on the prediction")
|
| 80 |
+
else:
|
| 81 |
+
print("Modalities disagree on the prediction")
|
| 82 |
+
|
| 83 |
+
confidence_diff = abs(
|
| 84 |
+
multimodal["prediction"]["confidence"] - audio_only["prediction"]["confidence"]
|
| 85 |
+
)
|
| 86 |
+
print(f"Confidence difference: {confidence_diff:.3f}")
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
if __name__ == "__main__":
|
| 90 |
+
sample = "sample"
|
| 91 |
+
|
| 92 |
+
explain_combined_runner(sample)
|
scripts/predict.py
CHANGED
|
@@ -126,6 +126,132 @@ def predict_unimodal(audio_file):
|
|
| 126 |
}
|
| 127 |
|
| 128 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
if __name__ == "__main__":
|
| 130 |
# Example usage (replace with real inputs, place song inside data/raw.)
|
| 131 |
data = pd.read_csv("data/raw/predict_data_final.csv")
|
|
|
|
| 126 |
}
|
| 127 |
|
| 128 |
|
| 129 |
+
def predict_combined(audio_file, lyrics):
|
| 130 |
+
"""
|
| 131 |
+
Generate both multimodal and audio-only predictions efficiently.
|
| 132 |
+
|
| 133 |
+
Follows the exact same logic as separate functions but reuses audio features.
|
| 134 |
+
|
| 135 |
+
Parameters
|
| 136 |
+
----------
|
| 137 |
+
audio_file : audio_object
|
| 138 |
+
Audio object file
|
| 139 |
+
lyrics : str
|
| 140 |
+
Lyric string
|
| 141 |
+
|
| 142 |
+
Returns
|
| 143 |
+
-------
|
| 144 |
+
dict
|
| 145 |
+
Combined results containing both multimodal and audio-only predictions
|
| 146 |
+
"""
|
| 147 |
+
import time
|
| 148 |
+
|
| 149 |
+
start_time = time.time()
|
| 150 |
+
|
| 151 |
+
# Load config once
|
| 152 |
+
config = load_config("config/model_config.yml")
|
| 153 |
+
|
| 154 |
+
# [1] Multimdoal prediction
|
| 155 |
+
print("[Predict] Running multimodal prediction...")
|
| 156 |
+
multimodal_start = time.time()
|
| 157 |
+
|
| 158 |
+
# 1.) Load LLM2Vec Model
|
| 159 |
+
llm2vec_model = load_llm2vec_model()
|
| 160 |
+
|
| 161 |
+
# 2.) Preprocess both audio and lyrics
|
| 162 |
+
audio_mm, lyrics_mm = single_preprocessing(audio_file, lyrics)
|
| 163 |
+
|
| 164 |
+
# 3.) Extract features
|
| 165 |
+
audio_features_mm = spectttra_predict(audio_mm)
|
| 166 |
+
audio_features_mm = audio_features_mm.reshape(1, -1)
|
| 167 |
+
lyrics_features = l2vec_single_train(llm2vec_model, lyrics_mm)
|
| 168 |
+
|
| 169 |
+
# 4.) Scale the vectors using Z-Score
|
| 170 |
+
audio_features_mm_scaled, lyrics_features_scaled = instance_scaler(
|
| 171 |
+
audio_features_mm, lyrics_features
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
# 5.) Reduce the lyrics using saved PCA model
|
| 175 |
+
reduced_lyrics = load_pca_model(lyrics_features_scaled)
|
| 176 |
+
|
| 177 |
+
# 6.) Concatenate the vectors
|
| 178 |
+
multimodal_features = np.concatenate(
|
| 179 |
+
[audio_features_mm_scaled, reduced_lyrics], axis=1
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
# Load MLP Classifier
|
| 183 |
+
multimodal_classifier = build_mlp(
|
| 184 |
+
input_dim=multimodal_features.shape[1], config=config
|
| 185 |
+
)
|
| 186 |
+
multimodal_classifier.load_model("models/mlp/mlp_best_multimodal.pth")
|
| 187 |
+
multimodal_classifier.model.eval()
|
| 188 |
+
|
| 189 |
+
# Run prediction
|
| 190 |
+
mm_confidence, mm_prediction, mm_label, mm_probability = (
|
| 191 |
+
multimodal_classifier.predict_single(multimodal_features.flatten())
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
multimodal_time = time.time() - multimodal_start
|
| 195 |
+
print(f"[Predict] Multimodal prediction completed in {multimodal_time:.2f}s")
|
| 196 |
+
|
| 197 |
+
# [2] Unimodal prediction (audio-only)
|
| 198 |
+
print("[Predict] Running audio-only prediction...")
|
| 199 |
+
audio_only_start = time.time()
|
| 200 |
+
|
| 201 |
+
# 1.) Preprocess the audio
|
| 202 |
+
audio_au = single_audio_preprocessing(audio_file)
|
| 203 |
+
|
| 204 |
+
# 2.) Extract audio features
|
| 205 |
+
audio_features_au = spectttra_predict(audio_au)
|
| 206 |
+
audio_features_au = audio_features_au.reshape(1, -1)
|
| 207 |
+
|
| 208 |
+
# 3.) Scale the vector using Z-Score
|
| 209 |
+
audio_features_au_scaled = audio_instance_scaler(audio_features_au)
|
| 210 |
+
|
| 211 |
+
# Load MLP Classifier
|
| 212 |
+
audio_classifier = build_mlp(
|
| 213 |
+
input_dim=audio_features_au_scaled.shape[1], config=config
|
| 214 |
+
)
|
| 215 |
+
audio_classifier.load_model("models/mlp/mlp_best_unimodal.pth")
|
| 216 |
+
audio_classifier.model.eval()
|
| 217 |
+
|
| 218 |
+
# Run prediction
|
| 219 |
+
au_confidence, au_prediction, au_label, au_probability = (
|
| 220 |
+
audio_classifier.predict_single(audio_features_au_scaled.flatten())
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
audio_only_time = time.time() - audio_only_start
|
| 224 |
+
print(f"[Predict] Audio-only prediction completed in {audio_only_time:.2f}s")
|
| 225 |
+
|
| 226 |
+
# Summary
|
| 227 |
+
total_time = time.time() - start_time
|
| 228 |
+
|
| 229 |
+
print("\n[Predict] Combined prediction completed!")
|
| 230 |
+
print(f"[Predict] Multimodal: {multimodal_time:.2f}s")
|
| 231 |
+
print(f"[Predict] Audio-only: {audio_only_time:.2f}s")
|
| 232 |
+
print(f"[Predict] Total: {total_time:.2f}s")
|
| 233 |
+
|
| 234 |
+
return {
|
| 235 |
+
"multimodal": {
|
| 236 |
+
"confidence": mm_confidence,
|
| 237 |
+
"prediction": mm_prediction,
|
| 238 |
+
"label": mm_label,
|
| 239 |
+
"probability": mm_probability,
|
| 240 |
+
},
|
| 241 |
+
"audio_only": {
|
| 242 |
+
"confidence": au_confidence,
|
| 243 |
+
"prediction": au_prediction,
|
| 244 |
+
"label": au_label,
|
| 245 |
+
"probability": au_probability,
|
| 246 |
+
},
|
| 247 |
+
"performance": {
|
| 248 |
+
"total_time_seconds": total_time,
|
| 249 |
+
"multimodal_time_seconds": multimodal_time,
|
| 250 |
+
"audio_only_time_seconds": audio_only_time,
|
| 251 |
+
},
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
|
| 255 |
if __name__ == "__main__":
|
| 256 |
# Example usage (replace with real inputs, place song inside data/raw.)
|
| 257 |
data = pd.read_csv("data/raw/predict_data_final.csv")
|
scripts/predict_combined_runner.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import librosa
|
| 2 |
+
from scripts.predict import predict_combined
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def predict_combined_runner(sample: str):
|
| 6 |
+
# Load test audio and lyrics
|
| 7 |
+
audio_path = f"data/external/{sample}.mp3"
|
| 8 |
+
lyrics_path = f"data/external/{sample}.txt"
|
| 9 |
+
|
| 10 |
+
# Load audio
|
| 11 |
+
audio_data, sr = librosa.load(audio_path)
|
| 12 |
+
|
| 13 |
+
# Load lyrics
|
| 14 |
+
with open(lyrics_path, "r", encoding="utf-8") as f:
|
| 15 |
+
lyrics_text = f.read()
|
| 16 |
+
|
| 17 |
+
print("Running combined prediction (optimized)...")
|
| 18 |
+
result = predict_combined(audio_data, lyrics_text)
|
| 19 |
+
|
| 20 |
+
# Display results
|
| 21 |
+
print(f"\n{'='*50}")
|
| 22 |
+
print("=== MULTIMODAL PREDICTION ===")
|
| 23 |
+
print(f"{'='*50}")
|
| 24 |
+
mm = result["multimodal"]
|
| 25 |
+
print(f"Prediction: {mm['prediction']}")
|
| 26 |
+
print(f"Label: {mm['label']}")
|
| 27 |
+
print(f"Confidence: {mm['confidence']:.4f}")
|
| 28 |
+
print(f"Probability: {mm['probability']:.4f}")
|
| 29 |
+
|
| 30 |
+
print(f"\n{'='*50}")
|
| 31 |
+
print("=== AUDIO-ONLY PREDICTION ===")
|
| 32 |
+
print(f"{'='*50}")
|
| 33 |
+
au = result["audio_only"]
|
| 34 |
+
print(f"Prediction: {au['prediction']}")
|
| 35 |
+
print(f"Label: {au['label']}")
|
| 36 |
+
print(f"Confidence: {au['confidence']:.4f}")
|
| 37 |
+
print(f"Probability: {au['probability']:.4f}")
|
| 38 |
+
|
| 39 |
+
print(f"\n{'='*50}")
|
| 40 |
+
print("=== PERFORMANCE SUMMARY ===")
|
| 41 |
+
print(f"{'='*50}")
|
| 42 |
+
perf = result["performance"]
|
| 43 |
+
print(f"Multimodal prediction: {perf['multimodal_time_seconds']:.2f}s")
|
| 44 |
+
print(f"Audio-only prediction: {perf['audio_only_time_seconds']:.2f}s")
|
| 45 |
+
print(f"Total time: {perf['total_time_seconds']:.2f}s")
|
| 46 |
+
|
| 47 |
+
print(f"\n{'='*50}")
|
| 48 |
+
print("=== COMPARISON ===")
|
| 49 |
+
print(f"{'='*50}")
|
| 50 |
+
print(f"Multimodal: {mm['prediction']} ({mm['probability']:.4f})")
|
| 51 |
+
print(f"Audio-only: {au['prediction']} ({au['probability']:.4f})")
|
| 52 |
+
|
| 53 |
+
prob_diff = abs(mm["probability"] - au["probability"])
|
| 54 |
+
print(f"Probability difference: {prob_diff:.4f}")
|
| 55 |
+
|
| 56 |
+
if mm["prediction"] == au["prediction"]:
|
| 57 |
+
print("Both modalities agree on the prediction")
|
| 58 |
+
else:
|
| 59 |
+
print("Modalities disagree on the prediction")
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
if __name__ == "__main__":
|
| 63 |
+
sample = "sample"
|
| 64 |
+
|
| 65 |
+
predict_combined_runner(sample)
|
src/musiclime/explainer.py
CHANGED
|
@@ -154,6 +154,91 @@ class MusicLIMEExplainer:
|
|
| 154 |
|
| 155 |
return explanation
|
| 156 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
def _generate_neighborhood(
|
| 158 |
self, audio_fact, text_fact, predict_fn, num_samples, modality="both"
|
| 159 |
):
|
|
|
|
| 154 |
|
| 155 |
return explanation
|
| 156 |
|
| 157 |
+
def explain_instance_with_factorization(
|
| 158 |
+
self,
|
| 159 |
+
audio_factorization,
|
| 160 |
+
text_factorization,
|
| 161 |
+
predict_fn,
|
| 162 |
+
num_samples=1000,
|
| 163 |
+
labels=(1,),
|
| 164 |
+
modality="both",
|
| 165 |
+
):
|
| 166 |
+
"""
|
| 167 |
+
Generate LIME explanations using pre-computed factorizations.
|
| 168 |
+
|
| 169 |
+
This method allows reusing expensive source separation across multiple explanations,
|
| 170 |
+
which significantly improves performance when generating both multimodal and audio-only
|
| 171 |
+
explanations for the same audio file.
|
| 172 |
+
|
| 173 |
+
Parameters
|
| 174 |
+
----------
|
| 175 |
+
audio_factorization : OpenUnmixFactorization
|
| 176 |
+
Pre-computed audio source separation components
|
| 177 |
+
text_factorization : LineIndexedString
|
| 178 |
+
Pre-computed text line factorization
|
| 179 |
+
predict_fn : callable
|
| 180 |
+
Prediction function that takes (texts, audios) and returns probabilities
|
| 181 |
+
num_samples : int, default=1000
|
| 182 |
+
Number of perturbed samples to generate for LIME
|
| 183 |
+
labels : tuple, default=(1,)
|
| 184 |
+
Target labels to explain (0=AI-Generated, 1=Human-Composed)
|
| 185 |
+
modality : str, default='both'
|
| 186 |
+
Explanation modality: 'both', 'audio', or 'lyrical'
|
| 187 |
+
|
| 188 |
+
Returns
|
| 189 |
+
-------
|
| 190 |
+
MusicLIMEExplanation
|
| 191 |
+
Explanation object containing feature importance weights and metadata
|
| 192 |
+
|
| 193 |
+
Raises
|
| 194 |
+
------
|
| 195 |
+
ValueError
|
| 196 |
+
If modality is not one of ['both', 'audio', 'lyrical']
|
| 197 |
+
"""
|
| 198 |
+
# Validate modality
|
| 199 |
+
if modality not in ["both", "audio", "lyrical"]:
|
| 200 |
+
raise ValueError('Set modality argument to "both", "audio" or "lyrical".')
|
| 201 |
+
|
| 202 |
+
print("[MusicLIME] Using pre-computed factorizations (optimized mode)")
|
| 203 |
+
print(f"[MusicLIME] Modality: {modality}")
|
| 204 |
+
print(
|
| 205 |
+
f"[MusicLIME] Audio components: {audio_factorization.get_number_components()}"
|
| 206 |
+
)
|
| 207 |
+
print(f"[MusicLIME] Text lines: {text_factorization.num_words()}")
|
| 208 |
+
|
| 209 |
+
# Generate perturbations and get predictions
|
| 210 |
+
print(f"[MusicLIME] Generating {num_samples} perturbations...")
|
| 211 |
+
data, predictions, distances = self._generate_neighborhood(
|
| 212 |
+
audio_factorization, text_factorization, predict_fn, num_samples, modality
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
# LIME fitting, create explanation object
|
| 216 |
+
start_time = time.time()
|
| 217 |
+
print("[MusicLIME] Fitting LIME model...")
|
| 218 |
+
explanation = MusicLIMEExplanation(
|
| 219 |
+
audio_factorization,
|
| 220 |
+
text_factorization,
|
| 221 |
+
data,
|
| 222 |
+
predictions,
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
for label in labels:
|
| 226 |
+
print(f"[MusicLIME] Explaining label {label}...")
|
| 227 |
+
(
|
| 228 |
+
explanation.intercept[label],
|
| 229 |
+
explanation.local_exp[label],
|
| 230 |
+
explanation.score[label],
|
| 231 |
+
explanation.local_pred[label],
|
| 232 |
+
) = self.base.explain_instance_with_data(
|
| 233 |
+
data, predictions, distances, label, num_features=20
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
lime_time = time.time() - start_time
|
| 237 |
+
print(green_bold(f"[MusicLIME] LIME fitting completed in {lime_time:.2f}s"))
|
| 238 |
+
print("[MusicLIME] MusicLIME explanation complete!")
|
| 239 |
+
|
| 240 |
+
return explanation
|
| 241 |
+
|
| 242 |
def _generate_neighborhood(
|
| 243 |
self, audio_fact, text_fact, predict_fn, num_samples, modality="both"
|
| 244 |
):
|