lllindsey0615 commited on
Commit
31bdbd1
·
1 Parent(s): 1035bfa

added model and stem selection

Browse files
Files changed (1) hide show
  1. app.py +63 -39
app.py CHANGED
@@ -3,85 +3,109 @@ import torchaudio
3
  import gradio as gr
4
  from demucs import pretrained
5
  from demucs.apply import apply_model
6
- from pyharp import ModelCard, LabelList, build_endpoint, save_audio
7
  from audiotools import AudioSignal
8
 
9
-
10
  DEMUX_MODELS = ["mdx_extra_q", "mdx_extra", "htdemucs", "mdx_q"]
11
 
12
- def separate_instrumental(audio_file_path: str, model_name: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  # Load Demucs model
14
  model = pretrained.get_model(model_name)
15
  model.to('cuda' if torch.cuda.is_available() else 'cpu')
16
  model.eval()
17
 
18
- # Load audio file (waveform shape: (channels, samples))
19
  waveform, sr = torchaudio.load(audio_file_path)
20
 
21
- # Check if the input is mono
22
  is_mono = waveform.shape[0] == 1
23
-
24
- # If mono, duplicate to stereo
25
  if is_mono:
26
- waveform = waveform.repeat(2, 1)
27
 
28
- # Run Demucs — returns a list (batch, stems, channels, samples)
29
  with torch.no_grad():
30
  stems_batch = apply_model(
31
  model,
32
- waveform.unsqueeze(0), # shape: (batch=1, channels, samples)
33
  overlap=0.2,
34
  shifts=1,
35
  split=True
36
  )
37
 
38
- stems = stems_batch[0] # Extract stems from batch
 
39
 
40
- # Extract the instrumental stem (stems[0] is vocals in most models)
41
- instrumental = stems[0]
42
 
 
 
 
 
 
 
 
43
  if is_mono:
44
- instrumental = instrumental.mean(dim=0, keepdim=True) # Stereo → Mono
45
 
46
- # Convert to an AudioSignal object
47
- instrumental_signal = AudioSignal(instrumental.cpu().numpy(), sample_rate=sr)
48
- return instrumental_signal
49
 
50
 
51
- def process_fn_instrumental(audio_file_path: str):
52
  """
53
- pyharp instrumentals_fn:
54
- - Receives an audio file path
55
- - Separates the audio into a instrumental stem with Demucs
56
- - Saves the instrumental stem as a .wav file
57
- - Returns the file path of the instrumental stem and an empty LabelList
58
  """
59
- # Separate instrumental using the hardcoded model
60
- instrumental_signal = separate_instrumental(audio_file_path, model_name='mdx_extra_q')
61
-
62
- # Save the instrumental stem to a .wav file
63
- instrumental_path= save_audio(instrumental_signal, "instrumental.wav")
64
-
65
- # Return the instrumental file path and an empty LabelList
66
- return instrumental_path, LabelList(labels=[])
67
 
68
 
69
  # Define the model card
70
  model_card = ModelCard(
71
- name="Demucs Vocal Separator",
72
- description="Uses Demucs to separate a music track into a vocal stem.",
73
  author="Alexandre Défossez, Nicolas Usunier, Léon Bottou, Francis Bach",
74
- tags=["demucs", "source-separation", "pyharp", "vocals"]
75
  )
76
 
77
- # Build Gradio interface
78
  with gr.Blocks() as demo:
79
- # Build the Gradio endpoint (only audio input, no dropdown)
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  app = build_endpoint(
81
  model_card=model_card,
82
- components=[],
83
- process_fn=process_fn_instrumental
84
  )
85
 
86
  demo.queue()
87
- demo.launch(share=True, show_error=True)
 
3
  import gradio as gr
4
  from demucs import pretrained
5
  from demucs.apply import apply_model
6
+ from pyharp import *
7
  from audiotools import AudioSignal
8
 
9
+ # Available Demucs models
10
  DEMUX_MODELS = ["mdx_extra_q", "mdx_extra", "htdemucs", "mdx_q"]
11
 
12
+ STEM_CHOICES = {
13
+ "Vocals": 3,
14
+ "Drums": 0,
15
+ "Bass": 1,
16
+ "Other": 2,
17
+ "Instrumental (No Vocals)": "instrumental"
18
+ }
19
+
20
+
21
+ def separate_stem(audio_file_path: str, model_name: str, stem_choice: str):
22
+ """
23
+ Separates an audio file into the chosen stem using a Demucs model.
24
+ Ensures correct stem ordering and supports mono input.
25
+ """
26
  # Load Demucs model
27
  model = pretrained.get_model(model_name)
28
  model.to('cuda' if torch.cuda.is_available() else 'cpu')
29
  model.eval()
30
 
31
+ # Load the audio file
32
  waveform, sr = torchaudio.load(audio_file_path)
33
 
34
+ # Check if input is mono
35
  is_mono = waveform.shape[0] == 1
 
 
36
  if is_mono:
37
+ waveform = waveform.repeat(2, 1) # Convert mono to stereo for Demucs
38
 
39
+ # Apply Demucs model
40
  with torch.no_grad():
41
  stems_batch = apply_model(
42
  model,
43
+ waveform.unsqueeze(0),
44
  overlap=0.2,
45
  shifts=1,
46
  split=True
47
  )
48
 
49
+ # stems shape: (batch, stems, channels, samples)
50
+ stems = stems_batch[0]
51
 
52
+ print(f"Model '{model_name}' extracted stems shape: {stems.shape}")
 
53
 
54
+ if stem_choice == "Instrumental (No Vocals)":
55
+ stem = stems[0] + stems[1] + stems[2] # Drums + Bass + Other
56
+ else:
57
+ stem_index = STEM_CHOICES[stem_choice]
58
+ stem = stems[stem_index]
59
+
60
+ # Convert back to mono if the input was originally mono
61
  if is_mono:
62
+ stem = stem.mean(dim=0, keepdim=True) # Stereo → Mono
63
 
64
+ # Convert to AudioSignal with float32 dtype
65
+ stem_signal = AudioSignal(stem.cpu().numpy().astype('float32'), sample_rate=sr)
66
+ return stem_signal
67
 
68
 
69
+ def process_fn_stem(audio_file_path: str, demucs_model: str, stem_choice: str):
70
  """
71
+ PyHARP process function:
72
+ - Separates the chosen stem using Demucs.
73
+ - Saves the stem as a .wav file.
 
 
74
  """
75
+ stem_signal = separate_stem(audio_file_path, model_name=demucs_model, stem_choice=stem_choice)
76
+ stem_path = save_audio(stem_signal, f"{stem_choice.lower().replace(' ', '_')}.wav")
77
+ return stem_path, LabelList(labels=[])
 
 
 
 
 
78
 
79
 
80
  # Define the model card
81
  model_card = ModelCard(
82
+ name="Demucs Stem Separator",
83
+ description="Uses Demucs to separate a music track into a selected stem.",
84
  author="Alexandre Défossez, Nicolas Usunier, Léon Bottou, Francis Bach",
85
+ tags=["demucs", "source-separation", "pyharp", "stems"]
86
  )
87
 
88
+ # Build Gradio interface with dropdowns for model and stem selection
89
  with gr.Blocks() as demo:
90
+ components = [
91
+ gr.Dropdown(
92
+ label="Select Demucs Model",
93
+ choices=DEMUX_MODELS,
94
+ value="mdx_extra_q"
95
+ ),
96
+ gr.Dropdown(
97
+ label="Select Stem to Separate",
98
+ choices=list(STEM_CHOICES.keys()),
99
+ value="Vocals"
100
+
101
+ )
102
+ ]
103
+
104
  app = build_endpoint(
105
  model_card=model_card,
106
+ components=components,
107
+ process_fn=process_fn_stem
108
  )
109
 
110
  demo.queue()
111
+ demo.launch(share=True, show_error=True)