Spaces:
Sleeping
Sleeping
Auto-deploy from GitHub: e4d0ee2ddb3dc15442ce902b31f6de26098a6291
Browse files- app/schemas.py +14 -5
- app/server.py +97 -84
- app/validators.py +75 -0
- scripts/explain.py +110 -8
- scripts/explain_runner.py +60 -21
- scripts/predict.py +6 -3
- scripts/predict_runner.py +13 -8
- src/musiclime/explainer.py +88 -28
- src/musiclime/wrapper.py +146 -3
app/schemas.py
CHANGED
|
@@ -40,8 +40,17 @@ class PredictionXAIResponse(BaseModel):
|
|
| 40 |
results: Optional[Dict] = None
|
| 41 |
|
| 42 |
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
results: Optional[Dict] = None
|
| 41 |
|
| 42 |
|
| 43 |
+
class AudioOnlyPredictionResponse(BaseModel):
|
| 44 |
+
status: str
|
| 45 |
+
audio_file_name: str
|
| 46 |
+
audio_content_type: str
|
| 47 |
+
audio_file_size: int
|
| 48 |
+
results: dict
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class AudioOnlyPredictionXAIResponse(BaseModel):
|
| 52 |
+
status: str
|
| 53 |
+
audio_file_name: str
|
| 54 |
+
audio_content_type: str
|
| 55 |
+
audio_file_size: int
|
| 56 |
+
results: dict
|
app/server.py
CHANGED
|
@@ -1,26 +1,27 @@
|
|
| 1 |
# Fast API imports
|
| 2 |
-
from fastapi import Depends, FastAPI,
|
| 3 |
from fastapi.middleware.cors import CORSMiddleware
|
| 4 |
|
| 5 |
# Utils/schemas imports
|
| 6 |
from app.schemas import (
|
| 7 |
-
ErrorResponse,
|
| 8 |
ModelInfoResponse,
|
| 9 |
PredictionResponse,
|
| 10 |
PredictionXAIResponse,
|
|
|
|
|
|
|
| 11 |
WelcomeResponse,
|
| 12 |
)
|
| 13 |
-
from app.utils import load_server_config, load_model_config
|
|
|
|
| 14 |
|
| 15 |
# Model/XAI-related imports
|
| 16 |
-
from scripts.explain import
|
| 17 |
-
from scripts.predict import predict_multimodal
|
| 18 |
|
| 19 |
# Other imports
|
| 20 |
import io
|
| 21 |
import librosa
|
| 22 |
-
from typing import
|
| 23 |
-
|
| 24 |
|
| 25 |
# Load configs at startup
|
| 26 |
server_config = load_server_config()
|
|
@@ -47,70 +48,9 @@ app.add_middleware(
|
|
| 47 |
)
|
| 48 |
|
| 49 |
|
| 50 |
-
def validate_lyrics(lyrics: str = Form(...)):
|
| 51 |
-
"""Validate lyrics length and content."""
|
| 52 |
-
if len(lyrics) > MAX_LYRICS_LENGTH:
|
| 53 |
-
raise HTTPException(
|
| 54 |
-
status_code=400,
|
| 55 |
-
detail=f"Lyrics too long. Maximum length is {MAX_LYRICS_LENGTH} characters.",
|
| 56 |
-
)
|
| 57 |
-
|
| 58 |
-
# Basic sanitization, remove excessive whitespace
|
| 59 |
-
lyrics = lyrics.strip()
|
| 60 |
-
if not lyrics:
|
| 61 |
-
raise HTTPException(
|
| 62 |
-
status_code=400,
|
| 63 |
-
detail="Lyrics cannot be empty.",
|
| 64 |
-
)
|
| 65 |
-
|
| 66 |
-
return lyrics
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
async def validate_audio_source(
|
| 70 |
-
audio_file: Optional[UploadFile] = File(None),
|
| 71 |
-
youtube_url: Optional[str] = Form(None),
|
| 72 |
-
) -> Tuple[Optional[bytes], str, str]:
|
| 73 |
-
"""
|
| 74 |
-
Validate and process audio source (either file or YouTube URL).
|
| 75 |
-
Returns: (audio_content, file_name, content_type)
|
| 76 |
-
"""
|
| 77 |
-
if not audio_file and not youtube_url:
|
| 78 |
-
raise HTTPException(
|
| 79 |
-
status_code=400, detail="Either audio_file or youtube_url must be provided"
|
| 80 |
-
)
|
| 81 |
-
|
| 82 |
-
if audio_file and youtube_url:
|
| 83 |
-
raise HTTPException(
|
| 84 |
-
status_code=400, detail="Provide either audio_file or youtube_url, not both"
|
| 85 |
-
)
|
| 86 |
-
|
| 87 |
-
# Process YouTube URL
|
| 88 |
-
if youtube_url:
|
| 89 |
-
audio_content = download_youtube_audio(youtube_url)
|
| 90 |
-
return audio_content, "youtube_audio.wav", "audio/wav"
|
| 91 |
-
|
| 92 |
-
# Process uploaded file
|
| 93 |
-
if audio_file.content_type not in ALLOWED_AUDIO_TYPES:
|
| 94 |
-
raise HTTPException(
|
| 95 |
-
status_code=400,
|
| 96 |
-
detail=f"Invalid file type. Supported formats: {', '.join(ALLOWED_AUDIO_TYPES)}",
|
| 97 |
-
)
|
| 98 |
-
|
| 99 |
-
audio_content = await audio_file.read()
|
| 100 |
-
if len(audio_content) > MAX_FILE_SIZE:
|
| 101 |
-
raise HTTPException(
|
| 102 |
-
status_code=400,
|
| 103 |
-
detail=f"File too large. Maximum size is {MAX_FILE_SIZE // (1024*1024)}MB.",
|
| 104 |
-
)
|
| 105 |
-
|
| 106 |
-
return audio_content, audio_file.filename, audio_file.content_type
|
| 107 |
-
|
| 108 |
-
|
| 109 |
@app.get("/", response_model=WelcomeResponse, tags=["Root"])
|
| 110 |
def root():
|
| 111 |
-
"""
|
| 112 |
-
Root endpoint to check if the API is running.
|
| 113 |
-
"""
|
| 114 |
return WelcomeResponse(
|
| 115 |
status="success",
|
| 116 |
message="Welcome to Bach or Bot API!",
|
|
@@ -118,18 +58,38 @@ def root():
|
|
| 118 |
"/": "This welcome message",
|
| 119 |
"/docs": "FastAPI auto-generated API docs",
|
| 120 |
"/api/v1/model/info": "Model information and capabilities",
|
| 121 |
-
"/api/v1/predict": "POST endpoint for bach-or-bot prediction",
|
| 122 |
-
"/api/v1/explain": "POST endpoint for prediction with explainability",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
},
|
| 124 |
)
|
| 125 |
|
| 126 |
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
)
|
| 132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
lyrics: str = Depends(validate_lyrics),
|
| 134 |
audio_data_tuple: Tuple = Depends(validate_audio_source),
|
| 135 |
):
|
|
@@ -164,12 +124,8 @@ async def predict_music(
|
|
| 164 |
raise HTTPException(status_code=500, detail=str(e))
|
| 165 |
|
| 166 |
|
| 167 |
-
@app.post(
|
| 168 |
-
|
| 169 |
-
response_model=PredictionXAIResponse,
|
| 170 |
-
responses={400: {"model": ErrorResponse}, 500: {"model": ErrorResponse}},
|
| 171 |
-
)
|
| 172 |
-
async def predict_music_with_xai(
|
| 173 |
lyrics: str = Depends(validate_lyrics),
|
| 174 |
audio_data_tuple: Tuple = Depends(validate_audio_source),
|
| 175 |
):
|
|
@@ -188,7 +144,7 @@ async def predict_music_with_xai(
|
|
| 188 |
raise HTTPException(status_code=400, detail=f"Invalid audio file: {str(e)}")
|
| 189 |
|
| 190 |
# Call musiclime runner script
|
| 191 |
-
results =
|
| 192 |
|
| 193 |
return PredictionXAIResponse(
|
| 194 |
status="success",
|
|
@@ -204,6 +160,63 @@ async def predict_music_with_xai(
|
|
| 204 |
raise HTTPException(status_code=500, detail=str(e))
|
| 205 |
|
| 206 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
@app.get("/api/v1/model/info", response_model=ModelInfoResponse, tags=["Model"])
|
| 208 |
async def get_model_info():
|
| 209 |
"""
|
|
|
|
| 1 |
# Fast API imports
|
| 2 |
+
from fastapi import Depends, FastAPI, HTTPException
|
| 3 |
from fastapi.middleware.cors import CORSMiddleware
|
| 4 |
|
| 5 |
# Utils/schemas imports
|
| 6 |
from app.schemas import (
|
|
|
|
| 7 |
ModelInfoResponse,
|
| 8 |
PredictionResponse,
|
| 9 |
PredictionXAIResponse,
|
| 10 |
+
AudioOnlyPredictionResponse,
|
| 11 |
+
AudioOnlyPredictionXAIResponse,
|
| 12 |
WelcomeResponse,
|
| 13 |
)
|
| 14 |
+
from app.utils import load_server_config, load_model_config
|
| 15 |
+
from app.validators import validate_lyrics, validate_audio_source, validate_audio_only
|
| 16 |
|
| 17 |
# Model/XAI-related imports
|
| 18 |
+
from scripts.explain import musiclime_multimodal, musiclime_unimodal
|
| 19 |
+
from scripts.predict import predict_multimodal, predict_unimodal
|
| 20 |
|
| 21 |
# Other imports
|
| 22 |
import io
|
| 23 |
import librosa
|
| 24 |
+
from typing import Tuple
|
|
|
|
| 25 |
|
| 26 |
# Load configs at startup
|
| 27 |
server_config = load_server_config()
|
|
|
|
| 48 |
)
|
| 49 |
|
| 50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
@app.get("/", response_model=WelcomeResponse, tags=["Root"])
|
| 52 |
def root():
|
| 53 |
+
"""Root endpoint to check if the API is running."""
|
|
|
|
|
|
|
| 54 |
return WelcomeResponse(
|
| 55 |
status="success",
|
| 56 |
message="Welcome to Bach or Bot API!",
|
|
|
|
| 58 |
"/": "This welcome message",
|
| 59 |
"/docs": "FastAPI auto-generated API docs",
|
| 60 |
"/api/v1/model/info": "Model information and capabilities",
|
| 61 |
+
"/api/v1/predict": "POST endpoint for bach-or-bot prediction (legacy)",
|
| 62 |
+
"/api/v1/explain": "POST endpoint for prediction with explainability (legacy)",
|
| 63 |
+
"/api/v1/predict/multimodal": "POST endpoint for multimodal prediction",
|
| 64 |
+
"/api/v1/explain/multimodal": "POST endpoint for multimodal explainability",
|
| 65 |
+
"/api/v1/predict/audio": "POST endpoint for audio-only prediction",
|
| 66 |
+
"/api/v1/explain/audio": "POST endpoint for audio-only explainability",
|
| 67 |
},
|
| 68 |
)
|
| 69 |
|
| 70 |
|
| 71 |
+
# Legacy endpoints (backward compatibility)
|
| 72 |
+
@app.post("/api/v1/predict", response_model=PredictionResponse)
|
| 73 |
+
async def predict_music_legacy(
|
| 74 |
+
lyrics: str = Depends(validate_lyrics),
|
| 75 |
+
audio_data_tuple: Tuple = Depends(validate_audio_source),
|
| 76 |
+
):
|
| 77 |
+
"""Legacy multimodal prediction endpoint."""
|
| 78 |
+
return await predict_multimodal_endpoint(lyrics, audio_data_tuple)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
@app.post("/api/v1/explain", response_model=PredictionXAIResponse)
|
| 82 |
+
async def explain_music_legacy(
|
| 83 |
+
lyrics: str = Depends(validate_lyrics),
|
| 84 |
+
audio_data_tuple: Tuple = Depends(validate_audio_source),
|
| 85 |
+
):
|
| 86 |
+
"""Legacy multimodal explanation endpoint."""
|
| 87 |
+
return await explain_multimodal_endpoint(lyrics, audio_data_tuple)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
# New multimodal endpoints
|
| 91 |
+
@app.post("/api/v1/predict/multimodal", response_model=PredictionResponse)
|
| 92 |
+
async def predict_multimodal_endpoint(
|
| 93 |
lyrics: str = Depends(validate_lyrics),
|
| 94 |
audio_data_tuple: Tuple = Depends(validate_audio_source),
|
| 95 |
):
|
|
|
|
| 124 |
raise HTTPException(status_code=500, detail=str(e))
|
| 125 |
|
| 126 |
|
| 127 |
+
@app.post("/api/v1/explain/multimodal", response_model=PredictionXAIResponse)
|
| 128 |
+
async def explain_multimodal_endpoint(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
lyrics: str = Depends(validate_lyrics),
|
| 130 |
audio_data_tuple: Tuple = Depends(validate_audio_source),
|
| 131 |
):
|
|
|
|
| 144 |
raise HTTPException(status_code=400, detail=f"Invalid audio file: {str(e)}")
|
| 145 |
|
| 146 |
# Call musiclime runner script
|
| 147 |
+
results = musiclime_multimodal(audio_data, lyrics)
|
| 148 |
|
| 149 |
return PredictionXAIResponse(
|
| 150 |
status="success",
|
|
|
|
| 160 |
raise HTTPException(status_code=500, detail=str(e))
|
| 161 |
|
| 162 |
|
| 163 |
+
# New audio-only endpoints
|
| 164 |
+
@app.post("/api/v1/predict/audio", response_model=AudioOnlyPredictionResponse)
|
| 165 |
+
async def predict_audio_only_endpoint(
|
| 166 |
+
audio_data_tuple: Tuple = Depends(validate_audio_only),
|
| 167 |
+
):
|
| 168 |
+
"""Audio-only prediction endpoint."""
|
| 169 |
+
try:
|
| 170 |
+
audio_content, audio_file_name, audio_content_type = audio_data_tuple
|
| 171 |
+
|
| 172 |
+
try:
|
| 173 |
+
audio_data, sr = librosa.load(io.BytesIO(audio_content))
|
| 174 |
+
except Exception as e:
|
| 175 |
+
raise HTTPException(status_code=400, detail=f"Invalid audio file: {str(e)}")
|
| 176 |
+
|
| 177 |
+
results = predict_unimodal(audio_data)
|
| 178 |
+
|
| 179 |
+
return AudioOnlyPredictionResponse(
|
| 180 |
+
status="success",
|
| 181 |
+
audio_file_name=audio_file_name,
|
| 182 |
+
audio_content_type=audio_content_type,
|
| 183 |
+
audio_file_size=len(audio_content),
|
| 184 |
+
results=results,
|
| 185 |
+
)
|
| 186 |
+
except HTTPException:
|
| 187 |
+
raise
|
| 188 |
+
except Exception as e:
|
| 189 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
@app.post("/api/v1/explain/audio", response_model=AudioOnlyPredictionXAIResponse)
|
| 193 |
+
async def explain_audio_only_endpoint(
|
| 194 |
+
audio_data_tuple: Tuple = Depends(validate_audio_only),
|
| 195 |
+
):
|
| 196 |
+
"""Audio-only explanation endpoint."""
|
| 197 |
+
try:
|
| 198 |
+
audio_content, audio_file_name, audio_content_type = audio_data_tuple
|
| 199 |
+
|
| 200 |
+
try:
|
| 201 |
+
audio_data, sr = librosa.load(io.BytesIO(audio_content))
|
| 202 |
+
except Exception as e:
|
| 203 |
+
raise HTTPException(status_code=400, detail=f"Invalid audio file: {str(e)}")
|
| 204 |
+
|
| 205 |
+
results = musiclime_unimodal(audio_data, modality="audio")
|
| 206 |
+
|
| 207 |
+
return AudioOnlyPredictionXAIResponse(
|
| 208 |
+
status="success",
|
| 209 |
+
audio_file_name=audio_file_name,
|
| 210 |
+
audio_content_type=audio_content_type,
|
| 211 |
+
audio_file_size=len(audio_content),
|
| 212 |
+
results=results,
|
| 213 |
+
)
|
| 214 |
+
except HTTPException:
|
| 215 |
+
raise
|
| 216 |
+
except Exception as e:
|
| 217 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 218 |
+
|
| 219 |
+
|
| 220 |
@app.get("/api/v1/model/info", response_model=ModelInfoResponse, tags=["Model"])
|
| 221 |
async def get_model_info():
|
| 222 |
"""
|
app/validators.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import File, Form, HTTPException, UploadFile
|
| 2 |
+
from typing import Optional, Tuple
|
| 3 |
+
from app.utils import download_youtube_audio
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
# Import config values
|
| 7 |
+
def get_config_values():
|
| 8 |
+
from app.server import MAX_FILE_SIZE, MAX_LYRICS_LENGTH, ALLOWED_AUDIO_TYPES
|
| 9 |
+
|
| 10 |
+
return MAX_FILE_SIZE, MAX_LYRICS_LENGTH, ALLOWED_AUDIO_TYPES
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def validate_lyrics(lyrics: str = Form(...)):
|
| 14 |
+
"""Validate lyrics length and content for multimodal endpoints."""
|
| 15 |
+
_, MAX_LYRICS_LENGTH, _ = get_config_values()
|
| 16 |
+
|
| 17 |
+
if len(lyrics) > MAX_LYRICS_LENGTH:
|
| 18 |
+
raise HTTPException(
|
| 19 |
+
status_code=400,
|
| 20 |
+
detail=f"Lyrics too long. Maximum length is {MAX_LYRICS_LENGTH} characters.",
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
lyrics = lyrics.strip()
|
| 24 |
+
if not lyrics:
|
| 25 |
+
raise HTTPException(
|
| 26 |
+
status_code=400,
|
| 27 |
+
detail="Lyrics cannot be empty.",
|
| 28 |
+
)
|
| 29 |
+
return lyrics
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
async def validate_audio_source(
|
| 33 |
+
audio_file: Optional[UploadFile] = File(None),
|
| 34 |
+
youtube_url: Optional[str] = Form(None),
|
| 35 |
+
) -> Tuple[Optional[bytes], str, str]:
|
| 36 |
+
"""Validate and process audio source from file upload or YouTube URL."""
|
| 37 |
+
MAX_FILE_SIZE, _, ALLOWED_AUDIO_TYPES = get_config_values()
|
| 38 |
+
|
| 39 |
+
if not audio_file and not youtube_url:
|
| 40 |
+
raise HTTPException(
|
| 41 |
+
status_code=400, detail="Either audio_file or youtube_url must be provided"
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
if audio_file and youtube_url:
|
| 45 |
+
raise HTTPException(
|
| 46 |
+
status_code=400, detail="Provide either audio_file or youtube_url, not both"
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
if youtube_url:
|
| 50 |
+
audio_content = download_youtube_audio(youtube_url)
|
| 51 |
+
return audio_content, "youtube_audio.wav", "audio/wav"
|
| 52 |
+
|
| 53 |
+
if audio_file.content_type not in ALLOWED_AUDIO_TYPES:
|
| 54 |
+
raise HTTPException(
|
| 55 |
+
status_code=400,
|
| 56 |
+
detail=f"Invalid file type. Supported formats: {', '.join(ALLOWED_AUDIO_TYPES)}",
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
audio_content = await audio_file.read()
|
| 60 |
+
if len(audio_content) > MAX_FILE_SIZE:
|
| 61 |
+
raise HTTPException(
|
| 62 |
+
status_code=400,
|
| 63 |
+
detail=f"File too large. Maximum size is {MAX_FILE_SIZE // (1024*1024)}MB.",
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
return audio_content, audio_file.filename, audio_file.content_type
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
async def validate_audio_only(
|
| 70 |
+
audio_file: Optional[UploadFile] = File(None),
|
| 71 |
+
youtube_url: Optional[str] = Form(None),
|
| 72 |
+
) -> Tuple[Optional[bytes], str, str]:
|
| 73 |
+
"""Validate audio source for audio-only endpoints (no lyrics required)."""
|
| 74 |
+
# Same validation as validate_audio_source but clearer naming for audio-only
|
| 75 |
+
return await validate_audio_source(audio_file, youtube_url)
|
scripts/explain.py
CHANGED
|
@@ -2,17 +2,25 @@ import os
|
|
| 2 |
import numpy as np
|
| 3 |
from datetime import datetime
|
| 4 |
from src.musiclime.explainer import MusicLIMEExplainer
|
| 5 |
-
from src.musiclime.wrapper import MusicLIMEPredictor
|
| 6 |
|
| 7 |
|
| 8 |
-
def
|
| 9 |
"""
|
| 10 |
-
MusicLIME
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
"""
|
| 17 |
start_time = datetime.now()
|
| 18 |
|
|
@@ -77,3 +85,97 @@ def musiclime(audio_data, lyrics_text):
|
|
| 77 |
"timestamp": start_time.isoformat(),
|
| 78 |
},
|
| 79 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import numpy as np
|
| 3 |
from datetime import datetime
|
| 4 |
from src.musiclime.explainer import MusicLIMEExplainer
|
| 5 |
+
from src.musiclime.wrapper import MusicLIMEPredictor, AudioOnlyPredictor
|
| 6 |
|
| 7 |
|
| 8 |
+
def musiclime_multimodal(audio_data, lyrics_text):
|
| 9 |
"""
|
| 10 |
+
Generate multimodal MusicLIME explanations for audio and lyrics.
|
| 11 |
+
|
| 12 |
+
Parameters
|
| 13 |
+
----------
|
| 14 |
+
audio_data : array-like
|
| 15 |
+
Audio waveform data from librosa.load or similar
|
| 16 |
+
lyrics_text : str
|
| 17 |
+
String containing song lyrics
|
| 18 |
+
|
| 19 |
+
Returns
|
| 20 |
+
-------
|
| 21 |
+
dict
|
| 22 |
+
Structured explanation results containing prediction info, feature explanations,
|
| 23 |
+
and processing metadata
|
| 24 |
"""
|
| 25 |
start_time = datetime.now()
|
| 26 |
|
|
|
|
| 85 |
"timestamp": start_time.isoformat(),
|
| 86 |
},
|
| 87 |
}
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def musiclime_unimodal(audio_data, modality="audio"):
|
| 91 |
+
"""
|
| 92 |
+
Generate unimodal MusicLIME explanations for single modality.
|
| 93 |
+
|
| 94 |
+
Parameters
|
| 95 |
+
----------
|
| 96 |
+
audio_data : array-like
|
| 97 |
+
Audio waveform data from librosa.load or similar
|
| 98 |
+
modality : str, default='audio'
|
| 99 |
+
Explanation modality, currently only supports 'audio'
|
| 100 |
+
|
| 101 |
+
Returns
|
| 102 |
+
-------
|
| 103 |
+
dict
|
| 104 |
+
Structured explanation results containing prediction info, audio-only feature
|
| 105 |
+
explanations, and processing metadata
|
| 106 |
+
|
| 107 |
+
Raises
|
| 108 |
+
------
|
| 109 |
+
ValueError
|
| 110 |
+
If modality is not 'audio' (lyrics is not yet implemented)
|
| 111 |
+
"""
|
| 112 |
+
if modality != "audio":
|
| 113 |
+
raise ValueError(
|
| 114 |
+
"Currently only 'audio' modality is supported for unimodal explanations"
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
start_time = datetime.now()
|
| 118 |
+
|
| 119 |
+
# Get number of samples from environment variable, default to 1000
|
| 120 |
+
num_samples = int(os.getenv("MUSICLIME_NUM_SAMPLES", "1000"))
|
| 121 |
+
num_features = int(os.getenv("MUSICLIME_NUM_FEATURES", "10"))
|
| 122 |
+
|
| 123 |
+
print(
|
| 124 |
+
f"[MusicLIME] Using num_samples={num_samples}, num_features={num_features} (audio-only mode)"
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
# Create musiclime instances
|
| 128 |
+
explainer = MusicLIMEExplainer(random_state=42)
|
| 129 |
+
predictor = AudioOnlyPredictor()
|
| 130 |
+
|
| 131 |
+
# Use empty lyrics for audio-only since they're ignored anyways
|
| 132 |
+
dummy_lyrics = ""
|
| 133 |
+
|
| 134 |
+
# Generate explanation
|
| 135 |
+
explanation = explainer.explain_instance(
|
| 136 |
+
audio=audio_data,
|
| 137 |
+
lyrics=dummy_lyrics,
|
| 138 |
+
predict_fn=predictor,
|
| 139 |
+
num_samples=num_samples,
|
| 140 |
+
labels=(1,),
|
| 141 |
+
modality=modality,
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
# Get prediction info
|
| 145 |
+
original_prediction = explanation.predictions[0]
|
| 146 |
+
predicted_class = np.argmax(original_prediction)
|
| 147 |
+
confidence = float(np.max(original_prediction))
|
| 148 |
+
|
| 149 |
+
# Get top features
|
| 150 |
+
top_features = explanation.get_explanation(label=1, num_features=num_features)
|
| 151 |
+
|
| 152 |
+
# Calculate runtime
|
| 153 |
+
end_time = datetime.now()
|
| 154 |
+
runtime_seconds = (end_time - start_time).total_seconds()
|
| 155 |
+
|
| 156 |
+
return {
|
| 157 |
+
"prediction": {
|
| 158 |
+
"class": int(predicted_class),
|
| 159 |
+
"class_name": "Human-Composed" if predicted_class == 1 else "AI-Generated",
|
| 160 |
+
"confidence": confidence,
|
| 161 |
+
"probabilities": original_prediction.tolist(),
|
| 162 |
+
},
|
| 163 |
+
"explanations": [
|
| 164 |
+
{
|
| 165 |
+
"rank": i + 1,
|
| 166 |
+
"modality": item["type"], # "audio" for all features
|
| 167 |
+
"feature_text": item["feature"],
|
| 168 |
+
"weight": float(item["weight"]),
|
| 169 |
+
"importance": abs(float(item["weight"])),
|
| 170 |
+
}
|
| 171 |
+
for i, item in enumerate(top_features)
|
| 172 |
+
],
|
| 173 |
+
"summary": {
|
| 174 |
+
"total_features_analyzed": len(top_features),
|
| 175 |
+
"audio_features_count": len(top_features), # All features are audio
|
| 176 |
+
"lyrics_features_count": 0, # No lyrics features
|
| 177 |
+
"runtime_seconds": runtime_seconds,
|
| 178 |
+
"samples_generated": num_samples,
|
| 179 |
+
"timestamp": start_time.isoformat(),
|
| 180 |
+
},
|
| 181 |
+
}
|
scripts/explain_runner.py
CHANGED
|
@@ -1,30 +1,69 @@
|
|
| 1 |
import librosa
|
| 2 |
-
from scripts.explain import
|
| 3 |
|
| 4 |
-
# Load test audio and lyrics
|
| 5 |
-
audio_path = "data/external/sample_1.mp3"
|
| 6 |
-
lyrics_path = "data/external/sample_1.txt"
|
| 7 |
|
| 8 |
-
|
| 9 |
-
|
|
|
|
|
|
|
| 10 |
|
| 11 |
-
# Load
|
| 12 |
-
|
| 13 |
-
lyrics_text = f.read()
|
| 14 |
|
| 15 |
-
|
| 16 |
-
|
|
|
|
| 17 |
|
| 18 |
-
print("
|
| 19 |
-
|
| 20 |
-
f"Prediction: {result['prediction']['class_name']} ({result['prediction']['confidence']:.3f})"
|
| 21 |
-
)
|
| 22 |
-
print(f"Runtime: {result['summary']['runtime_seconds']:.2f}s")
|
| 23 |
|
| 24 |
-
print("\n===
|
| 25 |
-
for feature in result["explanations"]:
|
| 26 |
print(
|
| 27 |
-
f"
|
| 28 |
)
|
| 29 |
-
print(f"
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import librosa
|
| 2 |
+
from scripts.explain import musiclime_multimodal, musiclime_unimodal
|
| 3 |
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
+
def explain_multimodal_runner(sample: str):
|
| 6 |
+
# Load test audio and lyrics
|
| 7 |
+
audio_path = f"data/external/{sample}.mp3"
|
| 8 |
+
lyrics_path = f"data/external/{sample}.txt"
|
| 9 |
|
| 10 |
+
# Load audio
|
| 11 |
+
audio_data, sr = librosa.load(audio_path)
|
|
|
|
| 12 |
|
| 13 |
+
# Load lyrics
|
| 14 |
+
with open(lyrics_path, "r", encoding="utf-8") as f:
|
| 15 |
+
lyrics_text = f.read()
|
| 16 |
|
| 17 |
+
print("Running multimodal MusicLIME explanation...")
|
| 18 |
+
result = musiclime_multimodal(audio_data, lyrics_text)
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
+
print("\n=== MULTIMODAL EXPLANATION RESULTS ===")
|
|
|
|
| 21 |
print(
|
| 22 |
+
f"Prediction: {result['prediction']['class_name']} ({result['prediction']['confidence']:.3f})"
|
| 23 |
)
|
| 24 |
+
print(f"Runtime: {result['summary']['runtime_seconds']:.2f}s")
|
| 25 |
+
|
| 26 |
+
print("\n=== TOP FEATURES (by absolute importance) ===")
|
| 27 |
+
for feature in result["explanations"]:
|
| 28 |
+
print(
|
| 29 |
+
f"Rank {feature['rank']}: {feature['modality']} | Weight: {feature['weight']:.4f} | Importance: {feature['importance']:.4f}"
|
| 30 |
+
)
|
| 31 |
+
print(f" Feature: {feature['feature_text'][:80]}...")
|
| 32 |
+
print()
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def explain_unimodal_runner(sample: str):
|
| 36 |
+
# Load test audio
|
| 37 |
+
audio_path = f"data/external/{sample}.mp3"
|
| 38 |
+
|
| 39 |
+
# Load audio
|
| 40 |
+
audio_data, sr = librosa.load(audio_path)
|
| 41 |
+
|
| 42 |
+
print("Running audio-only MusicLIME explanation...")
|
| 43 |
+
result = musiclime_unimodal(audio_data, modality="audio")
|
| 44 |
+
|
| 45 |
+
print("\n=== AUDIO-ONLY EXPLANATION RESULTS ===")
|
| 46 |
+
print(
|
| 47 |
+
f"Prediction: {result['prediction']['class_name']} ({result['prediction']['confidence']:.3f})"
|
| 48 |
+
)
|
| 49 |
+
print(f"Runtime: {result['summary']['runtime_seconds']:.2f}s")
|
| 50 |
+
|
| 51 |
+
print("\n=== TOP FEATURES (by absolute importance) ===")
|
| 52 |
+
for feature in result["explanations"]:
|
| 53 |
+
print(
|
| 54 |
+
f"Rank {feature['rank']}: {feature['modality']} | Weight: {feature['weight']:.4f} | Importance: {feature['importance']:.4f}"
|
| 55 |
+
)
|
| 56 |
+
print(f" Feature: {feature['feature_text'][:80]}...")
|
| 57 |
+
print()
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
if __name__ == "__main__":
|
| 61 |
+
sample = "sample"
|
| 62 |
+
|
| 63 |
+
# Run multimodal explanation
|
| 64 |
+
explain_multimodal_runner(sample)
|
| 65 |
+
|
| 66 |
+
print("\n" + "=" * 60 + "\n")
|
| 67 |
+
|
| 68 |
+
# Run audio-only explanation
|
| 69 |
+
explain_unimodal_runner(sample)
|
scripts/predict.py
CHANGED
|
@@ -1,4 +1,7 @@
|
|
| 1 |
-
from src.preprocessing.preprocessor import
|
|
|
|
|
|
|
|
|
|
| 2 |
from src.spectttra.spectttra_trainer import spectttra_predict
|
| 3 |
from src.llm2vectrain.model import load_llm2vec_model
|
| 4 |
from src.llm2vectrain.llm2vec_trainer import l2vec_single_train, load_pca_model
|
|
@@ -55,7 +58,7 @@ def predict_multimodal(audio_file, lyrics):
|
|
| 55 |
classifier = build_mlp(input_dim=results.shape[1], config=config)
|
| 56 |
|
| 57 |
# 7.) Load trained weights
|
| 58 |
-
model_path = "models/mlp/
|
| 59 |
classifier.load_model(model_path)
|
| 60 |
classifier.model.eval()
|
| 61 |
|
|
@@ -106,7 +109,7 @@ def predict_unimodal(audio_file):
|
|
| 106 |
classifier = build_mlp(input_dim=audio_features.shape[1], config=config)
|
| 107 |
|
| 108 |
# 6.) Load trained weights
|
| 109 |
-
model_path = "models/
|
| 110 |
classifier.load_model(model_path)
|
| 111 |
classifier.model.eval()
|
| 112 |
|
|
|
|
| 1 |
+
from src.preprocessing.preprocessor import (
|
| 2 |
+
single_preprocessing,
|
| 3 |
+
single_audio_preprocessing,
|
| 4 |
+
)
|
| 5 |
from src.spectttra.spectttra_trainer import spectttra_predict
|
| 6 |
from src.llm2vectrain.model import load_llm2vec_model
|
| 7 |
from src.llm2vectrain.llm2vec_trainer import l2vec_single_train, load_pca_model
|
|
|
|
| 58 |
classifier = build_mlp(input_dim=results.shape[1], config=config)
|
| 59 |
|
| 60 |
# 7.) Load trained weights
|
| 61 |
+
model_path = "models/mlp/mlp_best_multimodal.pth"
|
| 62 |
classifier.load_model(model_path)
|
| 63 |
classifier.model.eval()
|
| 64 |
|
|
|
|
| 109 |
classifier = build_mlp(input_dim=audio_features.shape[1], config=config)
|
| 110 |
|
| 111 |
# 6.) Load trained weights
|
| 112 |
+
model_path = "models/mlp/mlp_best_unimodal.pth"
|
| 113 |
classifier.load_model(model_path)
|
| 114 |
classifier.model.eval()
|
| 115 |
|
scripts/predict_runner.py
CHANGED
|
@@ -14,28 +14,33 @@ def predict_multimodal_runner(sample: str):
|
|
| 14 |
with open(lyrics_path, "r", encoding="utf-8") as f:
|
| 15 |
lyrics_text = f.read()
|
| 16 |
|
| 17 |
-
print("Running prediction pipeline...")
|
| 18 |
prediction = predict_multimodal(audio_data, lyrics_text)
|
| 19 |
|
| 20 |
-
print(
|
| 21 |
print(f"Prediction: {prediction}")
|
| 22 |
|
| 23 |
|
| 24 |
def predict_unimodal_runner(sample: str):
|
| 25 |
-
# Load test audio
|
| 26 |
-
audio_path = f"data/
|
| 27 |
|
| 28 |
# Load audio
|
| 29 |
audio_data, sr = librosa.load(audio_path)
|
| 30 |
|
| 31 |
-
print("Running prediction pipeline...")
|
| 32 |
prediction = predict_unimodal(audio_data)
|
| 33 |
|
| 34 |
-
print(
|
| 35 |
print(f"Prediction: {prediction}")
|
| 36 |
|
| 37 |
|
| 38 |
if __name__ == "__main__":
|
| 39 |
-
sample = "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
-
predict_unimodal_runner(sample)
|
|
|
|
| 14 |
with open(lyrics_path, "r", encoding="utf-8") as f:
|
| 15 |
lyrics_text = f.read()
|
| 16 |
|
| 17 |
+
print("Running multimodal prediction pipeline...")
|
| 18 |
prediction = predict_multimodal(audio_data, lyrics_text)
|
| 19 |
|
| 20 |
+
print("\n=== MULTIMODAL PREDICTION RESULT ===")
|
| 21 |
print(f"Prediction: {prediction}")
|
| 22 |
|
| 23 |
|
| 24 |
def predict_unimodal_runner(sample: str):
|
| 25 |
+
# Load test audio
|
| 26 |
+
audio_path = f"data/external/{sample}.mp3"
|
| 27 |
|
| 28 |
# Load audio
|
| 29 |
audio_data, sr = librosa.load(audio_path)
|
| 30 |
|
| 31 |
+
print("Running audio-only prediction pipeline...")
|
| 32 |
prediction = predict_unimodal(audio_data)
|
| 33 |
|
| 34 |
+
print("\n=== AUDIO-ONLY PREDICTION RESULT ===")
|
| 35 |
print(f"Prediction: {prediction}")
|
| 36 |
|
| 37 |
|
| 38 |
if __name__ == "__main__":
|
| 39 |
+
sample = "sample"
|
| 40 |
+
|
| 41 |
+
# Run both predictions
|
| 42 |
+
predict_multimodal_runner(sample)
|
| 43 |
+
|
| 44 |
+
print("\n" + "=" * 50 + "\n")
|
| 45 |
|
| 46 |
+
predict_unimodal_runner(sample)
|
src/musiclime/explainer.py
CHANGED
|
@@ -57,9 +57,15 @@ class MusicLIMEExplainer:
|
|
| 57 |
num_samples=1000,
|
| 58 |
labels=(1,),
|
| 59 |
temporal_segments=10,
|
|
|
|
| 60 |
):
|
| 61 |
"""
|
| 62 |
-
Generate LIME explanations for a music instance using audio and lyrics.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
|
| 64 |
Parameters
|
| 65 |
----------
|
|
@@ -75,18 +81,26 @@ class MusicLIMEExplainer:
|
|
| 75 |
Target labels to explain (0=AI-Generated, 1=Human-Composed)
|
| 76 |
temporal_segments : int, default=10
|
| 77 |
Number of temporal segments for audio factorization
|
|
|
|
|
|
|
| 78 |
|
| 79 |
Returns
|
| 80 |
-------
|
| 81 |
MusicLIMEExplanation
|
| 82 |
-
Explanation object containing feature importance weights
|
| 83 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
# These are for debugging only I have to see THAT progress
|
| 85 |
print("[MusicLIME] Starting MusicLIME explanation...")
|
| 86 |
print(
|
| 87 |
f"[MusicLIME] Audio length: {len(audio)/22050:.1f}s, Temporal segments: {temporal_segments}"
|
| 88 |
)
|
| 89 |
print(f"[MusicLIME] Lyrics lines: {len(lyrics.split(chr(10)))}")
|
|
|
|
|
|
|
| 90 |
|
| 91 |
# Create factorizations
|
| 92 |
print("[MusicLIME] Creating audio factorization (source separation)...")
|
|
@@ -111,7 +125,7 @@ class MusicLIMEExplainer:
|
|
| 111 |
# Generate perturbations and get predictions
|
| 112 |
print(f"[MusicLIME] Generating {num_samples} perturbations...")
|
| 113 |
data, predictions, distances = self._generate_neighborhood(
|
| 114 |
-
audio_factorization, text_factorization, predict_fn, num_samples
|
| 115 |
)
|
| 116 |
|
| 117 |
# LIME fitting, create explanation object
|
|
@@ -140,33 +154,55 @@ class MusicLIMEExplainer:
|
|
| 140 |
|
| 141 |
return explanation
|
| 142 |
|
| 143 |
-
def _generate_neighborhood(
|
|
|
|
|
|
|
| 144 |
"""
|
| 145 |
-
Generate perturbed samples and predictions for LIME explanation.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
|
| 147 |
Parameters
|
| 148 |
----------
|
| 149 |
audio_fact : OpenUnmixFactorization
|
| 150 |
-
Audio factorization object for source separation
|
| 151 |
text_fact : LineIndexedString
|
| 152 |
-
Text factorization object for line-based perturbations
|
| 153 |
predict_fn : callable
|
| 154 |
-
Model prediction function
|
| 155 |
num_samples : int
|
| 156 |
-
Number of
|
|
|
|
|
|
|
| 157 |
|
| 158 |
Returns
|
| 159 |
-------
|
| 160 |
data : ndarray
|
| 161 |
-
Binary perturbation masks (num_samples, total_features)
|
| 162 |
predictions : ndarray
|
| 163 |
-
Model predictions for perturbed instances
|
| 164 |
distances : ndarray
|
| 165 |
-
Cosine distances from original instance
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
"""
|
| 167 |
n_audio = audio_fact.get_number_components()
|
| 168 |
n_text = text_fact.num_words()
|
| 169 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
|
| 171 |
print(
|
| 172 |
f"[MusicLIME] Total features: {total_features} ({n_audio} audio + {n_text} text)"
|
|
@@ -187,22 +223,46 @@ class MusicLIMEExplainer:
|
|
| 187 |
texts = []
|
| 188 |
audios = []
|
| 189 |
|
| 190 |
-
for
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
texts.append(perturbed_text)
|
| 206 |
|
| 207 |
perturbation_time = time.time() - start_time
|
| 208 |
print(
|
|
@@ -221,7 +281,7 @@ class MusicLIMEExplainer:
|
|
| 221 |
confidence = original_prediction[predicted_class]
|
| 222 |
|
| 223 |
# Print original prediction
|
| 224 |
-
print(
|
| 225 |
print(
|
| 226 |
f" Raw probabilities: [AI: {original_prediction[0]:.3f}, Human: {original_prediction[1]:.3f}]"
|
| 227 |
)
|
|
|
|
| 57 |
num_samples=1000,
|
| 58 |
labels=(1,),
|
| 59 |
temporal_segments=10,
|
| 60 |
+
modality="both",
|
| 61 |
):
|
| 62 |
"""
|
| 63 |
+
Generate LIME explanations for a music instance using audio and/or lyrics.
|
| 64 |
+
|
| 65 |
+
This method creates local explanations by perturbing audio components (via source
|
| 66 |
+
separation) and/or lyrics lines, then analyzing their impact on model predictions.
|
| 67 |
+
Supports three modality modes: 'both' (multimodal), 'audio' (audio-only), and
|
| 68 |
+
'lyrical' (lyrics-only) following the original MusicLIME paper implementation.
|
| 69 |
|
| 70 |
Parameters
|
| 71 |
----------
|
|
|
|
| 81 |
Target labels to explain (0=AI-Generated, 1=Human-Composed)
|
| 82 |
temporal_segments : int, default=10
|
| 83 |
Number of temporal segments for audio factorization
|
| 84 |
+
modality : str, default='both'
|
| 85 |
+
Explanation modality: 'both' (multimodal), 'audio' (audio-only), or 'lyrical' (lyrics-only)
|
| 86 |
|
| 87 |
Returns
|
| 88 |
-------
|
| 89 |
MusicLIMEExplanation
|
| 90 |
+
Explanation object containing feature importance weights and metadata
|
| 91 |
"""
|
| 92 |
+
# Validation for modality choice
|
| 93 |
+
if modality not in ["both", "audio", "lyrical"]:
|
| 94 |
+
raise ValueError("Set modality argument to 'both', 'audio', 'lyrical'.")
|
| 95 |
+
|
| 96 |
# These are for debugging only I have to see THAT progress
|
| 97 |
print("[MusicLIME] Starting MusicLIME explanation...")
|
| 98 |
print(
|
| 99 |
f"[MusicLIME] Audio length: {len(audio)/22050:.1f}s, Temporal segments: {temporal_segments}"
|
| 100 |
)
|
| 101 |
print(f"[MusicLIME] Lyrics lines: {len(lyrics.split(chr(10)))}")
|
| 102 |
+
print("[MusicLIME] Starting MusicLIME explanation...")
|
| 103 |
+
print(f"[MusicLIME] Modality: {modality}")
|
| 104 |
|
| 105 |
# Create factorizations
|
| 106 |
print("[MusicLIME] Creating audio factorization (source separation)...")
|
|
|
|
| 125 |
# Generate perturbations and get predictions
|
| 126 |
print(f"[MusicLIME] Generating {num_samples} perturbations...")
|
| 127 |
data, predictions, distances = self._generate_neighborhood(
|
| 128 |
+
audio_factorization, text_factorization, predict_fn, num_samples, modality
|
| 129 |
)
|
| 130 |
|
| 131 |
# LIME fitting, create explanation object
|
|
|
|
| 154 |
|
| 155 |
return explanation
|
| 156 |
|
| 157 |
+
def _generate_neighborhood(
|
| 158 |
+
self, audio_fact, text_fact, predict_fn, num_samples, modality="both"
|
| 159 |
+
):
|
| 160 |
"""
|
| 161 |
+
Generate perturbed samples and predictions for LIME explanation based on modality.
|
| 162 |
+
|
| 163 |
+
Creates binary perturbation masks and generates corresponding perturbed audio-text
|
| 164 |
+
pairs. The perturbation strategy depends on the specified modality:
|
| 165 |
+
- 'both': Perturbs both audio components and lyrics lines independently
|
| 166 |
+
- 'audio': Perturbs only audio components, keeps original lyrics constant
|
| 167 |
+
- 'lyrical': Perturbs only lyrics lines, keeps original audio constant
|
| 168 |
|
| 169 |
Parameters
|
| 170 |
----------
|
| 171 |
audio_fact : OpenUnmixFactorization
|
| 172 |
+
Audio factorization object for source separation-based perturbations
|
| 173 |
text_fact : LineIndexedString
|
| 174 |
+
Text factorization object for line-based lyrics perturbations
|
| 175 |
predict_fn : callable
|
| 176 |
+
Model prediction function that processes (texts, audios) batches
|
| 177 |
num_samples : int
|
| 178 |
+
Number of perturbation samples to generate for LIME
|
| 179 |
+
modality : str, default='both'
|
| 180 |
+
Perturbation modality: 'both', 'audio', or 'lyrical'
|
| 181 |
|
| 182 |
Returns
|
| 183 |
-------
|
| 184 |
data : ndarray
|
| 185 |
+
Binary perturbation masks of shape (num_samples, total_features)
|
| 186 |
predictions : ndarray
|
| 187 |
+
Model predictions for perturbed instances of shape (num_samples, n_classes)
|
| 188 |
distances : ndarray
|
| 189 |
+
Cosine distances from original instance of shape (num_samples,)
|
| 190 |
+
|
| 191 |
+
Notes
|
| 192 |
+
-----
|
| 193 |
+
The first sample (index 0) is always the original unperturbed instance.
|
| 194 |
+
Feature ordering: [audio_components, lyrics_lines] for 'both' modality.
|
| 195 |
"""
|
| 196 |
n_audio = audio_fact.get_number_components()
|
| 197 |
n_text = text_fact.num_words()
|
| 198 |
+
|
| 199 |
+
# Set total features based on modality
|
| 200 |
+
if modality == "both":
|
| 201 |
+
total_features = n_audio + n_text
|
| 202 |
+
elif modality == "audio":
|
| 203 |
+
total_features = n_audio
|
| 204 |
+
elif modality == "lyrical":
|
| 205 |
+
total_features = n_text
|
| 206 |
|
| 207 |
print(
|
| 208 |
f"[MusicLIME] Total features: {total_features} ({n_audio} audio + {n_text} text)"
|
|
|
|
| 223 |
texts = []
|
| 224 |
audios = []
|
| 225 |
|
| 226 |
+
for _, row in enumerate(data):
|
| 227 |
+
if modality == "both":
|
| 228 |
+
# Audio perturbation & reconstruction
|
| 229 |
+
audio_mask = row[:n_audio]
|
| 230 |
+
active_audio_components = np.where(audio_mask != 0)[0]
|
| 231 |
+
perturbed_audio = audio_fact.compose_model_input(
|
| 232 |
+
active_audio_components
|
| 233 |
+
)
|
| 234 |
+
audios.append(perturbed_audio)
|
| 235 |
+
|
| 236 |
+
# Text perturbation & reconstruction
|
| 237 |
+
text_mask = row[n_audio:]
|
| 238 |
+
inactive_lines = np.where(text_mask == 0)[0]
|
| 239 |
+
perturbed_text = text_fact.inverse_removing(inactive_lines)
|
| 240 |
+
texts.append(perturbed_text)
|
| 241 |
+
|
| 242 |
+
elif modality == "audio":
|
| 243 |
+
# Audio perturbation, original lyrics
|
| 244 |
+
active_audio_components = np.where(row != 0)[0]
|
| 245 |
+
perturbed_audio = audio_fact.compose_model_input(
|
| 246 |
+
active_audio_components
|
| 247 |
+
)
|
| 248 |
+
audios.append(perturbed_audio)
|
| 249 |
+
|
| 250 |
+
# Use original lyrics (no perturbation)
|
| 251 |
+
perturbed_text = text_fact.inverse_removing(
|
| 252 |
+
[]
|
| 253 |
+
) # Empty array = no removal
|
| 254 |
+
texts.append(perturbed_text)
|
| 255 |
|
| 256 |
+
elif modality == "lyrical":
|
| 257 |
+
# Original audio, lyrics perturbation
|
| 258 |
+
all_audio_components = np.arange(n_audio) # Use all audio components
|
| 259 |
+
perturbed_audio = audio_fact.compose_model_input(all_audio_components)
|
| 260 |
+
audios.append(perturbed_audio)
|
| 261 |
|
| 262 |
+
# Perturb lyrics
|
| 263 |
+
inactive_lines = np.where(row == 0)[0]
|
| 264 |
+
perturbed_text = text_fact.inverse_removing(inactive_lines)
|
| 265 |
+
texts.append(perturbed_text)
|
|
|
|
| 266 |
|
| 267 |
perturbation_time = time.time() - start_time
|
| 268 |
print(
|
|
|
|
| 281 |
confidence = original_prediction[predicted_class]
|
| 282 |
|
| 283 |
# Print original prediction
|
| 284 |
+
print("[MusicLIME] Original Prediction:")
|
| 285 |
print(
|
| 286 |
f" Raw probabilities: [AI: {original_prediction[0]:.3f}, Human: {original_prediction[1]:.3f}]"
|
| 287 |
)
|
src/musiclime/wrapper.py
CHANGED
|
@@ -3,7 +3,10 @@ import joblib
|
|
| 3 |
import numpy as np
|
| 4 |
import torch
|
| 5 |
|
| 6 |
-
from src.preprocessing.preprocessor import
|
|
|
|
|
|
|
|
|
|
| 7 |
from src.spectttra.spectttra_trainer import spectttra_train
|
| 8 |
from src.llm2vectrain.llm2vec_trainer import l2vec_train
|
| 9 |
from src.llm2vectrain.model import load_llm2vec_model
|
|
@@ -159,7 +162,7 @@ class MusicLIMEPredictor:
|
|
| 159 |
self.classifier = build_mlp(
|
| 160 |
input_dim=combined_features_batch.shape[1], config=self.config
|
| 161 |
)
|
| 162 |
-
self.classifier.load_model("models/mlp/
|
| 163 |
|
| 164 |
probabilities, predictions = self.classifier.predict(combined_features_batch)
|
| 165 |
|
|
@@ -172,7 +175,7 @@ class MusicLIMEPredictor:
|
|
| 172 |
total_time = (
|
| 173 |
preprocessing_time + audio_time + lyrics_time + scaling_time + mlp_time
|
| 174 |
)
|
| 175 |
-
print(
|
| 176 |
print(
|
| 177 |
green_bold(
|
| 178 |
f"[MusicLIME] Total time: {total_time:.2f}s (Preprocessing: {preprocessing_time:.2f}s, Audio: {audio_time:.2f}s, Lyrics: {lyrics_time:.2f}s, Scaling: {scaling_time:.2f}s, MLP: {mlp_time:.2f}s)"
|
|
@@ -180,3 +183,143 @@ class MusicLIMEPredictor:
|
|
| 180 |
)
|
| 181 |
|
| 182 |
return np.array(batch_results)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import numpy as np
|
| 4 |
import torch
|
| 5 |
|
| 6 |
+
from src.preprocessing.preprocessor import (
|
| 7 |
+
single_preprocessing,
|
| 8 |
+
single_audio_preprocessing,
|
| 9 |
+
)
|
| 10 |
from src.spectttra.spectttra_trainer import spectttra_train
|
| 11 |
from src.llm2vectrain.llm2vec_trainer import l2vec_train
|
| 12 |
from src.llm2vectrain.model import load_llm2vec_model
|
|
|
|
| 162 |
self.classifier = build_mlp(
|
| 163 |
input_dim=combined_features_batch.shape[1], config=self.config
|
| 164 |
)
|
| 165 |
+
self.classifier.load_model("models/mlp/mlp_best_multimodal.pth")
|
| 166 |
|
| 167 |
probabilities, predictions = self.classifier.predict(combined_features_batch)
|
| 168 |
|
|
|
|
| 175 |
total_time = (
|
| 176 |
preprocessing_time + audio_time + lyrics_time + scaling_time + mlp_time
|
| 177 |
)
|
| 178 |
+
print("[MusicLIME] Batch processing complete!")
|
| 179 |
print(
|
| 180 |
green_bold(
|
| 181 |
f"[MusicLIME] Total time: {total_time:.2f}s (Preprocessing: {preprocessing_time:.2f}s, Audio: {audio_time:.2f}s, Lyrics: {lyrics_time:.2f}s, Scaling: {scaling_time:.2f}s, MLP: {mlp_time:.2f}s)"
|
|
|
|
| 183 |
)
|
| 184 |
|
| 185 |
return np.array(batch_results)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
class AudioOnlyPredictor:
|
| 189 |
+
"""
|
| 190 |
+
Audio-only prediction wrapper for MusicLIME explanations.
|
| 191 |
+
|
| 192 |
+
Integrates the audio-only Bach or Bot pipeline (SpecTTTra + MLP) into a single
|
| 193 |
+
callable for LIME perturbation processing. Optimized for batch processing of
|
| 194 |
+
multiple perturbed audio samples while ignoring lyrics input. Mirrors the
|
| 195 |
+
multimodal MusicLIMEPredictor but processes only audio features.
|
| 196 |
+
|
| 197 |
+
This predictor is specifically designed for audio-only explainability where
|
| 198 |
+
lyrics are kept constant and only audio components are perturbed through
|
| 199 |
+
source separation techniques.
|
| 200 |
+
|
| 201 |
+
Attributes
|
| 202 |
+
----------
|
| 203 |
+
classifier : MLPClassifier or None
|
| 204 |
+
Lazy-loaded MLP classifier for audio-only predictions
|
| 205 |
+
config : dict
|
| 206 |
+
Model configuration parameters loaded from config files
|
| 207 |
+
"""
|
| 208 |
+
|
| 209 |
+
def __init__(self):
|
| 210 |
+
"""
|
| 211 |
+
Initialize audio-only prediction wrapper.
|
| 212 |
+
|
| 213 |
+
Loads model configuration for batch processing of perturbed audio samples
|
| 214 |
+
during LIME explanation. The MLP classifier is lazy-loaded on first use
|
| 215 |
+
to optimize memory usage.
|
| 216 |
+
"""
|
| 217 |
+
print("[MusicLIME] Loading models for Audio-Only MusicLIME...")
|
| 218 |
+
config = load_config("config/model_config.yml")
|
| 219 |
+
self.classifier = None
|
| 220 |
+
self.config = config
|
| 221 |
+
|
| 222 |
+
def __call__(self, texts, audios):
|
| 223 |
+
"""
|
| 224 |
+
Batch prediction function for audio-only MusicLIME perturbations.
|
| 225 |
+
|
| 226 |
+
Processes multiple perturbed audio samples through the audio-only pipeline:
|
| 227 |
+
preprocessing -> SpecTTTra feature extraction -> scaling -> MLP prediction.
|
| 228 |
+
Text inputs are ignored as this is audio-only mode. Optimized for batch
|
| 229 |
+
processing of LIME perturbations with detailed timing analysis.
|
| 230 |
+
|
| 231 |
+
Parameters
|
| 232 |
+
----------
|
| 233 |
+
texts : list of str
|
| 234 |
+
List of text strings (ignored in audio-only mode, kept for API compatibility)
|
| 235 |
+
audios : list of array-like
|
| 236 |
+
List of perturbed audio waveforms from LIME perturbations
|
| 237 |
+
|
| 238 |
+
Returns
|
| 239 |
+
-------
|
| 240 |
+
ndarray
|
| 241 |
+
Prediction probabilities in format [[P(AI), P(Human)], ...]
|
| 242 |
+
for each input audio sample, shape (n_samples, 2)
|
| 243 |
+
"""
|
| 244 |
+
print(
|
| 245 |
+
f"[MusicLIME] Processing {len(audios)} samples with batch functions (audio-only mode)..."
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
# Step 1: Preprocess all audio samples
|
| 249 |
+
start_time = time.time()
|
| 250 |
+
print("[MusicLIME] Preprocessing audio samples...")
|
| 251 |
+
processed_audios = []
|
| 252 |
+
|
| 253 |
+
for audio in audios:
|
| 254 |
+
processed_audio = single_audio_preprocessing(audio)
|
| 255 |
+
processed_audios.append(processed_audio)
|
| 256 |
+
|
| 257 |
+
preprocessing_time = time.time() - start_time
|
| 258 |
+
print(
|
| 259 |
+
green_bold(
|
| 260 |
+
f"[MusicLIME] Audio preprocessing completed in {preprocessing_time:.2f}s"
|
| 261 |
+
)
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
# Step 2: Batch audio feature extraction
|
| 265 |
+
start_time = time.time()
|
| 266 |
+
print("[MusicLIME] Extracting audio features (batch)...")
|
| 267 |
+
audio_features_batch = spectttra_train(processed_audios)
|
| 268 |
+
|
| 269 |
+
# Clear GPU cache after audio processing
|
| 270 |
+
if torch.cuda.is_available():
|
| 271 |
+
torch.cuda.empty_cache()
|
| 272 |
+
|
| 273 |
+
audio_time = time.time() - start_time
|
| 274 |
+
print(
|
| 275 |
+
green_bold(
|
| 276 |
+
f"[MusicLIME] Audio feature extraction completed in {audio_time:.2f}s"
|
| 277 |
+
)
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
# Step 3: Scale audio features in batch
|
| 281 |
+
start_time = time.time()
|
| 282 |
+
print("[MusicLIME] Scaling audio features (batch)...")
|
| 283 |
+
|
| 284 |
+
# Load the audio scaler
|
| 285 |
+
audio_scaler = joblib.load("models/fusion/audio_scaler.pkl")
|
| 286 |
+
scaled_audio_batch = audio_scaler.transform(audio_features_batch)
|
| 287 |
+
|
| 288 |
+
scaling_time = time.time() - start_time
|
| 289 |
+
print(green_bold(f"[MusicLIME] Audio scaling completed in {scaling_time:.2f}s"))
|
| 290 |
+
|
| 291 |
+
# Step 4: Audio-only MLP prediction
|
| 292 |
+
start_time = time.time()
|
| 293 |
+
print("[MusicLIME] Running audio-only MLP predictions (batch)...")
|
| 294 |
+
|
| 295 |
+
if self.classifier is None:
|
| 296 |
+
self.classifier = build_mlp(
|
| 297 |
+
input_dim=scaled_audio_batch.shape[1], config=self.config
|
| 298 |
+
)
|
| 299 |
+
self.classifier.load_model("models/mlp/mlp_best_unimodal.pth")
|
| 300 |
+
|
| 301 |
+
probabilities, predictions = self.classifier.predict(scaled_audio_batch)
|
| 302 |
+
|
| 303 |
+
# Clear GPU cache after MLP processing
|
| 304 |
+
if torch.cuda.is_available():
|
| 305 |
+
torch.cuda.empty_cache()
|
| 306 |
+
|
| 307 |
+
# Convert to expected format
|
| 308 |
+
batch_results = [[1 - prob, prob] for prob in probabilities]
|
| 309 |
+
mlp_time = time.time() - start_time
|
| 310 |
+
print(
|
| 311 |
+
green_bold(
|
| 312 |
+
f"[MusicLIME] Audio-only MLP prediction completed in {mlp_time:.2f}s"
|
| 313 |
+
)
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
# Total time summary
|
| 317 |
+
total_time = preprocessing_time + audio_time + scaling_time + mlp_time
|
| 318 |
+
print("[MusicLIME] Audio-only batch processing complete!")
|
| 319 |
+
print(
|
| 320 |
+
green_bold(
|
| 321 |
+
f"[MusicLIME] Total time: {total_time:.2f}s (Preprocessing: {preprocessing_time:.2f}s, Audio: {audio_time:.2f}s, Scaling: {scaling_time:.2f}s, MLP: {mlp_time:.2f}s)"
|
| 322 |
+
)
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
return np.array(batch_results)
|