krislette commited on
Commit
f6feac1
·
1 Parent(s): 54f0d32

Made sample number configurable to prevent container rebuild

Browse files
Files changed (1) hide show
  1. scripts/explain.py +11 -6
scripts/explain.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import numpy as np
2
  from datetime import datetime
3
  from src.musiclime.explainer import MusicLIMEExplainer
@@ -7,16 +8,20 @@ from src.musiclime.wrapper import MusicLIMEPredictor
7
  def musiclime(audio_data, lyrics_text):
8
  """
9
  MusicLIME wrapper for API usage.
10
-
11
  Args:
12
  audio_data: Audio array (from librosa.load or similar)
13
  lyrics_text: String containing lyrics
14
-
15
  Returns:
16
  dict: Structured explanation results
17
  """
18
  start_time = datetime.now()
19
 
 
 
 
 
 
 
20
  # Create musiclime instances
21
  explainer = MusicLIMEExplainer()
22
  predictor = MusicLIMEPredictor()
@@ -26,7 +31,7 @@ def musiclime(audio_data, lyrics_text):
26
  audio=audio_data,
27
  lyrics=lyrics_text,
28
  predict_fn=predictor,
29
- num_samples=1000,
30
  labels=(1,),
31
  )
32
 
@@ -35,8 +40,8 @@ def musiclime(audio_data, lyrics_text):
35
  predicted_class = np.argmax(original_prediction)
36
  confidence = float(np.max(original_prediction))
37
 
38
- # Get top 10 features
39
- top_features = explanation.get_explanation(label=1, num_features=10)
40
 
41
  # Calculate runtime
42
  end_time = datetime.now()
@@ -68,7 +73,7 @@ def musiclime(audio_data, lyrics_text):
68
  [f for f in top_features if f["type"] == "lyrics"]
69
  ),
70
  "runtime_seconds": runtime_seconds,
71
- "samples_generated": 1000,
72
  "timestamp": start_time.isoformat(),
73
  },
74
  }
 
1
+ import os
2
  import numpy as np
3
  from datetime import datetime
4
  from src.musiclime.explainer import MusicLIMEExplainer
 
8
  def musiclime(audio_data, lyrics_text):
9
  """
10
  MusicLIME wrapper for API usage.
 
11
  Args:
12
  audio_data: Audio array (from librosa.load or similar)
13
  lyrics_text: String containing lyrics
 
14
  Returns:
15
  dict: Structured explanation results
16
  """
17
  start_time = datetime.now()
18
 
19
+ # Get number of samples from environment variable, default to 1000
20
+ num_samples = int(os.getenv("MUSICLIME_NUM_SAMPLES", "1000"))
21
+ num_features = int(os.getenv("MUSICLIME_NUM_FEATURES", "10"))
22
+
23
+ print(f"[MusicLIME] Using num_samples={num_samples}, num_features={num_features}")
24
+
25
  # Create musiclime instances
26
  explainer = MusicLIMEExplainer()
27
  predictor = MusicLIMEPredictor()
 
31
  audio=audio_data,
32
  lyrics=lyrics_text,
33
  predict_fn=predictor,
34
+ num_samples=num_samples,
35
  labels=(1,),
36
  )
37
 
 
40
  predicted_class = np.argmax(original_prediction)
41
  confidence = float(np.max(original_prediction))
42
 
43
+ # Get top features (I also made this configurable to prevent rebuilding)
44
+ top_features = explanation.get_explanation(label=1, num_features=num_features)
45
 
46
  # Calculate runtime
47
  end_time = datetime.now()
 
73
  [f for f in top_features if f["type"] == "lyrics"]
74
  ),
75
  "runtime_seconds": runtime_seconds,
76
+ "samples_generated": num_samples,
77
  "timestamp": start_time.isoformat(),
78
  },
79
  }