NeoPy commited on
Commit
a5b1635
Β·
verified Β·
1 Parent(s): f901b88

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +505 -0
app.py ADDED
@@ -0,0 +1,505 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Enhanced Audio Separator Demo
3
+ Advanced audio source separation with latest models and features
4
+ """
5
+ import os
6
+ import json
7
+ import torch
8
+ import logging
9
+ import traceback
10
+ from typing import Dict, List, Optional
11
+ import time
12
+ from datetime import datetime
13
+
14
+ import gradio as gr
15
+ import numpy as np
16
+ import librosa
17
+ import soundfile as sf
18
+ from pydub import AudioSegment
19
+ from audio_separator.separator import Separator
20
+ from audio_separator.separator import architectures
21
+
22
+
23
+ class AudioSeparatorDemo:
24
+ def __init__(self):
25
+ self.separator = None
26
+ self.available_models = {}
27
+ self.current_model = None
28
+ self.processing_history = []
29
+ self.setup_logging()
30
+
31
+ def setup_logging(self):
32
+ """Setup logging for the application"""
33
+ logging.basicConfig(level=logging.INFO)
34
+ self.logger = logging.getLogger(__name__)
35
+
36
+ def get_system_info(self):
37
+ """Get system information for hardware acceleration"""
38
+ info = {
39
+ "pytorch_version": torch.__version__,
40
+ "cuda_available": torch.cuda.is_available(),
41
+ "cuda_version": torch.version.cuda if torch.cuda.is_available() else "N/A",
42
+ "mps_available": hasattr(torch.backends, "mps") and torch.backends.mps.is_available(),
43
+ "device": "cuda" if torch.cuda.is_available() else ("mps" if hasattr(torch.backends, "mps") and torch.backends.mps.is_available() else "cpu")
44
+ }
45
+ return info
46
+
47
+ def initialize_separator(self, model_name: str = None, **kwargs):
48
+ """Initialize the separator with specified parameters"""
49
+ try:
50
+ # Clean up previous separator if exists
51
+ if self.separator is not None:
52
+ del self.separator
53
+ torch.cuda.empty_cache()
54
+
55
+ # Set default model if not specified
56
+ if model_name is None:
57
+ model_name = "model_bs_roformer_ep_317_sdr_12.9755.ckpt"
58
+
59
+ # Initialize separator with updated parameters
60
+ self.separator = Separator(
61
+ output_format="WAV",
62
+ use_autocast=True,
63
+ use_soundfile=True,
64
+ **kwargs
65
+ )
66
+
67
+ # Load the model
68
+ self.separator.load_model(model_name)
69
+ self.current_model = model_name
70
+
71
+ return True, f"Successfully initialized with model: {model_name}"
72
+
73
+ except Exception as e:
74
+ self.logger.error(f"Error initializing separator: {str(e)}")
75
+ return False, f"Error initializing separator: {str(e)}"
76
+
77
+ def get_available_models(self):
78
+ """Get list of available models with descriptions and scores"""
79
+ try:
80
+ if self.separator is None:
81
+ self.separator = Separator(info_only=True)
82
+
83
+ models = self.separator.list_supported_model_files()
84
+ simplified_models = self.separator.get_simplified_model_list()
85
+
86
+ return simplified_models
87
+ except Exception as e:
88
+ self.logger.error(f"Error getting available models: {str(e)}")
89
+ return {}
90
+
91
+ def process_audio(self, audio_file, model_name, output_format="WAV", quality_preset="Standard",
92
+ custom_params=None):
93
+ """Process audio with the selected model"""
94
+ if audio_file is None:
95
+ return None, "No audio file provided"
96
+
97
+ if model_name is None:
98
+ return None, "No model selected"
99
+
100
+ if self.separator is None or self.current_model != model_name:
101
+ success, message = self.initialize_separator(model_name)
102
+ if not success:
103
+ return None, message
104
+
105
+ try:
106
+ start_time = time.time()
107
+
108
+ # Update separator parameters based on quality preset
109
+ if custom_params is None:
110
+ custom_params = {}
111
+
112
+ # Apply quality preset
113
+ if quality_preset == "Fast":
114
+ custom_params.update({
115
+ "mdx_params": {"batch_size": 4, "overlap": 0.1, "segment_size": 128},
116
+ "vr_params": {"batch_size": 8, "aggression": 3},
117
+ "demucs_params": {"shifts": 1, "overlap": 0.1},
118
+ "mdxc_params": {"batch_size": 4, "overlap": 4}
119
+ })
120
+ elif quality_preset == "High Quality":
121
+ custom_params.update({
122
+ "mdx_params": {"batch_size": 1, "overlap": 0.5, "segment_size": 512, "enable_denoise": True},
123
+ "vr_params": {"batch_size": 1, "aggression": 8, "enable_tta": True, "enable_post_process": True},
124
+ "demucs_params": {"shifts": 4, "overlap": 0.5, "segments_enabled": False},
125
+ "mdxc_params": {"batch_size": 1, "overlap": 16, "pitch_shift": 0}
126
+ })
127
+
128
+ # Update separator parameters
129
+ for key, value in custom_params.items():
130
+ if hasattr(self.separator, key):
131
+ setattr(self.separator, key, value)
132
+
133
+ # Process the audio
134
+ output_files = self.separator.separate(audio_file.name)
135
+
136
+ processing_time = time.time() - start_time
137
+
138
+ # Read and prepare output audio
139
+ output_audio = {}
140
+ for file_path in output_files:
141
+ if os.path.exists(file_path):
142
+ # Create output with appropriate naming
143
+ stem_name = os.path.splitext(os.path.basename(file_path))[0]
144
+ audio_data, sample_rate = sf.read(file_path)
145
+ output_audio[stem_name] = (sample_rate, audio_data)
146
+
147
+ # Clean up file
148
+ os.remove(file_path)
149
+
150
+ if not output_audio:
151
+ return None, "No output files generated"
152
+
153
+ # Record processing history
154
+ history_entry = {
155
+ "timestamp": datetime.now().isoformat(),
156
+ "model": model_name,
157
+ "processing_time": processing_time,
158
+ "output_files": list(output_audio.keys())
159
+ }
160
+ self.processing_history.append(history_entry)
161
+
162
+ return output_audio, f"Processing completed in {processing_time:.2f} seconds with model: {model_name}"
163
+
164
+ except Exception as e:
165
+ error_msg = f"Error processing audio: {str(e)}"
166
+ self.logger.error(f"{error_msg}\n{traceback.format_exc()}")
167
+ return None, error_msg
168
+
169
+ def get_processing_history(self):
170
+ """Get processing history for display"""
171
+ if not self.processing_history:
172
+ return "No processing history available"
173
+
174
+ history_text = "Processing History:\n\n"
175
+ for i, entry in enumerate(self.processing_history[-10:], 1): # Show last 10 entries
176
+ history_text += f"{i}. {entry['timestamp']}\n"
177
+ history_text += f" Model: {entry['model']}\n"
178
+ history_text += f" Time: {entry['processing_time']:.2f}s\n"
179
+ history_text += f" Stems: {', '.join(entry['output_files'])}\n\n"
180
+
181
+ return history_text
182
+
183
+ def reset_history(self):
184
+ """Reset processing history"""
185
+ self.processing_history = []
186
+ return "Processing history cleared"
187
+
188
+ def compare_models(self, audio_file, model_list):
189
+ """Compare multiple models on the same audio"""
190
+ if audio_file is None or not model_list:
191
+ return None, "No audio file or models selected for comparison"
192
+
193
+ results = {}
194
+
195
+ for model_name in model_list:
196
+ try:
197
+ start_time = time.time()
198
+ success, message = self.initialize_separator(model_name)
199
+
200
+ if not success:
201
+ results[model_name] = {"error": message}
202
+ continue
203
+
204
+ output_files = self.separator.separate(audio_file.name)
205
+ processing_time = time.time() - start_time
206
+
207
+ # Analyze the first output file for basic metrics
208
+ if output_files and os.path.exists(output_files[0]):
209
+ audio_data, sample_rate = sf.read(output_files[0])
210
+
211
+ results[model_name] = {
212
+ "processing_time": processing_time,
213
+ "output_files": len(output_files),
214
+ "sample_rate": sample_rate,
215
+ "duration": len(audio_data) / sample_rate,
216
+ "status": "Success"
217
+ }
218
+
219
+ # Clean up
220
+ for file_path in output_files:
221
+ if os.path.exists(file_path):
222
+ os.remove(file_path)
223
+ else:
224
+ results[model_name] = {"status": "Failed", "error": "No output files generated"}
225
+
226
+ except Exception as e:
227
+ results[model_name] = {"status": "Error", "error": str(e)}
228
+
229
+ return results, f"Model comparison completed for {len(model_list)} models"
230
+
231
+
232
+ # Initialize the demo
233
+ demo = AudioSeparatorDemo()
234
+
235
+ def create_interface():
236
+ """Create the Gradio interface"""
237
+
238
+ with gr.Blocks(
239
+ title="Enhanced Audio Separator",
240
+ theme=gr.themes.Soft(
241
+ primary_hue="blue",
242
+ secondary_hue="slate",
243
+ font=[gr.themes.GoogleFont("Inter"), "Arial", "sans-serif"]
244
+ ),
245
+ css="""
246
+ .container {max-width: 1200px;}
247
+ .model-card {border: 1px solid #e5e7eb; border-radius: 8px; padding: 12px; margin: 8px 0;}
248
+ .quality-advanced {display: none;}
249
+ .processing-status {padding: 8px; border-radius: 4px; margin: 8px 0;}
250
+ """
251
+ ) as interface:
252
+
253
+ # Header
254
+ gr.Markdown(
255
+ """
256
+ # 🎡 Enhanced Audio Separator
257
+
258
+ Advanced audio source separation powered by the latest python-audio-separator library.
259
+ Support for MDX-Net, VR Arch, Demucs, MDXC, and Roformer models with hardware acceleration.
260
+ """,
261
+ elem_classes=["container"]
262
+ )
263
+
264
+ # System Information
265
+ with gr.Accordion("πŸ–₯️ System Information", open=False):
266
+ system_info = demo.get_system_info()
267
+ info_text = f"""
268
+ **PyTorch Version:** {system_info['pytorch_version']}
269
+
270
+ **Hardware Acceleration:** {system_info['device'].upper()}
271
+
272
+ **CUDA Available:** {system_info['cuda_available']} (Version: {system_info['cuda_version']})
273
+
274
+ **Apple Silicon (MPS):** {system_info['mps_available']}
275
+ """
276
+ gr.Markdown(info_text)
277
+
278
+ with gr.Row():
279
+ with gr.Column(scale=2):
280
+ # Main audio input
281
+ audio_input = gr.Audio(
282
+ label="Upload Audio File",
283
+ type="filepath",
284
+ format="wav"
285
+ )
286
+
287
+ # Model selection
288
+ model_list = demo.get_available_models()
289
+
290
+ model_dropdown = gr.Dropdown(
291
+ choices=list(model_list.keys()) if model_list else [],
292
+ value="model_bs_roformer_ep_317_sdr_12.9755.ckpt" if model_list else None,
293
+ label="Select Model",
294
+ info="Choose an AI model for audio separation"
295
+ )
296
+
297
+ # Quality preset
298
+ quality_preset = gr.Radio(
299
+ choices=["Fast", "Standard", "High Quality"],
300
+ value="Standard",
301
+ label="Quality Preset",
302
+ info="Choose processing quality vs speed trade-off"
303
+ )
304
+
305
+ # Advanced parameters (collapsible)
306
+ with gr.Accordion("πŸ”§ Advanced Parameters", open=False):
307
+ with gr.Row():
308
+ batch_size = gr.Slider(1, 8, value=1, step=1, label="Batch Size")
309
+ segment_size = gr.Slider(64, 1024, value=256, step=64, label="Segment Size")
310
+ overlap = gr.Slider(0.1, 0.5, value=0.25, step=0.05, label="Overlap")
311
+
312
+ denoise = gr.Checkbox(label="Enable Denoise", value=False)
313
+ tta = gr.Checkbox(label="Enable TTA (Test-Time Augmentation)", value=False)
314
+ post_process = gr.Checkbox(label="Enable Post-Processing", value=False)
315
+
316
+ # Process button
317
+ process_btn = gr.Button("🎡 Separate Audio", variant="primary", size="lg")
318
+
319
+ with gr.Column(scale=2):
320
+ # Results
321
+ status_output = gr.Textbox(label="Status", lines=3)
322
+
323
+ # Output audio components
324
+ with gr.Tabs():
325
+ with gr.TabItem("Vocals"):
326
+ vocals_output = gr.Audio(label="Vocals")
327
+
328
+ with gr.TabItem("Instrumental"):
329
+ instrumental_output = gr.Audio(label="Instrumental")
330
+
331
+ with gr.TabItem("Drums"):
332
+ drums_output = gr.Audio(label="Drums")
333
+
334
+ with gr.TabItem("Bass"):
335
+ bass_output = gr.Audio(label="Bass")
336
+
337
+ with gr.TabItem("Other Stems"):
338
+ other_output = gr.Audio(label="Other")
339
+
340
+ # Download section
341
+ with gr.Accordion("πŸ“₯ Batch Processing", open=False):
342
+ gr.Markdown("Upload multiple files for batch processing:")
343
+ batch_files = gr.File(file_count="multiple", file_types=["audio"], label="Batch Audio Files")
344
+ batch_btn = gr.Button("Process Batch")
345
+ batch_output = gr.File(label="Download Separated Files")
346
+
347
+ # Model Information and Comparison
348
+ with gr.Tabs():
349
+ with gr.TabItem("πŸ“Š Model Information"):
350
+ gr.Markdown("## Available Models and Performance Metrics")
351
+
352
+ # Model information display
353
+ model_info = gr.JSON(value=demo.get_available_models(), label="Model Details")
354
+ refresh_models_btn = gr.Button("πŸ”„ Refresh Models")
355
+
356
+ with gr.Row():
357
+ compare_models_dropdown = gr.Dropdown(
358
+ choices=list(model_list.keys())[:5] if model_list else [],
359
+ multiselect=True,
360
+ label="Select Models to Compare",
361
+ info="Choose up to 5 models for comparison"
362
+ )
363
+ compare_btn = gr.Button("πŸ” Compare Models")
364
+
365
+ comparison_results = gr.JSON(label="Model Comparison Results")
366
+
367
+ with gr.TabItem("πŸ“ˆ Processing History"):
368
+ history_output = gr.Textbox(label="History", lines=10)
369
+ refresh_history_btn = gr.Button("πŸ”„ Refresh History")
370
+ reset_history_btn = gr.Button("πŸ—‘οΈ Clear History")
371
+
372
+ # Event handlers
373
+ def process_audio(audio_file, model_name, quality_preset, batch_size, segment_size,
374
+ overlap, denoise, tta, post_process):
375
+ if not audio_file or not model_name:
376
+ return None, None, None, None, None, "Please upload an audio file and select a model"
377
+
378
+ # Prepare custom parameters
379
+ custom_params = {
380
+ "mdx_params": {
381
+ "batch_size": int(batch_size),
382
+ "segment_size": int(segment_size),
383
+ "overlap": float(overlap),
384
+ "enable_denoise": denoise
385
+ },
386
+ "vr_params": {
387
+ "batch_size": int(batch_size),
388
+ "enable_tta": tta,
389
+ "enable_post_process": post_process
390
+ },
391
+ "demucs_params": {
392
+ "overlap": float(overlap)
393
+ },
394
+ "mdxc_params": {
395
+ "batch_size": int(batch_size),
396
+ "overlap": int(overlap * 10)
397
+ }
398
+ }
399
+
400
+ output_audio, status = demo.process_audio(
401
+ audio_file, model_name, quality_preset=quality_preset, custom_params=custom_params
402
+ )
403
+
404
+ if output_audio is None:
405
+ return None, None, None, None, None, status
406
+
407
+ # Extract different stems
408
+ vocals = None
409
+ instrumental = None
410
+ drums = None
411
+ bass = None
412
+ other = None
413
+
414
+ for stem_name, (sample_rate, audio_data) in output_audio.items():
415
+ if "vocal" in stem_name.lower():
416
+ vocals = (sample_rate, audio_data)
417
+ elif "instrumental" in stem_name.lower():
418
+ instrumental = (sample_rate, audio_data)
419
+ elif "drum" in stem_name.lower():
420
+ drums = (sample_rate, audio_data)
421
+ elif "bass" in stem_name.lower():
422
+ bass = (sample_rate, audio_data)
423
+ else:
424
+ other = (sample_rate, audio_data)
425
+
426
+ return vocals, instrumental, drums, bass, other, status
427
+
428
+ def refresh_models():
429
+ updated_models = demo.get_available_models()
430
+ return (
431
+ gr.update(choices=list(updated_models.keys())),
432
+ updated_models
433
+ )
434
+
435
+ def refresh_history():
436
+ return demo.get_processing_history()
437
+
438
+ def clear_history():
439
+ return demo.reset_history()
440
+
441
+ def compare_selected_models(audio_file, model_list):
442
+ if not audio_file or not model_list:
443
+ return "Please upload an audio file and select models to compare"
444
+
445
+ results, status = demo.compare_models(audio_file, model_list)
446
+ return results
447
+
448
+ # Wire up event handlers
449
+ process_btn.click(
450
+ process_audio,
451
+ inputs=[
452
+ audio_input, model_dropdown, quality_preset,
453
+ batch_size, segment_size, overlap, denoise, tta, post_process
454
+ ],
455
+ outputs=[
456
+ vocals_output, instrumental_output, drums_output,
457
+ bass_output, other_output, status_output
458
+ ]
459
+ )
460
+
461
+ refresh_models_btn.click(refresh_models, outputs=[model_dropdown, model_info])
462
+ refresh_history_btn.click(refresh_history, outputs=[history_output])
463
+ reset_history_btn.click(clear_history, outputs=[history_output])
464
+ compare_btn.click(compare_selected_models, inputs=[audio_input, compare_models_dropdown], outputs=[comparison_results])
465
+
466
+ # Batch processing
467
+ def process_batch(batch_files, model_name):
468
+ if not batch_files or not model_name:
469
+ return None, "Please upload batch files and select a model"
470
+
471
+ # Create zip file for downloads
472
+ import zipfile
473
+ import io
474
+
475
+ zip_buffer = io.BytesIO()
476
+ with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
477
+ for file_info in batch_files:
478
+ output_files = demo.process_audio(file_info, model_name)
479
+ if output_files[0] is not None:
480
+ # Add files to zip
481
+ for stem_name, (sample_rate, audio_data) in output_files[0].items():
482
+ filename = f"{os.path.splitext(file_info.name)[0]}_{stem_name}.wav"
483
+ zip_file.writestr(filename, audio_data)
484
+
485
+ zip_buffer.seek(0)
486
+ return zip_buffer, f"Batch processing completed for {len(batch_files)} files"
487
+
488
+ batch_btn.click(
489
+ process_batch,
490
+ inputs=[batch_files, model_dropdown],
491
+ outputs=[batch_output, status_output]
492
+ )
493
+
494
+ return interface
495
+
496
+
497
+ if __name__ == "__main__":
498
+ interface = create_interface()
499
+ interface.launch(
500
+ server_name="0.0.0.0",
501
+ server_port=7860,
502
+ share=False,
503
+ show_error=True,
504
+ show_tips=True
505
+ )