peoplepilot commited on
Commit
3afa406
Β·
1 Parent(s): f2bdfc3

chore: initial setup

Browse files
Files changed (2) hide show
  1. app.py +290 -4
  2. requirements.txt +8 -0
app.py CHANGED
@@ -1,7 +1,293 @@
 
1
  import gradio as gr
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
  import gradio as gr
3
+ import torch
4
+ import torchaudio
5
+ import tempfile
6
+ import warnings
7
+ import os
8
+ warnings.filterwarnings("ignore")
9
 
10
+ from sam_audio import SAMAudio, SAMAudioProcessor
 
11
 
12
+ # Available models
13
+ MODELS = {
14
+ "sam-audio-small": "facebook/sam-audio-small",
15
+ "sam-audio-base": "facebook/sam-audio-base",
16
+ "sam-audio-large": "facebook/sam-audio-large",
17
+ "sam-audio-small-tv (Visual)": "facebook/sam-audio-small-tv",
18
+ "sam-audio-base-tv (Visual)": "facebook/sam-audio-base-tv",
19
+ "sam-audio-large-tv (Visual)": "facebook/sam-audio-large-tv",
20
+ }
21
+
22
+ DEFAULT_MODEL = "sam-audio-small"
23
+ EXAMPLES_DIR = "audio"
24
+ EXAMPLE_FILE = os.path.join(EXAMPLES_DIR, "PromoterClipMono.wav")
25
+
26
+ # Chunk processing settings
27
+ DEFAULT_CHUNK_DURATION = 30 # seconds per chunk
28
+ OVERLAP_DURATION = 2 # seconds of overlap between chunks
29
+ MAX_DURATION_WITHOUT_CHUNKING = 60 # auto-chunk if longer than this
30
+
31
+ # Global model cache
32
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
+ current_model_name = None
34
+ model = None
35
+ processor = None
36
+
37
+ def load_model(model_name):
38
+ global current_model_name, model, processor
39
+ model_id = MODELS.get(model_name, MODELS[DEFAULT_MODEL])
40
+ if current_model_name == model_name and model is not None:
41
+ return
42
+ print(f"Loading {model_id}...")
43
+ model = SAMAudio.from_pretrained(model_id).to(device).eval()
44
+ processor = SAMAudioProcessor.from_pretrained(model_id)
45
+ current_model_name = model_name
46
+ print(f"Model {model_id} loaded on {device}.")
47
+
48
+ load_model(DEFAULT_MODEL)
49
+
50
+ def load_audio(file_path):
51
+ """Load audio from file (supports both audio and video files)."""
52
+ waveform, sample_rate = torchaudio.load(file_path)
53
+ # Convert to mono if stereo
54
+ if waveform.shape[0] > 1:
55
+ waveform = waveform.mean(dim=0, keepdim=True)
56
+ return waveform, sample_rate
57
+
58
+ def split_audio_into_chunks(waveform, sample_rate, chunk_duration, overlap_duration):
59
+ """Split audio waveform into overlapping chunks."""
60
+ chunk_samples = int(chunk_duration * sample_rate)
61
+ overlap_samples = int(overlap_duration * sample_rate)
62
+ stride = chunk_samples - overlap_samples
63
+
64
+ chunks = []
65
+ total_samples = waveform.shape[1]
66
+
67
+ start = 0
68
+ while start < total_samples:
69
+ end = min(start + chunk_samples, total_samples)
70
+ chunk = waveform[:, start:end]
71
+ chunks.append(chunk)
72
+
73
+ if end >= total_samples:
74
+ break
75
+ start += stride
76
+
77
+ return chunks
78
+
79
+ def merge_chunks_with_crossfade(chunks, sample_rate, overlap_duration):
80
+ """Merge audio chunks with crossfade on overlapping regions."""
81
+ if len(chunks) == 1:
82
+ chunk = chunks[0]
83
+ # Ensure 2D tensor
84
+ if chunk.dim() == 1:
85
+ chunk = chunk.unsqueeze(0)
86
+ return chunk
87
+
88
+ overlap_samples = int(overlap_duration * sample_rate)
89
+
90
+ # Ensure all chunks are 2D [channels, samples]
91
+ processed_chunks = []
92
+ for chunk in chunks:
93
+ if chunk.dim() == 1:
94
+ chunk = chunk.unsqueeze(0)
95
+ processed_chunks.append(chunk)
96
+
97
+ result = processed_chunks[0]
98
+
99
+ for i in range(1, len(processed_chunks)):
100
+ prev_chunk = result
101
+ next_chunk = processed_chunks[i]
102
+
103
+ # Handle case where chunks are shorter than overlap
104
+ actual_overlap = min(overlap_samples, prev_chunk.shape[1], next_chunk.shape[1])
105
+
106
+ if actual_overlap <= 0:
107
+ # No overlap possible, just concatenate
108
+ result = torch.cat([prev_chunk, next_chunk], dim=1)
109
+ continue
110
+
111
+ # Create fade curves
112
+ fade_out = torch.linspace(1.0, 0.0, actual_overlap).to(prev_chunk.device)
113
+ fade_in = torch.linspace(0.0, 1.0, actual_overlap).to(next_chunk.device)
114
+
115
+ # Get overlapping regions
116
+ prev_overlap = prev_chunk[:, -actual_overlap:]
117
+ next_overlap = next_chunk[:, :actual_overlap]
118
+
119
+ # Crossfade mix
120
+ crossfaded = prev_overlap * fade_out + next_overlap * fade_in
121
+
122
+ # Concatenate: non-overlap of prev + crossfaded + non-overlap of next
123
+ result = torch.cat([
124
+ prev_chunk[:, :-actual_overlap],
125
+ crossfaded,
126
+ next_chunk[:, actual_overlap:]
127
+ ], dim=1)
128
+
129
+ return result
130
+
131
+ def save_audio(tensor, sample_rate):
132
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
133
+ torchaudio.save(tmp.name, tensor, sample_rate)
134
+ return tmp.name
135
+
136
+ @spaces.GPU(duration=300)
137
+ def separate_audio(model_name, file_path, text_prompt, chunk_duration=DEFAULT_CHUNK_DURATION, progress=gr.Progress()):
138
+ global model, processor
139
+
140
+ progress(0.05, desc="Checking inputs...")
141
+
142
+ if not file_path:
143
+ return None, None, "❌ Please upload an audio file."
144
+ if not text_prompt or not text_prompt.strip():
145
+ return None, None, "❌ Please enter a text prompt."
146
+
147
+ try:
148
+ progress(0.1, desc="Loading model...")
149
+ load_model(model_name)
150
+
151
+ progress(0.15, desc="Loading audio...")
152
+ waveform, sample_rate = load_audio(file_path)
153
+ duration = waveform.shape[1] / sample_rate
154
+
155
+ # Decide whether to use chunking
156
+ use_chunking = duration > MAX_DURATION_WITHOUT_CHUNKING
157
+
158
+ if use_chunking:
159
+ progress(0.2, desc=f"Audio is {duration:.1f}s, splitting into chunks...")
160
+ chunks = split_audio_into_chunks(waveform, sample_rate, chunk_duration, OVERLAP_DURATION)
161
+ num_chunks = len(chunks)
162
+
163
+ target_chunks = []
164
+ residual_chunks = []
165
+
166
+ for i, chunk in enumerate(chunks):
167
+ chunk_progress = 0.2 + (i / num_chunks) * 0.6
168
+ progress(chunk_progress, desc=f"Processing chunk {i+1}/{num_chunks}...")
169
+
170
+ # Save chunk to temp file for processor
171
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
172
+ torchaudio.save(tmp.name, chunk, sample_rate)
173
+ chunk_path = tmp.name
174
+
175
+ try:
176
+ inputs = processor(audios=[chunk_path], descriptions=[text_prompt.strip()]).to(device)
177
+
178
+ with torch.inference_mode():
179
+ result = model.separate(inputs, predict_spans=False, reranking_candidates=1)
180
+
181
+ target_chunks.append(result.target[0].cpu())
182
+ residual_chunks.append(result.residual[0].cpu())
183
+ finally:
184
+ os.unlink(chunk_path)
185
+
186
+ progress(0.85, desc="Merging chunks...")
187
+ target_merged = merge_chunks_with_crossfade(target_chunks, sample_rate, OVERLAP_DURATION)
188
+ residual_merged = merge_chunks_with_crossfade(residual_chunks, sample_rate, OVERLAP_DURATION)
189
+
190
+ progress(0.95, desc="Saving results...")
191
+ # merged tensors are already 2D [channels, samples]
192
+ target_path = save_audio(target_merged, sample_rate)
193
+ residual_path = save_audio(residual_merged, sample_rate)
194
+
195
+ progress(1.0, desc="Done!")
196
+ return target_path, residual_path, f"βœ… Isolated '{text_prompt}' using {model_name} ({num_chunks} chunks)"
197
+ else:
198
+ # Process without chunking
199
+ progress(0.3, desc="Processing audio...")
200
+ inputs = processor(audios=[file_path], descriptions=[text_prompt.strip()]).to(device)
201
+
202
+ progress(0.6, desc="Separating sounds...")
203
+ with torch.inference_mode():
204
+ result = model.separate(inputs, predict_spans=False, reranking_candidates=1)
205
+
206
+ progress(0.9, desc="Saving results...")
207
+ sample_rate = processor.audio_sampling_rate
208
+ target_path = save_audio(result.target[0].unsqueeze(0).cpu(), sample_rate)
209
+ residual_path = save_audio(result.residual[0].unsqueeze(0).cpu(), sample_rate)
210
+
211
+ progress(1.0, desc="Done!")
212
+ return target_path, residual_path, f"βœ… Isolated '{text_prompt}' using {model_name}"
213
+ except Exception as e:
214
+ import traceback
215
+ traceback.print_exc()
216
+ return None, None, f"❌ Error: {str(e)}"
217
+
218
+ # Build Interface
219
+ with gr.Blocks(title="SAM-Audio Test") as demo:
220
+ gr.Markdown(
221
+ """
222
+ # 🎡 SAM-Audio: Segment Anything for Audio
223
+ Isolate specific sounds from audio or video using natural language prompts.
224
+ """
225
+ )
226
+
227
+ with gr.Row():
228
+ with gr.Column(scale=1):
229
+ model_selector = gr.Dropdown(
230
+ choices=list(MODELS.keys()),
231
+ value=DEFAULT_MODEL,
232
+ label="Model"
233
+ )
234
+
235
+ with gr.Accordion("βš™οΈ Advanced Options", open=False):
236
+ chunk_duration_slider = gr.Slider(
237
+ minimum=10,
238
+ maximum=60,
239
+ value=DEFAULT_CHUNK_DURATION,
240
+ step=5,
241
+ label="Chunk Duration (seconds)",
242
+ info=f"Audio longer than {MAX_DURATION_WITHOUT_CHUNKING}s will be automatically split"
243
+ )
244
+
245
+ gr.Markdown("#### Upload Audio")
246
+ input_audio = gr.Audio(label="Audio File", type="filepath")
247
+
248
+ text_prompt = gr.Textbox(
249
+ label="Text Prompt",
250
+ placeholder="e.g., 'guitar', 'voice'"
251
+ )
252
+
253
+ run_btn = gr.Button("🎯 Isolate Sound", variant="primary", size="lg")
254
+ status_output = gr.Markdown("")
255
+
256
+ with gr.Column(scale=1):
257
+ gr.Markdown("### Results")
258
+ output_target = gr.Audio(label="Isolated Sound (Target)")
259
+ output_residual = gr.Audio(label="Background (Residual)")
260
+
261
+ gr.Markdown("---")
262
+ gr.Markdown("### 🎬 Demo Examples")
263
+ gr.Markdown("Click to load example audio and prompt:")
264
+
265
+ with gr.Row():
266
+ if os.path.exists(EXAMPLE_FILE):
267
+ example_btn1 = gr.Button("🎀 Man Speaking")
268
+ example_btn2 = gr.Button("🎀 Woman Speaking")
269
+ example_btn3 = gr.Button("🎡 Background Music")
270
+
271
+ # Main process button
272
+ def process(model_name, audio_path, prompt, chunk_duration, progress=gr.Progress()):
273
+ return separate_audio(model_name, audio_path, prompt, chunk_duration, progress)
274
+
275
+ run_btn.click(
276
+ fn=process,
277
+ inputs=[model_selector, input_audio, text_prompt, chunk_duration_slider],
278
+ outputs=[output_target, output_residual, status_output]
279
+ )
280
+
281
+ # Example buttons - just fill the prompt, user clicks button to process
282
+ if os.path.exists(EXAMPLE_FILE):
283
+ example_btn1.click(
284
+ fn=lambda: (EXAMPLE_FILE, "Guitar"),
285
+ outputs=[input_audio, text_prompt]
286
+ )
287
+ example_btn2.click(
288
+ fn=lambda: (EXAMPLE_FILE, "Voice"),
289
+ outputs=[input_audio, text_prompt]
290
+ )
291
+
292
+ if __name__ == "__main__":
293
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ gradio>=5.0.0
2
+ torch
3
+ transformers
4
+ huggingface_hub
5
+ spaces
6
+ torchaudio
7
+ scipy
8
+ git+https://github.com/hx23840/sam-audio.git