NeoPy commited on
Commit
c841e1f
·
verified ·
1 Parent(s): 18f8581

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +876 -195
app.py CHANGED
@@ -1,15 +1,13 @@
1
- """
2
- Enhanced Audio Separator Demo
3
- Advanced audio source separation with latest models and features
4
- """
5
  import os
6
  import json
7
  import torch
8
  import logging
9
  import traceback
10
- from typing import Dict, List, Optional
11
  import time
12
  from datetime import datetime
 
 
13
 
14
  import gradio as gr
15
  import numpy as np
@@ -20,13 +18,16 @@ from audio_separator.separator import Separator
20
  from audio_separator.separator import architectures
21
 
22
 
23
- class AudioSeparatorDemo:
24
  def __init__(self):
25
  self.separator = None
26
  self.available_models = {}
27
  self.current_model = None
28
  self.processing_history = []
 
 
29
  self.setup_logging()
 
30
 
31
  def setup_logging(self):
32
  """Setup logging for the application"""
@@ -40,63 +41,468 @@ class AudioSeparatorDemo:
40
  "cuda_available": torch.cuda.is_available(),
41
  "cuda_version": torch.version.cuda if torch.cuda.is_available() else "N/A",
42
  "mps_available": hasattr(torch.backends, "mps") and torch.backends.mps.is_available(),
43
- "device": "cuda" if torch.cuda.is_available() else ("mps" if hasattr(torch.backends, "mps") and torch.backends.mps.is_available() else "cpu")
 
 
44
  }
45
  return info
46
 
47
- def initialize_separator(self, model_name: str = None, **kwargs):
48
- """Initialize the separator with specified parameters"""
49
  try:
50
- # Clean up previous separator if exists
51
- if self.separator is not None:
52
- del self.separator
53
- torch.cuda.empty_cache()
54
-
55
- # Set default model if not specified
56
- if model_name is None:
57
- model_name = "model_bs_roformer_ep_317_sdr_12.9755.ckpt"
58
-
59
- # Initialize separator with updated parameters
60
- self.separator = Separator(
61
- output_format="WAV",
62
- use_autocast=True,
63
- use_soundfile=True,
64
- **kwargs
65
- )
66
 
67
- # Load the model
68
- self.separator.load_model(model_name)
69
- self.current_model = model_name
 
70
 
71
- return True, f"Successfully initialized with model: {model_name}"
 
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  except Exception as e:
74
- self.logger.error(f"Error initializing separator: {str(e)}")
75
- return False, f"Error initializing separator: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  def get_available_models(self):
78
- """Get list of available models with descriptions and scores"""
79
  try:
80
- if self.separator is None:
81
- self.separator = Separator(info_only=True)
82
-
83
- models = self.separator.list_supported_model_files()
84
- simplified_models = self.separator.get_simplified_model_list()
85
-
86
- return simplified_models
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  except Exception as e:
88
  self.logger.error(f"Error getting available models: {str(e)}")
89
  return {}
90
 
91
- def process_audio(self, audio_file, model_name, output_format="WAV", quality_preset="Standard",
92
- custom_params=None):
93
- """Process audio with the selected model"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  if audio_file is None:
95
  return None, "No audio file provided"
96
 
97
  if model_name is None:
98
  return None, "No model selected"
99
 
 
 
 
 
 
100
  if self.separator is None or self.current_model != model_name:
101
  success, message = self.initialize_separator(model_name)
102
  if not success:
@@ -105,11 +511,10 @@ class AudioSeparatorDemo:
105
  try:
106
  start_time = time.time()
107
 
108
- # Update separator parameters based on quality preset
109
  if custom_params is None:
110
  custom_params = {}
111
 
112
- # Apply quality preset
113
  if quality_preset == "Fast":
114
  custom_params.update({
115
  "mdx_params": {"batch_size": 4, "overlap": 0.1, "segment_size": 128},
@@ -155,28 +560,72 @@ class AudioSeparatorDemo:
155
  "timestamp": datetime.now().isoformat(),
156
  "model": model_name,
157
  "processing_time": processing_time,
158
- "output_files": list(output_audio.keys())
 
 
159
  }
160
  self.processing_history.append(history_entry)
161
 
162
- return output_audio, f"Processing completed in {processing_time:.2f} seconds with model: {model_name}"
163
 
164
  except Exception as e:
165
  error_msg = f"Error processing audio: {str(e)}"
166
  self.logger.error(f"{error_msg}\n{traceback.format_exc()}")
167
  return None, error_msg
168
 
169
- def get_processing_history(self):
170
- """Get processing history for display"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  if not self.processing_history:
172
  return "No processing history available"
173
 
174
- history_text = "Processing History:\n\n"
175
- for i, entry in enumerate(self.processing_history[-10:], 1): # Show last 10 entries
176
- history_text += f"{i}. {entry['timestamp']}\n"
 
 
177
  history_text += f" Model: {entry['model']}\n"
178
  history_text += f" Time: {entry['processing_time']:.2f}s\n"
179
- history_text += f" Stems: {', '.join(entry['output_files'])}\n\n"
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
  return history_text
182
 
@@ -184,72 +633,27 @@ class AudioSeparatorDemo:
184
  """Reset processing history"""
185
  self.processing_history = []
186
  return "Processing history cleared"
187
-
188
- def compare_models(self, audio_file, model_list):
189
- """Compare multiple models on the same audio"""
190
- if audio_file is None or not model_list:
191
- return None, "No audio file or models selected for comparison"
192
-
193
- results = {}
194
-
195
- for model_name in model_list:
196
- try:
197
- start_time = time.time()
198
- success, message = self.initialize_separator(model_name)
199
-
200
- if not success:
201
- results[model_name] = {"error": message}
202
- continue
203
-
204
- output_files = self.separator.separate(audio_file)
205
- processing_time = time.time() - start_time
206
-
207
- # Analyze the first output file for basic metrics
208
- if output_files and os.path.exists(output_files[0]):
209
- audio_data, sample_rate = sf.read(output_files[0])
210
-
211
- results[model_name] = {
212
- "processing_time": processing_time,
213
- "output_files": len(output_files),
214
- "sample_rate": sample_rate,
215
- "duration": len(audio_data) / sample_rate,
216
- "status": "Success"
217
- }
218
-
219
- # Clean up
220
- for file_path in output_files:
221
- if os.path.exists(file_path):
222
- os.remove(file_path)
223
- else:
224
- results[model_name] = {"status": "Failed", "error": "No output files generated"}
225
-
226
- except Exception as e:
227
- results[model_name] = {"status": "Error", "error": str(e)}
228
-
229
- return results, f"Model comparison completed for {len(model_list)} models"
230
 
231
 
232
- # Initialize the demo
233
- demo = AudioSeparatorDemo()
234
 
235
  def create_interface():
236
- """Create the Gradio interface"""
237
 
238
- with gr.Blocks(title="Audio Separator") as interface:
239
-
240
- # Header
241
  gr.Markdown(
242
  """
243
- # 🎵 Audio Separator
 
 
244
 
245
- Advanced audio source separation powered by the latest python-audio-separator library.
246
- Support for MDX-Net, VR Arch, Demucs, MDXC, and Roformer models with hardware acceleration.
247
  """
248
  )
249
 
250
  # System Information
251
  with gr.Accordion("🖥️ System Information", open=False):
252
- system_info = demo.get_system_info()
253
  info_text = f"""
254
  **PyTorch Version:** {system_info['pytorch_version']}
255
 
@@ -258,6 +662,8 @@ def create_interface():
258
  **CUDA Available:** {system_info['cuda_available']} (Version: {system_info['cuda_version']})
259
 
260
  **Apple Silicon (MPS):** {system_info['mps_available']}
 
 
261
  """
262
  gr.Markdown(info_text)
263
 
@@ -265,101 +671,253 @@ def create_interface():
265
  with gr.Column(scale=2):
266
  # Main audio input
267
  audio_input = gr.Audio(
268
- label="Upload Audio File",
269
- type="filepath"
 
270
  )
271
 
272
- # Model selection
273
- model_list = demo.get_available_models()
 
 
 
274
 
 
 
 
 
275
  model_dropdown = gr.Dropdown(
276
  choices=list(model_list.keys()) if model_list else [],
277
  value="model_bs_roformer_ep_317_sdr_12.9755.ckpt" if model_list else None,
278
- label="Select Model",
279
- info="Choose an AI model for audio separation"
 
280
  )
281
 
282
- # Quality preset
283
- quality_preset = gr.Radio(
284
- choices=["Fast", "Standard", "High Quality"],
285
- value="Standard",
286
- label="Quality Preset",
287
- info="Choose processing quality vs speed trade-off"
288
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
 
290
- # Advanced parameters (collapsible)
291
  with gr.Accordion("🔧 Advanced Parameters", open=False):
292
  with gr.Row():
293
  batch_size = gr.Slider(1, 8, value=1, step=1, label="Batch Size")
294
  segment_size = gr.Slider(64, 1024, value=256, step=64, label="Segment Size")
295
  overlap = gr.Slider(0.1, 0.5, value=0.25, step=0.05, label="Overlap")
296
 
297
- denoise = gr.Checkbox(label="Enable Denoise", value=False)
298
- tta = gr.Checkbox(label="Enable TTA (Test-Time Augmentation)", value=False)
299
- post_process = gr.Checkbox(label="Enable Post-Processing", value=False)
 
 
300
 
301
  # Process button
302
- process_btn = gr.Button("🎵 Separate Audio", variant="primary", scale=1)
303
 
304
  with gr.Column(scale=2):
305
- # Results
306
- status_output = gr.Textbox(label="Status", lines=3)
307
 
308
- # Output audio components
309
  with gr.Tabs():
310
- with gr.Tab("Vocals"):
311
  vocals_output = gr.Audio(label="Vocals")
312
 
313
- with gr.Tab("Instrumental"):
314
  instrumental_output = gr.Audio(label="Instrumental")
315
 
316
- with gr.Tab("Drums"):
317
  drums_output = gr.Audio(label="Drums")
318
 
319
- with gr.Tab("Bass"):
320
  bass_output = gr.Audio(label="Bass")
321
 
322
- with gr.Tab("Other Stems"):
323
- other_output = gr.Audio(label="Other")
 
 
 
324
 
325
  # Download section
326
- with gr.Accordion("📥 Batch Processing", open=False):
327
- gr.Markdown("Upload multiple files for batch processing:")
328
- batch_files = gr.File(file_count="multiple", file_types=[".wav", ".mp3", ".flac", ".m4a"], label="Batch Audio Files")
329
- batch_btn = gr.Button("Process Batch")
330
- batch_output = gr.File(label="Download Separated Files")
 
 
 
 
 
 
 
 
331
 
332
- # Model Information and Comparison
333
  with gr.Tabs():
334
- with gr.Tab("📊 Model Information"):
335
- gr.Markdown("## Available Models and Performance Metrics")
336
 
337
- # Model information display
338
- model_info = gr.JSON(value=demo.get_available_models(), label="Model Details")
339
  refresh_models_btn = gr.Button("🔄 Refresh Models")
340
 
 
341
  with gr.Row():
342
- compare_models_dropdown = gr.Dropdown(
343
- choices=list(model_list.keys())[:5] if model_list else [],
344
- multiselect=True,
345
- label="Select Models to Compare",
346
- info="Choose up to 5 models for comparison"
 
 
 
 
 
 
 
 
 
347
  )
348
- compare_btn = gr.Button("🔍 Compare Models")
349
 
350
- comparison_results = gr.JSON(label="Model Comparison Results")
 
 
 
 
 
 
 
 
351
 
352
- with gr.Tab("📈 Processing History"):
353
- history_output = gr.Textbox(label="History", lines=10)
 
354
  with gr.Row():
355
  refresh_history_btn = gr.Button("🔄 Refresh History")
356
  reset_history_btn = gr.Button("🗑️ Clear History", variant="stop")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357
 
358
  # Event handlers
359
- def process_audio(audio_file, model_name, quality_preset, batch_size, segment_size,
360
- overlap, denoise, tta, post_process):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
361
  if not audio_file or not model_name:
362
- return None, None, None, None, None, "Please upload an audio file and select a model"
363
 
364
  # Prepare custom parameters
365
  custom_params = {
@@ -372,23 +930,28 @@ def create_interface():
372
  "vr_params": {
373
  "batch_size": int(batch_size),
374
  "enable_tta": tta,
375
- "enable_post_process": post_process
 
376
  },
377
  "demucs_params": {
378
  "overlap": float(overlap)
379
  },
380
  "mdxc_params": {
381
  "batch_size": int(batch_size),
382
- "overlap": int(overlap * 10)
 
383
  }
384
  }
385
 
386
- output_audio, status = demo.process_audio(
387
- audio_file, model_name, quality_preset=quality_preset, custom_params=custom_params
 
 
 
388
  )
389
 
390
  if output_audio is None:
391
- return None, None, None, None, None, status
392
 
393
  # Extract different stems
394
  vocals = None
@@ -409,80 +972,177 @@ def create_interface():
409
  else:
410
  other = (sample_rate, audio_data)
411
 
412
- return vocals, instrumental, drums, bass, other, status
413
-
414
- def refresh_models():
415
- updated_models = demo.get_available_models()
416
- return (
417
- gr.Dropdown(choices=list(updated_models.keys())),
418
- gr.JSON(value=updated_models)
419
- )
420
-
421
- def refresh_history():
422
- return demo.get_processing_history()
423
-
424
- def clear_history():
425
- demo.reset_history()
426
- return "Processing history cleared"
427
 
428
- def compare_selected_models(audio_file, model_list):
429
  if not audio_file or not model_list:
430
  return {"error": "Please upload an audio file and select models to compare"}
431
 
432
- results, status = demo.compare_models(audio_file, model_list)
433
  return results
434
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
435
  # Wire up event handlers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
436
  process_btn.click(
437
- fn=process_audio,
438
  inputs=[
439
  audio_input, model_dropdown, quality_preset,
440
- batch_size, segment_size, overlap, denoise, tta, post_process
 
441
  ],
442
  outputs=[
443
  vocals_output, instrumental_output, drums_output,
444
- bass_output, other_output, status_output
445
  ]
446
  )
447
 
 
 
 
 
 
 
448
  refresh_models_btn.click(
449
- fn=refresh_models,
450
- outputs=[model_dropdown, model_info]
451
  )
452
 
453
  refresh_history_btn.click(
454
- fn=refresh_history,
455
  outputs=[history_output]
456
  )
457
 
458
  reset_history_btn.click(
459
- fn=clear_history,
460
  outputs=[history_output]
461
  )
462
 
463
- compare_btn.click(
464
- fn=compare_selected_models,
465
- inputs=[audio_input, compare_models_dropdown],
466
- outputs=[comparison_results]
 
 
 
 
 
 
467
  )
468
 
469
  # Batch processing
470
- def process_batch(batch_files, model_name):
471
  if not batch_files or not model_name:
472
  return None, "Please upload batch files and select a model"
473
 
474
- # Create zip file for downloads
475
  import zipfile
476
  import io
477
 
478
  zip_buffer = io.BytesIO()
479
  with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
480
  for file_info in batch_files:
481
- output_audio, _ = demo.process_audio(file_info, model_name)
482
  if output_audio is not None:
483
- # Add files to zip
484
  for stem_name, (sample_rate, audio_data) in output_audio.items():
485
- # Convert numpy array to audio file in memory
486
  import tempfile
487
  with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_file:
488
  sf.write(tmp_file.name, audio_data, sample_rate)
@@ -494,10 +1154,30 @@ def create_interface():
494
  return gr.File(value=zip_buffer, visible=True), f"Batch processing completed for {len(batch_files)} files"
495
 
496
  batch_btn.click(
497
- fn=process_batch,
498
  inputs=[batch_files, model_dropdown],
499
  outputs=[batch_output, status_output]
500
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
501
 
502
  return interface
503
 
@@ -505,7 +1185,8 @@ def create_interface():
505
  if __name__ == "__main__":
506
  interface = create_interface()
507
  interface.launch(
508
- theme='terastudio/yellow',
509
  server_port=7860,
510
- share=True
 
 
511
  )
 
 
 
 
 
1
  import os
2
  import json
3
  import torch
4
  import logging
5
  import traceback
6
+ from typing import Dict, List, Optional, Tuple
7
  import time
8
  from datetime import datetime
9
+ import threading
10
+ from collections import defaultdict
11
 
12
  import gradio as gr
13
  import numpy as np
 
18
  from audio_separator.separator import architectures
19
 
20
 
21
+ class AudioSeparatorD:
22
  def __init__(self):
23
  self.separator = None
24
  self.available_models = {}
25
  self.current_model = None
26
  self.processing_history = []
27
+ self.model_performance_cache = {}
28
+ self.model_recommendations = {}
29
  self.setup_logging()
30
+ self.model_lock = threading.Lock()
31
 
32
  def setup_logging(self):
33
  """Setup logging for the application"""
 
41
  "cuda_available": torch.cuda.is_available(),
42
  "cuda_version": torch.version.cuda if torch.cuda.is_available() else "N/A",
43
  "mps_available": hasattr(torch.backends, "mps") and torch.backends.mps.is_available(),
44
+ "device": "cuda" if torch.cuda.is_available() else ("mps" if hasattr(torch.backends, "mps") and torch.backends.mps.is_available() else "cpu"),
45
+ "memory_total": torch.cuda.get_device_properties(0).total_memory if torch.cuda.is_available() else 0,
46
+ "memory_allocated": torch.cuda.memory_allocated() if torch.cuda.is_available() else 0
47
  }
48
  return info
49
 
50
+ def analyze_audio_characteristics(self, audio_file: str) -> Dict:
51
+ """Analyze audio file characteristics for smart model selection"""
52
  try:
53
+ # Load audio for analysis
54
+ y, sr = librosa.load(audio_file, sr=None)
55
+ duration = len(y) / sr
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
+ # Analyze spectral characteristics
58
+ spectral_centroids = librosa.feature.spectral_centroid(y=y, sr=sr)[0]
59
+ spectral_rolloff = librosa.feature.spectral_rolloff(y=y, sr=sr)[0]
60
+ zero_crossing_rate = librosa.feature.zero_crossing_rate(y)[0]
61
 
62
+ # Analyze tempo and rhythm
63
+ tempo, _ = librosa.beat.beat_track(y=y, sr=sr)
64
 
65
+ # Analyze dynamic range
66
+ rms = librosa.feature.rms(y=y)[0]
67
+ dynamic_range = np.std(rms)
68
+
69
+ # Determine audio characteristics
70
+ characteristics = {
71
+ "duration": duration,
72
+ "sample_rate": sr,
73
+ "tempo": float(tempo),
74
+ "avg_spectral_centroid": float(np.mean(spectral_centroids)),
75
+ "avg_spectral_rolloff": float(np.mean(spectral_rolloff)),
76
+ "avg_zero_crossing_rate": float(np.mean(zero_crossing_rate)),
77
+ "dynamic_range": float(dynamic_range),
78
+ "audio_type": self._classify_audio_type(
79
+ np.mean(spectral_centroids),
80
+ float(tempo),
81
+ dynamic_range
82
+ )
83
+ }
84
+
85
+ return characteristics
86
  except Exception as e:
87
+ self.logger.error(f"Error analyzing audio: {str(e)}")
88
+ return {"audio_type": "unknown", "error": str(e)}
89
+
90
+ def _classify_audio_type(self, spectral_centroid: float, tempo: float, dynamic_range: float) -> str:
91
+ """Classify audio type based on spectral and temporal features"""
92
+ if spectral_centroid < 1000:
93
+ return "bass_heavy"
94
+ elif spectral_centroid > 4000:
95
+ return "bright_crisp"
96
+ elif tempo > 120:
97
+ return "upbeat"
98
+ elif dynamic_range > 0.1:
99
+ return "dynamic"
100
+ else:
101
+ return "balanced"
102
 
103
  def get_available_models(self):
104
+ """Get list of available models with enhanced information"""
105
  try:
106
+ with self.model_lock:
107
+ if self.separator is None:
108
+ self.separator = Separator(info_only=True)
109
+
110
+ models = self.separator.list_supported_model_files()
111
+ simplified_models = self.separator.get_simplified_model_list()
112
+
113
+ # Enhance model information
114
+ enhanced_models = {}
115
+ for model_name, model_info in simplified_models.items():
116
+ # Parse model filename for better names
117
+ friendly_name = self._generate_friendly_name(model_name, model_info)
118
+
119
+ # Determine best use cases
120
+ use_cases = self._determine_use_cases(model_name, model_info)
121
+
122
+ # Estimate performance characteristics
123
+ perf_chars = self._estimate_performance(model_name)
124
+
125
+ enhanced_models[model_name] = {
126
+ **model_info,
127
+ "friendly_name": friendly_name,
128
+ "use_cases": use_cases,
129
+ "performance_characteristics": perf_chars,
130
+ "architecture_type": self._get_architecture_type(model_name),
131
+ "recommended_for": self._get_recommendations(model_name, model_info)
132
+ }
133
+
134
+ return enhanced_models
135
  except Exception as e:
136
  self.logger.error(f"Error getting available models: {str(e)}")
137
  return {}
138
 
139
+ def _generate_friendly_name(self, model_name: str, model_info: Dict) -> str:
140
+ """Generate user-friendly model names"""
141
+ # Remove common prefixes and suffixes
142
+ clean_name = model_name.replace('model_', '').replace('.ckpt', '').replace('.yaml', '')
143
+
144
+ # Handle specific known models
145
+ if 'roformer' in model_name.lower():
146
+ return f"🎵 Roformer {clean_name.split('_')[-1] if '_' in clean_name else ''}".strip()
147
+ elif 'demucs' in model_name.lower():
148
+ return f"🥁 Demucs {clean_name.replace('htdemucs', '').replace('_', ' ')}".strip()
149
+ elif 'mdx' in model_name.lower():
150
+ return f"🎤 MDX-Net {clean_name[-3:] if clean_name[-3:].isdigit() else ''}".strip()
151
+ else:
152
+ # Capitalize words
153
+ words = clean_name.replace('_', ' ').split()
154
+ return ' '.join(word.capitalize() for word in words)
155
+
156
+ def _determine_use_cases(self, model_name: str, model_info: Dict) -> List[str]:
157
+ """Determine what this model is best for"""
158
+ use_cases = []
159
+
160
+ # Check output stems
161
+ if 'vocals' in str(model_info).lower():
162
+ use_cases.append("🎤 Vocal Isolation")
163
+ if 'drums' in str(model_info).lower():
164
+ use_cases.append("🥁 Drum Separation")
165
+ if 'bass' in str(model_info).lower():
166
+ use_cases.append("🎸 Bass Extraction")
167
+ if 'instrumental' in str(model_info).lower():
168
+ use_cases.append("🎹 Instrumental")
169
+ if 'guitar' in str(model_info).lower() or 'piano' in str(model_info).lower():
170
+ use_cases.append("🎸 Specific Instruments")
171
+
172
+ # Architecture-based use cases
173
+ if 'roformer' in model_name.lower():
174
+ use_cases.append("⚡ High Quality")
175
+ elif 'demucs' in model_name.lower():
176
+ use_cases.append("🎛️ Multi-stem")
177
+ elif 'mdx' in model_name.lower():
178
+ use_cases.append("🎵 Fast Processing")
179
+
180
+ return use_cases[:3] # Limit to top 3
181
+
182
+ def _estimate_performance(self, model_name: str) -> Dict:
183
+ """Estimate performance characteristics"""
184
+ perf = {
185
+ "speed_rating": "medium",
186
+ "quality_rating": "medium",
187
+ "memory_usage": "medium"
188
+ }
189
+
190
+ if 'roformer' in model_name.lower():
191
+ perf.update({"speed_rating": "slow", "quality_rating": "high", "memory_usage": "high"})
192
+ elif 'demucs' in model_name.lower():
193
+ perf.update({"speed_rating": "slow", "quality_rating": "high", "memory_usage": "high"})
194
+ elif 'mdx' in model_name.lower():
195
+ perf.update({"speed_rating": "fast", "quality_rating": "medium", "memory_usage": "low"})
196
+
197
+ return perf
198
+
199
+ def _get_architecture_type(self, model_name: str) -> str:
200
+ """Extract architecture type from model name"""
201
+ if 'roformer' in model_name.lower():
202
+ return "🎵 Roformer (MDXC)"
203
+ elif 'demucs' in model_name.lower():
204
+ return "🥁 Demucs"
205
+ elif 'mdx' in model_name.lower():
206
+ return "🎤 MDX-Net"
207
+ elif 'vr' in model_name.lower():
208
+ return "🎛️ VR Arch"
209
+ else:
210
+ return "🔧 Unknown"
211
+
212
+ def _get_recommendations(self, model_name: str, model_info: Dict) -> Dict:
213
+ """Get specific recommendations for model usage"""
214
+ recommendations = {
215
+ "best_for": "General use",
216
+ "avoid_for": "None",
217
+ "tips": []
218
+ }
219
+
220
+ if 'roformer' in model_name.lower():
221
+ recommendations.update({
222
+ "best_for": "High-quality vocal isolation",
223
+ "avoid_for": "Real-time processing",
224
+ "tips": ["Best results with longer audio files", "Higher memory usage", "Excellent for final mastering"]
225
+ })
226
+ elif 'demucs' in model_name.lower():
227
+ recommendations.update({
228
+ "best_for": "Multi-stem separation (drums, bass, vocals)",
229
+ "avoid_for": "Simple vocal/instrumental separation",
230
+ "tips": ["Creates multiple output files", "Good for music production", "Slower but comprehensive"]
231
+ })
232
+ elif 'mdx' in model_name.lower():
233
+ recommendations.update({
234
+ "best_for": "Fast vocal isolation",
235
+ "avoid_for": "Multi-instrument separation",
236
+ "tips": ["Quick processing", "Good for demos", "Lower memory requirements"]
237
+ })
238
+
239
+ return recommendations
240
+
241
+ def auto_select_model(self, audio_characteristics: Dict, desired_stems: List[str],
242
+ priority: str = "quality") -> Optional[str]:
243
+ """Automatically select the best model based on audio characteristics and requirements"""
244
+ try:
245
+ models = self.get_available_models()
246
+ if not models:
247
+ return None
248
+
249
+ # Score models based on criteria
250
+ model_scores = {}
251
+
252
+ for model_name, model_info in models.items():
253
+ score = 0
254
+
255
+ # Base score from performance characteristics
256
+ perf_chars = model_info.get('performance_characteristics', {})
257
+
258
+ if priority == "quality":
259
+ if perf_chars.get('quality_rating') == 'high':
260
+ score += 10
261
+ elif perf_chars.get('quality_rating') == 'medium':
262
+ score += 5
263
+ elif priority == "speed":
264
+ if perf_chars.get('speed_rating') == 'fast':
265
+ score += 10
266
+ elif perf_chars.get('speed_rating') == 'medium':
267
+ score += 5
268
+
269
+ # Audio type matching
270
+ audio_type = audio_characteristics.get('audio_type', 'balanced')
271
+ use_cases = model_info.get('use_cases', [])
272
+
273
+ if audio_type == 'bass_heavy' and '🎸 Bass Extraction' in use_cases:
274
+ score += 8
275
+ elif audio_type == 'bright_crisp' and '🎤 Vocal Isolation' in use_cases:
276
+ score += 8
277
+ elif audio_type == 'upbeat' and '🎹 Instrumental' in use_cases:
278
+ score += 6
279
+
280
+ # Stem compatibility
281
+ model_stems = str(model_info).lower()
282
+ for stem in desired_stems:
283
+ if stem.lower() in model_stems:
284
+ score += 5
285
+
286
+ # Architecture preference based on priority
287
+ arch_type = model_info.get('architecture_type', '')
288
+ if priority == "quality" and "Roformer" in arch_type:
289
+ score += 15
290
+ elif priority == "speed" and "MDX-Net" in arch_type:
291
+ score += 15
292
+
293
+ model_scores[model_name] = score
294
+
295
+ # Return highest scoring model
296
+ if model_scores:
297
+ best_model = max(model_scores.items(), key=lambda x: x[1])
298
+ return best_model[0]
299
+
300
+ return None
301
+
302
+ except Exception as e:
303
+ self.logger.error(f"Error in auto-select: {str(e)}")
304
+ return None
305
+
306
+ def compare_models(self, audio_file: str, model_list: List[str]) -> Dict:
307
+ """Enhanced model comparison with detailed metrics"""
308
+ if not audio_file or not model_list:
309
+ return {"error": "Please provide audio file and select models to compare"}
310
+
311
+ comparison_results = {
312
+ "audio_analysis": self.analyze_audio_characteristics(audio_file),
313
+ "model_results": {},
314
+ "summary": {},
315
+ "recommendations": []
316
+ }
317
+
318
+ for model_name in model_list:
319
+ try:
320
+ start_time = time.time()
321
+
322
+ # Initialize separator for this model
323
+ success, message = self.initialize_separator(model_name)
324
+
325
+ if not success:
326
+ comparison_results["model_results"][model_name] = {
327
+ "status": "Failed",
328
+ "error": message,
329
+ "processing_time": 0
330
+ }
331
+ continue
332
+
333
+ # Process audio
334
+ output_files = self.separator.separate(audio_file)
335
+ processing_time = time.time() - start_time
336
+
337
+ # Analyze results
338
+ if output_files and os.path.exists(output_files[0]):
339
+ audio_data, sample_rate = sf.read(output_files[0])
340
+
341
+ # Calculate quality metrics
342
+ quality_metrics = self._calculate_quality_metrics(audio_data, sample_rate)
343
+
344
+ comparison_results["model_results"][model_name] = {
345
+ "status": "Success",
346
+ "processing_time": processing_time,
347
+ "output_files": len(output_files),
348
+ "sample_rate": sample_rate,
349
+ "duration": len(audio_data) / sample_rate,
350
+ "quality_metrics": quality_metrics,
351
+ "output_stems": [os.path.basename(f) for f in output_files],
352
+ "model_info": self.get_available_models().get(model_name, {})
353
+ }
354
+
355
+ # Clean up
356
+ for file_path in output_files:
357
+ if os.path.exists(file_path):
358
+ os.remove(file_path)
359
+ else:
360
+ comparison_results["model_results"][model_name] = {
361
+ "status": "Failed",
362
+ "error": "No output files generated",
363
+ "processing_time": processing_time
364
+ }
365
+
366
+ except Exception as e:
367
+ comparison_results["model_results"][model_name] = {
368
+ "status": "Error",
369
+ "error": str(e),
370
+ "processing_time": 0
371
+ }
372
+
373
+ # Generate summary and recommendations
374
+ comparison_results["summary"] = self._generate_comparison_summary(comparison_results["model_results"])
375
+ comparison_results["recommendations"] = self._generate_recommendations(
376
+ comparison_results["audio_analysis"],
377
+ comparison_results["model_results"]
378
+ )
379
+
380
+ return comparison_results
381
+
382
+ def _calculate_quality_metrics(self, audio_data: np.ndarray, sample_rate: int) -> Dict:
383
+ """Calculate audio quality metrics"""
384
+ try:
385
+ # RMS level
386
+ rms = np.sqrt(np.mean(audio_data**2))
387
+
388
+ # Dynamic range
389
+ peak = np.max(np.abs(audio_data))
390
+ dynamic_range = 20 * np.log10(peak / (rms + 1e-10))
391
+
392
+ # Spectral characteristics
393
+ spectral_centroid = np.mean(librosa.feature.spectral_centroid(y=audio_data, sr=sample_rate))
394
+
395
+ return {
396
+ "rms_level": float(rms),
397
+ "peak_level": float(peak),
398
+ "dynamic_range": float(dynamic_range),
399
+ "spectral_centroid": float(spectral_centroid),
400
+ "length_samples": len(audio_data),
401
+ "length_seconds": len(audio_data) / sample_rate
402
+ }
403
+ except Exception as e:
404
+ return {"error": str(e)}
405
+
406
+ def _generate_comparison_summary(self, model_results: Dict) -> Dict:
407
+ """Generate summary statistics from model comparison"""
408
+ successful_results = {k: v for k, v in model_results.items() if v.get("status") == "Success"}
409
+
410
+ if not successful_results:
411
+ return {"message": "No successful model runs to compare"}
412
+
413
+ summary = {
414
+ "total_models": len(model_results),
415
+ "successful_models": len(successful_results),
416
+ "fastest_model": None,
417
+ "slowest_model": None,
418
+ "best_quality": None,
419
+ "average_processing_time": 0
420
+ }
421
+
422
+ # Find fastest and slowest
423
+ if successful_results:
424
+ times = {k: v.get("processing_time", 0) for k, v in successful_results.items()}
425
+ summary["fastest_model"] = min(times.items(), key=lambda x: x[1])[0]
426
+ summary["slowest_model"] = max(times.items(), key=lambda x: x[1])[0]
427
+ summary["average_processing_time"] = np.mean(list(times.values()))
428
+
429
+ return summary
430
+
431
+ def _generate_recommendations(self, audio_analysis: Dict, model_results: Dict) -> List[str]:
432
+ """Generate intelligent recommendations based on comparison"""
433
+ recommendations = []
434
+
435
+ # Find best performing model
436
+ successful_models = {k: v for k, v in model_results.items() if v.get("status") == "Success"}
437
+
438
+ if successful_models:
439
+ # Find fastest successful model
440
+ fastest_model = min(successful_models.items(),
441
+ key=lambda x: x[1].get("processing_time", float('inf')))
442
+ recommendations.append(f"⚡ Fastest: {fastest_model[0]} ({fastest_model[1]['processing_time']:.2f}s)")
443
+
444
+ # Find model with most outputs
445
+ most_outputs = max(successful_models.items(),
446
+ key=lambda x: x[1].get("output_files", 0))
447
+ recommendations.append(f"🎛️ Most stems: {most_outputs[0]} ({most_outputs[1]['output_files']} files)")
448
+
449
+ # Audio-based recommendations
450
+ audio_type = audio_analysis.get('audio_type', 'unknown')
451
+ if audio_type == 'bass_heavy':
452
+ recommendations.append("🎸 Consider models with bass separation capabilities")
453
+ elif audio_type == 'bright_crisp':
454
+ recommendations.append("🎤 Models optimized for vocal clarity work best")
455
+ elif audio_type == 'upbeat':
456
+ recommendations.append("🎹 Fast processing models recommended for energetic tracks")
457
+
458
+ return recommendations
459
+
460
+ def initialize_separator(self, model_name: str = None, **kwargs):
461
+ """Initialize the separator with specified parameters"""
462
+ try:
463
+ with self.model_lock:
464
+ # Clean up previous separator if exists
465
+ if self.separator is not None:
466
+ del self.separator
467
+ torch.cuda.empty_cache()
468
+
469
+ # Set default model if not specified
470
+ if model_name is None:
471
+ model_name = "model_bs_roformer_ep_317_sdr_12.9755.ckpt"
472
+
473
+ # Initialize separator with updated parameters
474
+ self.separator = Separator(
475
+ output_format="WAV",
476
+ use_autocast=True,
477
+ use_soundfile=True,
478
+ **kwargs
479
+ )
480
+
481
+ # Load the model
482
+ self.separator.load_model(model_name)
483
+ self.current_model = model_name
484
+
485
+ return True, f"Successfully initialized with model: {model_name}"
486
+
487
+ except Exception as e:
488
+ self.logger.error(f"Error initializing separator: {str(e)}")
489
+ return False, f"Error initializing separator: {str(e)}"
490
+
491
+ def infer(self, audio_file: str, model_name: str, output_format: str = "WAV",
492
+ quality_preset: str = "Standard", custom_params: Dict = None,
493
+ enable_auto_optimize: bool = True):
494
+ """Enhanced audio processing with auto-optimization"""
495
  if audio_file is None:
496
  return None, "No audio file provided"
497
 
498
  if model_name is None:
499
  return None, "No model selected"
500
 
501
+ # Auto-optimize parameters if enabled
502
+ if enable_auto_optimize:
503
+ audio_analysis = self.analyze_audio_characteristics(audio_file)
504
+ custom_params = self._optimize_parameters_for_audio(audio_analysis, custom_params)
505
+
506
  if self.separator is None or self.current_model != model_name:
507
  success, message = self.initialize_separator(model_name)
508
  if not success:
 
511
  try:
512
  start_time = time.time()
513
 
514
+ # Apply quality preset
515
  if custom_params is None:
516
  custom_params = {}
517
 
 
518
  if quality_preset == "Fast":
519
  custom_params.update({
520
  "mdx_params": {"batch_size": 4, "overlap": 0.1, "segment_size": 128},
 
560
  "timestamp": datetime.now().isoformat(),
561
  "model": model_name,
562
  "processing_time": processing_time,
563
+ "output_files": list(output_audio.keys()),
564
+ "audio_analysis": self.analyze_audio_characteristics(audio_file) if enable_auto_optimize else {},
565
+ "quality_preset": quality_preset
566
  }
567
  self.processing_history.append(history_entry)
568
 
569
+ return output_audio, f"Processing completed in {processing_time:.2f}s with model: {model_name}"
570
 
571
  except Exception as e:
572
  error_msg = f"Error processing audio: {str(e)}"
573
  self.logger.error(f"{error_msg}\n{traceback.format_exc()}")
574
  return None, error_msg
575
 
576
+ def _optimize_parameters_for_audio(self, audio_analysis: Dict, custom_params: Dict) -> Dict:
577
+ """Automatically optimize parameters based on audio characteristics"""
578
+ if custom_params is None:
579
+ custom_params = {}
580
+
581
+ # Adjust parameters based on audio characteristics
582
+ duration = audio_analysis.get('duration', 0)
583
+ audio_type = audio_analysis.get('audio_type', 'balanced')
584
+
585
+ # For longer audio, increase batch size for efficiency
586
+ if duration > 300: # 5 minutes
587
+ custom_params.setdefault('mdx_params', {})['batch_size'] = 2
588
+ custom_params.setdefault('vr_params', {})['batch_size'] = 2
589
+
590
+ # For bass-heavy audio, increase aggression
591
+ if audio_type == 'bass_heavy':
592
+ custom_params.setdefault('vr_params', {})['aggression'] = 7
593
+
594
+ # For bright/crisp audio, enable post-processing
595
+ if audio_type == 'bright_crisp':
596
+ custom_params.setdefault('vr_params', {})['enable_post_process'] = True
597
+
598
+ # For dynamic audio, enable TTA for better quality
599
+ if audio_analysis.get('dynamic_range', 0) > 0.1:
600
+ custom_params.setdefault('vr_params', {})['enable_tta'] = True
601
+
602
+ return custom_params
603
+
604
+ def (self):
605
+ """Get enhanced processing history with analytics"""
606
  if not self.processing_history:
607
  return "No processing history available"
608
 
609
+ history_text = "🎵 Enhanced Processing History\n\n"
610
+
611
+ # Show recent entries with details
612
+ for i, entry in enumerate(self.processing_history[-10:], 1):
613
+ history_text += f"**{i}. {entry['timestamp'][:19]}**\n"
614
  history_text += f" Model: {entry['model']}\n"
615
  history_text += f" Time: {entry['processing_time']:.2f}s\n"
616
+ history_text += f" Stems: {', '.join(entry['output_files'])}\n"
617
+
618
+ # Add audio analysis if available
619
+ if 'audio_analysis' in entry and entry['audio_analysis']:
620
+ audio_type = entry['audio_analysis'].get('audio_type', 'unknown')
621
+ duration = entry['audio_analysis'].get('duration', 0)
622
+ history_text += f" Audio: {audio_type} ({duration:.1f}s)\n"
623
+
624
+ # Add quality preset info
625
+ if 'quality_preset' in entry:
626
+ history_text += f" Preset: {entry['quality_preset']}\n"
627
+
628
+ history_text += "\n"
629
 
630
  return history_text
631
 
 
633
  """Reset processing history"""
634
  self.processing_history = []
635
  return "Processing history cleared"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
636
 
637
 
638
+ # Initialize the enhanced demo
639
+ demo1 = AudioSeparatorD()
640
 
641
  def create_interface():
 
642
 
643
+ with gr.Blocks(title="🎵 Enhanced Audio Separator") as interface:
 
 
644
  gr.Markdown(
645
  """
646
+ # 🎵 Audio Separator Web UI
647
+
648
+ **Smart AI-Powered Audio Source Separation with Auto-Selection & Advanced Model Comparison**
649
 
650
+ **Features**: Auto model selection, performance analytics, smart parameter optimization, and comprehensive model comparison
 
651
  """
652
  )
653
 
654
  # System Information
655
  with gr.Accordion("🖥️ System Information", open=False):
656
+ system_info = demo1.get_system_info()
657
  info_text = f"""
658
  **PyTorch Version:** {system_info['pytorch_version']}
659
 
 
662
  **CUDA Available:** {system_info['cuda_available']} (Version: {system_info['cuda_version']})
663
 
664
  **Apple Silicon (MPS):** {system_info['mps_available']}
665
+
666
+ **GPU Memory:** {system_info['memory_allocated'] // 1024**2}MB / {system_info['memory_total'] // 1024**2}MB
667
  """
668
  gr.Markdown(info_text)
669
 
 
671
  with gr.Column(scale=2):
672
  # Main audio input
673
  audio_input = gr.Audio(
674
+ label="🎵 Upload Audio File",
675
+ type="filepath",
676
+ info="Upload audio for intelligent analysis and separation"
677
  )
678
 
679
+ # Auto-analyze button
680
+ analyze_btn = gr.Button("🔍 Analyze Audio", variant="secondary")
681
+
682
+ # Audio analysis output
683
+ audio_analysis_output = gr.JSON(label="Audio Analysis Results", visible=False)
684
 
685
+ # Enhanced model selection
686
+ model_list = demo1.get_available_models()
687
+
688
+ # Model dropdown with enhanced display
689
  model_dropdown = gr.Dropdown(
690
  choices=list(model_list.keys()) if model_list else [],
691
  value="model_bs_roformer_ep_317_sdr_12.9755.ckpt" if model_list else None,
692
+ label="🤖 AI Model Selection",
693
+ info="Choose an AI model or use auto-selection",
694
+ elem_id="model_dropdown"
695
  )
696
 
697
+ # Auto-selection controls
698
+ with gr.Row():
699
+ auto_select_btn = gr.Button("🎯 Auto-Select Best Model", variant="primary")
700
+ priority_radio = gr.Radio(
701
+ choices=["Quality", "Speed", "Balanced"],
702
+ value="Quality",
703
+ label="Selection Priority",
704
+ info="What matters most for model selection?"
705
+ )
706
+
707
+ # Model info display
708
+ model_info_display = gr.JSON(label="📊 Selected Model Information")
709
+
710
+ # Quality preset and optimization
711
+ with gr.Row():
712
+ quality_preset = gr.Radio(
713
+ choices=["Fast", "Standard", "High Quality", "Custom"],
714
+ value="Standard",
715
+ label="⚡ Processing Quality",
716
+ info="Choose processing quality vs speed trade-off"
717
+ )
718
+
719
+ auto_optimize = gr.Checkbox(
720
+ label="🧠 Auto-Optimize Parameters",
721
+ value=True,
722
+ info="Automatically optimize parameters based on audio analysis"
723
+ )
724
 
725
+ # Enhanced advanced parameters
726
  with gr.Accordion("🔧 Advanced Parameters", open=False):
727
  with gr.Row():
728
  batch_size = gr.Slider(1, 8, value=1, step=1, label="Batch Size")
729
  segment_size = gr.Slider(64, 1024, value=256, step=64, label="Segment Size")
730
  overlap = gr.Slider(0.1, 0.5, value=0.25, step=0.05, label="Overlap")
731
 
732
+ with gr.Row():
733
+ denoise = gr.Checkbox(label="Enable Denoise", value=False)
734
+ tta = gr.Checkbox(label="Enable TTA", value=False)
735
+ post_process = gr.Checkbox(label="Enable Post-Processing", value=False)
736
+ pitch_shift = gr.Slider(-12, 12, value=0, step=1, label="Pitch Shift (semitones)")
737
 
738
  # Process button
739
+ process_btn = gr.Button("🎵 Smart Separate Audio", variant="primary", size="lg")
740
 
741
  with gr.Column(scale=2):
742
+ # Status and results
743
+ status_output = gr.Textbox(label="📋 Status", lines=4)
744
 
745
+ # Enhanced output tabs
746
  with gr.Tabs():
747
+ with gr.Tab("🎤 Vocals"):
748
  vocals_output = gr.Audio(label="Vocals")
749
 
750
+ with gr.Tab("🎹 Instrumental"):
751
  instrumental_output = gr.Audio(label="Instrumental")
752
 
753
+ with gr.Tab("🥁 Drums"):
754
  drums_output = gr.Audio(label="Drums")
755
 
756
+ with gr.Tab("🎸 Bass"):
757
  bass_output = gr.Audio(label="Bass")
758
 
759
+ with gr.Tab("🎛️ Other Stems"):
760
+ other_output = gr.Audio(label="Other Stems")
761
+
762
+ # Performance metrics
763
+ performance_metrics = gr.JSON(label="📈 Performance Metrics", visible=False)
764
 
765
  # Download section
766
+ with gr.Accordion("📥 Batch & Download", open=False):
767
+ gr.Markdown("### 🔄 Batch Processing")
768
+ batch_files = gr.File(
769
+ file_count="multiple",
770
+ file_types=[".wav", ".mp3", ".flac", ".m4a"],
771
+ label="Batch Audio Files"
772
+ )
773
+
774
+ with gr.Row():
775
+ batch_btn = gr.Button("⚡ Process Batch")
776
+ auto_batch_btn = gr.Button("🎯 Auto-Select & Batch")
777
+
778
+ batch_output = gr.File(label="📦 Download Batch Results")
779
 
780
+ # Enhanced Model Management Tabs
781
  with gr.Tabs():
782
+ with gr.Tab("🔍 Model Explorer"):
783
+ gr.Markdown("## 🧠 Intelligent Model Comparison & Selection")
784
 
785
+ # Enhanced model information
786
+ model_info = gr.JSON(value=demo1.get_available_models(), label="📊 Model Database")
787
  refresh_models_btn = gr.Button("🔄 Refresh Models")
788
 
789
+ # Advanced model filtering
790
  with gr.Row():
791
+ filter_architecture = gr.Dropdown(
792
+ choices=["All", "MDX-Net", "Demucs", "Roformer", "VR Arch"],
793
+ value="All",
794
+ label="Filter by Architecture"
795
+ )
796
+ filter_use_case = gr.Dropdown(
797
+ choices=["All", "Vocals", "Instrumental", "Drums", "Bass", "Multi-stem"],
798
+ value="All",
799
+ label="Filter by Use Case"
800
+ )
801
+ filter_priority = gr.Dropdown(
802
+ choices=["All", "Quality", "Speed", "Memory Efficient"],
803
+ value="All",
804
+ label="Filter by Priority"
805
  )
 
806
 
807
+ filtered_models = gr.Dropdown(
808
+ choices=list(model_list.keys())[:10] if model_list else [],
809
+ multiselect=True,
810
+ label="🎯 Models for Comparison",
811
+ info="Select up to 5 models for detailed comparison"
812
+ )
813
+
814
+ compare_btn = gr.Button("🔬 Advanced Model Comparison")
815
+ comparison_results = gr.JSON(label="📊 Comparison Results")
816
 
817
+ with gr.Tab("📈 Analytics & History"):
818
+ history_output = gr.Textbox(label="📜 Processing History", lines=15)
819
+
820
  with gr.Row():
821
  refresh_history_btn = gr.Button("🔄 Refresh History")
822
  reset_history_btn = gr.Button("🗑️ Clear History", variant="stop")
823
+ export_history_btn = gr.Button("📊 Export Analytics")
824
+
825
+ analytics_output = gr.JSON(label="📊 Analytics Dashboard")
826
+
827
+ with gr.Tab("🎯 Smart Recommendations"):
828
+ gr.Markdown("## 🤖 AI-Powered Model Recommendations")
829
+
830
+ recommendation_status = gr.Textbox(label="Recommendation Status", lines=3)
831
+
832
+ with gr.Row():
833
+ get_recommendations_btn = gr.Button("🎯 Get Smart Recommendations")
834
+ apply_recommendation_btn = gr.Button("✨ Apply Best Recommendation")
835
+
836
+ recommendations_display = gr.JSON(label="🎯 Personalized Recommendations")
837
 
838
  # Event handlers
839
+ def analyze_audio(audio_file):
840
+ if not audio_file:
841
+ return None, "No audio file provided"
842
+
843
+ analysis = demo1.analyze_audio_characteristics(audio_file)
844
+
845
+ # Format analysis for display
846
+ if "error" not in analysis:
847
+ formatted_analysis = f"""
848
+ **Audio Type:** {analysis.get('audio_type', 'Unknown').title().replace('_', ' ')}
849
+ **Duration:** {analysis.get('duration', 0):.1f} seconds
850
+ **Sample Rate:** {analysis.get('sample_rate', 0)} Hz
851
+ **Tempo:** {analysis.get('tempo', 0):.1f} BPM
852
+ **Spectral Characteristics:** {analysis.get('avg_spectral_centroid', 0):.0f} Hz (centroid)
853
+ **Dynamic Range:** {analysis.get('dynamic_range', 0):.3f}
854
+ """
855
+
856
+ return analysis, formatted_analysis
857
+ else:
858
+ return analysis, f"Analysis failed: {analysis['error']}"
859
+
860
+ def auto_select_model(audio_file, priority):
861
+ if not audio_file:
862
+ return None, "No audio file provided", None
863
+
864
+ # Analyze audio first
865
+ audio_analysis = demo1.analyze_audio_characteristics(audio_file)
866
+
867
+ # Determine desired stems based on audio analysis
868
+ desired_stems = ["vocals"] # Default
869
+ if audio_analysis.get('audio_type') == 'bass_heavy':
870
+ desired_stems.append("bass")
871
+ elif audio_analysis.get('tempo', 0) > 120:
872
+ desired_stems.append("drums")
873
+
874
+ # Auto-select model
875
+ selected_model = demo1.auto_select_model(
876
+ audio_analysis, desired_stems, priority.lower()
877
+ )
878
+
879
+ if selected_model:
880
+ models = demo1.get_available_models()
881
+ model_info = models.get(selected_model, {})
882
+
883
+ return (
884
+ selected_model,
885
+ f"🎯 Auto-selected: {model_info.get('friendly_name', selected_model)}\n"
886
+ f"Architecture: {model_info.get('architecture_type', 'Unknown')}\n"
887
+ f"Best for: {', '.join(model_info.get('use_cases', [])[:2])}",
888
+ model_info
889
+ )
890
+ else:
891
+ return None, "Auto-selection failed - no suitable model found", None
892
+
893
+ def update_model_info(model_name):
894
+ if not model_name:
895
+ return None
896
+
897
+ models = demo1.get_available_models()
898
+ model_info = models.get(model_name, {})
899
+
900
+ if model_info:
901
+ # Format model information
902
+ friendly_info = {
903
+ "🤖 Friendly Name": model_info.get('friendly_name', model_name),
904
+ "🏗️ Architecture": model_info.get('architecture_type', 'Unknown'),
905
+ "💡 Best For": model_info.get('use_cases', []),
906
+ "⚡ Performance": model_info.get('performance_characteristics', {}),
907
+ "🎯 Recommendations": model_info.get('recommended_for', {}),
908
+ "📊 Technical Details": {
909
+ "Filename": model_name,
910
+ "Supported Stems": len(str(model_info)) // 10 # Rough estimate
911
+ }
912
+ }
913
+ return friendly_info
914
+
915
+ return {"error": "Model information not available"}
916
+
917
+ def infer(audio_file, model_name, quality_preset, batch_size, segment_size,
918
+ overlap, denoise, tta, post_process, pitch_shift, auto_optimize):
919
  if not audio_file or not model_name:
920
+ return None, None, None, None, None, "Please upload an audio file and select a model", None
921
 
922
  # Prepare custom parameters
923
  custom_params = {
 
930
  "vr_params": {
931
  "batch_size": int(batch_size),
932
  "enable_tta": tta,
933
+ "enable_post_process": post_process,
934
+ "aggression": 5 # Default
935
  },
936
  "demucs_params": {
937
  "overlap": float(overlap)
938
  },
939
  "mdxc_params": {
940
  "batch_size": int(batch_size),
941
+ "overlap": int(overlap * 10),
942
+ "pitch_shift": int(pitch_shift)
943
  }
944
  }
945
 
946
+ output_audio, status = demo1.infer(
947
+ audio_file, model_name,
948
+ quality_preset=quality_preset,
949
+ custom_params=custom_params,
950
+ enable_auto_optimize=auto_optimize
951
  )
952
 
953
  if output_audio is None:
954
+ return None, None, None, None, None, status, None
955
 
956
  # Extract different stems
957
  vocals = None
 
972
  else:
973
  other = (sample_rate, audio_data)
974
 
975
+ # Generate performance metrics
976
+ performance_metrics = {
977
+ "Model": model_name,
978
+ "Quality Preset": quality_preset,
979
+ "Output Stems": len(output_audio),
980
+ "Processing": "Completed Successfully"
981
+ }
982
+
983
+ return vocals, instrumental, drums, bass, other, status, performance_metrics
 
 
 
 
 
 
984
 
985
+ def compare_models_advanced(audio_file, model_list):
986
  if not audio_file or not model_list:
987
  return {"error": "Please upload an audio file and select models to compare"}
988
 
989
+ results = demo1.compare_models(audio_file, model_list)
990
  return results
991
 
992
+ def get_smart_recommendations(audio_file):
993
+ if not audio_file:
994
+ return "Please upload an audio file first", {}
995
+
996
+ # Analyze audio
997
+ audio_analysis = demo1.analyze_audio_characteristics(audio_file)
998
+ models = demo1.get_available_models()
999
+
1000
+ # Generate recommendations
1001
+ recommendations = {
1002
+ "audio_analysis": audio_analysis,
1003
+ "recommended_models": [],
1004
+ "tips": []
1005
+ }
1006
+
1007
+ # Quality-focused recommendations
1008
+ quality_models = []
1009
+ speed_models = []
1010
+
1011
+ for model_name, model_info in models.items():
1012
+ perf_chars = model_info.get('performance_characteristics', {})
1013
+
1014
+ if perf_chars.get('quality_rating') == 'high':
1015
+ quality_models.append({
1016
+ 'model': model_name,
1017
+ 'name': model_info.get('friendly_name', model_name),
1018
+ 'reason': 'High quality output'
1019
+ })
1020
+
1021
+ if perf_chars.get('speed_rating') == 'fast':
1022
+ speed_models.append({
1023
+ 'model': model_name,
1024
+ 'name': model_info.get('friendly_name', model_name),
1025
+ 'reason': 'Fast processing'
1026
+ })
1027
+
1028
+ recommendations["recommended_models"] = {
1029
+ "🎯 For Best Quality": quality_models[:3],
1030
+ "⚡ For Speed": speed_models[:3]
1031
+ }
1032
+
1033
+ # Audio-specific tips
1034
+ audio_type = audio_analysis.get('audio_type', 'balanced')
1035
+ if audio_type == 'bass_heavy':
1036
+ recommendations["tips"].append("🎸 Models with bass separation work best")
1037
+ elif audio_type == 'bright_crisp':
1038
+ recommendations["tips"].append("🎤 Post-processing enabled for vocal clarity")
1039
+ elif audio_type == 'upbeat':
1040
+ recommendations["tips"].append("🥁 Consider drum isolation for energetic tracks")
1041
+
1042
+ status = f"✅ Generated recommendations for {audio_analysis.get('audio_type', 'unknown')} audio"
1043
+ return status, recommendations
1044
+
1045
+ def apply_best_recommendation(audio_file):
1046
+ if not audio_file:
1047
+ return None, "Please upload an audio file first", None
1048
+
1049
+ # Get auto-selection with quality priority
1050
+ audio_analysis = demo1.analyze_audio_characteristics(audio_file)
1051
+ selected_model = demo1.auto_select_model(
1052
+ audio_analysis, ["vocals"], "quality"
1053
+ )
1054
+
1055
+ if selected_model:
1056
+ models = demo1.get_available_models()
1057
+ model_info = models.get(selected_model, {})
1058
+
1059
+ return (
1060
+ selected_model,
1061
+ f"✨ Applied recommendation: {model_info.get('friendly_name', selected_model)}",
1062
+ model_info
1063
+ )
1064
+ else:
1065
+ return None, "Could not generate recommendations", None
1066
+
1067
  # Wire up event handlers
1068
+ analyze_btn.click(
1069
+ fn=analyze_audio,
1070
+ inputs=[audio_input],
1071
+ outputs=[audio_analysis_output, recommendation_status]
1072
+ )
1073
+
1074
+ auto_select_btn.click(
1075
+ fn=auto_select_model,
1076
+ inputs=[audio_input, priority_radio],
1077
+ outputs=[model_dropdown, recommendation_status, model_info_display]
1078
+ )
1079
+
1080
+ model_dropdown.change(
1081
+ fn=update_model_info,
1082
+ inputs=[model_dropdown],
1083
+ outputs=[model_info_display]
1084
+ )
1085
+
1086
  process_btn.click(
1087
+ fn=infer,
1088
  inputs=[
1089
  audio_input, model_dropdown, quality_preset,
1090
+ batch_size, segment_size, overlap, denoise, tta, post_process,
1091
+ pitch_shift, auto_optimize
1092
  ],
1093
  outputs=[
1094
  vocals_output, instrumental_output, drums_output,
1095
+ bass_output, other_output, status_output, performance_metrics
1096
  ]
1097
  )
1098
 
1099
+ compare_btn.click(
1100
+ fn=compare_models_advanced,
1101
+ inputs=[audio_input, filtered_models],
1102
+ outputs=[comparison_results]
1103
+ )
1104
+
1105
  refresh_models_btn.click(
1106
+ fn=lambda: demo1.get_available_models(),
1107
+ outputs=[model_info]
1108
  )
1109
 
1110
  refresh_history_btn.click(
1111
+ fn=lambda: demo1.get_phistory(),
1112
  outputs=[history_output]
1113
  )
1114
 
1115
  reset_history_btn.click(
1116
+ fn=lambda: demo1.reset_history(),
1117
  outputs=[history_output]
1118
  )
1119
 
1120
+ get_recommendations_btn.click(
1121
+ fn=get_smart_recommendations,
1122
+ inputs=[audio_input],
1123
+ outputs=[recommendation_status, recommendations_display]
1124
+ )
1125
+
1126
+ apply_recommendation_btn.click(
1127
+ fn=apply_best_recommendation,
1128
+ inputs=[audio_input],
1129
+ outputs=[model_dropdown, recommendation_status, model_info_display]
1130
  )
1131
 
1132
  # Batch processing
1133
+ def batch_inf(batch_files, model_name):
1134
  if not batch_files or not model_name:
1135
  return None, "Please upload batch files and select a model"
1136
 
 
1137
  import zipfile
1138
  import io
1139
 
1140
  zip_buffer = io.BytesIO()
1141
  with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
1142
  for file_info in batch_files:
1143
+ output_audio, _ = demo1.infer(file_info, model_name)
1144
  if output_audio is not None:
 
1145
  for stem_name, (sample_rate, audio_data) in output_audio.items():
 
1146
  import tempfile
1147
  with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_file:
1148
  sf.write(tmp_file.name, audio_data, sample_rate)
 
1154
  return gr.File(value=zip_buffer, visible=True), f"Batch processing completed for {len(batch_files)} files"
1155
 
1156
  batch_btn.click(
1157
+ fn=batch_inf,
1158
  inputs=[batch_files, model_dropdown],
1159
  outputs=[batch_output, status_output]
1160
  )
1161
+
1162
+ def auto_batch_process(batch_files, priority):
1163
+ if not batch_files:
1164
+ return None, "Please upload batch files"
1165
+
1166
+ # Auto-select best model for first file as representative
1167
+ if batch_files:
1168
+ audio_analysis = demo1.analyze_audio_characteristics(batch_files[0])
1169
+ selected_model = demo1.auto_select_model(audio_analysis, ["vocals"], priority.lower())
1170
+
1171
+ if selected_model:
1172
+ return batch_inf(batch_files, selected_model)
1173
+
1174
+ return None, "Auto-selection failed"
1175
+
1176
+ auto_batch_btn.click(
1177
+ fn=auto_batch_process,
1178
+ inputs=[batch_files, priority_radio],
1179
+ outputs=[batch_output, status_output]
1180
+ )
1181
 
1182
  return interface
1183
 
 
1185
  if __name__ == "__main__":
1186
  interface = create_interface()
1187
  interface.launch(
 
1188
  server_port=7860,
1189
+ theme=" NeoPy/Soft",
1190
+ share=True,
1191
+ debug=True
1192
  )