""" NAM Garden - Integrated Design Version Combines NAM processing with modern UI design """ import gradio as gr import json import os from pathlib import Path from datetime import datetime from typing import Dict, Any, Optional, Tuple, List import numpy as np import torch import torch.nn as nn import torchaudio import warnings import tempfile import zipfile import matplotlib.pyplot as plt import matplotlib import requests import re matplotlib.use('Agg') # Suppress warnings warnings.filterwarnings('ignore') # Check for GPU availability device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Using device: {device}") # Constants MODELS_DIR = Path("models") SAMPLE_RATE = 48000 CHUNK_SIZE = 512 # ========== NAM MODEL CLASSES ========== class Linear(nn.Module): """Simple linear model for NAM processing""" def __init__(self, receptive_field: int): super().__init__() self.receptive_field = receptive_field self.fc = nn.Linear(receptive_field, 1, bias=True) def forward(self, x): return self.fc(x) class LSTM(nn.Module): """LSTM model for NAM processing""" def __init__(self, input_size=1, hidden_size=32, num_layers=1): super().__init__() self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True) self.fc = nn.Linear(hidden_size, 1) def forward(self, x): if x.dim() == 2: x = x.unsqueeze(-1) lstm_out, _ = self.lstm(x) return self.fc(lstm_out) class SimpleWaveNet(nn.Module): """Simplified WaveNet for NAM processing""" def __init__(self, layers_config=None, head_scale=0.02): super().__init__() self.layers = nn.ModuleList() self.head_scale = head_scale if layers_config is None: layers_config = [ {"channels": 16, "kernel_size": 3, "dilation": 1}, {"channels": 16, "kernel_size": 3, "dilation": 2}, {"channels": 16, "kernel_size": 3, "dilation": 4}, {"channels": 16, "kernel_size": 3, "dilation": 8}, ] in_channels = 1 for config in layers_config: self.layers.append( nn.Conv1d( in_channels, config.get("channels", 16), kernel_size=config.get("kernel_size", 3), dilation=config.get("dilation", 1), padding=config.get("dilation", 1) * (config.get("kernel_size", 3) - 1) // 2 ) ) in_channels = config.get("channels", 16) self.output_layer = nn.Conv1d(in_channels, 1, kernel_size=1) def forward(self, x): if x.dim() == 2: x = x.unsqueeze(1) elif x.dim() == 1: x = x.unsqueeze(0).unsqueeze(1) for layer in self.layers: x = torch.tanh(layer(x)) return self.output_layer(x) * self.head_scale # ========== NAM PROCESSOR ========== class NAMProcessor: """Main processor for NAM models with cloud storage support""" def __init__(self): self.models = {} self.current_model = None self.current_model_name = None self.processed_files = [] self.custom_models = {} # Store user's cloud models self.loading_errors = [] self.load_available_models() def load_available_models(self): """Load metadata for available NAM models (lazy loading)""" self.loading_errors = [] if not MODELS_DIR.exists(): print(f"Creating models directory: {MODELS_DIR}") MODELS_DIR.mkdir(exist_ok=True) nam_files = list(MODELS_DIR.glob("*.nam")) print(f"Found {len(nam_files)} local NAM model files") if not nam_files: self.loading_errors.append("No .nam files found in 'models' directory.") return for nam_file in nam_files: try: # Only load metadata, not the full model data with open(nam_file, 'r', encoding='utf-8') as f: data = json.load(f) model_name = nam_file.stem # Store path and metadata only, load data on demand self.models[model_name] = { 'path': nam_file, 'data': None, # Will load on demand 'metadata': data.get('metadata', {}), 'source': 'local' } print(f"Indexed local model: {model_name}") except Exception as e: error_msg = f"Error indexing {nam_file.name}: {e}" print(error_msg) self.loading_errors.append(error_msg) def download_from_url(self, url: str, model_name: str = None) -> bool: """Download NAM model from URL (supports direct links, Google Drive, Dropbox)""" try: # Generate model name if not provided if not model_name: model_name = Path(url).stem or f"cloud_model_{len(self.custom_models)}" # Handle different cloud services final_url = url # Google Drive handling if 'drive.google.com' in url: # Extract file ID from Google Drive URL match = re.search(r'/d/([a-zA-Z0-9_-]+)', url) if match: file_id = match.group(1) final_url = f'https://drive.google.com/uc?id={file_id}&export=download' # Dropbox handling elif 'dropbox.com' in url: final_url = url.replace('?dl=0', '?dl=1') if '?dl=0' in url else url # Download the file response = requests.get(final_url, allow_redirects=True, timeout=30) response.raise_for_status() # Parse JSON content data = response.json() # Store in custom models self.custom_models[model_name] = { 'data': data, 'metadata': data.get('metadata', {}), 'source': 'cloud', 'url': url } # Also add to main models dict self.models[f"[Cloud] {model_name}"] = self.custom_models[model_name] print(f"Successfully loaded cloud model: {model_name}") return True except Exception as e: print(f"Error downloading model from {url}: {e}") return False def load_folder_from_url(self, folder_url: str) -> int: """Load multiple NAM models from a cloud folder""" loaded_count = 0 try: # For Google Drive folders if 'drive.google.com/drive/folders' in folder_url: # Note: Full folder downloading would require Google Drive API # For now, users need to provide individual file links return 0 # For GitHub repos/folders elif 'github.com' in folder_url: # Convert to raw GitHub URL if needed if '/tree/' in folder_url: folder_url = folder_url.replace('github.com', 'raw.githubusercontent.com') folder_url = folder_url.replace('/tree/', '/') # Try to fetch a manifest or list file manifest_url = folder_url.rstrip('/') + '/manifest.json' response = requests.get(manifest_url) if response.status_code == 200: manifest = response.json() for model_file in manifest.get('models', []): model_url = folder_url.rstrip('/') + '/' + model_file if self.download_from_url(model_url, Path(model_file).stem): loaded_count += 1 # For direct server folders with index else: response = requests.get(folder_url) if response.status_code == 200: # Look for .nam file links in the response nam_links = re.findall(r'href="([^"]*\.nam)"', response.text) for nam_file in nam_links: full_url = folder_url.rstrip('/') + '/' + nam_file if self.download_from_url(full_url, Path(nam_file).stem): loaded_count += 1 except Exception as e: print(f"Error loading folder from {folder_url}: {e}") return loaded_count def get_model_choices(self): """Get model choices for dropdown""" if not self.models: # Try loading again in case models weren't loaded self.load_available_models() choices = [] # Add local models first local_models = [name for name, info in self.models.items() if info.get('source') == 'local'] if local_models: choices.append("āāā šø Pre-loaded Models āāā") choices.extend(sorted(local_models)) # Add cloud models cloud_models = [name for name, info in self.models.items() if info.get('source') == 'cloud'] if cloud_models: if local_models: # Add separator only if there are local models choices.append("āāā āļø Cloud Models āāā") else: choices.append("āāā āļø Cloud Models āāā") choices.extend(sorted(cloud_models)) if self.loading_errors: choices.append("āāā ā ļø Loading Errors āāā") for err in self.loading_errors: # Truncate long error messages for display choices.append(err[:100] + '...' if len(err) > 100 else err) if not choices: return ["No models found - Add cloud models below"] return choices def clear_custom_models(self): """Clear all custom cloud models""" # Remove cloud models from main dict self.models = {k: v for k, v in self.models.items() if v.get('source') != 'cloud'} self.custom_models.clear() print("Cleared all cloud models") def load_model(self, model_name: str) -> bool: """Load a NAM model by name""" if not model_name or model_name not in self.models: return False if model_name == self.current_model_name: return True try: model_data = self.models[model_name]['data'] architecture = model_data.get('architecture', 'Linear') config = model_data.get('config', {}) # Create model based on architecture if architecture == 'Linear': model = Linear(config.get('receptive_field', 32)) elif architecture == 'LSTM': model = LSTM(hidden_size=config.get('hidden_size', 32)) elif architecture == 'WaveNet': model = SimpleWaveNet(config.get('layers', None), config.get('head_scale', 0.02)) else: return False # Load weights if available if 'weights' in model_data: try: weights = model_data['weights'] if isinstance(weights, list): weights = torch.tensor(weights, dtype=torch.float32) if architecture == 'Linear' and hasattr(model, 'fc'): weight_size = model.fc.weight.numel() bias_size = model.fc.bias.numel() if len(weights) >= weight_size + bias_size: model.fc.weight.data = weights[:weight_size].reshape(model.fc.weight.shape) model.fc.bias.data = weights[weight_size:weight_size + bias_size] except Exception as e: print(f"Warning: Could not load weights: {e}") self.current_model = model.to(device) self.current_model_name = model_name self.current_model.eval() print(f"Loaded model: {model_name}") return True except Exception as e: print(f"Error loading model {model_name}: {e}") return False def process_audio(self, audio_data, sr, input_gain, output_gain, mix): """Process audio through the current model""" if self.current_model is None: return None try: # Resample if needed if sr != SAMPLE_RATE: audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) resampler = torchaudio.transforms.Resample(sr, SAMPLE_RATE) audio_tensor = resampler(audio_tensor) audio_data = audio_tensor.squeeze(0).numpy() # Apply input gain if input_gain != 0: gain_linear = 10 ** (input_gain / 20) audio_data = audio_data * gain_linear audio_data = np.tanh(audio_data) # Store dry signal dry_signal = audio_data.copy() # Process through model audio_tensor = torch.from_numpy(audio_data).float().to(device) if audio_tensor.dim() == 1: audio_tensor = audio_tensor.unsqueeze(0) with torch.no_grad(): processed = self.current_model(audio_tensor) if processed.dim() == 3: processed = processed.squeeze(1) if processed.dim() == 2: processed = processed.squeeze(0) processed_audio = processed.cpu().numpy() # Apply mix (convert percentage to 0-1) mix_ratio = mix / 100.0 processed_audio = dry_signal * (1 - mix_ratio) + processed_audio * mix_ratio # Apply output gain if output_gain != 0: gain_linear = 10 ** (output_gain / 20) processed_audio = processed_audio * gain_linear # Clip to prevent distortion processed_audio = np.clip(processed_audio, -1.0, 1.0) return processed_audio except Exception as e: print(f"Processing error: {e}") return None # Initialize processor processor = NAMProcessor() print(f"\nā Loaded {len([m for m in processor.models.values() if m.get('source') == 'local'])} pre-loaded NAM models") print(f"Available models: {list(processor.models.keys())}\n") # ========== CUSTOM CSS ========== custom_css = """ /* Prevent body scrolling */ body { overflow: hidden !important; height: 100vh !important; margin: 0 !important; padding: 0 !important; } /* Main app container */ #component-0, .gradio-container > div:first-child { height: 100vh !important; max-height: 100vh !important; overflow: hidden !important; } /* Clean modern design - Dark teal/coral theme */ .gradio-container { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; background: linear-gradient(135deg, #0f2027 0%, #203a43 50%, #2c5364 100%); color: #fff; height: 100vh; max-height: 100vh; overflow: hidden; padding: 0 !important; } /* Make columns full height */ .gradio-container .gr-column { height: 100%; display: flex; flex-direction: column; } /* Main row container */ .gradio-container .gr-row { height: 100vh; max-height: 100vh; overflow: hidden; } /* File columns specific styling */ .file-column { height: 100vh !important; max-height: 100vh !important; overflow: hidden !important; display: flex !important; flex-direction: column !important; } /* File panel container fix */ .file-panel-container { height: 100vh !important; min-height: 100vh !important; max-height: 100vh !important; overflow: hidden; display: flex; flex-direction: column; } .file-panel-container > div { height: 100% !important; min-height: 100% !important; display: flex; flex-direction: column; } /* File panel styling */ .file-panel { background: rgba(255, 255, 255, 0.1) !important; backdrop-filter: blur(10px) !important; border-right: 1px solid rgba(255, 255, 255, 0.2); height: 100vh !important; min-height: 100vh !important; overflow: hidden; display: flex; flex-direction: column; } .file-panel:last-child { border-right: none; border-left: 1px solid rgba(255, 255, 255, 0.2); } /* File panel header */ .file-panel-header { background: rgba(255, 255, 255, 0.15); padding: 15px; border-bottom: 2px solid #ff6b6b; display: flex; align-items: center; gap: 10px; } /* File list styling */ .file-list { flex: 1 1 auto; min-height: 0; overflow-y: auto; padding: 10px; height: calc(100vh - 80px); } .file-item { background: rgba(255, 255, 255, 0.1); margin: 5px 0; padding: 12px; border-radius: 6px; cursor: pointer; transition: all 0.2s; display: flex; align-items: center; gap: 10px; backdrop-filter: blur(5px); } .file-item:hover { background: rgba(255, 107, 107, 0.3); transform: translateX(5px); } .file-item.selected { background: rgba(255, 107, 107, 0.4); border-left: 4px solid #ff6b6b; } /* Main content area */ .main-content { padding: 30px; background: rgba(255, 255, 255, 0.05); backdrop-filter: blur(20px); height: 100vh; overflow-y: auto; overflow-x: hidden; } /* Control styling */ .control-group { background: rgba(255, 255, 255, 0.1); backdrop-filter: blur(10px); padding: 20px; border-radius: 10px; margin: 20px 0; border: 1px solid rgba(255, 255, 255, 0.2); } /* Button styling - Coral accent */ .gr-button { background: linear-gradient(135deg, #ff6b6b 0%, #ee5a50 100%) !important; border: none !important; color: #fff !important; padding: 10px 20px !important; border-radius: 6px !important; font-weight: 600 !important; transition: all 0.2s !important; } .gr-button:hover { background: linear-gradient(135deg, #ee5a50 0%, #e04545 100%) !important; transform: translateY(-2px) !important; box-shadow: 0 5px 15px rgba(255, 107, 107, 0.4) !important; } /* Slider styling */ .gr-slider input { accent-color: #ff6b6b !important; } /* Center slider values */ .gr-slider .gr-slider-value, .gr-slider span[data-testid="number"], .gr-slider input[type="number"] { text-align: center !important; display: block !important; width: 100% !important; } /* Progress bar */ .progress-bar { position: relative; background: rgba(255, 255, 255, 0.1); backdrop-filter: blur(10px); border: 1px solid rgba(255, 107, 107, 0.5); border-radius: 6px; padding: 0; margin-top: 20px; height: 44px; overflow: hidden; } .progress-fill { position: absolute; top: 0; left: 0; height: 100%; background: linear-gradient(90deg, #ff6b6b, #ee5a50); width: 0%; transition: width 0.3s ease; } .progress-fill.active { animation: fillProgress 2s ease-out forwards; } @keyframes fillProgress { from { width: 0%; } to { width: 100%; } } .progress-text { position: relative; z-index: 1; color: #fff; font-size: 14px; text-align: center; line-height: 44px; text-shadow: 0 0 4px rgba(0, 0, 0, 0.3); font-weight: 500; } /* Dropdown styling */ .gr-dropdown { background: rgba(255, 255, 255, 0.1) !important; border: 1px solid rgba(255, 107, 107, 0.5) !important; color: #fff !important; user-select: none !important; } """ # ========== FILE MANAGEMENT JAVASCRIPT ========== file_management_js = """ """ # ========== UI HELPER FUNCTIONS ========== def create_input_file_html(files): """Create HTML for input files panel""" if not files: file_items_html = """