audio-separator / app.py
NeoPy's picture
Update app.py
bd5bd8e verified
import os
import json
import torch
import logging
import traceback
from typing import Dict, List, Optional, Tuple
import time
from datetime import datetime
import threading
from collections import defaultdict
import gradio as gr
import numpy as np
import librosa
import soundfile as sf
from pydub import AudioSegment
from audio_separator.separator import Separator
from audio_separator.separator import architectures
class AudioSeparatorD:
def __init__(self):
self.separator = None
self.available_models = {}
self.current_model = None
self.processing_history = []
self.model_performance_cache = {}
self.model_recommendations = {}
self.setup_logging()
self.model_lock = threading.Lock()
def setup_logging(self):
"""Setup logging for the application"""
logging.basicConfig(level=logging.INFO)
self.logger = logging.getLogger(__name__)
def get_system_info(self):
"""Get system information for hardware acceleration"""
info = {
"pytorch_version": torch.__version__,
"cuda_available": torch.cuda.is_available(),
"cuda_version": torch.version.cuda if torch.cuda.is_available() else "N/A",
"mps_available": hasattr(torch.backends, "mps") and torch.backends.mps.is_available(),
"device": "cuda" if torch.cuda.is_available() else ("mps" if hasattr(torch.backends, "mps") and torch.backends.mps.is_available() else "cpu"),
}
# Only add memory info if CUDA is available
if torch.cuda.is_available():
info["memory_total"] = torch.cuda.get_device_properties(0).total_memory
info["memory_allocated"] = torch.cuda.memory_allocated()
else:
info["memory_total"] = 0
info["memory_allocated"] = 0
return info
def analyze_audio_characteristics(self, audio_file: str) -> Dict:
"""Analyze audio file characteristics for smart model selection"""
try:
# Load audio for analysis
y, sr = librosa.load(audio_file, sr=None)
duration = len(y) / sr
# Analyze spectral characteristics
spectral_centroids = librosa.feature.spectral_centroid(y=y, sr=sr)[0]
spectral_rolloff = librosa.feature.spectral_rolloff(y=y, sr=sr)[0]
zero_crossing_rate = librosa.feature.zero_crossing_rate(y)[0]
# Analyze tempo and rhythm
tempo, _ = librosa.beat.beat_track(y=y, sr=sr)
# Analyze dynamic range
rms = librosa.feature.rms(y=y)[0]
dynamic_range = np.std(rms)
# Determine audio characteristics
characteristics = {
"duration": duration,
"sample_rate": sr,
"tempo": float(tempo),
"avg_spectral_centroid": float(np.mean(spectral_centroids)),
"avg_spectral_rolloff": float(np.mean(spectral_rolloff)),
"avg_zero_crossing_rate": float(np.mean(zero_crossing_rate)),
"dynamic_range": float(dynamic_range),
"audio_type": self._classify_audio_type(
np.mean(spectral_centroids),
float(tempo),
dynamic_range
)
}
return characteristics
except Exception as e:
self.logger.error(f"Error analyzing audio: {str(e)}")
return {"audio_type": "unknown", "error": str(e)}
def _classify_audio_type(self, spectral_centroid: float, tempo: float, dynamic_range: float) -> str:
"""Classify audio type based on spectral and temporal features"""
if spectral_centroid < 1000:
return "bass_heavy"
elif spectral_centroid > 4000:
return "bright_crisp"
elif tempo > 120:
return "upbeat"
elif dynamic_range > 0.1:
return "dynamic"
else:
return "balanced"
def get_available_models(self):
"""Get list of available models with enhanced information"""
try:
with self.model_lock:
if self.separator is None:
self.separator = Separator(info_only=True)
models = self.separator.list_supported_model_files()
simplified_models = self.separator.get_simplified_model_list()
# Enhance model information
enhanced_models = {}
for model_name, model_info in simplified_models.items():
# Parse model filename for better names
friendly_name = self._generate_friendly_name(model_name, model_info)
# Determine best use cases
use_cases = self._determine_use_cases(model_name, model_info)
# Estimate performance characteristics
perf_chars = self._estimate_performance(model_name)
enhanced_models[model_name] = {
**model_info,
"friendly_name": friendly_name,
"use_cases": use_cases,
"performance_characteristics": perf_chars,
"architecture_type": self._get_architecture_type(model_name),
"recommended_for": self._get_recommendations(model_name, model_info)
}
return enhanced_models
except Exception as e:
self.logger.error(f"Error getting available models: {str(e)}")
return {}
def _generate_friendly_name(self, model_name: str, model_info: Dict) -> str:
"""Generate user-friendly model names"""
# Remove common prefixes and suffixes
clean_name = model_name.replace('model_', '').replace('.ckpt', '').replace('.yaml', '')
# Handle specific known models
if 'roformer' in model_name.lower():
return f"🎵 Roformer {clean_name.split('_')[-1] if '_' in clean_name else ''}".strip()
elif 'demucs' in model_name.lower():
return f"🥁 Demucs {clean_name.replace('htdemucs', '').replace('_', ' ')}".strip()
elif 'mdx' in model_name.lower():
return f"🎤 MDX-Net {clean_name[-3:] if clean_name[-3:].isdigit() else ''}".strip()
else:
# Capitalize words
words = clean_name.replace('_', ' ').split()
return ' '.join(word.capitalize() for word in words)
def _determine_use_cases(self, model_name: str, model_info: Dict) -> List[str]:
"""Determine what this model is best for"""
use_cases = []
# Check output stems
if 'vocals' in str(model_info).lower():
use_cases.append("🎤 Vocal Isolation")
if 'drums' in str(model_info).lower():
use_cases.append("🥁 Drum Separation")
if 'bass' in str(model_info).lower():
use_cases.append("🎸 Bass Extraction")
if 'instrumental' in str(model_info).lower():
use_cases.append("🎹 Instrumental")
if 'guitar' in str(model_info).lower() or 'piano' in str(model_info).lower():
use_cases.append("🎸 Specific Instruments")
# Architecture-based use cases
if 'roformer' in model_name.lower():
use_cases.append("⚡ High Quality")
elif 'demucs' in model_name.lower():
use_cases.append("🎛️ Multi-stem")
elif 'mdx' in model_name.lower():
use_cases.append("🎵 Fast Processing")
return use_cases[:3] # Limit to top 3
def _estimate_performance(self, model_name: str) -> Dict:
"""Estimate performance characteristics"""
perf = {
"speed_rating": "medium",
"quality_rating": "medium",
"memory_usage": "medium"
}
if 'roformer' in model_name.lower():
perf.update({"speed_rating": "slow", "quality_rating": "high", "memory_usage": "high"})
elif 'demucs' in model_name.lower():
perf.update({"speed_rating": "slow", "quality_rating": "high", "memory_usage": "high"})
elif 'mdx' in model_name.lower():
perf.update({"speed_rating": "fast", "quality_rating": "medium", "memory_usage": "low"})
return perf
def _get_architecture_type(self, model_name: str) -> str:
"""Extract architecture type from model name"""
if 'roformer' in model_name.lower():
return "🎵 Roformer (MDXC)"
elif 'demucs' in model_name.lower():
return "🥁 Demucs"
elif 'mdx' in model_name.lower():
return "🎤 MDX-Net"
elif 'vr' in model_name.lower():
return "🎛️ VR Arch"
else:
return "🔧 Unknown"
def _get_recommendations(self, model_name: str, model_info: Dict) -> Dict:
"""Get specific recommendations for model usage"""
recommendations = {
"best_for": "General use",
"avoid_for": "None",
"tips": []
}
if 'roformer' in model_name.lower():
recommendations.update({
"best_for": "High-quality vocal isolation",
"avoid_for": "Real-time processing",
"tips": ["Best results with longer audio files", "Higher memory usage", "Excellent for final mastering"]
})
elif 'demucs' in model_name.lower():
recommendations.update({
"best_for": "Multi-stem separation (drums, bass, vocals)",
"avoid_for": "Simple vocal/instrumental separation",
"tips": ["Creates multiple output files", "Good for music production", "Slower but comprehensive"]
})
elif 'mdx' in model_name.lower():
recommendations.update({
"best_for": "Fast vocal isolation",
"avoid_for": "Multi-instrument separation",
"tips": ["Quick processing", "Good for demos", "Lower memory requirements"]
})
return recommendations
def auto_select_model(self, audio_characteristics: Dict, desired_stems: List[str],
priority: str = "quality") -> Optional[str]:
"""Automatically select the best model based on audio characteristics and requirements"""
try:
models = self.get_available_models()
if not models:
return None
# Score models based on criteria
model_scores = {}
for model_name, model_info in models.items():
score = 0
# Base score from performance characteristics
perf_chars = model_info.get('performance_characteristics', {})
if priority == "quality":
if perf_chars.get('quality_rating') == 'high':
score += 10
elif perf_chars.get('quality_rating') == 'medium':
score += 5
elif priority == "speed":
if perf_chars.get('speed_rating') == 'fast':
score += 10
elif perf_chars.get('speed_rating') == 'medium':
score += 5
# Audio type matching
audio_type = audio_characteristics.get('audio_type', 'balanced')
use_cases = model_info.get('use_cases', [])
if audio_type == 'bass_heavy' and '🎸 Bass Extraction' in use_cases:
score += 8
elif audio_type == 'bright_crisp' and '🎤 Vocal Isolation' in use_cases:
score += 8
elif audio_type == 'upbeat' and '🎹 Instrumental' in use_cases:
score += 6
# Stem compatibility
model_stems = str(model_info).lower()
for stem in desired_stems:
if stem.lower() in model_stems:
score += 5
# Architecture preference based on priority
arch_type = model_info.get('architecture_type', '')
if priority == "quality" and "Roformer" in arch_type:
score += 15
elif priority == "speed" and "MDX-Net" in arch_type:
score += 15
model_scores[model_name] = score
# Return highest scoring model
if model_scores:
best_model = max(model_scores.items(), key=lambda x: x[1])
return best_model[0]
return None
except Exception as e:
self.logger.error(f"Error in auto-select: {str(e)}")
return None
def compare_models(self, audio_file: str, model_list: List[str]) -> Dict:
"""Enhanced model comparison with detailed metrics"""
if not audio_file or not model_list:
return {"error": "Please provide audio file and select models to compare"}
comparison_results = {
"audio_analysis": self.analyze_audio_characteristics(audio_file),
"model_results": {},
"summary": {},
"recommendations": []
}
for model_name in model_list:
try:
start_time = time.time()
# Initialize separator for this model
success, message = self.initialize_separator(model_name)
if not success:
comparison_results["model_results"][model_name] = {
"status": "Failed",
"error": message,
"processing_time": 0
}
continue
# Process audio
output_files = self.separator.separate(audio_file)
processing_time = time.time() - start_time
# Analyze results
if output_files and os.path.exists(output_files[0]):
audio_data, sample_rate = sf.read(output_files[0])
# Calculate quality metrics
quality_metrics = self._calculate_quality_metrics(audio_data, sample_rate)
comparison_results["model_results"][model_name] = {
"status": "Success",
"processing_time": processing_time,
"output_files": len(output_files),
"sample_rate": sample_rate,
"duration": len(audio_data) / sample_rate,
"quality_metrics": quality_metrics,
"output_stems": [os.path.basename(f) for f in output_files],
"model_info": self.get_available_models().get(model_name, {})
}
# Clean up
for file_path in output_files:
if os.path.exists(file_path):
os.remove(file_path)
else:
comparison_results["model_results"][model_name] = {
"status": "Failed",
"error": "No output files generated",
"processing_time": processing_time
}
except Exception as e:
comparison_results["model_results"][model_name] = {
"status": "Error",
"error": str(e),
"processing_time": 0
}
# Generate summary and recommendations
comparison_results["summary"] = self._generate_comparison_summary(comparison_results["model_results"])
comparison_results["recommendations"] = self._generate_recommendations(
comparison_results["audio_analysis"],
comparison_results["model_results"]
)
return comparison_results
def _calculate_quality_metrics(self, audio_data: np.ndarray, sample_rate: int) -> Dict:
"""Calculate audio quality metrics"""
try:
# RMS level
rms = np.sqrt(np.mean(audio_data**2))
# Dynamic range
peak = np.max(np.abs(audio_data))
dynamic_range = 20 * np.log10(peak / (rms + 1e-10))
# Spectral characteristics
spectral_centroid = np.mean(librosa.feature.spectral_centroid(y=audio_data, sr=sample_rate))
return {
"rms_level": float(rms),
"peak_level": float(peak),
"dynamic_range": float(dynamic_range),
"spectral_centroid": float(spectral_centroid),
"length_samples": len(audio_data),
"length_seconds": len(audio_data) / sample_rate
}
except Exception as e:
return {"error": str(e)}
def _generate_comparison_summary(self, model_results: Dict) -> Dict:
"""Generate summary statistics from model comparison"""
successful_results = {k: v for k, v in model_results.items() if v.get("status") == "Success"}
if not successful_results:
return {"message": "No successful model runs to compare"}
summary = {
"total_models": len(model_results),
"successful_models": len(successful_results),
"fastest_model": None,
"slowest_model": None,
"best_quality": None,
"average_processing_time": 0
}
# Find fastest and slowest
if successful_results:
times = {k: v.get("processing_time", 0) for k, v in successful_results.items()}
summary["fastest_model"] = min(times.items(), key=lambda x: x[1])[0]
summary["slowest_model"] = max(times.items(), key=lambda x: x[1])[0]
summary["average_processing_time"] = np.mean(list(times.values()))
return summary
def _generate_recommendations(self, audio_analysis: Dict, model_results: Dict) -> List[str]:
"""Generate intelligent recommendations based on comparison"""
recommendations = []
# Find best performing model
successful_models = {k: v for k, v in model_results.items() if v.get("status") == "Success"}
if successful_models:
# Find fastest successful model
fastest_model = min(successful_models.items(),
key=lambda x: x[1].get("processing_time", float('inf')))
recommendations.append(f"⚡ Fastest: {fastest_model[0]} ({fastest_model[1]['processing_time']:.2f}s)")
# Find model with most outputs
most_outputs = max(successful_models.items(),
key=lambda x: x[1].get("output_files", 0))
recommendations.append(f"🎛️ Most stems: {most_outputs[0]} ({most_outputs[1]['output_files']} files)")
# Audio-based recommendations
audio_type = audio_analysis.get('audio_type', 'unknown')
if audio_type == 'bass_heavy':
recommendations.append("🎸 Consider models with bass separation capabilities")
elif audio_type == 'bright_crisp':
recommendations.append("🎤 Models optimized for vocal clarity work best")
elif audio_type == 'upbeat':
recommendations.append("🎹 Fast processing models recommended for energetic tracks")
return recommendations
def initialize_separator(self, model_name: str = None, **kwargs):
"""Initialize the separator with specified parameters"""
try:
with self.model_lock:
# Clean up previous separator if exists
if self.separator is not None:
del self.separator
torch.cuda.empty_cache()
# Set default model if not specified
if model_name is None:
models = self.get_available_models()
if models:
model_name = list(models.keys())[0] # Use first available model
else:
return False, "No models available"
# Initialize separator with updated parameters
self.separator = Separator(
output_format="WAV",
use_autocast=True,
use_soundfile=True,
**kwargs
)
# Load the model
self.separator.load_model(model_name)
self.current_model = model_name
return True, f"Successfully initialized with model: {model_name}"
except Exception as e:
self.logger.error(f"Error initializing separator: {str(e)}")
return False, f"Error initializing separator: {str(e)}"
def infer(self, audio_file: str, model_name: str, output_format: str = "WAV",
quality_preset: str = "Standard", custom_params: Dict = None,
enable_auto_optimize: bool = True):
"""Enhanced audio processing with auto-optimization"""
if audio_file is None:
return None, "No audio file provided"
if model_name is None:
return None, "No model selected"
# Auto-optimize parameters if enabled
if enable_auto_optimize:
audio_analysis = self.analyze_audio_characteristics(audio_file)
custom_params = self._optimize_parameters_for_audio(audio_analysis, custom_params)
if self.separator is None or self.current_model != model_name:
success, message = self.initialize_separator(model_name)
if not success:
return None, message
try:
start_time = time.time()
# Apply quality preset
if custom_params is None:
custom_params = {}
if quality_preset == "Fast":
custom_params.update({
"mdx_params": {"batch_size": 4, "overlap": 0.1, "segment_size": 128},
"vr_params": {"batch_size": 8, "aggression": 3},
"demucs_params": {"shifts": 1, "overlap": 0.1},
"mdxc_params": {"batch_size": 4, "overlap": 4}
})
elif quality_preset == "High Quality":
custom_params.update({
"mdx_params": {"batch_size": 1, "overlap": 0.5, "segment_size": 512, "enable_denoise": True},
"vr_params": {"batch_size": 1, "aggression": 8, "enable_tta": True, "enable_post_process": True},
"demucs_params": {"shifts": 4, "overlap": 0.5, "segments_enabled": False},
"mdxc_params": {"batch_size": 1, "overlap": 16, "pitch_shift": 0}
})
# Update separator parameters
for key, value in custom_params.items():
if hasattr(self.separator, key):
setattr(self.separator, key, value)
# Process the audio
output_files = self.separator.separate(audio_file)
processing_time = time.time() - start_time
# Read and prepare output audio
output_audio = {}
for file_path in output_files:
if os.path.exists(file_path):
# Create output with appropriate naming
stem_name = os.path.splitext(os.path.basename(file_path))[0]
audio_data, sample_rate = sf.read(file_path)
output_audio[stem_name] = (sample_rate, audio_data)
# Clean up file
os.remove(file_path)
if not output_audio:
return None, "No output files generated"
# Record processing history
history_entry = {
"timestamp": datetime.now().isoformat(),
"model": model_name,
"processing_time": processing_time,
"output_files": list(output_audio.keys()),
"audio_analysis": self.analyze_audio_characteristics(audio_file) if enable_auto_optimize else {},
"quality_preset": quality_preset
}
self.processing_history.append(history_entry)
return output_audio, f"Processing completed in {processing_time:.2f}s with model: {model_name}"
except Exception as e:
error_msg = f"Error processing audio: {str(e)}"
self.logger.error(f"{error_msg}\n{traceback.format_exc()}")
return None, error_msg
def _optimize_parameters_for_audio(self, audio_analysis: Dict, custom_params: Dict) -> Dict:
"""Automatically optimize parameters based on audio characteristics"""
if custom_params is None:
custom_params = {}
# Adjust parameters based on audio characteristics
duration = audio_analysis.get('duration', 0)
audio_type = audio_analysis.get('audio_type', 'balanced')
# For longer audio, increase batch size for efficiency
if duration > 300: # 5 minutes
custom_params.setdefault('mdx_params', {})['batch_size'] = 2
custom_params.setdefault('vr_params', {})['batch_size'] = 2
# For bass-heavy audio, increase aggression
if audio_type == 'bass_heavy':
custom_params.setdefault('vr_params', {})['aggression'] = 7
# For bright/crisp audio, enable post-processing
if audio_type == 'bright_crisp':
custom_params.setdefault('vr_params', {})['enable_post_process'] = True
# For dynamic audio, enable TTA for better quality
if audio_analysis.get('dynamic_range', 0) > 0.1:
custom_params.setdefault('vr_params', {})['enable_tta'] = True
return custom_params
def get_phistory(self):
"""Get enhanced processing history with analytics"""
if not self.processing_history:
return "No processing history available"
history_text = "🎵 Enhanced Processing History\n\n"
# Show recent entries with details
for i, entry in enumerate(self.processing_history[-10:], 1):
history_text += f"**{i}. {entry['timestamp'][:19]}**\n"
history_text += f" Model: {entry['model']}\n"
history_text += f" Time: {entry['processing_time']:.2f}s\n"
history_text += f" Stems: {', '.join(entry['output_files'])}\n"
# Add audio analysis if available
if 'audio_analysis' in entry and entry['audio_analysis']:
audio_type = entry['audio_analysis'].get('audio_type', 'unknown')
duration = entry['audio_analysis'].get('duration', 0)
history_text += f" Audio: {audio_type} ({duration:.1f}s)\n"
# Add quality preset info
if 'quality_preset' in entry:
history_text += f" Preset: {entry['quality_preset']}\n"
history_text += "\n"
return history_text
def reset_history(self):
"""Reset processing history"""
self.processing_history = []
return "Processing history cleared"
# Initialize the enhanced demo
demo1 = AudioSeparatorD()
# Create the Gradio interface directly
with gr.Blocks(theme="NeoPy/Soft", title="🎵 Enhanced Audio Separator") as app:
gr.Markdown(
"""
# 🎵 Audio Separator Web UI
**Smart AI-Powered Audio Source Separation with Auto-Selection & Advanced Model Comparison**
✨ **Features**: Auto model selection, performance analytics, smart parameter optimization, and comprehensive model comparison
"""
)
# System Information
with gr.Accordion("🖥️ System Information", open=False):
system_info = demo1.get_system_info()
info_text = f"""
**PyTorch Version:** {system_info['pytorch_version']}
**Hardware Acceleration:** {system_info['device'].upper()}
**CUDA Available:** {system_info['cuda_available']} (Version: {system_info['cuda_version']})
**Apple Silicon (MPS):** {system_info['mps_available']}
**GPU Memory:** {system_info['memory_allocated'] // 1024**2}MB / {system_info['memory_total'] // 1024**2}MB
"""
gr.Markdown(info_text)
with gr.Row():
with gr.Column():
# Main audio input
audio_input = gr.Audio(
label="🎵 Upload Audio File",
type="filepath"
)
# Add info text separately
gr.Markdown("*Upload audio for intelligent analysis and separation*")
# Auto-analyze button
analyze_btn = gr.Button("🔍 Analyze Audio", variant="secondary")
# Audio analysis output
audio_analysis_output = gr.JSON(label="Audio Analysis Results", visible=False)
# Enhanced model selection
model_list = demo1.get_available_models()
# Model dropdown with enhanced display
model_dropdown = gr.Dropdown(
choices=list(model_list.keys()) if model_list else [],
value=list(model_list.keys())[0] if model_list else None,
label="🤖 AI Model Selection",
elem_id="model_dropdown"
)
# Add info text separately
gr.Markdown("*Choose an AI model or use auto-selection*")
# Auto-selection controls
with gr.Row():
auto_select_btn = gr.Button("🎯 Auto-Select Best Model", variant="primary")
priority_radio = gr.Radio(
choices=["Quality", "Speed", "Balanced"],
value="Quality",
label="Selection Priority"
)
# Add info text separately
gr.Markdown("*What matters most for model selection?*")
# Model info display
model_info_display = gr.JSON(label="📊 Selected Model Information")
# Quality preset and optimization
with gr.Row():
quality_preset = gr.Radio(
choices=["Fast", "Standard", "High Quality", "Custom"],
value="Standard",
label="⚡ Processing Quality"
)
auto_optimize = gr.Checkbox(
label="🧠 Auto-Optimize Parameters",
value=True
)
# Add info text separately
gr.Markdown("*Automatically optimize parameters based on audio analysis*")
# Enhanced advanced parameters
with gr.Accordion("🔧 Advanced Parameters", open=False):
with gr.Row():
batch_size = gr.Slider(1, 8, value=1, step=1, label="Batch Size")
segment_size = gr.Slider(64, 1024, value=256, step=64, label="Segment Size")
overlap = gr.Slider(0.1, 0.5, value=0.25, step=0.05, label="Overlap")
with gr.Row():
denoise = gr.Checkbox(label="Enable Denoise", value=False)
tta = gr.Checkbox(label="Enable TTA", value=False)
post_process = gr.Checkbox(label="Enable Post-Processing", value=False)
pitch_shift = gr.Slider(-12, 12, value=0, step=1, label="Pitch Shift (semitones)")
# Process button
process_btn = gr.Button("🎵 Smart Separate Audio", variant="primary", size="lg")
with gr.Column():
# Status and results
status_output = gr.Textbox(label="📋 Status", lines=4)
# Enhanced output tabs
with gr.Tabs():
with gr.Tab("🎤 Vocals"):
vocals_output = gr.Audio(label="Vocals")
with gr.Tab("🎹 Instrumental"):
instrumental_output = gr.Audio(label="Instrumental")
with gr.Tab("🥁 Drums"):
drums_output = gr.Audio(label="Drums")
with gr.Tab("🎸 Bass"):
bass_output = gr.Audio(label="Bass")
with gr.Tab("🎛️ Other Stems"):
other_output = gr.Audio(label="Other Stems")
# Performance metrics
performance_metrics = gr.JSON(label="📈 Performance Metrics", visible=False)
# Download section
with gr.Accordion("📥 Batch & Download", open=False):
gr.Markdown("### 🔄 Batch Processing")
batch_files = gr.File(
file_count="multiple",
file_types=[".wav", ".mp3", ".flac", ".m4a"],
label="Batch Audio Files"
)
with gr.Row():
batch_btn = gr.Button("⚡ Process Batch")
auto_batch_btn = gr.Button("🎯 Auto-Select & Batch")
batch_output = gr.File(label="📦 Download Batch Results")
# Enhanced Model Management Tabs
with gr.Tabs():
with gr.Tab("🔍 Model Explorer"):
gr.Markdown("## 🧠 Intelligent Model Comparison & Selection")
# Enhanced model information
model_info = gr.JSON(value=demo1.get_available_models(), label="📊 Model Database")
refresh_models_btn = gr.Button("🔄 Refresh Models")
# Advanced model filtering
with gr.Row():
filter_architecture = gr.Dropdown(
choices=["All", "MDX-Net", "Demucs", "Roformer", "VR Arch"],
value="All",
label="Filter by Architecture"
)
filter_use_case = gr.Dropdown(
choices=["All", "Vocals", "Instrumental", "Drums", "Bass", "Multi-stem"],
value="All",
label="Filter by Use Case"
)
filter_priority = gr.Dropdown(
choices=["All", "Quality", "Speed", "Memory Efficient"],
value="All",
label="Filter by Priority"
)
filtered_models = gr.Dropdown(
choices=list(model_list.keys())[:10] if model_list else [],
multiselect=True,
label="🎯 Models for Comparison"
)
# Add info text separately
gr.Markdown("*Select up to 5 models for detailed comparison*")
compare_btn = gr.Button("🔬 Advanced Model Comparison")
comparison_results = gr.JSON(label="📊 Comparison Results")
with gr.Tab("📈 Analytics & History"):
history_output = gr.Textbox(label="📜 Processing History", lines=15)
with gr.Row():
refresh_history_btn = gr.Button("🔄 Refresh History")
reset_history_btn = gr.Button("🗑️ Clear History", variant="stop")
export_history_btn = gr.Button("📊 Export Analytics")
analytics_output = gr.JSON(label="📊 Analytics Dashboard")
with gr.Tab("🎯 Smart Recommendations"):
gr.Markdown("## 🤖 AI-Powered Model Recommendations")
recommendation_status = gr.Textbox(label="Recommendation Status", lines=3)
with gr.Row():
get_recommendations_btn = gr.Button("🎯 Get Smart Recommendations")
apply_recommendation_btn = gr.Button("✨ Apply Best Recommendation")
recommendations_display = gr.JSON(label="🎯 Personalized Recommendations")
# Event handlers
def analyze_audio(audio_file):
if not audio_file:
return None, "No audio file provided"
analysis = demo1.analyze_audio_characteristics(audio_file)
# Format analysis for display
if "error" not in analysis:
formatted_analysis = f"""
**Audio Type:** {analysis.get('audio_type', 'Unknown').title().replace('_', ' ')}
**Duration:** {analysis.get('duration', 0):.1f} seconds
**Sample Rate:** {analysis.get('sample_rate', 0)} Hz
**Tempo:** {analysis.get('tempo', 0):.1f} BPM
**Spectral Characteristics:** {analysis.get('avg_spectral_centroid', 0):.0f} Hz (centroid)
**Dynamic Range:** {analysis.get('dynamic_range', 0):.3f}
"""
return analysis, formatted_analysis
else:
return analysis, f"Analysis failed: {analysis['error']}"
def auto_select_model(audio_file, priority):
if not audio_file:
return None, "No audio file provided", None
# Analyze audio first
audio_analysis = demo1.analyze_audio_characteristics(audio_file)
# Determine desired stems based on audio analysis
desired_stems = ["vocals"] # Default
if audio_analysis.get('audio_type') == 'bass_heavy':
desired_stems.append("bass")
elif audio_analysis.get('tempo', 0) > 120:
desired_stems.append("drums")
# Auto-select model
selected_model = demo1.auto_select_model(
audio_analysis, desired_stems, priority.lower()
)
if selected_model:
models = demo1.get_available_models()
model_info = models.get(selected_model, {})
return (
selected_model,
f"🎯 Auto-selected: {model_info.get('friendly_name', selected_model)}\n"
f"Architecture: {model_info.get('architecture_type', 'Unknown')}\n"
f"Best for: {', '.join(model_info.get('use_cases', [])[:2])}",
model_info
)
else:
return None, "Auto-selection failed - no suitable model found", None
def update_model_info(model_name):
if not model_name:
return None
models = demo1.get_available_models()
model_info = models.get(model_name, {})
if model_info:
# Format model information
friendly_info = {
"🤖 Friendly Name": model_info.get('friendly_name', model_name),
"🏗️ Architecture": model_info.get('architecture_type', 'Unknown'),
"💡 Best For": model_info.get('use_cases', []),
"⚡ Performance": model_info.get('performance_characteristics', {}),
"🎯 Recommendations": model_info.get('recommended_for', {}),
"📊 Technical Details": {
"Filename": model_name,
"Supported Stems": len(str(model_info)) // 10 # Rough estimate
}
}
return friendly_info
return {"error": "Model information not available"}
def infer(audio_file, model_name, quality_preset, batch_size, segment_size,
overlap, denoise, tta, post_process, pitch_shift, auto_optimize):
if not audio_file or not model_name:
return None, None, None, None, None, "Please upload an audio file and select a model", None
# Prepare custom parameters
custom_params = {
"mdx_params": {
"batch_size": int(batch_size),
"segment_size": int(segment_size),
"overlap": float(overlap),
"enable_denoise": denoise
},
"vr_params": {
"batch_size": int(batch_size),
"enable_tta": tta,
"enable_post_process": post_process,
"aggression": 5 # Default
},
"demucs_params": {
"overlap": float(overlap)
},
"mdxc_params": {
"batch_size": int(batch_size),
"overlap": int(overlap * 10),
"pitch_shift": int(pitch_shift)
}
}
output_audio, status = demo1.infer(
audio_file, model_name,
quality_preset=quality_preset,
custom_params=custom_params,
enable_auto_optimize=auto_optimize
)
if output_audio is None:
return None, None, None, None, None, status, None
# Extract different stems
vocals = None
instrumental = None
drums = None
bass = None
other = None
for stem_name, (sample_rate, audio_data) in output_audio.items():
if "vocal" in stem_name.lower():
vocals = (sample_rate, audio_data)
elif "instrumental" in stem_name.lower():
instrumental = (sample_rate, audio_data)
elif "drum" in stem_name.lower():
drums = (sample_rate, audio_data)
elif "bass" in stem_name.lower():
bass = (sample_rate, audio_data)
else:
other = (sample_rate, audio_data)
# Generate performance metrics
performance_metrics = {
"Model": model_name,
"Quality Preset": quality_preset,
"Output Stems": len(output_audio),
"Processing": "Completed Successfully"
}
return vocals, instrumental, drums, bass, other, status, performance_metrics
def compare_models_advanced(audio_file, model_list):
if not audio_file or not model_list:
return {"error": "Please upload an audio file and select models to compare"}
results = demo1.compare_models(audio_file, model_list)
return results
def get_smart_recommendations(audio_file):
if not audio_file:
return "Please upload an audio file first", {}
# Analyze audio
audio_analysis = demo1.analyze_audio_characteristics(audio_file)
models = demo1.get_available_models()
# Generate recommendations
recommendations = {
"audio_analysis": audio_analysis,
"recommended_models": [],
"tips": []
}
# Quality-focused recommendations
quality_models = []
speed_models = []
for model_name, model_info in models.items():
perf_chars = model_info.get('performance_characteristics', {})
if perf_chars.get('quality_rating') == 'high':
quality_models.append({
'model': model_name,
'name': model_info.get('friendly_name', model_name),
'reason': 'High quality output'
})
if perf_chars.get('speed_rating') == 'fast':
speed_models.append({
'model': model_name,
'name': model_info.get('friendly_name', model_name),
'reason': 'Fast processing'
})
recommendations["recommended_models"] = {
"🎯 For Best Quality": quality_models[:3],
"⚡ For Speed": speed_models[:3]
}
# Audio-specific tips
audio_type = audio_analysis.get('audio_type', 'balanced')
if audio_type == 'bass_heavy':
recommendations["tips"].append("🎸 Models with bass separation work best")
elif audio_type == 'bright_crisp':
recommendations["tips"].append("🎤 Post-processing enabled for vocal clarity")
elif audio_type == 'upbeat':
recommendations["tips"].append("🥁 Consider drum isolation for energetic tracks")
status = f"✅ Generated recommendations for {audio_analysis.get('audio_type', 'unknown')} audio"
return status, recommendations
def apply_best_recommendation(audio_file):
if not audio_file:
return None, "Please upload an audio file first", None
# Get auto-selection with quality priority
audio_analysis = demo1.analyze_audio_characteristics(audio_file)
selected_model = demo1.auto_select_model(
audio_analysis, ["vocals"], "quality"
)
if selected_model:
models = demo1.get_available_models()
model_info = models.get(selected_model, {})
return (
selected_model,
f"✨ Applied recommendation: {model_info.get('friendly_name', selected_model)}",
model_info
)
else:
return None, "Could not generate recommendations", None
# Wire up event handlers
analyze_btn.click(
fn=analyze_audio,
inputs=[audio_input],
outputs=[audio_analysis_output, recommendation_status]
)
auto_select_btn.click(
fn=auto_select_model,
inputs=[audio_input, priority_radio],
outputs=[model_dropdown, recommendation_status, model_info_display]
)
model_dropdown.change(
fn=update_model_info,
inputs=[model_dropdown],
outputs=[model_info_display]
)
process_btn.click(
fn=infer,
inputs=[
audio_input, model_dropdown, quality_preset,
batch_size, segment_size, overlap, denoise, tta, post_process,
pitch_shift, auto_optimize
],
outputs=[
vocals_output, instrumental_output, drums_output,
bass_output, other_output, status_output, performance_metrics
]
)
compare_btn.click(
fn=compare_models_advanced,
inputs=[audio_input, filtered_models],
outputs=[comparison_results]
)
refresh_models_btn.click(
fn=lambda: demo1.get_available_models(),
outputs=[model_info]
)
refresh_history_btn.click(
fn=lambda: demo1.get_phistory(),
outputs=[history_output]
)
reset_history_btn.click(
fn=lambda: demo1.reset_history(),
outputs=[history_output]
)
get_recommendations_btn.click(
fn=get_smart_recommendations,
inputs=[audio_input],
outputs=[recommendation_status, recommendations_display]
)
apply_recommendation_btn.click(
fn=apply_best_recommendation,
inputs=[audio_input],
outputs=[model_dropdown, recommendation_status, model_info_display]
)
# Batch processing
def batch_inf(batch_files, model_name):
if not batch_files or not model_name:
return None, "Please upload batch files and select a model"
import zipfile
import io
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file:
for file_info in batch_files:
output_audio, _ = demo1.infer(file_info, model_name)
if output_audio is not None:
for stem_name, (sample_rate, audio_data) in output_audio.items():
import tempfile
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_file:
sf.write(tmp_file.name, audio_data, sample_rate)
with open(tmp_file.name, 'rb') as f:
zip_file.writestr(f"{os.path.splitext(os.path.basename(file_info))[0]}_{stem_name}.wav", f.read())
os.unlink(tmp_file.name)
zip_buffer.seek(0)
return gr.File(value=zip_buffer, visible=True), f"Batch processing completed for {len(batch_files)} files"
batch_btn.click(
fn=batch_inf,
inputs=[batch_files, model_dropdown],
outputs=[batch_output, status_output]
)
def auto_batch_process(batch_files, priority):
if not batch_files:
return None, "Please upload batch files"
# Auto-select best model for first file as representative
if batch_files:
audio_analysis = demo1.analyze_audio_characteristics(batch_files[0])
selected_model = demo1.auto_select_model(audio_analysis, ["vocals"], priority.lower())
if selected_model:
return batch_inf(batch_files, selected_model)
return None, "Auto-selection failed"
auto_batch_btn.click(
fn=auto_batch_process,
inputs=[batch_files, priority_radio],
outputs=[batch_output, status_output]
)
app.launch(
server_port=7860,
share=True,
ssr_mode=True
)