Spaces:
Sleeping
Sleeping
| """ | |
| NAM Garden - Integrated Design Version | |
| Combines NAM processing with modern UI design | |
| """ | |
| import gradio as gr | |
| import json | |
| import os | |
| from pathlib import Path | |
| from datetime import datetime | |
| from typing import Dict, Any, Optional, Tuple, List | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torchaudio | |
| import warnings | |
| import tempfile | |
| import zipfile | |
| import matplotlib.pyplot as plt | |
| import matplotlib | |
| import requests | |
| import re | |
| matplotlib.use('Agg') | |
| # Suppress warnings | |
| warnings.filterwarnings('ignore') | |
| # Check for GPU availability | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| print(f"Using device: {device}") | |
| # Constants | |
| MODELS_DIR = Path("models") | |
| SAMPLE_RATE = 48000 | |
| CHUNK_SIZE = 512 | |
| # ========== NAM MODEL CLASSES ========== | |
| class Linear(nn.Module): | |
| """Simple linear model for NAM processing""" | |
| def __init__(self, receptive_field: int): | |
| super().__init__() | |
| self.receptive_field = receptive_field | |
| self.fc = nn.Linear(receptive_field, 1, bias=True) | |
| def forward(self, x): | |
| return self.fc(x) | |
| class LSTM(nn.Module): | |
| """LSTM model for NAM processing""" | |
| def __init__(self, input_size=1, hidden_size=32, num_layers=1): | |
| super().__init__() | |
| self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True) | |
| self.fc = nn.Linear(hidden_size, 1) | |
| def forward(self, x): | |
| if x.dim() == 2: | |
| x = x.unsqueeze(-1) | |
| lstm_out, _ = self.lstm(x) | |
| return self.fc(lstm_out) | |
| class SimpleWaveNet(nn.Module): | |
| """Simplified WaveNet for NAM processing""" | |
| def __init__(self, layers_config=None, head_scale=0.02): | |
| super().__init__() | |
| self.layers = nn.ModuleList() | |
| self.head_scale = head_scale | |
| if layers_config is None: | |
| layers_config = [ | |
| {"channels": 16, "kernel_size": 3, "dilation": 1}, | |
| {"channels": 16, "kernel_size": 3, "dilation": 2}, | |
| {"channels": 16, "kernel_size": 3, "dilation": 4}, | |
| {"channels": 16, "kernel_size": 3, "dilation": 8}, | |
| ] | |
| in_channels = 1 | |
| for config in layers_config: | |
| self.layers.append( | |
| nn.Conv1d( | |
| in_channels, | |
| config.get("channels", 16), | |
| kernel_size=config.get("kernel_size", 3), | |
| dilation=config.get("dilation", 1), | |
| padding=config.get("dilation", 1) * (config.get("kernel_size", 3) - 1) // 2 | |
| ) | |
| ) | |
| in_channels = config.get("channels", 16) | |
| self.output_layer = nn.Conv1d(in_channels, 1, kernel_size=1) | |
| def forward(self, x): | |
| if x.dim() == 2: | |
| x = x.unsqueeze(1) | |
| elif x.dim() == 1: | |
| x = x.unsqueeze(0).unsqueeze(1) | |
| for layer in self.layers: | |
| x = torch.tanh(layer(x)) | |
| return self.output_layer(x) * self.head_scale | |
| # ========== NAM PROCESSOR ========== | |
| class NAMProcessor: | |
| """Main processor for NAM models with cloud storage support""" | |
| def __init__(self): | |
| self.models = {} | |
| self.current_model = None | |
| self.current_model_name = None | |
| self.processed_files = [] | |
| self.custom_models = {} # Store user's cloud models | |
| self.loading_errors = [] | |
| self.load_available_models() | |
| def load_available_models(self): | |
| """Load metadata for available NAM models (lazy loading)""" | |
| self.loading_errors = [] | |
| if not MODELS_DIR.exists(): | |
| print(f"Creating models directory: {MODELS_DIR}") | |
| MODELS_DIR.mkdir(exist_ok=True) | |
| nam_files = list(MODELS_DIR.glob("*.nam")) | |
| print(f"Found {len(nam_files)} local NAM model files") | |
| if not nam_files: | |
| self.loading_errors.append("No .nam files found in 'models' directory.") | |
| return | |
| for nam_file in nam_files: | |
| try: | |
| # Only load metadata, not the full model data | |
| with open(nam_file, 'r', encoding='utf-8') as f: | |
| data = json.load(f) | |
| model_name = nam_file.stem | |
| # Store path and metadata only, load data on demand | |
| self.models[model_name] = { | |
| 'path': nam_file, | |
| 'data': None, # Will load on demand | |
| 'metadata': data.get('metadata', {}), | |
| 'source': 'local' | |
| } | |
| print(f"Indexed local model: {model_name}") | |
| except Exception as e: | |
| error_msg = f"Error indexing {nam_file.name}: {e}" | |
| print(error_msg) | |
| self.loading_errors.append(error_msg) | |
| def download_from_url(self, url: str, model_name: str = None) -> bool: | |
| """Download NAM model from URL (supports direct links, Google Drive, Dropbox)""" | |
| try: | |
| # Generate model name if not provided | |
| if not model_name: | |
| model_name = Path(url).stem or f"cloud_model_{len(self.custom_models)}" | |
| # Handle different cloud services | |
| final_url = url | |
| # Google Drive handling | |
| if 'drive.google.com' in url: | |
| # Extract file ID from Google Drive URL | |
| match = re.search(r'/d/([a-zA-Z0-9_-]+)', url) | |
| if match: | |
| file_id = match.group(1) | |
| final_url = f'https://drive.google.com/uc?id={file_id}&export=download' | |
| # Dropbox handling | |
| elif 'dropbox.com' in url: | |
| final_url = url.replace('?dl=0', '?dl=1') if '?dl=0' in url else url | |
| # Download the file | |
| response = requests.get(final_url, allow_redirects=True, timeout=30) | |
| response.raise_for_status() | |
| # Parse JSON content | |
| data = response.json() | |
| # Store in custom models | |
| self.custom_models[model_name] = { | |
| 'data': data, | |
| 'metadata': data.get('metadata', {}), | |
| 'source': 'cloud', | |
| 'url': url | |
| } | |
| # Also add to main models dict | |
| self.models[f"[Cloud] {model_name}"] = self.custom_models[model_name] | |
| print(f"Successfully loaded cloud model: {model_name}") | |
| return True | |
| except Exception as e: | |
| print(f"Error downloading model from {url}: {e}") | |
| return False | |
| def load_folder_from_url(self, folder_url: str) -> int: | |
| """Load multiple NAM models from a cloud folder""" | |
| loaded_count = 0 | |
| try: | |
| # For Google Drive folders | |
| if 'drive.google.com/drive/folders' in folder_url: | |
| # Note: Full folder downloading would require Google Drive API | |
| # For now, users need to provide individual file links | |
| return 0 | |
| # For GitHub repos/folders | |
| elif 'github.com' in folder_url: | |
| # Convert to raw GitHub URL if needed | |
| if '/tree/' in folder_url: | |
| folder_url = folder_url.replace('github.com', 'raw.githubusercontent.com') | |
| folder_url = folder_url.replace('/tree/', '/') | |
| # Try to fetch a manifest or list file | |
| manifest_url = folder_url.rstrip('/') + '/manifest.json' | |
| response = requests.get(manifest_url) | |
| if response.status_code == 200: | |
| manifest = response.json() | |
| for model_file in manifest.get('models', []): | |
| model_url = folder_url.rstrip('/') + '/' + model_file | |
| if self.download_from_url(model_url, Path(model_file).stem): | |
| loaded_count += 1 | |
| # For direct server folders with index | |
| else: | |
| response = requests.get(folder_url) | |
| if response.status_code == 200: | |
| # Look for .nam file links in the response | |
| nam_links = re.findall(r'href="([^"]*\.nam)"', response.text) | |
| for nam_file in nam_links: | |
| full_url = folder_url.rstrip('/') + '/' + nam_file | |
| if self.download_from_url(full_url, Path(nam_file).stem): | |
| loaded_count += 1 | |
| except Exception as e: | |
| print(f"Error loading folder from {folder_url}: {e}") | |
| return loaded_count | |
| def get_model_choices(self): | |
| """Get model choices for dropdown""" | |
| if not self.models: | |
| # Try loading again in case models weren't loaded | |
| self.load_available_models() | |
| choices = [] | |
| # Add local models first | |
| local_models = [name for name, info in self.models.items() if info.get('source') == 'local'] | |
| if local_models: | |
| choices.append("βββ πΈ Pre-loaded Models βββ") | |
| choices.extend(sorted(local_models)) | |
| # Add cloud models | |
| cloud_models = [name for name, info in self.models.items() if info.get('source') == 'cloud'] | |
| if cloud_models: | |
| if local_models: # Add separator only if there are local models | |
| choices.append("βββ βοΈ Cloud Models βββ") | |
| else: | |
| choices.append("βββ βοΈ Cloud Models βββ") | |
| choices.extend(sorted(cloud_models)) | |
| if self.loading_errors: | |
| choices.append("βββ β οΈ Loading Errors βββ") | |
| for err in self.loading_errors: | |
| # Truncate long error messages for display | |
| choices.append(err[:100] + '...' if len(err) > 100 else err) | |
| if not choices: | |
| return ["No models found - Add cloud models below"] | |
| return choices | |
| def clear_custom_models(self): | |
| """Clear all custom cloud models""" | |
| # Remove cloud models from main dict | |
| self.models = {k: v for k, v in self.models.items() if v.get('source') != 'cloud'} | |
| self.custom_models.clear() | |
| print("Cleared all cloud models") | |
| def load_model(self, model_name: str) -> bool: | |
| """Load a NAM model by name""" | |
| if not model_name or model_name not in self.models: | |
| return False | |
| if model_name == self.current_model_name: | |
| return True | |
| try: | |
| model_data = self.models[model_name]['data'] | |
| architecture = model_data.get('architecture', 'Linear') | |
| config = model_data.get('config', {}) | |
| # Create model based on architecture | |
| if architecture == 'Linear': | |
| model = Linear(config.get('receptive_field', 32)) | |
| elif architecture == 'LSTM': | |
| model = LSTM(hidden_size=config.get('hidden_size', 32)) | |
| elif architecture == 'WaveNet': | |
| model = SimpleWaveNet(config.get('layers', None), config.get('head_scale', 0.02)) | |
| else: | |
| return False | |
| # Load weights if available | |
| if 'weights' in model_data: | |
| try: | |
| weights = model_data['weights'] | |
| if isinstance(weights, list): | |
| weights = torch.tensor(weights, dtype=torch.float32) | |
| if architecture == 'Linear' and hasattr(model, 'fc'): | |
| weight_size = model.fc.weight.numel() | |
| bias_size = model.fc.bias.numel() | |
| if len(weights) >= weight_size + bias_size: | |
| model.fc.weight.data = weights[:weight_size].reshape(model.fc.weight.shape) | |
| model.fc.bias.data = weights[weight_size:weight_size + bias_size] | |
| except Exception as e: | |
| print(f"Warning: Could not load weights: {e}") | |
| self.current_model = model.to(device) | |
| self.current_model_name = model_name | |
| self.current_model.eval() | |
| print(f"Loaded model: {model_name}") | |
| return True | |
| except Exception as e: | |
| print(f"Error loading model {model_name}: {e}") | |
| return False | |
| def process_audio(self, audio_data, sr, input_gain, output_gain, mix): | |
| """Process audio through the current model""" | |
| if self.current_model is None: | |
| return None | |
| try: | |
| # Resample if needed | |
| if sr != SAMPLE_RATE: | |
| audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) | |
| resampler = torchaudio.transforms.Resample(sr, SAMPLE_RATE) | |
| audio_tensor = resampler(audio_tensor) | |
| audio_data = audio_tensor.squeeze(0).numpy() | |
| # Apply input gain | |
| if input_gain != 0: | |
| gain_linear = 10 ** (input_gain / 20) | |
| audio_data = audio_data * gain_linear | |
| audio_data = np.tanh(audio_data) | |
| # Store dry signal | |
| dry_signal = audio_data.copy() | |
| # Process through model | |
| audio_tensor = torch.from_numpy(audio_data).float().to(device) | |
| if audio_tensor.dim() == 1: | |
| audio_tensor = audio_tensor.unsqueeze(0) | |
| with torch.no_grad(): | |
| processed = self.current_model(audio_tensor) | |
| if processed.dim() == 3: | |
| processed = processed.squeeze(1) | |
| if processed.dim() == 2: | |
| processed = processed.squeeze(0) | |
| processed_audio = processed.cpu().numpy() | |
| # Apply mix (convert percentage to 0-1) | |
| mix_ratio = mix / 100.0 | |
| processed_audio = dry_signal * (1 - mix_ratio) + processed_audio * mix_ratio | |
| # Apply output gain | |
| if output_gain != 0: | |
| gain_linear = 10 ** (output_gain / 20) | |
| processed_audio = processed_audio * gain_linear | |
| # Clip to prevent distortion | |
| processed_audio = np.clip(processed_audio, -1.0, 1.0) | |
| return processed_audio | |
| except Exception as e: | |
| print(f"Processing error: {e}") | |
| return None | |
| # Initialize processor | |
| processor = NAMProcessor() | |
| print(f"\nβ Loaded {len([m for m in processor.models.values() if m.get('source') == 'local'])} pre-loaded NAM models") | |
| print(f"Available models: {list(processor.models.keys())}\n") | |
| # ========== CUSTOM CSS ========== | |
| custom_css = """ | |
| /* Prevent body scrolling */ | |
| body { | |
| overflow: hidden !important; | |
| height: 100vh !important; | |
| margin: 0 !important; | |
| padding: 0 !important; | |
| } | |
| /* Main app container */ | |
| #component-0, | |
| .gradio-container > div:first-child { | |
| height: 100vh !important; | |
| max-height: 100vh !important; | |
| overflow: hidden !important; | |
| } | |
| /* Clean modern design - Dark teal/coral theme */ | |
| .gradio-container { | |
| font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; | |
| background: linear-gradient(135deg, #0f2027 0%, #203a43 50%, #2c5364 100%); | |
| color: #fff; | |
| height: 100vh; | |
| max-height: 100vh; | |
| overflow: hidden; | |
| padding: 0 !important; | |
| } | |
| /* Make columns full height */ | |
| .gradio-container .gr-column { | |
| height: 100%; | |
| display: flex; | |
| flex-direction: column; | |
| } | |
| /* Main row container */ | |
| .gradio-container .gr-row { | |
| height: 100vh; | |
| max-height: 100vh; | |
| overflow: hidden; | |
| } | |
| /* File columns specific styling */ | |
| .file-column { | |
| height: 100vh !important; | |
| max-height: 100vh !important; | |
| overflow: hidden !important; | |
| display: flex !important; | |
| flex-direction: column !important; | |
| } | |
| /* File panel container fix */ | |
| .file-panel-container { | |
| height: 100vh !important; | |
| min-height: 100vh !important; | |
| max-height: 100vh !important; | |
| overflow: hidden; | |
| display: flex; | |
| flex-direction: column; | |
| } | |
| .file-panel-container > div { | |
| height: 100% !important; | |
| min-height: 100% !important; | |
| display: flex; | |
| flex-direction: column; | |
| } | |
| /* File panel styling */ | |
| .file-panel { | |
| background: rgba(255, 255, 255, 0.1) !important; | |
| backdrop-filter: blur(10px) !important; | |
| border-right: 1px solid rgba(255, 255, 255, 0.2); | |
| height: 100vh !important; | |
| min-height: 100vh !important; | |
| overflow: hidden; | |
| display: flex; | |
| flex-direction: column; | |
| } | |
| .file-panel:last-child { | |
| border-right: none; | |
| border-left: 1px solid rgba(255, 255, 255, 0.2); | |
| } | |
| /* File panel header */ | |
| .file-panel-header { | |
| background: rgba(255, 255, 255, 0.15); | |
| padding: 15px; | |
| border-bottom: 2px solid #ff6b6b; | |
| display: flex; | |
| align-items: center; | |
| gap: 10px; | |
| } | |
| /* File list styling */ | |
| .file-list { | |
| flex: 1 1 auto; | |
| min-height: 0; | |
| overflow-y: auto; | |
| padding: 10px; | |
| height: calc(100vh - 80px); | |
| } | |
| .file-item { | |
| background: rgba(255, 255, 255, 0.1); | |
| margin: 5px 0; | |
| padding: 12px; | |
| border-radius: 6px; | |
| cursor: pointer; | |
| transition: all 0.2s; | |
| display: flex; | |
| align-items: center; | |
| gap: 10px; | |
| backdrop-filter: blur(5px); | |
| } | |
| .file-item:hover { | |
| background: rgba(255, 107, 107, 0.3); | |
| transform: translateX(5px); | |
| } | |
| .file-item.selected { | |
| background: rgba(255, 107, 107, 0.4); | |
| border-left: 4px solid #ff6b6b; | |
| } | |
| /* Main content area */ | |
| .main-content { | |
| padding: 30px; | |
| background: rgba(255, 255, 255, 0.05); | |
| backdrop-filter: blur(20px); | |
| height: 100vh; | |
| overflow-y: auto; | |
| overflow-x: hidden; | |
| } | |
| /* Control styling */ | |
| .control-group { | |
| background: rgba(255, 255, 255, 0.1); | |
| backdrop-filter: blur(10px); | |
| padding: 20px; | |
| border-radius: 10px; | |
| margin: 20px 0; | |
| border: 1px solid rgba(255, 255, 255, 0.2); | |
| } | |
| /* Button styling - Coral accent */ | |
| .gr-button { | |
| background: linear-gradient(135deg, #ff6b6b 0%, #ee5a50 100%) !important; | |
| border: none !important; | |
| color: #fff !important; | |
| padding: 10px 20px !important; | |
| border-radius: 6px !important; | |
| font-weight: 600 !important; | |
| transition: all 0.2s !important; | |
| } | |
| .gr-button:hover { | |
| background: linear-gradient(135deg, #ee5a50 0%, #e04545 100%) !important; | |
| transform: translateY(-2px) !important; | |
| box-shadow: 0 5px 15px rgba(255, 107, 107, 0.4) !important; | |
| } | |
| /* Slider styling */ | |
| .gr-slider input { | |
| accent-color: #ff6b6b !important; | |
| } | |
| /* Center slider values */ | |
| .gr-slider .gr-slider-value, | |
| .gr-slider span[data-testid="number"], | |
| .gr-slider input[type="number"] { | |
| text-align: center !important; | |
| display: block !important; | |
| width: 100% !important; | |
| } | |
| /* Progress bar */ | |
| .progress-bar { | |
| position: relative; | |
| background: rgba(255, 255, 255, 0.1); | |
| backdrop-filter: blur(10px); | |
| border: 1px solid rgba(255, 107, 107, 0.5); | |
| border-radius: 6px; | |
| padding: 0; | |
| margin-top: 20px; | |
| height: 44px; | |
| overflow: hidden; | |
| } | |
| .progress-fill { | |
| position: absolute; | |
| top: 0; | |
| left: 0; | |
| height: 100%; | |
| background: linear-gradient(90deg, #ff6b6b, #ee5a50); | |
| width: 0%; | |
| transition: width 0.3s ease; | |
| } | |
| .progress-fill.active { | |
| animation: fillProgress 2s ease-out forwards; | |
| } | |
| @keyframes fillProgress { | |
| from { width: 0%; } | |
| to { width: 100%; } | |
| } | |
| .progress-text { | |
| position: relative; | |
| z-index: 1; | |
| color: #fff; | |
| font-size: 14px; | |
| text-align: center; | |
| line-height: 44px; | |
| text-shadow: 0 0 4px rgba(0, 0, 0, 0.3); | |
| font-weight: 500; | |
| } | |
| /* Dropdown styling */ | |
| .gr-dropdown { | |
| background: rgba(255, 255, 255, 0.1) !important; | |
| border: 1px solid rgba(255, 107, 107, 0.5) !important; | |
| color: #fff !important; | |
| user-select: none !important; | |
| } | |
| """ | |
| # ========== FILE MANAGEMENT JAVASCRIPT ========== | |
| file_management_js = """ | |
| <script> | |
| let selectedInputFile = null; | |
| let selectedProcessedFile = null; | |
| let inputFiles = []; | |
| let processedFiles = []; | |
| // Wait for DOM to be ready | |
| window.addEventListener('DOMContentLoaded', (event) => { | |
| // Setup Import button click handler with multiple attempts | |
| let attempts = 0; | |
| const setupImportButton = () => { | |
| const importBtn = document.getElementById('import-btn-header'); | |
| const fileInput = document.querySelector('input[type="file"][accept*="audio"]'); | |
| if (importBtn && fileInput) { | |
| console.log('Setting up import button handler'); | |
| importBtn.addEventListener('click', (e) => { | |
| e.preventDefault(); | |
| e.stopPropagation(); | |
| fileInput.click(); | |
| }); | |
| return true; | |
| } | |
| return false; | |
| }; | |
| // Try immediately and then with delays | |
| if (!setupImportButton() && attempts < 10) { | |
| const interval = setInterval(() => { | |
| attempts++; | |
| if (setupImportButton() || attempts >= 10) { | |
| clearInterval(interval); | |
| } | |
| }, 500); | |
| } | |
| }); | |
| function selectFile(type, index) { | |
| if (type === 'input') { | |
| document.querySelectorAll('#input-file-list .file-item').forEach(item => { | |
| item.classList.remove('selected'); | |
| }); | |
| const selectedItem = document.getElementById(`input-file-${index}`); | |
| if (selectedItem) { | |
| selectedItem.classList.add('selected'); | |
| selectedInputFile = index; | |
| // Update waveform | |
| updateWaveformForFile(inputFiles[index]); | |
| } | |
| } else if (type === 'processed') { | |
| document.querySelectorAll('#processed-file-list .file-item').forEach(item => { | |
| item.classList.remove('selected'); | |
| }); | |
| const selectedItem = document.getElementById(`processed-file-${index}`); | |
| if (selectedItem) { | |
| selectedItem.classList.add('selected'); | |
| selectedProcessedFile = index; | |
| } | |
| } | |
| } | |
| function updateWaveformForFile(fileInfo) { | |
| const canvas = document.getElementById('waveform'); | |
| if (canvas && fileInfo) { | |
| const ctx = canvas.getContext('2d'); | |
| canvas.width = canvas.offsetWidth; | |
| canvas.height = canvas.offsetHeight; | |
| ctx.clearRect(0, 0, canvas.width, canvas.height); | |
| ctx.strokeStyle = '#ff6b6b'; | |
| ctx.lineWidth = 2; | |
| ctx.beginPath(); | |
| const width = canvas.width; | |
| const height = canvas.height; | |
| const centerY = height / 2; | |
| // Generate waveform based on file | |
| let seed = 0; | |
| if (fileInfo && fileInfo.name) { | |
| for (let i = 0; i < fileInfo.name.length; i++) { | |
| seed += fileInfo.name.charCodeAt(i); | |
| } | |
| } | |
| for (let x = 0; x < width; x++) { | |
| const y = centerY + Math.sin(x * 0.02 + seed * 0.01) * 30 * Math.sin(x * 0.001); | |
| if (x === 0) ctx.moveTo(x, y); | |
| else ctx.lineTo(x, y); | |
| } | |
| ctx.stroke(); | |
| ctx.fillStyle = '#ff6b6b'; | |
| ctx.font = '12px monospace'; | |
| ctx.fillText(fileInfo ? `${fileInfo.name}` : 'No file selected', 10, 20); | |
| } | |
| } | |
| function saveSelectedFile() { | |
| if (selectedProcessedFile !== null && processedFiles[selectedProcessedFile]) { | |
| console.log('Exporting:', processedFiles[selectedProcessedFile].name); | |
| // Trigger the download button in Gradio | |
| const downloadBtn = document.querySelector('#download-btn'); | |
| if (downloadBtn) { | |
| downloadBtn.click(); | |
| } | |
| } else { | |
| alert('Please select a processed file to export'); | |
| } | |
| } | |
| function saveAllFiles() { | |
| if (processedFiles.length > 0) { | |
| console.log('Exporting all processed files'); | |
| // Trigger the download button in Gradio | |
| const downloadBtn = document.querySelector('#download-btn'); | |
| if (downloadBtn) { | |
| downloadBtn.click(); | |
| } | |
| } else { | |
| alert('No processed files to export'); | |
| } | |
| } | |
| // Progress animation | |
| function startProcessing() { | |
| const progressBar = document.querySelector('.progress-fill'); | |
| const progressText = document.querySelector('.progress-text'); | |
| if (progressBar && progressText) { | |
| progressText.textContent = 'Processing...'; | |
| progressBar.classList.add('active'); | |
| progressBar.style.width = '0%'; | |
| setTimeout(() => { | |
| progressBar.classList.remove('active'); | |
| }, 2000); | |
| } | |
| } | |
| </script> | |
| """ | |
| # ========== UI HELPER FUNCTIONS ========== | |
| def create_input_file_html(files): | |
| """Create HTML for input files panel""" | |
| if not files: | |
| file_items_html = """ | |
| <div style='text-align: center; padding: 40px 20px; color: rgba(255,255,255,0.5);'> | |
| <div style='font-size: 48px; opacity: 0.3; margin-bottom: 10px;'>π</div> | |
| <div>No files loaded</div> | |
| <div style='font-size: 12px; margin-top: 10px;'>Import audio files to begin</div> | |
| </div> | |
| """ | |
| file_data = [] | |
| else: | |
| file_items_html = "" | |
| file_data = [] | |
| for i, file in enumerate(files): | |
| file_name = Path(file.name).name | |
| file_size = os.path.getsize(file.name) / (1024 * 1024) # Convert to MB | |
| file_info = { | |
| 'name': file_name, | |
| 'path': file.name, | |
| 'size': f"{file_size:.1f} MB" | |
| } | |
| file_data.append(file_info) | |
| file_items_html += f""" | |
| <div class='file-item' id='input-file-{i}' onclick='selectFile("input", {i})'> | |
| <span style='font-size: 20px;'>π΅</span> | |
| <div style='flex: 1;'> | |
| <div style='font-weight: 500; color: #fff;'>{file_name}</div> | |
| <div style='font-size: 11px; opacity: 0.7;'>{file_info['size']}</div> | |
| </div> | |
| </div> | |
| """ | |
| safe_file_data = json.dumps(file_data).replace("</", "<\/") | |
| return f""" | |
| <div class='file-panel'> | |
| <div class='file-panel-header'> | |
| <span style='font-size: 20px; font-weight: 600;'>Input Files</span> | |
| <div style='margin-left: auto; display: flex; gap: 8px;'> | |
| <button id='import-btn-header' style=' | |
| background: linear-gradient(135deg, #ff6b6b 0%, #ee5a50 100%); | |
| border: none; | |
| color: #fff; | |
| padding: 8px 16px; | |
| border-radius: 6px; | |
| cursor: pointer; | |
| font-weight: 500; | |
| font-size: 13px; | |
| min-width: 90px; | |
| '>Import</button> | |
| </div> | |
| </div> | |
| <div class='file-list' id='input-file-list'> | |
| {file_items_html} | |
| </div> | |
| </div> | |
| <script> | |
| inputFiles = {safe_file_data}; | |
| // Re-attach the click handler after HTML update | |
| setTimeout(() => {{ | |
| const importBtn = document.getElementById('import-btn-header'); | |
| if (importBtn) {{ | |
| importBtn.onclick = (e) => {{ | |
| e.preventDefault(); | |
| e.stopPropagation(); | |
| const fileInput = document.querySelector('input[type="file"][accept*="audio"]'); | |
| if (fileInput) {{ | |
| fileInput.click(); | |
| }} else {{ | |
| // Fallback: try other selectors | |
| const altInput = document.querySelector('#file-upload input[type="file"]'); | |
| if (altInput) altInput.click(); | |
| }} | |
| }}; | |
| }} | |
| }}, 100); | |
| </script> | |
| """ | |
| def create_processed_file_html(files): | |
| """Create HTML for processed files panel""" | |
| if not files: | |
| return """ | |
| <div class='file-panel'> | |
| <div class='file-panel-header'> | |
| <span style='font-size: 20px; font-weight: 600;'>Processed Files</span> | |
| <div style='margin-left: auto; display: flex; gap: 8px;'> | |
| <button onclick='saveSelectedFile()' style=' | |
| background: linear-gradient(135deg, #ff6b6b 0%, #ee5a50 100%); | |
| border: none; | |
| color: #fff; | |
| padding: 8px 16px; | |
| border-radius: 6px; | |
| cursor: pointer; | |
| font-weight: 500; | |
| font-size: 13px; | |
| min-width: 90px; | |
| '>Export</button> | |
| <button onclick='saveAllFiles()' style=' | |
| background: linear-gradient(135deg, #ff6b6b 0%, #ee5a50 100%); | |
| border: none; | |
| color: #fff; | |
| padding: 8px 16px; | |
| border-radius: 6px; | |
| cursor: pointer; | |
| font-weight: 500; | |
| font-size: 13px; | |
| min-width: 90px; | |
| '>Export All</button> | |
| </div> | |
| </div> | |
| <div class='file-list' id='processed-file-list'> | |
| <div style='text-align: center; padding: 40px 20px; color: rgba(255,255,255,0.5);'> | |
| <div style='font-size: 48px; opacity: 0.3; margin-bottom: 10px;'>πΏ</div> | |
| <div>No processed files</div> | |
| <div style='font-size: 12px; margin-top: 10px;'>Process audio to see results here</div> | |
| </div> | |
| </div> | |
| </div> | |
| <script>processedFiles = [];</script> | |
| """ | |
| file_items_html = "" | |
| file_data = [] | |
| for i, file_info in enumerate(files): | |
| file_data.append(file_info) | |
| file_items_html += f""" | |
| <div class='file-item' id='processed-file-{i}' onclick='selectFile("processed", {i})'> | |
| <span style='font-size: 20px;'>πΏ</span> | |
| <div style='flex: 1;'> | |
| <div style='font-weight: 500; color: #fff;'>{file_info['name']}</div> | |
| <div style='font-size: 11px; opacity: 0.7;'>{file_info.get('size', 'Unknown size')}</div> | |
| </div> | |
| </div> | |
| """ | |
| safe_file_data = json.dumps(file_data).replace("</", "<\/") | |
| return f""" | |
| <div class='file-panel'> | |
| <div class='file-panel-header'> | |
| <span style='font-size: 20px; font-weight: 600;'>Processed Files</span> | |
| <div style='margin-left: auto; display: flex; gap: 8px;'> | |
| <button onclick='saveSelectedFile()' style=' | |
| background: linear-gradient(135deg, #ff6b6b 0%, #ee5a50 100%); | |
| border: none; | |
| color: #fff; | |
| padding: 8px 16px; | |
| border-radius: 6px; | |
| cursor: pointer; | |
| font-weight: 500; | |
| font-size: 13px; | |
| min-width: 90px; | |
| '>Export</button> | |
| <button onclick='saveAllFiles()' style=' | |
| background: linear-gradient(135deg, #ff6b6b 0%, #ee5a50 100%); | |
| border: none; | |
| color: #fff; | |
| padding: 8px 16px; | |
| border-radius: 6px; | |
| cursor: pointer; | |
| font-weight: 500; | |
| font-size: 13px; | |
| min-width: 90px; | |
| '>Export All</button> | |
| </div> | |
| </div> | |
| <div class='file-list' id='processed-file-list'> | |
| {file_items_html} | |
| </div> | |
| </div> | |
| <script>processedFiles = {safe_file_data};</script> | |
| """ | |
| def update_status(message, processing=False): | |
| """Update status bar with optional progress animation""" | |
| if processing: | |
| return f""" | |
| <div class='progress-bar'> | |
| <div class='progress-fill active'></div> | |
| <div class='progress-text'>{message}</div> | |
| </div> | |
| """ | |
| else: | |
| return f""" | |
| <div class='progress-bar'> | |
| <div class='progress-fill' style='width: 0%;'></div> | |
| <div class='progress-text'>{message}</div> | |
| </div> | |
| """ | |
| # ========== MAIN PROCESSING FUNCTIONS ========== | |
| def process_audio_file(file, profile, input_gain, output_gain, mix): | |
| """Process a single audio file""" | |
| if not file: | |
| return None, "No file selected" | |
| if not processor.load_model(profile): | |
| return None, f"Failed to load model: {profile}" | |
| try: | |
| # Load audio | |
| audio_data, sr = torchaudio.load(file.name) | |
| audio_numpy = audio_data.numpy() | |
| # Convert to mono if needed | |
| if audio_numpy.shape[0] > 1: | |
| audio_numpy = np.mean(audio_numpy, axis=0) | |
| else: | |
| audio_numpy = audio_numpy[0] | |
| # Process | |
| processed = processor.process_audio(audio_numpy, sr, input_gain, output_gain, mix) | |
| if processed is None: | |
| return None, "Processing failed" | |
| # Save to temporary file | |
| temp_path = tempfile.mktemp(suffix='.wav') | |
| torchaudio.save( | |
| temp_path, | |
| torch.from_numpy(processed).unsqueeze(0), | |
| SAMPLE_RATE | |
| ) | |
| return temp_path, "Processing complete!" | |
| except Exception as e: | |
| return None, f"Error: {str(e)}" | |
| def process_all_files(files, profile, input_gain, output_gain, mix): | |
| """Process all uploaded files""" | |
| if not files: | |
| return [], "No files to process" | |
| processed = [] | |
| for file in files: | |
| result, status = process_audio_file(file, profile, input_gain, output_gain, mix) | |
| if result: | |
| file_name = Path(file.name).stem + "_processed.wav" | |
| file_size = os.path.getsize(result) / (1024 * 1024) | |
| processed.append({ | |
| 'name': file_name, | |
| 'path': result, | |
| 'size': f"{file_size:.1f} MB" | |
| }) | |
| processor.processed_files = processed | |
| return processed, f"Processed {len(processed)} files" | |
| def download_processed_files(): | |
| """Create download link for processed files""" | |
| if not processor.processed_files: | |
| return None | |
| if len(processor.processed_files) == 1: | |
| return processor.processed_files[0]['path'] | |
| # Create zip file for multiple files | |
| temp_dir = tempfile.mkdtemp() | |
| zip_path = Path(temp_dir) / "processed_audio.zip" | |
| with zipfile.ZipFile(zip_path, 'w') as zipf: | |
| for file_info in processor.processed_files: | |
| zipf.write(file_info['path'], file_info['name']) | |
| return str(zip_path) | |
| # ========== GRADIO INTERFACE ========== | |
| with gr.Blocks(css=custom_css, theme=gr.themes.Base()) as demo: | |
| # Add JavaScript | |
| gr.HTML(file_management_js) | |
| # State management | |
| uploaded_files = gr.State([]) | |
| processed_files = gr.State([]) | |
| with gr.Row(): | |
| # Left panel - Input files | |
| with gr.Column(scale=1, min_width=250, elem_classes="file-column"): | |
| # Input files display | |
| input_file_panel = gr.HTML( | |
| value=create_input_file_html([]), | |
| elem_classes="file-panel-container" | |
| ) | |
| # File upload (hidden but accessible) | |
| file_upload = gr.File( | |
| label="Upload Audio Files", | |
| file_count="multiple", | |
| file_types=["audio"], | |
| visible=False, | |
| elem_id="file-upload" | |
| ) | |
| # Center - Main content | |
| with gr.Column(scale=2, elem_classes="main-content"): | |
| # Header | |
| gr.Markdown("# πΈ NAM Garden Audio Processor", elem_classes="main-title") | |
| # Profile/IR Selection | |
| with gr.Group(elem_classes="control-group"): | |
| model_choices = processor.get_model_choices() | |
| # Select first non-header model if available | |
| default_model = None | |
| for choice in model_choices: | |
| if choice and "βββ" not in choice: | |
| default_model = choice | |
| break | |
| model_dropdown = gr.Dropdown( | |
| choices=model_choices, | |
| value=default_model, | |
| label="Select Profile/IR", | |
| container=False | |
| ) | |
| # Cloud model loading section (separate group) | |
| with gr.Group(elem_classes="control-group"): | |
| with gr.Accordion("βοΈ Load Cloud Models", open=False): | |
| gr.Markdown(""" | |
| **Add your own NAM models from cloud storage:** | |
| - Direct `.nam` file URLs | |
| - Google Drive links | |
| - Dropbox links | |
| - GitHub raw files | |
| """) | |
| cloud_url = gr.Textbox( | |
| label="Model URL", | |
| placeholder="https://drive.google.com/file/d/... or https://dropbox.com/...", | |
| lines=1 | |
| ) | |
| with gr.Row(): | |
| load_cloud_btn = gr.Button("β¬ Load Model", size="sm") | |
| load_folder_btn = gr.Button("π Load Folder", size="sm") | |
| clear_cloud_btn = gr.Button("π Clear Cloud Models", size="sm") | |
| cloud_status = gr.Textbox( | |
| label="Cloud Status", | |
| value="Ready to load cloud models", | |
| interactive=False, | |
| lines=1 | |
| ) | |
| # Audio Controls | |
| with gr.Group(elem_classes="control-group"): | |
| with gr.Row(): | |
| input_gain = gr.Slider( | |
| minimum=-20, | |
| maximum=20, | |
| value=0, | |
| step=1, | |
| label="In (dB)", | |
| container=True | |
| ) | |
| output_gain = gr.Slider( | |
| minimum=-20, | |
| maximum=20, | |
| value=0, | |
| step=1, | |
| label="Out (dB)", | |
| container=True | |
| ) | |
| mix_slider = gr.Slider( | |
| minimum=0, | |
| maximum=100, | |
| value=100, | |
| step=1, | |
| label="Mix (%)", | |
| container=True | |
| ) | |
| # Waveform Display | |
| with gr.Group(elem_classes="control-group"): | |
| gr.HTML(""" | |
| <div style=' | |
| height: 200px; | |
| background: rgba(255, 255, 255, 0.08); | |
| backdrop-filter: blur(10px); | |
| border-radius: 8px; | |
| position: relative; | |
| overflow: hidden; | |
| '> | |
| <canvas id='waveform' style='width: 100%; height: 100%;'></canvas> | |
| </div> | |
| """) | |
| # Process Button | |
| with gr.Row(): | |
| process_btn = gr.Button( | |
| "β‘ Process Audio", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| # Status Bar | |
| status_display = gr.HTML( | |
| value=update_status("Ready"), | |
| elem_id="status-display" | |
| ) | |
| # Right panel - Processed files | |
| with gr.Column(scale=1, min_width=250, elem_classes="file-column"): | |
| processed_file_panel = gr.HTML( | |
| value=create_processed_file_html([]), | |
| elem_classes="file-panel-container" | |
| ) | |
| # Download button | |
| download_btn = gr.Button( | |
| "πΎ Download Processed", | |
| visible=False, | |
| elem_id="download-btn" | |
| ) | |
| download_file = gr.File( | |
| label="Download", | |
| visible=False | |
| ) | |
| # Event handlers | |
| def handle_upload(files): | |
| """Handle file upload""" | |
| if not files: | |
| return create_input_file_html([]), [], update_status("No files uploaded") | |
| return create_input_file_html(files), files, update_status(f"Loaded {len(files)} file(s)") | |
| def handle_process(files, profile, in_gain, out_gain, mix): | |
| """Handle processing""" | |
| if not files: | |
| return create_processed_file_html([]), update_status("No files to process"), gr.update(visible=False) | |
| # Skip if header is selected | |
| if profile and "βββ" in profile: | |
| return create_processed_file_html([]), update_status("Please select a valid model"), gr.update(visible=False) | |
| # Process files | |
| processed, status_msg = process_all_files(files, profile, in_gain, out_gain, mix) | |
| if processed: | |
| return ( | |
| create_processed_file_html(processed), | |
| update_status(status_msg), | |
| gr.update(visible=True) | |
| ) | |
| else: | |
| return ( | |
| create_processed_file_html([]), | |
| update_status("Processing failed"), | |
| gr.update(visible=False) | |
| ) | |
| def handle_download(): | |
| """Handle download""" | |
| file_path = download_processed_files() | |
| if file_path: | |
| return gr.update(value=file_path, visible=True) | |
| return gr.update(visible=False) | |
| def handle_load_cloud_model(url): | |
| """Handle loading a single cloud model""" | |
| if not url: | |
| return gr.update(), "Please enter a URL" | |
| success = processor.download_from_url(url) | |
| if success: | |
| return gr.update(choices=processor.get_model_choices()), f"β Loaded model from cloud" | |
| else: | |
| return gr.update(), f"β Failed to load model from URL" | |
| def handle_load_cloud_folder(url): | |
| """Handle loading a cloud folder""" | |
| if not url: | |
| return gr.update(), "Please enter a folder URL" | |
| count = processor.load_folder_from_url(url) | |
| if count > 0: | |
| return gr.update(choices=processor.get_model_choices()), f"β Loaded {count} models from folder" | |
| else: | |
| return gr.update(), f"β No models found in folder" | |
| def handle_clear_cloud(): | |
| """Handle clearing cloud models""" | |
| processor.clear_custom_models() | |
| return gr.update(choices=processor.get_model_choices()), "π Cleared all cloud models" | |
| # Connect events | |
| file_upload.change( | |
| fn=handle_upload, | |
| inputs=[file_upload], | |
| outputs=[input_file_panel, uploaded_files, status_display] | |
| ) | |
| process_btn.click( | |
| fn=handle_process, | |
| inputs=[uploaded_files, model_dropdown, input_gain, output_gain, mix_slider], | |
| outputs=[processed_file_panel, status_display, download_btn] | |
| ) | |
| download_btn.click( | |
| fn=handle_download, | |
| outputs=download_file | |
| ) | |
| # Cloud model loading events | |
| load_cloud_btn.click( | |
| fn=handle_load_cloud_model, | |
| inputs=[cloud_url], | |
| outputs=[model_dropdown, cloud_status] | |
| ) | |
| load_folder_btn.click( | |
| fn=handle_load_cloud_folder, | |
| inputs=[cloud_url], | |
| outputs=[model_dropdown, cloud_status] | |
| ) | |
| clear_cloud_btn.click( | |
| fn=handle_clear_cloud, | |
| outputs=[model_dropdown, cloud_status] | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False | |
| ) |