neonwatty commited on
Commit
a02a0e8
·
verified ·
1 Parent(s): af3deb4

Switch to Demucs for vocal separation (SAM Audio incompatible with ZeroGPU)

Browse files
Files changed (1) hide show
  1. app.py +62 -96
app.py CHANGED
@@ -1,12 +1,10 @@
1
  """
2
- SAM Audio Source Separation - Gradio Backend
3
  Runs on Hugging Face Spaces with ZeroGPU
 
4
  """
5
 
6
  import os
7
- # Set CUDA debugging before any torch imports
8
- os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
9
-
10
  import spaces
11
  import gradio as gr
12
  import torch
@@ -14,34 +12,23 @@ import torchaudio
14
  import tempfile
15
  import warnings
16
  import numpy as np
17
- from huggingface_hub import login
18
 
19
  warnings.filterwarnings("ignore")
20
 
21
- # Login to HuggingFace if token is available (for gated models)
22
- hf_token = os.environ.get("HF_TOKEN")
23
- if hf_token:
24
- login(token=hf_token)
25
- print("Logged in to HuggingFace Hub")
26
-
27
- # DO NOT import sam_audio here - it initializes CUDA
28
- # Import inside GPU function to avoid ZeroGPU CUDA fork issues
29
-
30
- MODEL_ID = "facebook/sam-audio-small"
31
- print(f"Model ID: {MODEL_ID} (will load on first GPU request)")
32
 
33
 
34
  @spaces.GPU(duration=120)
35
  def run_separation_gpu(
36
  waveform_np: np.ndarray,
37
  sample_rate: int,
38
- description: str,
39
- predict_spans: bool,
40
- reranking_candidates: int
41
  ):
42
- """Run separation on GPU with numpy waveform input."""
43
- # Import sam_audio inside GPU function to avoid CUDA fork issues
44
- from sam_audio import SAMAudio, SAMAudioProcessor
 
45
 
46
  print(f"[GPU] run_separation_gpu called")
47
  print(f"[GPU] waveform shape: {waveform_np.shape}, sample_rate: {sample_rate}")
@@ -49,81 +36,76 @@ def run_separation_gpu(
49
  device = "cuda" if torch.cuda.is_available() else "cpu"
50
  print(f"[GPU] Using device: {device}")
51
 
52
- # Load model fresh each time (ZeroGPU workers don't persist state)
53
- print(f"[GPU] Loading model to {device}...")
54
- print(f"[GPU] CUDA available: {torch.cuda.is_available()}")
55
  if torch.cuda.is_available():
56
  print(f"[GPU] CUDA device: {torch.cuda.get_device_name(0)}")
57
  print(f"[GPU] CUDA memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
58
 
59
- processor = SAMAudioProcessor.from_pretrained(MODEL_ID)
60
- print(f"[GPU] Processor loaded")
61
-
62
- model = SAMAudio.from_pretrained(MODEL_ID)
63
- print(f"[GPU] Model loaded to CPU")
64
-
65
- # Clear CUDA cache before moving model
66
- if torch.cuda.is_available():
67
- torch.cuda.empty_cache()
68
- print(f"[GPU] CUDA cache cleared")
69
-
70
- model = model.to(device)
71
- print(f"[GPU] Model moved to {device}")
72
-
73
  model.eval()
74
- print(f"[GPU] Model in eval mode")
75
 
76
- # Convert numpy to tensor and save to temp file for processor
77
- # Gradio passes audio as (samples, channels), torchaudio expects (channels, samples)
78
  waveform = torch.from_numpy(waveform_np).float()
79
- if waveform.dim() == 2:
80
- waveform = waveform.T # Transpose to (channels, samples)
81
- elif waveform.dim() == 1:
82
- waveform = waveform.unsqueeze(0) # Add channel dimension for mono
83
 
84
- temp_dir = tempfile.mkdtemp()
85
- input_path = os.path.join(temp_dir, "input.wav")
86
- torchaudio.save(input_path, waveform, sample_rate)
87
- print(f"[GPU] Saved input to: {input_path}")
 
 
88
 
89
- # Process audio
90
- batch = processor(
91
- audios=[input_path],
92
- descriptions=[description],
93
- ).to(device)
 
 
 
 
 
 
94
 
95
  # Run separation
 
96
  with torch.inference_mode():
97
- result = model.separate(
98
- batch,
99
- predict_spans=predict_spans,
100
- reranking_candidates=reranking_candidates
101
- )
 
 
 
 
 
 
 
 
102
 
103
  # Save outputs
104
- output_sample_rate = processor.audio_sampling_rate
105
- target_path = os.path.join(temp_dir, "target.wav")
106
- residual_path = os.path.join(temp_dir, "residual.wav")
107
 
108
- torchaudio.save(target_path, result.target.cpu(), output_sample_rate)
109
- torchaudio.save(residual_path, result.residual.cpu(), output_sample_rate)
 
110
 
111
- print(f"[GPU] Saved outputs to {target_path} and {residual_path}")
112
- return target_path, residual_path
113
 
114
 
115
- def separate_audio(
116
- audio_tuple,
117
- description: str,
118
- predict_spans: bool = True,
119
- reranking_candidates: int = 1
120
- ):
121
  """
122
  Wrapper that receives numpy audio from Gradio and calls GPU function.
123
  audio_tuple is (sample_rate, numpy_array) when type="numpy"
124
  """
125
  print(f"[Main] separate_audio called")
126
- print(f"[Main] audio_tuple type: {type(audio_tuple)}")
127
 
128
  if audio_tuple is None:
129
  raise gr.Error("Please upload an audio file")
@@ -131,38 +113,22 @@ def separate_audio(
131
  sample_rate, audio_data = audio_tuple
132
  print(f"[Main] sample_rate: {sample_rate}, audio_data shape: {audio_data.shape}")
133
 
134
- if not description:
135
- raise gr.Error("Please enter a description of the sound to isolate")
136
-
137
- # Call the GPU function with numpy data
138
- return run_separation_gpu(
139
- audio_data,
140
- sample_rate,
141
- description,
142
- predict_spans,
143
- reranking_candidates
144
- )
145
 
146
 
147
- # Create Gradio interface with type="numpy" to avoid file path issues
148
  demo = gr.Interface(
149
  fn=separate_audio,
150
  inputs=[
151
  gr.Audio(label="Upload Audio", type="numpy"),
152
- gr.Textbox(
153
- label="Sound to Isolate",
154
- value="singing voice, vocals, human voice",
155
- placeholder="e.g., 'singing voice, vocals, human voice'"
156
- ),
157
- gr.Checkbox(label="Auto-detect timing", value=True),
158
- gr.Slider(label="Quality", minimum=1, maximum=3, step=1, value=1)
159
  ],
160
  outputs=[
161
- gr.Audio(label="Isolated Sound (Target)"),
162
- gr.Audio(label="Background (Residual)")
163
  ],
164
  title="Forgot The Words - API Backend",
165
- description="Remove vocals from songs using [Meta SAM Audio](https://github.com/facebookresearch/sam-audio).",
166
  api_name="separate_audio",
167
  allow_flagging="never"
168
  )
 
1
  """
2
+ Demucs Audio Source Separation - Gradio Backend
3
  Runs on Hugging Face Spaces with ZeroGPU
4
+ Uses Meta's Demucs model for vocal separation
5
  """
6
 
7
  import os
 
 
 
8
  import spaces
9
  import gradio as gr
10
  import torch
 
12
  import tempfile
13
  import warnings
14
  import numpy as np
 
15
 
16
  warnings.filterwarnings("ignore")
17
 
18
+ # Demucs model - htdemucs is the best quality model
19
+ MODEL_NAME = "htdemucs"
20
+ print(f"Model: {MODEL_NAME} (will load on first GPU request)")
 
 
 
 
 
 
 
 
21
 
22
 
23
  @spaces.GPU(duration=120)
24
  def run_separation_gpu(
25
  waveform_np: np.ndarray,
26
  sample_rate: int,
 
 
 
27
  ):
28
+ """Run Demucs separation on GPU."""
29
+ # Import demucs inside GPU function to avoid CUDA issues
30
+ from demucs.pretrained import get_model
31
+ from demucs.apply import apply_model
32
 
33
  print(f"[GPU] run_separation_gpu called")
34
  print(f"[GPU] waveform shape: {waveform_np.shape}, sample_rate: {sample_rate}")
 
36
  device = "cuda" if torch.cuda.is_available() else "cpu"
37
  print(f"[GPU] Using device: {device}")
38
 
 
 
 
39
  if torch.cuda.is_available():
40
  print(f"[GPU] CUDA device: {torch.cuda.get_device_name(0)}")
41
  print(f"[GPU] CUDA memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
42
 
43
+ # Load Demucs model
44
+ print(f"[GPU] Loading Demucs model: {MODEL_NAME}")
45
+ model = get_model(MODEL_NAME)
46
+ model.to(device)
 
 
 
 
 
 
 
 
 
 
47
  model.eval()
48
+ print(f"[GPU] Model loaded and moved to {device}")
49
 
50
+ # Convert numpy to tensor
51
+ # Gradio passes audio as (samples, channels), we need (batch, channels, samples)
52
  waveform = torch.from_numpy(waveform_np).float()
 
 
 
 
53
 
54
+ if waveform.dim() == 1:
55
+ # Mono: (samples,) -> (1, 1, samples)
56
+ waveform = waveform.unsqueeze(0).unsqueeze(0)
57
+ elif waveform.dim() == 2:
58
+ # Stereo from Gradio: (samples, channels) -> (1, channels, samples)
59
+ waveform = waveform.T.unsqueeze(0)
60
 
61
+ print(f"[GPU] Waveform tensor shape: {waveform.shape}")
62
+
63
+ # Resample to model's expected sample rate (44100 Hz for Demucs)
64
+ model_sr = model.samplerate
65
+ if sample_rate != model_sr:
66
+ print(f"[GPU] Resampling from {sample_rate} to {model_sr}")
67
+ resampler = torchaudio.transforms.Resample(sample_rate, model_sr)
68
+ waveform = resampler(waveform)
69
+
70
+ # Move to device
71
+ waveform = waveform.to(device)
72
 
73
  # Run separation
74
+ print(f"[GPU] Running separation...")
75
  with torch.inference_mode():
76
+ sources = apply_model(model, waveform, device=device, progress=False)
77
+
78
+ # sources shape: (batch, num_sources, channels, samples)
79
+ # htdemucs sources: drums, bass, other, vocals
80
+ sources = sources.squeeze(0) # Remove batch dimension
81
+ print(f"[GPU] Sources shape: {sources.shape}")
82
+
83
+ # Get vocals and create instrumental (everything except vocals)
84
+ # Source indices for htdemucs: 0=drums, 1=bass, 2=other, 3=vocals
85
+ vocals = sources[3] # vocals
86
+ instrumental = sources[0] + sources[1] + sources[2] # drums + bass + other
87
+
88
+ print(f"[GPU] Vocals shape: {vocals.shape}, Instrumental shape: {instrumental.shape}")
89
 
90
  # Save outputs
91
+ temp_dir = tempfile.mkdtemp()
92
+ vocals_path = os.path.join(temp_dir, "vocals.wav")
93
+ instrumental_path = os.path.join(temp_dir, "instrumental.wav")
94
 
95
+ # Save at model's sample rate
96
+ torchaudio.save(vocals_path, vocals.cpu(), model_sr)
97
+ torchaudio.save(instrumental_path, instrumental.cpu(), model_sr)
98
 
99
+ print(f"[GPU] Saved outputs to {vocals_path} and {instrumental_path}")
100
+ return vocals_path, instrumental_path
101
 
102
 
103
+ def separate_audio(audio_tuple):
 
 
 
 
 
104
  """
105
  Wrapper that receives numpy audio from Gradio and calls GPU function.
106
  audio_tuple is (sample_rate, numpy_array) when type="numpy"
107
  """
108
  print(f"[Main] separate_audio called")
 
109
 
110
  if audio_tuple is None:
111
  raise gr.Error("Please upload an audio file")
 
113
  sample_rate, audio_data = audio_tuple
114
  print(f"[Main] sample_rate: {sample_rate}, audio_data shape: {audio_data.shape}")
115
 
116
+ # Call the GPU function
117
+ return run_separation_gpu(audio_data, sample_rate)
 
 
 
 
 
 
 
 
 
118
 
119
 
120
+ # Create Gradio interface
121
  demo = gr.Interface(
122
  fn=separate_audio,
123
  inputs=[
124
  gr.Audio(label="Upload Audio", type="numpy"),
 
 
 
 
 
 
 
125
  ],
126
  outputs=[
127
+ gr.Audio(label="Vocals"),
128
+ gr.Audio(label="Instrumental (Karaoke)")
129
  ],
130
  title="Forgot The Words - API Backend",
131
+ description="Remove vocals from songs using [Meta Demucs](https://github.com/facebookresearch/demucs). Upload a song and get the vocals and instrumental tracks separated.",
132
  api_name="separate_audio",
133
  allow_flagging="never"
134
  )