lllindsey0615 commited on
Commit
1035bfa
·
1 Parent(s): f750bef

Supports mono audio input

Browse files
Files changed (1) hide show
  1. app.py +15 -6
app.py CHANGED
@@ -7,10 +7,9 @@ from pyharp import ModelCard, LabelList, build_endpoint, save_audio
7
  from audiotools import AudioSignal
8
 
9
 
10
- def separate_instrumental(audio_file_path: str, model_name: str = 'mdx_extra_q'):
11
- """
12
- Separates an audio file into a instrumental stem using a Demucs model.
13
- """
14
  # Load Demucs model
15
  model = pretrained.get_model(model_name)
16
  model.to('cuda' if torch.cuda.is_available() else 'cpu')
@@ -19,6 +18,13 @@ def separate_instrumental(audio_file_path: str, model_name: str = 'mdx_extra_q'
19
  # Load audio file (waveform shape: (channels, samples))
20
  waveform, sr = torchaudio.load(audio_file_path)
21
 
 
 
 
 
 
 
 
22
  # Run Demucs — returns a list (batch, stems, channels, samples)
23
  with torch.no_grad():
24
  stems_batch = apply_model(
@@ -28,12 +34,15 @@ def separate_instrumental(audio_file_path: str, model_name: str = 'mdx_extra_q'
28
  shifts=1,
29
  split=True
30
  )
31
-
32
  stems = stems_batch[0] # Extract stems from batch
33
 
34
- # Extract the vocal stem (stems[0] is vocals in most models)
35
  instrumental = stems[0]
36
 
 
 
 
37
  # Convert to an AudioSignal object
38
  instrumental_signal = AudioSignal(instrumental.cpu().numpy(), sample_rate=sr)
39
  return instrumental_signal
 
7
  from audiotools import AudioSignal
8
 
9
 
10
+ DEMUX_MODELS = ["mdx_extra_q", "mdx_extra", "htdemucs", "mdx_q"]
11
+
12
+ def separate_instrumental(audio_file_path: str, model_name: str):
 
13
  # Load Demucs model
14
  model = pretrained.get_model(model_name)
15
  model.to('cuda' if torch.cuda.is_available() else 'cpu')
 
18
  # Load audio file (waveform shape: (channels, samples))
19
  waveform, sr = torchaudio.load(audio_file_path)
20
 
21
+ # Check if the input is mono
22
+ is_mono = waveform.shape[0] == 1
23
+
24
+ # If mono, duplicate to stereo
25
+ if is_mono:
26
+ waveform = waveform.repeat(2, 1)
27
+
28
  # Run Demucs — returns a list (batch, stems, channels, samples)
29
  with torch.no_grad():
30
  stems_batch = apply_model(
 
34
  shifts=1,
35
  split=True
36
  )
37
+
38
  stems = stems_batch[0] # Extract stems from batch
39
 
40
+ # Extract the instrumental stem (stems[0] is vocals in most models)
41
  instrumental = stems[0]
42
 
43
+ if is_mono:
44
+ instrumental = instrumental.mean(dim=0, keepdim=True) # Stereo → Mono
45
+
46
  # Convert to an AudioSignal object
47
  instrumental_signal = AudioSignal(instrumental.cpu().numpy(), sample_rate=sr)
48
  return instrumental_signal