krislette commited on
Commit
97eaafb
·
1 Parent(s): c84f2c4

Auto-deploy from GitHub: e4d0ee2ddb3dc15442ce902b31f6de26098a6291

Browse files
app/schemas.py CHANGED
@@ -40,8 +40,17 @@ class PredictionXAIResponse(BaseModel):
40
  results: Optional[Dict] = None
41
 
42
 
43
- # Pydantic model for the error response
44
- class ErrorResponse(BaseModel):
45
- status: str = "error"
46
- code: int
47
- message: str
 
 
 
 
 
 
 
 
 
 
40
  results: Optional[Dict] = None
41
 
42
 
43
+ class AudioOnlyPredictionResponse(BaseModel):
44
+ status: str
45
+ audio_file_name: str
46
+ audio_content_type: str
47
+ audio_file_size: int
48
+ results: dict
49
+
50
+
51
+ class AudioOnlyPredictionXAIResponse(BaseModel):
52
+ status: str
53
+ audio_file_name: str
54
+ audio_content_type: str
55
+ audio_file_size: int
56
+ results: dict
app/server.py CHANGED
@@ -1,26 +1,27 @@
1
  # Fast API imports
2
- from fastapi import Depends, FastAPI, File, Form, HTTPException, UploadFile
3
  from fastapi.middleware.cors import CORSMiddleware
4
 
5
  # Utils/schemas imports
6
  from app.schemas import (
7
- ErrorResponse,
8
  ModelInfoResponse,
9
  PredictionResponse,
10
  PredictionXAIResponse,
 
 
11
  WelcomeResponse,
12
  )
13
- from app.utils import load_server_config, load_model_config, download_youtube_audio
 
14
 
15
  # Model/XAI-related imports
16
- from scripts.explain import musiclime
17
- from scripts.predict import predict_multimodal
18
 
19
  # Other imports
20
  import io
21
  import librosa
22
- from typing import Optional, Tuple
23
-
24
 
25
  # Load configs at startup
26
  server_config = load_server_config()
@@ -47,70 +48,9 @@ app.add_middleware(
47
  )
48
 
49
 
50
- def validate_lyrics(lyrics: str = Form(...)):
51
- """Validate lyrics length and content."""
52
- if len(lyrics) > MAX_LYRICS_LENGTH:
53
- raise HTTPException(
54
- status_code=400,
55
- detail=f"Lyrics too long. Maximum length is {MAX_LYRICS_LENGTH} characters.",
56
- )
57
-
58
- # Basic sanitization, remove excessive whitespace
59
- lyrics = lyrics.strip()
60
- if not lyrics:
61
- raise HTTPException(
62
- status_code=400,
63
- detail="Lyrics cannot be empty.",
64
- )
65
-
66
- return lyrics
67
-
68
-
69
- async def validate_audio_source(
70
- audio_file: Optional[UploadFile] = File(None),
71
- youtube_url: Optional[str] = Form(None),
72
- ) -> Tuple[Optional[bytes], str, str]:
73
- """
74
- Validate and process audio source (either file or YouTube URL).
75
- Returns: (audio_content, file_name, content_type)
76
- """
77
- if not audio_file and not youtube_url:
78
- raise HTTPException(
79
- status_code=400, detail="Either audio_file or youtube_url must be provided"
80
- )
81
-
82
- if audio_file and youtube_url:
83
- raise HTTPException(
84
- status_code=400, detail="Provide either audio_file or youtube_url, not both"
85
- )
86
-
87
- # Process YouTube URL
88
- if youtube_url:
89
- audio_content = download_youtube_audio(youtube_url)
90
- return audio_content, "youtube_audio.wav", "audio/wav"
91
-
92
- # Process uploaded file
93
- if audio_file.content_type not in ALLOWED_AUDIO_TYPES:
94
- raise HTTPException(
95
- status_code=400,
96
- detail=f"Invalid file type. Supported formats: {', '.join(ALLOWED_AUDIO_TYPES)}",
97
- )
98
-
99
- audio_content = await audio_file.read()
100
- if len(audio_content) > MAX_FILE_SIZE:
101
- raise HTTPException(
102
- status_code=400,
103
- detail=f"File too large. Maximum size is {MAX_FILE_SIZE // (1024*1024)}MB.",
104
- )
105
-
106
- return audio_content, audio_file.filename, audio_file.content_type
107
-
108
-
109
  @app.get("/", response_model=WelcomeResponse, tags=["Root"])
110
  def root():
111
- """
112
- Root endpoint to check if the API is running.
113
- """
114
  return WelcomeResponse(
115
  status="success",
116
  message="Welcome to Bach or Bot API!",
@@ -118,18 +58,38 @@ def root():
118
  "/": "This welcome message",
119
  "/docs": "FastAPI auto-generated API docs",
120
  "/api/v1/model/info": "Model information and capabilities",
121
- "/api/v1/predict": "POST endpoint for bach-or-bot prediction",
122
- "/api/v1/explain": "POST endpoint for prediction with explainability",
 
 
 
 
123
  },
124
  )
125
 
126
 
127
- @app.post(
128
- "/api/v1/predict",
129
- response_model=PredictionResponse,
130
- responses={400: {"model": ErrorResponse}, 500: {"model": ErrorResponse}},
131
- )
132
- async def predict_music(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  lyrics: str = Depends(validate_lyrics),
134
  audio_data_tuple: Tuple = Depends(validate_audio_source),
135
  ):
@@ -164,12 +124,8 @@ async def predict_music(
164
  raise HTTPException(status_code=500, detail=str(e))
165
 
166
 
167
- @app.post(
168
- "/api/v1/explain",
169
- response_model=PredictionXAIResponse,
170
- responses={400: {"model": ErrorResponse}, 500: {"model": ErrorResponse}},
171
- )
172
- async def predict_music_with_xai(
173
  lyrics: str = Depends(validate_lyrics),
174
  audio_data_tuple: Tuple = Depends(validate_audio_source),
175
  ):
@@ -188,7 +144,7 @@ async def predict_music_with_xai(
188
  raise HTTPException(status_code=400, detail=f"Invalid audio file: {str(e)}")
189
 
190
  # Call musiclime runner script
191
- results = musiclime(audio_data, lyrics)
192
 
193
  return PredictionXAIResponse(
194
  status="success",
@@ -204,6 +160,63 @@ async def predict_music_with_xai(
204
  raise HTTPException(status_code=500, detail=str(e))
205
 
206
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  @app.get("/api/v1/model/info", response_model=ModelInfoResponse, tags=["Model"])
208
  async def get_model_info():
209
  """
 
1
  # Fast API imports
2
+ from fastapi import Depends, FastAPI, HTTPException
3
  from fastapi.middleware.cors import CORSMiddleware
4
 
5
  # Utils/schemas imports
6
  from app.schemas import (
 
7
  ModelInfoResponse,
8
  PredictionResponse,
9
  PredictionXAIResponse,
10
+ AudioOnlyPredictionResponse,
11
+ AudioOnlyPredictionXAIResponse,
12
  WelcomeResponse,
13
  )
14
+ from app.utils import load_server_config, load_model_config
15
+ from app.validators import validate_lyrics, validate_audio_source, validate_audio_only
16
 
17
  # Model/XAI-related imports
18
+ from scripts.explain import musiclime_multimodal, musiclime_unimodal
19
+ from scripts.predict import predict_multimodal, predict_unimodal
20
 
21
  # Other imports
22
  import io
23
  import librosa
24
+ from typing import Tuple
 
25
 
26
  # Load configs at startup
27
  server_config = load_server_config()
 
48
  )
49
 
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  @app.get("/", response_model=WelcomeResponse, tags=["Root"])
52
  def root():
53
+ """Root endpoint to check if the API is running."""
 
 
54
  return WelcomeResponse(
55
  status="success",
56
  message="Welcome to Bach or Bot API!",
 
58
  "/": "This welcome message",
59
  "/docs": "FastAPI auto-generated API docs",
60
  "/api/v1/model/info": "Model information and capabilities",
61
+ "/api/v1/predict": "POST endpoint for bach-or-bot prediction (legacy)",
62
+ "/api/v1/explain": "POST endpoint for prediction with explainability (legacy)",
63
+ "/api/v1/predict/multimodal": "POST endpoint for multimodal prediction",
64
+ "/api/v1/explain/multimodal": "POST endpoint for multimodal explainability",
65
+ "/api/v1/predict/audio": "POST endpoint for audio-only prediction",
66
+ "/api/v1/explain/audio": "POST endpoint for audio-only explainability",
67
  },
68
  )
69
 
70
 
71
+ # Legacy endpoints (backward compatibility)
72
+ @app.post("/api/v1/predict", response_model=PredictionResponse)
73
+ async def predict_music_legacy(
74
+ lyrics: str = Depends(validate_lyrics),
75
+ audio_data_tuple: Tuple = Depends(validate_audio_source),
76
+ ):
77
+ """Legacy multimodal prediction endpoint."""
78
+ return await predict_multimodal_endpoint(lyrics, audio_data_tuple)
79
+
80
+
81
+ @app.post("/api/v1/explain", response_model=PredictionXAIResponse)
82
+ async def explain_music_legacy(
83
+ lyrics: str = Depends(validate_lyrics),
84
+ audio_data_tuple: Tuple = Depends(validate_audio_source),
85
+ ):
86
+ """Legacy multimodal explanation endpoint."""
87
+ return await explain_multimodal_endpoint(lyrics, audio_data_tuple)
88
+
89
+
90
+ # New multimodal endpoints
91
+ @app.post("/api/v1/predict/multimodal", response_model=PredictionResponse)
92
+ async def predict_multimodal_endpoint(
93
  lyrics: str = Depends(validate_lyrics),
94
  audio_data_tuple: Tuple = Depends(validate_audio_source),
95
  ):
 
124
  raise HTTPException(status_code=500, detail=str(e))
125
 
126
 
127
+ @app.post("/api/v1/explain/multimodal", response_model=PredictionXAIResponse)
128
+ async def explain_multimodal_endpoint(
 
 
 
 
129
  lyrics: str = Depends(validate_lyrics),
130
  audio_data_tuple: Tuple = Depends(validate_audio_source),
131
  ):
 
144
  raise HTTPException(status_code=400, detail=f"Invalid audio file: {str(e)}")
145
 
146
  # Call musiclime runner script
147
+ results = musiclime_multimodal(audio_data, lyrics)
148
 
149
  return PredictionXAIResponse(
150
  status="success",
 
160
  raise HTTPException(status_code=500, detail=str(e))
161
 
162
 
163
+ # New audio-only endpoints
164
+ @app.post("/api/v1/predict/audio", response_model=AudioOnlyPredictionResponse)
165
+ async def predict_audio_only_endpoint(
166
+ audio_data_tuple: Tuple = Depends(validate_audio_only),
167
+ ):
168
+ """Audio-only prediction endpoint."""
169
+ try:
170
+ audio_content, audio_file_name, audio_content_type = audio_data_tuple
171
+
172
+ try:
173
+ audio_data, sr = librosa.load(io.BytesIO(audio_content))
174
+ except Exception as e:
175
+ raise HTTPException(status_code=400, detail=f"Invalid audio file: {str(e)}")
176
+
177
+ results = predict_unimodal(audio_data)
178
+
179
+ return AudioOnlyPredictionResponse(
180
+ status="success",
181
+ audio_file_name=audio_file_name,
182
+ audio_content_type=audio_content_type,
183
+ audio_file_size=len(audio_content),
184
+ results=results,
185
+ )
186
+ except HTTPException:
187
+ raise
188
+ except Exception as e:
189
+ raise HTTPException(status_code=500, detail=str(e))
190
+
191
+
192
+ @app.post("/api/v1/explain/audio", response_model=AudioOnlyPredictionXAIResponse)
193
+ async def explain_audio_only_endpoint(
194
+ audio_data_tuple: Tuple = Depends(validate_audio_only),
195
+ ):
196
+ """Audio-only explanation endpoint."""
197
+ try:
198
+ audio_content, audio_file_name, audio_content_type = audio_data_tuple
199
+
200
+ try:
201
+ audio_data, sr = librosa.load(io.BytesIO(audio_content))
202
+ except Exception as e:
203
+ raise HTTPException(status_code=400, detail=f"Invalid audio file: {str(e)}")
204
+
205
+ results = musiclime_unimodal(audio_data, modality="audio")
206
+
207
+ return AudioOnlyPredictionXAIResponse(
208
+ status="success",
209
+ audio_file_name=audio_file_name,
210
+ audio_content_type=audio_content_type,
211
+ audio_file_size=len(audio_content),
212
+ results=results,
213
+ )
214
+ except HTTPException:
215
+ raise
216
+ except Exception as e:
217
+ raise HTTPException(status_code=500, detail=str(e))
218
+
219
+
220
  @app.get("/api/v1/model/info", response_model=ModelInfoResponse, tags=["Model"])
221
  async def get_model_info():
222
  """
app/validators.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import File, Form, HTTPException, UploadFile
2
+ from typing import Optional, Tuple
3
+ from app.utils import download_youtube_audio
4
+
5
+
6
+ # Import config values
7
+ def get_config_values():
8
+ from app.server import MAX_FILE_SIZE, MAX_LYRICS_LENGTH, ALLOWED_AUDIO_TYPES
9
+
10
+ return MAX_FILE_SIZE, MAX_LYRICS_LENGTH, ALLOWED_AUDIO_TYPES
11
+
12
+
13
+ def validate_lyrics(lyrics: str = Form(...)):
14
+ """Validate lyrics length and content for multimodal endpoints."""
15
+ _, MAX_LYRICS_LENGTH, _ = get_config_values()
16
+
17
+ if len(lyrics) > MAX_LYRICS_LENGTH:
18
+ raise HTTPException(
19
+ status_code=400,
20
+ detail=f"Lyrics too long. Maximum length is {MAX_LYRICS_LENGTH} characters.",
21
+ )
22
+
23
+ lyrics = lyrics.strip()
24
+ if not lyrics:
25
+ raise HTTPException(
26
+ status_code=400,
27
+ detail="Lyrics cannot be empty.",
28
+ )
29
+ return lyrics
30
+
31
+
32
+ async def validate_audio_source(
33
+ audio_file: Optional[UploadFile] = File(None),
34
+ youtube_url: Optional[str] = Form(None),
35
+ ) -> Tuple[Optional[bytes], str, str]:
36
+ """Validate and process audio source from file upload or YouTube URL."""
37
+ MAX_FILE_SIZE, _, ALLOWED_AUDIO_TYPES = get_config_values()
38
+
39
+ if not audio_file and not youtube_url:
40
+ raise HTTPException(
41
+ status_code=400, detail="Either audio_file or youtube_url must be provided"
42
+ )
43
+
44
+ if audio_file and youtube_url:
45
+ raise HTTPException(
46
+ status_code=400, detail="Provide either audio_file or youtube_url, not both"
47
+ )
48
+
49
+ if youtube_url:
50
+ audio_content = download_youtube_audio(youtube_url)
51
+ return audio_content, "youtube_audio.wav", "audio/wav"
52
+
53
+ if audio_file.content_type not in ALLOWED_AUDIO_TYPES:
54
+ raise HTTPException(
55
+ status_code=400,
56
+ detail=f"Invalid file type. Supported formats: {', '.join(ALLOWED_AUDIO_TYPES)}",
57
+ )
58
+
59
+ audio_content = await audio_file.read()
60
+ if len(audio_content) > MAX_FILE_SIZE:
61
+ raise HTTPException(
62
+ status_code=400,
63
+ detail=f"File too large. Maximum size is {MAX_FILE_SIZE // (1024*1024)}MB.",
64
+ )
65
+
66
+ return audio_content, audio_file.filename, audio_file.content_type
67
+
68
+
69
+ async def validate_audio_only(
70
+ audio_file: Optional[UploadFile] = File(None),
71
+ youtube_url: Optional[str] = Form(None),
72
+ ) -> Tuple[Optional[bytes], str, str]:
73
+ """Validate audio source for audio-only endpoints (no lyrics required)."""
74
+ # Same validation as validate_audio_source but clearer naming for audio-only
75
+ return await validate_audio_source(audio_file, youtube_url)
scripts/explain.py CHANGED
@@ -2,17 +2,25 @@ import os
2
  import numpy as np
3
  from datetime import datetime
4
  from src.musiclime.explainer import MusicLIMEExplainer
5
- from src.musiclime.wrapper import MusicLIMEPredictor
6
 
7
 
8
- def musiclime(audio_data, lyrics_text):
9
  """
10
- MusicLIME wrapper for API usage.
11
- Args:
12
- audio_data: Audio array (from librosa.load or similar)
13
- lyrics_text: String containing lyrics
14
- Returns:
15
- dict: Structured explanation results
 
 
 
 
 
 
 
 
16
  """
17
  start_time = datetime.now()
18
 
@@ -77,3 +85,97 @@ def musiclime(audio_data, lyrics_text):
77
  "timestamp": start_time.isoformat(),
78
  },
79
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import numpy as np
3
  from datetime import datetime
4
  from src.musiclime.explainer import MusicLIMEExplainer
5
+ from src.musiclime.wrapper import MusicLIMEPredictor, AudioOnlyPredictor
6
 
7
 
8
+ def musiclime_multimodal(audio_data, lyrics_text):
9
  """
10
+ Generate multimodal MusicLIME explanations for audio and lyrics.
11
+
12
+ Parameters
13
+ ----------
14
+ audio_data : array-like
15
+ Audio waveform data from librosa.load or similar
16
+ lyrics_text : str
17
+ String containing song lyrics
18
+
19
+ Returns
20
+ -------
21
+ dict
22
+ Structured explanation results containing prediction info, feature explanations,
23
+ and processing metadata
24
  """
25
  start_time = datetime.now()
26
 
 
85
  "timestamp": start_time.isoformat(),
86
  },
87
  }
88
+
89
+
90
+ def musiclime_unimodal(audio_data, modality="audio"):
91
+ """
92
+ Generate unimodal MusicLIME explanations for single modality.
93
+
94
+ Parameters
95
+ ----------
96
+ audio_data : array-like
97
+ Audio waveform data from librosa.load or similar
98
+ modality : str, default='audio'
99
+ Explanation modality, currently only supports 'audio'
100
+
101
+ Returns
102
+ -------
103
+ dict
104
+ Structured explanation results containing prediction info, audio-only feature
105
+ explanations, and processing metadata
106
+
107
+ Raises
108
+ ------
109
+ ValueError
110
+ If modality is not 'audio' (lyrics is not yet implemented)
111
+ """
112
+ if modality != "audio":
113
+ raise ValueError(
114
+ "Currently only 'audio' modality is supported for unimodal explanations"
115
+ )
116
+
117
+ start_time = datetime.now()
118
+
119
+ # Get number of samples from environment variable, default to 1000
120
+ num_samples = int(os.getenv("MUSICLIME_NUM_SAMPLES", "1000"))
121
+ num_features = int(os.getenv("MUSICLIME_NUM_FEATURES", "10"))
122
+
123
+ print(
124
+ f"[MusicLIME] Using num_samples={num_samples}, num_features={num_features} (audio-only mode)"
125
+ )
126
+
127
+ # Create musiclime instances
128
+ explainer = MusicLIMEExplainer(random_state=42)
129
+ predictor = AudioOnlyPredictor()
130
+
131
+ # Use empty lyrics for audio-only since they're ignored anyways
132
+ dummy_lyrics = ""
133
+
134
+ # Generate explanation
135
+ explanation = explainer.explain_instance(
136
+ audio=audio_data,
137
+ lyrics=dummy_lyrics,
138
+ predict_fn=predictor,
139
+ num_samples=num_samples,
140
+ labels=(1,),
141
+ modality=modality,
142
+ )
143
+
144
+ # Get prediction info
145
+ original_prediction = explanation.predictions[0]
146
+ predicted_class = np.argmax(original_prediction)
147
+ confidence = float(np.max(original_prediction))
148
+
149
+ # Get top features
150
+ top_features = explanation.get_explanation(label=1, num_features=num_features)
151
+
152
+ # Calculate runtime
153
+ end_time = datetime.now()
154
+ runtime_seconds = (end_time - start_time).total_seconds()
155
+
156
+ return {
157
+ "prediction": {
158
+ "class": int(predicted_class),
159
+ "class_name": "Human-Composed" if predicted_class == 1 else "AI-Generated",
160
+ "confidence": confidence,
161
+ "probabilities": original_prediction.tolist(),
162
+ },
163
+ "explanations": [
164
+ {
165
+ "rank": i + 1,
166
+ "modality": item["type"], # "audio" for all features
167
+ "feature_text": item["feature"],
168
+ "weight": float(item["weight"]),
169
+ "importance": abs(float(item["weight"])),
170
+ }
171
+ for i, item in enumerate(top_features)
172
+ ],
173
+ "summary": {
174
+ "total_features_analyzed": len(top_features),
175
+ "audio_features_count": len(top_features), # All features are audio
176
+ "lyrics_features_count": 0, # No lyrics features
177
+ "runtime_seconds": runtime_seconds,
178
+ "samples_generated": num_samples,
179
+ "timestamp": start_time.isoformat(),
180
+ },
181
+ }
scripts/explain_runner.py CHANGED
@@ -1,30 +1,69 @@
1
  import librosa
2
- from scripts.explain import musiclime
3
 
4
- # Load test audio and lyrics
5
- audio_path = "data/external/sample_1.mp3"
6
- lyrics_path = "data/external/sample_1.txt"
7
 
8
- # Load audio
9
- audio_data, sr = librosa.load(audio_path)
 
 
10
 
11
- # Load lyrics
12
- with open(lyrics_path, "r", encoding="utf-8") as f:
13
- lyrics_text = f.read()
14
 
15
- print("Running MusicLIME explanation...")
16
- result = musiclime(audio_data, lyrics_text)
 
17
 
18
- print("\n=== EXPLANATION RESULTS ===")
19
- print(
20
- f"Prediction: {result['prediction']['class_name']} ({result['prediction']['confidence']:.3f})"
21
- )
22
- print(f"Runtime: {result['summary']['runtime_seconds']:.2f}s")
23
 
24
- print("\n=== TOP FEATURES (by absolute importance) ===")
25
- for feature in result["explanations"]:
26
  print(
27
- f"Rank {feature['rank']}: {feature['modality']} | Weight: {feature['weight']:.4f} | Importance: {feature['importance']:.4f}"
28
  )
29
- print(f" Feature: {feature['feature_text'][:80]}...")
30
- print()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import librosa
2
+ from scripts.explain import musiclime_multimodal, musiclime_unimodal
3
 
 
 
 
4
 
5
+ def explain_multimodal_runner(sample: str):
6
+ # Load test audio and lyrics
7
+ audio_path = f"data/external/{sample}.mp3"
8
+ lyrics_path = f"data/external/{sample}.txt"
9
 
10
+ # Load audio
11
+ audio_data, sr = librosa.load(audio_path)
 
12
 
13
+ # Load lyrics
14
+ with open(lyrics_path, "r", encoding="utf-8") as f:
15
+ lyrics_text = f.read()
16
 
17
+ print("Running multimodal MusicLIME explanation...")
18
+ result = musiclime_multimodal(audio_data, lyrics_text)
 
 
 
19
 
20
+ print("\n=== MULTIMODAL EXPLANATION RESULTS ===")
 
21
  print(
22
+ f"Prediction: {result['prediction']['class_name']} ({result['prediction']['confidence']:.3f})"
23
  )
24
+ print(f"Runtime: {result['summary']['runtime_seconds']:.2f}s")
25
+
26
+ print("\n=== TOP FEATURES (by absolute importance) ===")
27
+ for feature in result["explanations"]:
28
+ print(
29
+ f"Rank {feature['rank']}: {feature['modality']} | Weight: {feature['weight']:.4f} | Importance: {feature['importance']:.4f}"
30
+ )
31
+ print(f" Feature: {feature['feature_text'][:80]}...")
32
+ print()
33
+
34
+
35
+ def explain_unimodal_runner(sample: str):
36
+ # Load test audio
37
+ audio_path = f"data/external/{sample}.mp3"
38
+
39
+ # Load audio
40
+ audio_data, sr = librosa.load(audio_path)
41
+
42
+ print("Running audio-only MusicLIME explanation...")
43
+ result = musiclime_unimodal(audio_data, modality="audio")
44
+
45
+ print("\n=== AUDIO-ONLY EXPLANATION RESULTS ===")
46
+ print(
47
+ f"Prediction: {result['prediction']['class_name']} ({result['prediction']['confidence']:.3f})"
48
+ )
49
+ print(f"Runtime: {result['summary']['runtime_seconds']:.2f}s")
50
+
51
+ print("\n=== TOP FEATURES (by absolute importance) ===")
52
+ for feature in result["explanations"]:
53
+ print(
54
+ f"Rank {feature['rank']}: {feature['modality']} | Weight: {feature['weight']:.4f} | Importance: {feature['importance']:.4f}"
55
+ )
56
+ print(f" Feature: {feature['feature_text'][:80]}...")
57
+ print()
58
+
59
+
60
+ if __name__ == "__main__":
61
+ sample = "sample"
62
+
63
+ # Run multimodal explanation
64
+ explain_multimodal_runner(sample)
65
+
66
+ print("\n" + "=" * 60 + "\n")
67
+
68
+ # Run audio-only explanation
69
+ explain_unimodal_runner(sample)
scripts/predict.py CHANGED
@@ -1,4 +1,7 @@
1
- from src.preprocessing.preprocessor import single_preprocessing, single_audio_preprocessing
 
 
 
2
  from src.spectttra.spectttra_trainer import spectttra_predict
3
  from src.llm2vectrain.model import load_llm2vec_model
4
  from src.llm2vectrain.llm2vec_trainer import l2vec_single_train, load_pca_model
@@ -55,7 +58,7 @@ def predict_multimodal(audio_file, lyrics):
55
  classifier = build_mlp(input_dim=results.shape[1], config=config)
56
 
57
  # 7.) Load trained weights
58
- model_path = "models/mlp/mlp_best.pth"
59
  classifier.load_model(model_path)
60
  classifier.model.eval()
61
 
@@ -106,7 +109,7 @@ def predict_unimodal(audio_file):
106
  classifier = build_mlp(input_dim=audio_features.shape[1], config=config)
107
 
108
  # 6.) Load trained weights
109
- model_path = "models/spectttra/mlp_best.pth"
110
  classifier.load_model(model_path)
111
  classifier.model.eval()
112
 
 
1
+ from src.preprocessing.preprocessor import (
2
+ single_preprocessing,
3
+ single_audio_preprocessing,
4
+ )
5
  from src.spectttra.spectttra_trainer import spectttra_predict
6
  from src.llm2vectrain.model import load_llm2vec_model
7
  from src.llm2vectrain.llm2vec_trainer import l2vec_single_train, load_pca_model
 
58
  classifier = build_mlp(input_dim=results.shape[1], config=config)
59
 
60
  # 7.) Load trained weights
61
+ model_path = "models/mlp/mlp_best_multimodal.pth"
62
  classifier.load_model(model_path)
63
  classifier.model.eval()
64
 
 
109
  classifier = build_mlp(input_dim=audio_features.shape[1], config=config)
110
 
111
  # 6.) Load trained weights
112
+ model_path = "models/mlp/mlp_best_unimodal.pth"
113
  classifier.load_model(model_path)
114
  classifier.model.eval()
115
 
scripts/predict_runner.py CHANGED
@@ -14,28 +14,33 @@ def predict_multimodal_runner(sample: str):
14
  with open(lyrics_path, "r", encoding="utf-8") as f:
15
  lyrics_text = f.read()
16
 
17
- print("Running prediction pipeline...")
18
  prediction = predict_multimodal(audio_data, lyrics_text)
19
 
20
- print(f"\n=== PREDICTION RESULT ===")
21
  print(f"Prediction: {prediction}")
22
 
23
 
24
  def predict_unimodal_runner(sample: str):
25
- # Load test audio and lyrics
26
- audio_path = f"data/raw/{sample}.mp3"
27
 
28
  # Load audio
29
  audio_data, sr = librosa.load(audio_path)
30
 
31
- print("Running prediction pipeline...")
32
  prediction = predict_unimodal(audio_data)
33
 
34
- print(f"\n=== PREDICTION RESULT ===")
35
  print(f"Prediction: {prediction}")
36
 
37
 
38
  if __name__ == "__main__":
39
- sample = "fake_sunshine"
 
 
 
 
 
40
 
41
- predict_unimodal_runner(sample)
 
14
  with open(lyrics_path, "r", encoding="utf-8") as f:
15
  lyrics_text = f.read()
16
 
17
+ print("Running multimodal prediction pipeline...")
18
  prediction = predict_multimodal(audio_data, lyrics_text)
19
 
20
+ print("\n=== MULTIMODAL PREDICTION RESULT ===")
21
  print(f"Prediction: {prediction}")
22
 
23
 
24
  def predict_unimodal_runner(sample: str):
25
+ # Load test audio
26
+ audio_path = f"data/external/{sample}.mp3"
27
 
28
  # Load audio
29
  audio_data, sr = librosa.load(audio_path)
30
 
31
+ print("Running audio-only prediction pipeline...")
32
  prediction = predict_unimodal(audio_data)
33
 
34
+ print("\n=== AUDIO-ONLY PREDICTION RESULT ===")
35
  print(f"Prediction: {prediction}")
36
 
37
 
38
  if __name__ == "__main__":
39
+ sample = "sample"
40
+
41
+ # Run both predictions
42
+ predict_multimodal_runner(sample)
43
+
44
+ print("\n" + "=" * 50 + "\n")
45
 
46
+ predict_unimodal_runner(sample)
src/musiclime/explainer.py CHANGED
@@ -57,9 +57,15 @@ class MusicLIMEExplainer:
57
  num_samples=1000,
58
  labels=(1,),
59
  temporal_segments=10,
 
60
  ):
61
  """
62
- Generate LIME explanations for a music instance using audio and lyrics.
 
 
 
 
 
63
 
64
  Parameters
65
  ----------
@@ -75,18 +81,26 @@ class MusicLIMEExplainer:
75
  Target labels to explain (0=AI-Generated, 1=Human-Composed)
76
  temporal_segments : int, default=10
77
  Number of temporal segments for audio factorization
 
 
78
 
79
  Returns
80
  -------
81
  MusicLIMEExplanation
82
- Explanation object containing feature importance weights
83
  """
 
 
 
 
84
  # These are for debugging only I have to see THAT progress
85
  print("[MusicLIME] Starting MusicLIME explanation...")
86
  print(
87
  f"[MusicLIME] Audio length: {len(audio)/22050:.1f}s, Temporal segments: {temporal_segments}"
88
  )
89
  print(f"[MusicLIME] Lyrics lines: {len(lyrics.split(chr(10)))}")
 
 
90
 
91
  # Create factorizations
92
  print("[MusicLIME] Creating audio factorization (source separation)...")
@@ -111,7 +125,7 @@ class MusicLIMEExplainer:
111
  # Generate perturbations and get predictions
112
  print(f"[MusicLIME] Generating {num_samples} perturbations...")
113
  data, predictions, distances = self._generate_neighborhood(
114
- audio_factorization, text_factorization, predict_fn, num_samples
115
  )
116
 
117
  # LIME fitting, create explanation object
@@ -140,33 +154,55 @@ class MusicLIMEExplainer:
140
 
141
  return explanation
142
 
143
- def _generate_neighborhood(self, audio_fact, text_fact, predict_fn, num_samples):
 
 
144
  """
145
- Generate perturbed samples and predictions for LIME explanation.
 
 
 
 
 
 
146
 
147
  Parameters
148
  ----------
149
  audio_fact : OpenUnmixFactorization
150
- Audio factorization object for source separation
151
  text_fact : LineIndexedString
152
- Text factorization object for line-based perturbations
153
  predict_fn : callable
154
- Model prediction function
155
  num_samples : int
156
- Number of perturbations to generate
 
 
157
 
158
  Returns
159
  -------
160
  data : ndarray
161
- Binary perturbation masks (num_samples, total_features)
162
  predictions : ndarray
163
- Model predictions for perturbed instances
164
  distances : ndarray
165
- Cosine distances from original instance
 
 
 
 
 
166
  """
167
  n_audio = audio_fact.get_number_components()
168
  n_text = text_fact.num_words()
169
- total_features = n_audio + n_text
 
 
 
 
 
 
 
170
 
171
  print(
172
  f"[MusicLIME] Total features: {total_features} ({n_audio} audio + {n_text} text)"
@@ -187,22 +223,46 @@ class MusicLIMEExplainer:
187
  texts = []
188
  audios = []
189
 
190
- for i, row in enumerate(data):
191
- # Progress check for every hundred samples
192
- if i % 100 == 0:
193
- print(f"[MusicLIME] Progress: {i}/{num_samples} samples")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
- # Audio perturbation & reconstruction
196
- audio_mask = row[:n_audio]
197
- active_audio_components = np.where(audio_mask != 0)[0]
198
- perturbed_audio = audio_fact.compose_model_input(active_audio_components)
199
- audios.append(perturbed_audio)
200
 
201
- # Text perturbation & reconstruction
202
- text_mask = row[n_audio:]
203
- inactive_lines = np.where(text_mask == 0)[0]
204
- perturbed_text = text_fact.inverse_removing(inactive_lines)
205
- texts.append(perturbed_text)
206
 
207
  perturbation_time = time.time() - start_time
208
  print(
@@ -221,7 +281,7 @@ class MusicLIMEExplainer:
221
  confidence = original_prediction[predicted_class]
222
 
223
  # Print original prediction
224
- print(f"[MusicLIME] Original Prediction:")
225
  print(
226
  f" Raw probabilities: [AI: {original_prediction[0]:.3f}, Human: {original_prediction[1]:.3f}]"
227
  )
 
57
  num_samples=1000,
58
  labels=(1,),
59
  temporal_segments=10,
60
+ modality="both",
61
  ):
62
  """
63
+ Generate LIME explanations for a music instance using audio and/or lyrics.
64
+
65
+ This method creates local explanations by perturbing audio components (via source
66
+ separation) and/or lyrics lines, then analyzing their impact on model predictions.
67
+ Supports three modality modes: 'both' (multimodal), 'audio' (audio-only), and
68
+ 'lyrical' (lyrics-only) following the original MusicLIME paper implementation.
69
 
70
  Parameters
71
  ----------
 
81
  Target labels to explain (0=AI-Generated, 1=Human-Composed)
82
  temporal_segments : int, default=10
83
  Number of temporal segments for audio factorization
84
+ modality : str, default='both'
85
+ Explanation modality: 'both' (multimodal), 'audio' (audio-only), or 'lyrical' (lyrics-only)
86
 
87
  Returns
88
  -------
89
  MusicLIMEExplanation
90
+ Explanation object containing feature importance weights and metadata
91
  """
92
+ # Validation for modality choice
93
+ if modality not in ["both", "audio", "lyrical"]:
94
+ raise ValueError("Set modality argument to 'both', 'audio', 'lyrical'.")
95
+
96
  # These are for debugging only I have to see THAT progress
97
  print("[MusicLIME] Starting MusicLIME explanation...")
98
  print(
99
  f"[MusicLIME] Audio length: {len(audio)/22050:.1f}s, Temporal segments: {temporal_segments}"
100
  )
101
  print(f"[MusicLIME] Lyrics lines: {len(lyrics.split(chr(10)))}")
102
+ print("[MusicLIME] Starting MusicLIME explanation...")
103
+ print(f"[MusicLIME] Modality: {modality}")
104
 
105
  # Create factorizations
106
  print("[MusicLIME] Creating audio factorization (source separation)...")
 
125
  # Generate perturbations and get predictions
126
  print(f"[MusicLIME] Generating {num_samples} perturbations...")
127
  data, predictions, distances = self._generate_neighborhood(
128
+ audio_factorization, text_factorization, predict_fn, num_samples, modality
129
  )
130
 
131
  # LIME fitting, create explanation object
 
154
 
155
  return explanation
156
 
157
+ def _generate_neighborhood(
158
+ self, audio_fact, text_fact, predict_fn, num_samples, modality="both"
159
+ ):
160
  """
161
+ Generate perturbed samples and predictions for LIME explanation based on modality.
162
+
163
+ Creates binary perturbation masks and generates corresponding perturbed audio-text
164
+ pairs. The perturbation strategy depends on the specified modality:
165
+ - 'both': Perturbs both audio components and lyrics lines independently
166
+ - 'audio': Perturbs only audio components, keeps original lyrics constant
167
+ - 'lyrical': Perturbs only lyrics lines, keeps original audio constant
168
 
169
  Parameters
170
  ----------
171
  audio_fact : OpenUnmixFactorization
172
+ Audio factorization object for source separation-based perturbations
173
  text_fact : LineIndexedString
174
+ Text factorization object for line-based lyrics perturbations
175
  predict_fn : callable
176
+ Model prediction function that processes (texts, audios) batches
177
  num_samples : int
178
+ Number of perturbation samples to generate for LIME
179
+ modality : str, default='both'
180
+ Perturbation modality: 'both', 'audio', or 'lyrical'
181
 
182
  Returns
183
  -------
184
  data : ndarray
185
+ Binary perturbation masks of shape (num_samples, total_features)
186
  predictions : ndarray
187
+ Model predictions for perturbed instances of shape (num_samples, n_classes)
188
  distances : ndarray
189
+ Cosine distances from original instance of shape (num_samples,)
190
+
191
+ Notes
192
+ -----
193
+ The first sample (index 0) is always the original unperturbed instance.
194
+ Feature ordering: [audio_components, lyrics_lines] for 'both' modality.
195
  """
196
  n_audio = audio_fact.get_number_components()
197
  n_text = text_fact.num_words()
198
+
199
+ # Set total features based on modality
200
+ if modality == "both":
201
+ total_features = n_audio + n_text
202
+ elif modality == "audio":
203
+ total_features = n_audio
204
+ elif modality == "lyrical":
205
+ total_features = n_text
206
 
207
  print(
208
  f"[MusicLIME] Total features: {total_features} ({n_audio} audio + {n_text} text)"
 
223
  texts = []
224
  audios = []
225
 
226
+ for _, row in enumerate(data):
227
+ if modality == "both":
228
+ # Audio perturbation & reconstruction
229
+ audio_mask = row[:n_audio]
230
+ active_audio_components = np.where(audio_mask != 0)[0]
231
+ perturbed_audio = audio_fact.compose_model_input(
232
+ active_audio_components
233
+ )
234
+ audios.append(perturbed_audio)
235
+
236
+ # Text perturbation & reconstruction
237
+ text_mask = row[n_audio:]
238
+ inactive_lines = np.where(text_mask == 0)[0]
239
+ perturbed_text = text_fact.inverse_removing(inactive_lines)
240
+ texts.append(perturbed_text)
241
+
242
+ elif modality == "audio":
243
+ # Audio perturbation, original lyrics
244
+ active_audio_components = np.where(row != 0)[0]
245
+ perturbed_audio = audio_fact.compose_model_input(
246
+ active_audio_components
247
+ )
248
+ audios.append(perturbed_audio)
249
+
250
+ # Use original lyrics (no perturbation)
251
+ perturbed_text = text_fact.inverse_removing(
252
+ []
253
+ ) # Empty array = no removal
254
+ texts.append(perturbed_text)
255
 
256
+ elif modality == "lyrical":
257
+ # Original audio, lyrics perturbation
258
+ all_audio_components = np.arange(n_audio) # Use all audio components
259
+ perturbed_audio = audio_fact.compose_model_input(all_audio_components)
260
+ audios.append(perturbed_audio)
261
 
262
+ # Perturb lyrics
263
+ inactive_lines = np.where(row == 0)[0]
264
+ perturbed_text = text_fact.inverse_removing(inactive_lines)
265
+ texts.append(perturbed_text)
 
266
 
267
  perturbation_time = time.time() - start_time
268
  print(
 
281
  confidence = original_prediction[predicted_class]
282
 
283
  # Print original prediction
284
+ print("[MusicLIME] Original Prediction:")
285
  print(
286
  f" Raw probabilities: [AI: {original_prediction[0]:.3f}, Human: {original_prediction[1]:.3f}]"
287
  )
src/musiclime/wrapper.py CHANGED
@@ -3,7 +3,10 @@ import joblib
3
  import numpy as np
4
  import torch
5
 
6
- from src.preprocessing.preprocessor import single_preprocessing
 
 
 
7
  from src.spectttra.spectttra_trainer import spectttra_train
8
  from src.llm2vectrain.llm2vec_trainer import l2vec_train
9
  from src.llm2vectrain.model import load_llm2vec_model
@@ -159,7 +162,7 @@ class MusicLIMEPredictor:
159
  self.classifier = build_mlp(
160
  input_dim=combined_features_batch.shape[1], config=self.config
161
  )
162
- self.classifier.load_model("models/mlp/mlp_best.pth")
163
 
164
  probabilities, predictions = self.classifier.predict(combined_features_batch)
165
 
@@ -172,7 +175,7 @@ class MusicLIMEPredictor:
172
  total_time = (
173
  preprocessing_time + audio_time + lyrics_time + scaling_time + mlp_time
174
  )
175
- print(f"[MusicLIME] Batch processing complete!")
176
  print(
177
  green_bold(
178
  f"[MusicLIME] Total time: {total_time:.2f}s (Preprocessing: {preprocessing_time:.2f}s, Audio: {audio_time:.2f}s, Lyrics: {lyrics_time:.2f}s, Scaling: {scaling_time:.2f}s, MLP: {mlp_time:.2f}s)"
@@ -180,3 +183,143 @@ class MusicLIMEPredictor:
180
  )
181
 
182
  return np.array(batch_results)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import numpy as np
4
  import torch
5
 
6
+ from src.preprocessing.preprocessor import (
7
+ single_preprocessing,
8
+ single_audio_preprocessing,
9
+ )
10
  from src.spectttra.spectttra_trainer import spectttra_train
11
  from src.llm2vectrain.llm2vec_trainer import l2vec_train
12
  from src.llm2vectrain.model import load_llm2vec_model
 
162
  self.classifier = build_mlp(
163
  input_dim=combined_features_batch.shape[1], config=self.config
164
  )
165
+ self.classifier.load_model("models/mlp/mlp_best_multimodal.pth")
166
 
167
  probabilities, predictions = self.classifier.predict(combined_features_batch)
168
 
 
175
  total_time = (
176
  preprocessing_time + audio_time + lyrics_time + scaling_time + mlp_time
177
  )
178
+ print("[MusicLIME] Batch processing complete!")
179
  print(
180
  green_bold(
181
  f"[MusicLIME] Total time: {total_time:.2f}s (Preprocessing: {preprocessing_time:.2f}s, Audio: {audio_time:.2f}s, Lyrics: {lyrics_time:.2f}s, Scaling: {scaling_time:.2f}s, MLP: {mlp_time:.2f}s)"
 
183
  )
184
 
185
  return np.array(batch_results)
186
+
187
+
188
+ class AudioOnlyPredictor:
189
+ """
190
+ Audio-only prediction wrapper for MusicLIME explanations.
191
+
192
+ Integrates the audio-only Bach or Bot pipeline (SpecTTTra + MLP) into a single
193
+ callable for LIME perturbation processing. Optimized for batch processing of
194
+ multiple perturbed audio samples while ignoring lyrics input. Mirrors the
195
+ multimodal MusicLIMEPredictor but processes only audio features.
196
+
197
+ This predictor is specifically designed for audio-only explainability where
198
+ lyrics are kept constant and only audio components are perturbed through
199
+ source separation techniques.
200
+
201
+ Attributes
202
+ ----------
203
+ classifier : MLPClassifier or None
204
+ Lazy-loaded MLP classifier for audio-only predictions
205
+ config : dict
206
+ Model configuration parameters loaded from config files
207
+ """
208
+
209
+ def __init__(self):
210
+ """
211
+ Initialize audio-only prediction wrapper.
212
+
213
+ Loads model configuration for batch processing of perturbed audio samples
214
+ during LIME explanation. The MLP classifier is lazy-loaded on first use
215
+ to optimize memory usage.
216
+ """
217
+ print("[MusicLIME] Loading models for Audio-Only MusicLIME...")
218
+ config = load_config("config/model_config.yml")
219
+ self.classifier = None
220
+ self.config = config
221
+
222
+ def __call__(self, texts, audios):
223
+ """
224
+ Batch prediction function for audio-only MusicLIME perturbations.
225
+
226
+ Processes multiple perturbed audio samples through the audio-only pipeline:
227
+ preprocessing -> SpecTTTra feature extraction -> scaling -> MLP prediction.
228
+ Text inputs are ignored as this is audio-only mode. Optimized for batch
229
+ processing of LIME perturbations with detailed timing analysis.
230
+
231
+ Parameters
232
+ ----------
233
+ texts : list of str
234
+ List of text strings (ignored in audio-only mode, kept for API compatibility)
235
+ audios : list of array-like
236
+ List of perturbed audio waveforms from LIME perturbations
237
+
238
+ Returns
239
+ -------
240
+ ndarray
241
+ Prediction probabilities in format [[P(AI), P(Human)], ...]
242
+ for each input audio sample, shape (n_samples, 2)
243
+ """
244
+ print(
245
+ f"[MusicLIME] Processing {len(audios)} samples with batch functions (audio-only mode)..."
246
+ )
247
+
248
+ # Step 1: Preprocess all audio samples
249
+ start_time = time.time()
250
+ print("[MusicLIME] Preprocessing audio samples...")
251
+ processed_audios = []
252
+
253
+ for audio in audios:
254
+ processed_audio = single_audio_preprocessing(audio)
255
+ processed_audios.append(processed_audio)
256
+
257
+ preprocessing_time = time.time() - start_time
258
+ print(
259
+ green_bold(
260
+ f"[MusicLIME] Audio preprocessing completed in {preprocessing_time:.2f}s"
261
+ )
262
+ )
263
+
264
+ # Step 2: Batch audio feature extraction
265
+ start_time = time.time()
266
+ print("[MusicLIME] Extracting audio features (batch)...")
267
+ audio_features_batch = spectttra_train(processed_audios)
268
+
269
+ # Clear GPU cache after audio processing
270
+ if torch.cuda.is_available():
271
+ torch.cuda.empty_cache()
272
+
273
+ audio_time = time.time() - start_time
274
+ print(
275
+ green_bold(
276
+ f"[MusicLIME] Audio feature extraction completed in {audio_time:.2f}s"
277
+ )
278
+ )
279
+
280
+ # Step 3: Scale audio features in batch
281
+ start_time = time.time()
282
+ print("[MusicLIME] Scaling audio features (batch)...")
283
+
284
+ # Load the audio scaler
285
+ audio_scaler = joblib.load("models/fusion/audio_scaler.pkl")
286
+ scaled_audio_batch = audio_scaler.transform(audio_features_batch)
287
+
288
+ scaling_time = time.time() - start_time
289
+ print(green_bold(f"[MusicLIME] Audio scaling completed in {scaling_time:.2f}s"))
290
+
291
+ # Step 4: Audio-only MLP prediction
292
+ start_time = time.time()
293
+ print("[MusicLIME] Running audio-only MLP predictions (batch)...")
294
+
295
+ if self.classifier is None:
296
+ self.classifier = build_mlp(
297
+ input_dim=scaled_audio_batch.shape[1], config=self.config
298
+ )
299
+ self.classifier.load_model("models/mlp/mlp_best_unimodal.pth")
300
+
301
+ probabilities, predictions = self.classifier.predict(scaled_audio_batch)
302
+
303
+ # Clear GPU cache after MLP processing
304
+ if torch.cuda.is_available():
305
+ torch.cuda.empty_cache()
306
+
307
+ # Convert to expected format
308
+ batch_results = [[1 - prob, prob] for prob in probabilities]
309
+ mlp_time = time.time() - start_time
310
+ print(
311
+ green_bold(
312
+ f"[MusicLIME] Audio-only MLP prediction completed in {mlp_time:.2f}s"
313
+ )
314
+ )
315
+
316
+ # Total time summary
317
+ total_time = preprocessing_time + audio_time + scaling_time + mlp_time
318
+ print("[MusicLIME] Audio-only batch processing complete!")
319
+ print(
320
+ green_bold(
321
+ f"[MusicLIME] Total time: {total_time:.2f}s (Preprocessing: {preprocessing_time:.2f}s, Audio: {audio_time:.2f}s, Scaling: {scaling_time:.2f}s, MLP: {mlp_time:.2f}s)"
322
+ )
323
+ )
324
+
325
+ return np.array(batch_results)