Peter Shi commited on
Commit
8752ef6
Β·
1 Parent(s): 3cc9650

feat: To implement the audio chunking function with overlapping and cross-fading for processing long audio files, and to add chunk duration control to the UI.

Browse files
Files changed (1) hide show
  1. app.py +141 -18
app.py CHANGED
@@ -23,6 +23,11 @@ DEFAULT_MODEL = "sam-audio-small"
23
  EXAMPLES_DIR = "examples"
24
  EXAMPLE_FILE = os.path.join(EXAMPLES_DIR, "office.mp4")
25
 
 
 
 
 
 
26
  # Global model cache
27
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
  current_model_name = None
@@ -42,16 +47,77 @@ def load_model(model_name):
42
 
43
  load_model(DEFAULT_MODEL)
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  def save_audio(tensor, sample_rate):
46
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
47
  torchaudio.save(tmp.name, tensor, sample_rate)
48
  return tmp.name
49
 
50
  @spaces.GPU(duration=300)
51
- def separate_audio(model_name, file_path, text_prompt, progress=gr.Progress()):
52
  global model, processor
53
 
54
- progress(0.1, desc="Checking inputs...")
55
 
56
  if not file_path:
57
  return None, None, "❌ Please upload an audio or video file."
@@ -59,23 +125,70 @@ def separate_audio(model_name, file_path, text_prompt, progress=gr.Progress()):
59
  return None, None, "❌ Please enter a text prompt."
60
 
61
  try:
62
- progress(0.2, desc="Loading model...")
63
  load_model(model_name)
64
 
65
- progress(0.4, desc="Processing audio...")
66
- inputs = processor(audios=[file_path], descriptions=[text_prompt.strip()]).to(device)
67
-
68
- progress(0.6, desc="Separating sounds...")
69
- with torch.inference_mode():
70
- result = model.separate(inputs, predict_spans=False, reranking_candidates=1)
71
 
72
- progress(0.8, desc="Saving results...")
73
- sample_rate = processor.audio_sampling_rate
74
- target_path = save_audio(result.target[0].unsqueeze(0).cpu(), sample_rate)
75
- residual_path = save_audio(result.residual[0].unsqueeze(0).cpu(), sample_rate)
76
 
77
- progress(1.0, desc="Done!")
78
- return target_path, residual_path, f"βœ… Isolated '{text_prompt}' using {model_name}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  except Exception as e:
80
  import traceback
81
  traceback.print_exc()
@@ -98,6 +211,16 @@ with gr.Blocks(title="SAM-Audio Demo") as demo:
98
  label="Model"
99
  )
100
 
 
 
 
 
 
 
 
 
 
 
101
  gr.Markdown("#### Upload Audio")
102
  input_audio = gr.Audio(label="Audio File", type="filepath")
103
 
@@ -128,13 +251,13 @@ with gr.Blocks(title="SAM-Audio Demo") as demo:
128
  example_btn3 = gr.Button("🎡 Background Music")
129
 
130
  # Main process button
131
- def process(model_name, audio_path, video_path, prompt, progress=gr.Progress()):
132
  file_path = video_path if video_path else audio_path
133
- return separate_audio(model_name, file_path, prompt, progress)
134
 
135
  run_btn.click(
136
  fn=process,
137
- inputs=[model_selector, input_audio, input_video, text_prompt],
138
  outputs=[output_target, output_residual, status_output]
139
  )
140
 
 
23
  EXAMPLES_DIR = "examples"
24
  EXAMPLE_FILE = os.path.join(EXAMPLES_DIR, "office.mp4")
25
 
26
+ # Chunk processing settings
27
+ DEFAULT_CHUNK_DURATION = 30 # seconds per chunk
28
+ OVERLAP_DURATION = 2 # seconds of overlap between chunks
29
+ MAX_DURATION_WITHOUT_CHUNKING = 60 # auto-chunk if longer than this
30
+
31
  # Global model cache
32
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
  current_model_name = None
 
47
 
48
  load_model(DEFAULT_MODEL)
49
 
50
+ def load_audio(file_path):
51
+ """Load audio from file (supports both audio and video files)."""
52
+ waveform, sample_rate = torchaudio.load(file_path)
53
+ # Convert to mono if stereo
54
+ if waveform.shape[0] > 1:
55
+ waveform = waveform.mean(dim=0, keepdim=True)
56
+ return waveform, sample_rate
57
+
58
+ def split_audio_into_chunks(waveform, sample_rate, chunk_duration, overlap_duration):
59
+ """Split audio waveform into overlapping chunks."""
60
+ chunk_samples = int(chunk_duration * sample_rate)
61
+ overlap_samples = int(overlap_duration * sample_rate)
62
+ stride = chunk_samples - overlap_samples
63
+
64
+ chunks = []
65
+ total_samples = waveform.shape[1]
66
+
67
+ start = 0
68
+ while start < total_samples:
69
+ end = min(start + chunk_samples, total_samples)
70
+ chunk = waveform[:, start:end]
71
+ chunks.append(chunk)
72
+
73
+ if end >= total_samples:
74
+ break
75
+ start += stride
76
+
77
+ return chunks
78
+
79
+ def merge_chunks_with_crossfade(chunks, sample_rate, overlap_duration):
80
+ """Merge audio chunks with crossfade on overlapping regions."""
81
+ if len(chunks) == 1:
82
+ return chunks[0]
83
+
84
+ overlap_samples = int(overlap_duration * sample_rate)
85
+ result = chunks[0]
86
+
87
+ for i in range(1, len(chunks)):
88
+ prev_chunk = result
89
+ next_chunk = chunks[i]
90
+
91
+ # Create fade curves
92
+ fade_out = torch.linspace(1.0, 0.0, overlap_samples).to(prev_chunk.device)
93
+ fade_in = torch.linspace(0.0, 1.0, overlap_samples).to(next_chunk.device)
94
+
95
+ # Get overlapping regions
96
+ prev_overlap = prev_chunk[:, -overlap_samples:]
97
+ next_overlap = next_chunk[:, :overlap_samples]
98
+
99
+ # Crossfade mix
100
+ crossfaded = prev_overlap * fade_out + next_overlap * fade_in
101
+
102
+ # Concatenate: non-overlap of prev + crossfaded + non-overlap of next
103
+ result = torch.cat([
104
+ prev_chunk[:, :-overlap_samples],
105
+ crossfaded,
106
+ next_chunk[:, overlap_samples:]
107
+ ], dim=1)
108
+
109
+ return result
110
+
111
  def save_audio(tensor, sample_rate):
112
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
113
  torchaudio.save(tmp.name, tensor, sample_rate)
114
  return tmp.name
115
 
116
  @spaces.GPU(duration=300)
117
+ def separate_audio(model_name, file_path, text_prompt, chunk_duration=DEFAULT_CHUNK_DURATION, progress=gr.Progress()):
118
  global model, processor
119
 
120
+ progress(0.05, desc="Checking inputs...")
121
 
122
  if not file_path:
123
  return None, None, "❌ Please upload an audio or video file."
 
125
  return None, None, "❌ Please enter a text prompt."
126
 
127
  try:
128
+ progress(0.1, desc="Loading model...")
129
  load_model(model_name)
130
 
131
+ progress(0.15, desc="Loading audio...")
132
+ waveform, sample_rate = load_audio(file_path)
133
+ duration = waveform.shape[1] / sample_rate
 
 
 
134
 
135
+ # Decide whether to use chunking
136
+ use_chunking = duration > MAX_DURATION_WITHOUT_CHUNKING
 
 
137
 
138
+ if use_chunking:
139
+ progress(0.2, desc=f"Audio is {duration:.1f}s, splitting into chunks...")
140
+ chunks = split_audio_into_chunks(waveform, sample_rate, chunk_duration, OVERLAP_DURATION)
141
+ num_chunks = len(chunks)
142
+
143
+ target_chunks = []
144
+ residual_chunks = []
145
+
146
+ for i, chunk in enumerate(chunks):
147
+ chunk_progress = 0.2 + (i / num_chunks) * 0.6
148
+ progress(chunk_progress, desc=f"Processing chunk {i+1}/{num_chunks}...")
149
+
150
+ # Save chunk to temp file for processor
151
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
152
+ torchaudio.save(tmp.name, chunk, sample_rate)
153
+ chunk_path = tmp.name
154
+
155
+ try:
156
+ inputs = processor(audios=[chunk_path], descriptions=[text_prompt.strip()]).to(device)
157
+
158
+ with torch.inference_mode():
159
+ result = model.separate(inputs, predict_spans=False, reranking_candidates=1)
160
+
161
+ target_chunks.append(result.target[0].cpu())
162
+ residual_chunks.append(result.residual[0].cpu())
163
+ finally:
164
+ os.unlink(chunk_path)
165
+
166
+ progress(0.85, desc="Merging chunks...")
167
+ target_merged = merge_chunks_with_crossfade(target_chunks, sample_rate, OVERLAP_DURATION)
168
+ residual_merged = merge_chunks_with_crossfade(residual_chunks, sample_rate, OVERLAP_DURATION)
169
+
170
+ progress(0.95, desc="Saving results...")
171
+ target_path = save_audio(target_merged.unsqueeze(0), sample_rate)
172
+ residual_path = save_audio(residual_merged.unsqueeze(0), sample_rate)
173
+
174
+ progress(1.0, desc="Done!")
175
+ return target_path, residual_path, f"βœ… Isolated '{text_prompt}' using {model_name} ({num_chunks} chunks)"
176
+ else:
177
+ # Process without chunking
178
+ progress(0.3, desc="Processing audio...")
179
+ inputs = processor(audios=[file_path], descriptions=[text_prompt.strip()]).to(device)
180
+
181
+ progress(0.6, desc="Separating sounds...")
182
+ with torch.inference_mode():
183
+ result = model.separate(inputs, predict_spans=False, reranking_candidates=1)
184
+
185
+ progress(0.9, desc="Saving results...")
186
+ sample_rate = processor.audio_sampling_rate
187
+ target_path = save_audio(result.target[0].unsqueeze(0).cpu(), sample_rate)
188
+ residual_path = save_audio(result.residual[0].unsqueeze(0).cpu(), sample_rate)
189
+
190
+ progress(1.0, desc="Done!")
191
+ return target_path, residual_path, f"βœ… Isolated '{text_prompt}' using {model_name}"
192
  except Exception as e:
193
  import traceback
194
  traceback.print_exc()
 
211
  label="Model"
212
  )
213
 
214
+ with gr.Accordion("βš™οΈ Advanced Options", open=False):
215
+ chunk_duration_slider = gr.Slider(
216
+ minimum=10,
217
+ maximum=60,
218
+ value=DEFAULT_CHUNK_DURATION,
219
+ step=5,
220
+ label="Chunk Duration (seconds)",
221
+ info=f"Audio longer than {MAX_DURATION_WITHOUT_CHUNKING}s will be automatically split"
222
+ )
223
+
224
  gr.Markdown("#### Upload Audio")
225
  input_audio = gr.Audio(label="Audio File", type="filepath")
226
 
 
251
  example_btn3 = gr.Button("🎡 Background Music")
252
 
253
  # Main process button
254
+ def process(model_name, audio_path, video_path, prompt, chunk_duration, progress=gr.Progress()):
255
  file_path = video_path if video_path else audio_path
256
+ return separate_audio(model_name, file_path, prompt, chunk_duration, progress)
257
 
258
  run_btn.click(
259
  fn=process,
260
+ inputs=[model_selector, input_audio, input_video, text_prompt, chunk_duration_slider],
261
  outputs=[output_target, output_residual, status_output]
262
  )
263