ray-006 commited on
Commit
daa94ed
Β·
verified Β·
1 Parent(s): 6bdfb90

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +299 -59
app.py CHANGED
@@ -1,71 +1,311 @@
1
  import gradio as gr
2
  import torch
3
  import torchaudio
 
4
  import os
5
  import tempfile
6
- from huggingface_hub import hf_hub_download
7
- from sam_audio import SAMAudio, SAMAudioProcessor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
- model = SAMAudio.from_pretrained("facebook/sam-audio-large",
11
- token=os.environ.get("HF_TOKEN")).to(device).eval()
12
- processor = SAMAudioProcessor.from_pretrained("facebook/sam-audio-large")
13
-
14
- def separate_audio(audio_path, description, reranking_candidates):
15
- if audio_path is None or not description:
16
- return None, None
17
-
18
- # Process inputs
19
- inputs = processor(audios=[audio_path], descriptions=[description]).to(device)
20
-
21
- with torch.inference_mode():
22
- # Using reranking if candidates > 1
23
- result = model.separate(
24
- inputs,
25
- predict_spans=True,
26
- reranking_candidates=int(reranking_candidates)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  )
28
 
29
- # Use temporary files to store the results for Gradio
30
- target_path = os.path.join(tempfile.gettempdir(), "target.wav")
31
- residual_path = os.path.join(tempfile.gettempdir(), "residual.wav")
32
-
33
- # Save target and residual
34
- torchaudio.save(target_path, result.target[0].unsqueeze(0).cpu(), processor.audio_sampling_rate)
35
- torchaudio.save(residual_path, result.residual[0].unsqueeze(0).cpu(), processor.audio_sampling_rate)
36
-
37
- return target_path, residual_path
38
-
39
- # --- UI Design ---
40
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
41
- gr.Markdown("# 🎡 SAM-Audio Separation")
42
- gr.Markdown("Upload an audio file and describe the specific sound you want to isolate (e.g., 'A dog barking' or 'A man speaking').")
43
-
44
- with gr.Row():
45
- with gr.Column():
46
- input_audio = gr.Audio(label="Input Audio", type="filepath")
47
- description = gr.Textbox(
48
- label="What do you want to isolate?",
49
- placeholder="e.g. A person laughing"
50
- )
51
- rerank_slider = gr.Slider(
52
- minimum=1,
53
- maximum=16,
54
- value=1,
55
- step=1,
56
- label="Reranking Candidates",
57
- info="Higher values improve quality but increase processing time."
58
- )
59
- btn = gr.Button("Separate Sound", variant="primary")
60
-
61
- with gr.Column():
62
- output_target = gr.Audio(label="Isolated (Target) Audio")
63
- output_residual = gr.Audio(label="Residual Audio")
64
-
65
- btn.click(
66
- fn=separate_audio,
67
- inputs=[input_audio, description, rerank_slider],
68
- outputs=[output_target, output_residual]
 
 
 
 
 
69
  )
70
 
71
  if __name__ == "__main__":
 
1
  import gradio as gr
2
  import torch
3
  import torchaudio
4
+ import numpy as np
5
  import os
6
  import tempfile
7
+ import spaces
8
+
9
+ from typing import Iterable
10
+ from gradio.themes import Soft
11
+ from gradio.themes.utils import colors, fonts, sizes
12
+
13
+ # --- Custom Theme Configuration ---
14
+ class MidnightTheme(Soft):
15
+ def __init__(self):
16
+ super().__init__(
17
+ # Using your specific text and button colors for the palettes
18
+ primary_hue=colors.Color(
19
+ name="brand",
20
+ c50="#eef2ff", c100="#e0e7ff", c200="#c7d2fe", c300="#a5b4fc",
21
+ c400="#818cf8", c500="#5248e9", c600="#4f46e5", c700="#4338ca",
22
+ c800="#3730a3", c900="#312e81", c950="#1e1b4b"
23
+ ),
24
+ neutral_hue=colors.Color(
25
+ name="dark_slate",
26
+ c50="#f8fafc", c100="#f1f5f9", c200="#e2e8f0", c300="#cbd5e1",
27
+ c400="#94a3b8", c500="#64748b", c600="#51748c", c700="#334155", # c600 is your secondary text
28
+ c800="#20293c", c900="#10172b", c950="#030617" # c800-950 are your BG/Button darks
29
+ ),
30
+ font=(fonts.GoogleFont("Outfit"), "Arial", "sans-serif"),
31
+ )
32
+ super().set(
33
+ # Backgrounds
34
+ body_background_fill="#030617",
35
+ block_background_fill="#10172b",
36
+ block_border_color="#20293c",
37
+
38
+ # Text Colors
39
+ body_text_color="#cdd6e2",
40
+ block_label_text_color="#51748c",
41
+ block_title_text_color="#cdd6e2",
42
+
43
+ # Buttons
44
+ button_primary_background_fill="#5248e9",
45
+ button_primary_text_color="white",
46
+ button_secondary_background_fill="#20293c",
47
+ button_secondary_text_color="#cdd6e2",
48
+
49
+ # Inputs
50
+ input_background_fill="#030617",
51
+ input_border_color="#20293c",
52
+ )
53
+
54
+ midnight_theme = MidnightTheme()
55
+
56
+ # --- CSS for Layout Polish ---
57
+ css = """
58
+ #container { max-width: 1000px; margin: auto; padding-top: 2rem; }
59
+ #title-area { text-align: center; margin-bottom: 2rem; }
60
+ .gradio-container { background-color: #030617 !important; }
61
+ .output-audio { background-color: #030617 !important; }
62
+ """
63
+
64
+ try:
65
+ from sam_audio import SAMAudio, SAMAudioProcessor
66
+ except ImportError as e:
67
+ print(f"Warning: 'sam_audio' library not found. Please install it to use this app. Error: {e}")
68
+
69
+ MODEL_ID = "facebook/sam-audio-large"
70
+ DEFAULT_CHUNK_DURATION = 30.0
71
+ OVERLAP_DURATION = 2.0
72
+ MAX_DURATION_WITHOUT_CHUNKING = 30.0
73
 
74
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
75
+ print(f"Loading {MODEL_ID} on {device}...")
76
+
77
+ try:
78
+ model = SAMAudio.from_pretrained(MODEL_ID,token=os.environ.get("HF_TOKEN")).to(device).eval()
79
+ processor = SAMAudioProcessor.from_pretrained(MODEL_ID)
80
+
81
+ print("βœ… SAM-Audio loaded successfully.")
82
+ except Exception as e:
83
+ print(f"❌ Error loading SAM-Audio: {e}")
84
+
85
+ def load_audio(file_path):
86
+ """Load audio from file (supports both audio and video files)."""
87
+ waveform, sample_rate = torchaudio.load(file_path)
88
+ if waveform.shape[0] > 1:
89
+ waveform = waveform.mean(dim=0, keepdim=True)
90
+ return waveform, sample_rate
91
+
92
+ def split_audio_into_chunks(waveform, sample_rate, chunk_duration, overlap_duration):
93
+ """Split audio waveform into overlapping chunks."""
94
+ chunk_samples = int(chunk_duration * sample_rate)
95
+ overlap_samples = int(overlap_duration * sample_rate)
96
+ stride = chunk_samples - overlap_samples
97
+
98
+ chunks = []
99
+ total_samples = waveform.shape[1]
100
+
101
+ if total_samples <= chunk_samples:
102
+ return [waveform]
103
+
104
+ start = 0
105
+ while start < total_samples:
106
+ end = min(start + chunk_samples, total_samples)
107
+ chunk = waveform[:, start:end]
108
+ chunks.append(chunk)
109
+ if end >= total_samples:
110
+ break
111
+ start += stride
112
+
113
+ return chunks
114
+
115
+ def merge_chunks_with_crossfade(chunks, sample_rate, overlap_duration):
116
+ """Merge audio chunks with crossfade on overlapping regions."""
117
+ if len(chunks) == 1:
118
+ chunk = chunks[0]
119
+ if chunk.dim() == 1:
120
+ chunk = chunk.unsqueeze(0)
121
+ return chunk
122
+
123
+ overlap_samples = int(overlap_duration * sample_rate)
124
+
125
+ processed_chunks = []
126
+ for chunk in chunks:
127
+ if chunk.dim() == 1:
128
+ chunk = chunk.unsqueeze(0)
129
+ processed_chunks.append(chunk)
130
+
131
+ result = processed_chunks[0]
132
+
133
+ for i in range(1, len(processed_chunks)):
134
+ prev_chunk = result
135
+ next_chunk = processed_chunks[i]
136
+
137
+ actual_overlap = min(overlap_samples, prev_chunk.shape[1], next_chunk.shape[1])
138
+
139
+ if actual_overlap <= 0:
140
+ result = torch.cat([prev_chunk, next_chunk], dim=1)
141
+ continue
142
+
143
+ fade_out = torch.linspace(1.0, 0.0, actual_overlap).to(prev_chunk.device)
144
+ fade_in = torch.linspace(0.0, 1.0, actual_overlap).to(next_chunk.device)
145
+
146
+ prev_overlap = prev_chunk[:, -actual_overlap:]
147
+ next_overlap = next_chunk[:, :actual_overlap]
148
+
149
+ crossfaded = prev_overlap * fade_out + next_overlap * fade_in
150
+
151
+ result = torch.cat([
152
+ prev_chunk[:, :-actual_overlap],
153
+ crossfaded,
154
+ next_chunk[:, actual_overlap:]
155
+ ], dim=1)
156
+
157
+ return result
158
+
159
+ def save_audio(tensor, sample_rate):
160
+ """Saves a tensor to a temporary WAV file and returns path."""
161
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
162
+ tensor = tensor.cpu()
163
+ if tensor.dim() == 1:
164
+ tensor = tensor.unsqueeze(0)
165
+ torchaudio.save(tmp.name, tensor, sample_rate)
166
+ return tmp.name
167
+
168
+ @spaces.GPU(duration=120)
169
+ def process_audio(file_path, text_prompt, chunk_duration_val, progress=gr.Progress()):
170
+ global model, processor
171
+
172
+ if model is None or processor is None:
173
+ return None, None, "❌ Model not loaded correctly. Check logs."
174
+
175
+ progress(0.05, desc="Checking inputs...")
176
+
177
+ if not file_path:
178
+ return None, None, "❌ Please upload an audio or video file."
179
+ if not text_prompt or not text_prompt.strip():
180
+ return None, None, "❌ Please enter a text prompt."
181
+
182
+ try:
183
+ progress(0.15, desc="Loading audio...")
184
+ waveform, sample_rate = load_audio(file_path)
185
+ duration = waveform.shape[1] / sample_rate
186
+
187
+ c_dur = chunk_duration_val if chunk_duration_val else DEFAULT_CHUNK_DURATION
188
+ use_chunking = duration > MAX_DURATION_WITHOUT_CHUNKING
189
+
190
+ if use_chunking:
191
+ progress(0.2, desc=f"Audio is {duration:.1f}s, splitting into chunks...")
192
+ chunks = split_audio_into_chunks(waveform, sample_rate, c_dur, OVERLAP_DURATION)
193
+ num_chunks = len(chunks)
194
+
195
+ target_chunks = []
196
+ residual_chunks = []
197
+
198
+ for i, chunk in enumerate(chunks):
199
+ chunk_progress = 0.2 + (i / num_chunks) * 0.6
200
+ progress(chunk_progress, desc=f"Processing chunk {i+1}/{num_chunks}...")
201
+
202
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
203
+ torchaudio.save(tmp.name, chunk, sample_rate)
204
+ chunk_path = tmp.name
205
+
206
+ try:
207
+ inputs = processor(audios=[chunk_path], descriptions=[text_prompt.strip()]).to(device)
208
+
209
+ with torch.inference_mode():
210
+ result = model.separate(inputs, predict_spans=False, reranking_candidates=1)
211
+
212
+ target_chunks.append(result.target[0].detach().cpu())
213
+ residual_chunks.append(result.residual[0].detach().cpu())
214
+ finally:
215
+ if os.path.exists(chunk_path):
216
+ os.unlink(chunk_path)
217
+
218
+ progress(0.85, desc="Merging chunks...")
219
+ target_merged = merge_chunks_with_crossfade(target_chunks, sample_rate, OVERLAP_DURATION)
220
+ residual_merged = merge_chunks_with_crossfade(residual_chunks, sample_rate, OVERLAP_DURATION)
221
+
222
+ progress(0.95, desc="Saving results...")
223
+ target_path = save_audio(target_merged, sample_rate)
224
+ residual_path = save_audio(residual_merged, sample_rate)
225
+
226
+ progress(1.0, desc="Done!")
227
+ return target_path, residual_path, f"βœ… Isolated '{text_prompt}' ({num_chunks} chunks)"
228
+
229
+ else:
230
+ progress(0.3, desc="Processing audio...")
231
+ inputs = processor(audios=[file_path], descriptions=[text_prompt.strip()]).to(device)
232
+
233
+ progress(0.6, desc="Separating sounds...")
234
+ with torch.inference_mode():
235
+ result = model.separate(inputs, predict_spans=False, reranking_candidates=1)
236
+
237
+ progress(0.9, desc="Saving results...")
238
+ sr = processor.audio_sampling_rate
239
+ target_path = save_audio(result.target[0].unsqueeze(0).cpu(), sr)
240
+ residual_path = save_audio(result.residual[0].unsqueeze(0).cpu(), sr)
241
+
242
+ progress(1.0, desc="Done!")
243
+ return target_path, residual_path, f"βœ… Isolated '{text_prompt}'"
244
+
245
+ except Exception as e:
246
+ import traceback
247
+ traceback.print_exc()
248
+ return None, None, f"❌ Error: {str(e)}"
249
+
250
+ def dummy_process(file, text, duration): # Placeholder for structure
251
+ return None, None, "Processing..."
252
+
253
+ with gr.Blocks(theme=midnight_theme, css=css) as demo:
254
+ with gr.Column(elem_id="container"):
255
+ # Header Section
256
+ gr.Markdown(
257
+ """
258
+ # πŸŽ™οΈ SAM-Audio Segmenter
259
+ ### Isolate specific sounds using natural language descriptions.
260
+ """,
261
+ elem_id="title-area"
262
  )
263
 
264
+ with gr.Row(equal_height=True):
265
+ # Left Side: Inputs
266
+ with gr.Column(scale=1):
267
+ with gr.Group():
268
+ gr.Markdown("### 1. Upload & Describe")
269
+ input_file = gr.Audio(label="Input Audio Source", type="filepath")
270
+ text_prompt = gr.Textbox(
271
+ label="Target Sound",
272
+ placeholder="e.g. 'electric guitar solo' or 'birds chirping'",
273
+ info="What sound should we isolate from the background?"
274
+ )
275
+
276
+ with gr.Accordion("Advanced Processing Settings", open=False):
277
+ chunk_duration_slider = gr.Slider(
278
+ minimum=10, maximum=60, value=30, step=5,
279
+ label="Chunk Duration (s)",
280
+ info="Shorter chunks save memory for long files."
281
+ )
282
+
283
+ run_btn = gr.Button("πŸš€ Start Separation", variant="primary")
284
+
285
+ # Right Side: Outputs
286
+ with gr.Column(scale=1):
287
+ with gr.Group():
288
+ gr.Markdown("### 2. Results")
289
+ output_target = gr.Audio(label="Isolated Result", type="filepath")
290
+ output_residual = gr.Audio(label="Background / Remainder", type="filepath")
291
+ status_out = gr.Textbox(label="Status Log", interactive=False, lines=2)
292
+
293
+ # Examples Section at Bottom
294
+ gr.Markdown("---")
295
+ gr.Examples(
296
+ examples=[
297
+ ["example_audio/speech.mp3", "Music", 30],
298
+ ["example_audio/song.mp3", "Drum", 30]
299
+ ],
300
+ inputs=[input_file, text_prompt, chunk_duration_slider],
301
+ label="Try an Example"
302
+ )
303
+
304
+ # Event Binding
305
+ run_btn.click(
306
+ fn=process_audio, # Use your real function here
307
+ inputs=[input_file, text_prompt, chunk_duration_slider],
308
+ outputs=[output_target, output_residual, status_out]
309
  )
310
 
311
  if __name__ == "__main__":