Anvita Pandit commited on
Commit
2440fbf
·
1 Parent(s): 80334fc

Fall back to librosa onset detection when madmom is unavailable

Browse files

madmom can't build on HF Spaces (requires Cython at build time with
--no-build-isolation). onset_mask now tries madmom first, and falls back
to librosa's onset_detect which is already installed. Removed madmom
from requirements.txt since it can't be pip-installed with build isolation.

Made-with: Cursor

Files changed (2) hide show
  1. requirements.txt +0 -1
  2. vampnet/vampnet/mask.py +33 -12
requirements.txt CHANGED
@@ -14,4 +14,3 @@ wavebeat @ git+https://github.com/hugofloresgarcia/wavebeat
14
  lac @ git+https://github.com/hugofloresgarcia/lac.git
15
  descript-audiotools @ git+https://github.com/hugofloresgarcia/audiotools.git
16
  pyharp
17
- madmom
 
14
  lac @ git+https://github.com/hugofloresgarcia/lac.git
15
  descript-audiotools @ git+https://github.com/hugofloresgarcia/audiotools.git
16
  pyharp
 
vampnet/vampnet/mask.py CHANGED
@@ -189,6 +189,26 @@ def time_stretch_mask(
189
  mask = periodic_mask(x, stretch_factor, width=1)
190
  return mask
191
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  def onset_mask(
193
  sig: AudioSignal,
194
  z: torch.Tensor,
@@ -196,24 +216,26 @@ def onset_mask(
196
  width: int = 1
197
  ):
198
  import librosa
199
- import madmom
200
- from madmom.features.onsets import RNNOnsetProcessor, OnsetPeakPickingProcessor
201
  import tempfile
202
  import numpy as np
203
 
 
 
 
 
 
 
 
 
 
204
  with tempfile.NamedTemporaryFile(suffix='.wav') as f:
205
  sig = sig.clone()
206
  sig.write(f.name)
207
 
208
- proc = RNNOnsetProcessor(online=False)
209
- onsetproc = OnsetPeakPickingProcessor(threshold=0.3,
210
- fps=sig.sample_rate/interface.codec.hop_length)
211
-
212
- act = proc(f.name)
213
- onset_times = onsetproc(act)
214
-
215
- # convert to indices for z array
216
- onset_indices = librosa.time_to_frames(onset_times, sr=sig.sample_rate, hop_length=interface.codec.hop_length)
217
 
218
  if onset_indices.shape[0] == 0:
219
  mask = empty_mask(z)
@@ -223,7 +245,6 @@ def onset_mask(
223
  print("onset indices: ", onset_indices)
224
  print("onset times: ", onset_times)
225
 
226
- # create a mask, set onset
227
  mask = torch.ones_like(z)
228
  n_timesteps = z.shape[-1]
229
 
 
189
  mask = periodic_mask(x, stretch_factor, width=1)
190
  return mask
191
 
192
+ def _onset_times_madmom(wav_path, sample_rate, hop_length):
193
+ from madmom.features.onsets import RNNOnsetProcessor, OnsetPeakPickingProcessor
194
+ proc = RNNOnsetProcessor(online=False)
195
+ onsetproc = OnsetPeakPickingProcessor(
196
+ threshold=0.3, fps=sample_rate / hop_length
197
+ )
198
+ act = proc(wav_path)
199
+ return onsetproc(act)
200
+
201
+
202
+ def _onset_times_librosa(wav_path, sample_rate, hop_length):
203
+ import librosa
204
+ y, sr = librosa.load(wav_path, sr=sample_rate)
205
+ onset_env = librosa.onset.onset_strength(y=y, sr=sr, hop_length=hop_length)
206
+ onset_frames = librosa.onset.onset_detect(
207
+ onset_envelope=onset_env, sr=sr, hop_length=hop_length, backtrack=False
208
+ )
209
+ return librosa.frames_to_time(onset_frames, sr=sr, hop_length=hop_length)
210
+
211
+
212
  def onset_mask(
213
  sig: AudioSignal,
214
  z: torch.Tensor,
 
216
  width: int = 1
217
  ):
218
  import librosa
 
 
219
  import tempfile
220
  import numpy as np
221
 
222
+ try:
223
+ import madmom # noqa: F401
224
+ _get_onset_times = _onset_times_madmom
225
+ except ImportError:
226
+ print("madmom not installed, falling back to librosa for onset detection")
227
+ _get_onset_times = _onset_times_librosa
228
+
229
+ hop_length = interface.codec.hop_length
230
+
231
  with tempfile.NamedTemporaryFile(suffix='.wav') as f:
232
  sig = sig.clone()
233
  sig.write(f.name)
234
 
235
+ onset_times = _get_onset_times(f.name, sig.sample_rate, hop_length)
236
+ onset_indices = librosa.time_to_frames(
237
+ onset_times, sr=sig.sample_rate, hop_length=hop_length
238
+ )
 
 
 
 
 
239
 
240
  if onset_indices.shape[0] == 0:
241
  mask = empty_mask(z)
 
245
  print("onset indices: ", onset_indices)
246
  print("onset times: ", onset_times)
247
 
 
248
  mask = torch.ones_like(z)
249
  n_timesteps = z.shape[-1]
250