design / app.py
harrytarlton
Update app.py
a690306
"""
NAM Garden - Integrated Design Version
Combines NAM processing with modern UI design
"""
import gradio as gr
import json
import os
from pathlib import Path
from datetime import datetime
from typing import Dict, Any, Optional, Tuple, List
import numpy as np
import torch
import torch.nn as nn
import torchaudio
import warnings
import tempfile
import zipfile
import matplotlib.pyplot as plt
import matplotlib
import requests
import re
matplotlib.use('Agg')
# Suppress warnings
warnings.filterwarnings('ignore')
# Check for GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# Constants
MODELS_DIR = Path("models")
SAMPLE_RATE = 48000
CHUNK_SIZE = 512
# ========== NAM MODEL CLASSES ==========
class Linear(nn.Module):
"""Simple linear model for NAM processing"""
def __init__(self, receptive_field: int):
super().__init__()
self.receptive_field = receptive_field
self.fc = nn.Linear(receptive_field, 1, bias=True)
def forward(self, x):
return self.fc(x)
class LSTM(nn.Module):
"""LSTM model for NAM processing"""
def __init__(self, input_size=1, hidden_size=32, num_layers=1):
super().__init__()
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_size, 1)
def forward(self, x):
if x.dim() == 2:
x = x.unsqueeze(-1)
lstm_out, _ = self.lstm(x)
return self.fc(lstm_out)
class SimpleWaveNet(nn.Module):
"""Simplified WaveNet for NAM processing"""
def __init__(self, layers_config=None, head_scale=0.02):
super().__init__()
self.layers = nn.ModuleList()
self.head_scale = head_scale
if layers_config is None:
layers_config = [
{"channels": 16, "kernel_size": 3, "dilation": 1},
{"channels": 16, "kernel_size": 3, "dilation": 2},
{"channels": 16, "kernel_size": 3, "dilation": 4},
{"channels": 16, "kernel_size": 3, "dilation": 8},
]
in_channels = 1
for config in layers_config:
self.layers.append(
nn.Conv1d(
in_channels,
config.get("channels", 16),
kernel_size=config.get("kernel_size", 3),
dilation=config.get("dilation", 1),
padding=config.get("dilation", 1) * (config.get("kernel_size", 3) - 1) // 2
)
)
in_channels = config.get("channels", 16)
self.output_layer = nn.Conv1d(in_channels, 1, kernel_size=1)
def forward(self, x):
if x.dim() == 2:
x = x.unsqueeze(1)
elif x.dim() == 1:
x = x.unsqueeze(0).unsqueeze(1)
for layer in self.layers:
x = torch.tanh(layer(x))
return self.output_layer(x) * self.head_scale
# ========== NAM PROCESSOR ==========
class NAMProcessor:
"""Main processor for NAM models with cloud storage support"""
def __init__(self):
self.models = {}
self.current_model = None
self.current_model_name = None
self.processed_files = []
self.custom_models = {} # Store user's cloud models
self.loading_errors = []
self.load_available_models()
def load_available_models(self):
"""Load metadata for available NAM models (lazy loading)"""
self.loading_errors = []
if not MODELS_DIR.exists():
print(f"Creating models directory: {MODELS_DIR}")
MODELS_DIR.mkdir(exist_ok=True)
nam_files = list(MODELS_DIR.glob("*.nam"))
print(f"Found {len(nam_files)} local NAM model files")
if not nam_files:
self.loading_errors.append("No .nam files found in 'models' directory.")
return
for nam_file in nam_files:
try:
# Only load metadata, not the full model data
with open(nam_file, 'r', encoding='utf-8') as f:
data = json.load(f)
model_name = nam_file.stem
# Store path and metadata only, load data on demand
self.models[model_name] = {
'path': nam_file,
'data': None, # Will load on demand
'metadata': data.get('metadata', {}),
'source': 'local'
}
print(f"Indexed local model: {model_name}")
except Exception as e:
error_msg = f"Error indexing {nam_file.name}: {e}"
print(error_msg)
self.loading_errors.append(error_msg)
def download_from_url(self, url: str, model_name: str = None) -> bool:
"""Download NAM model from URL (supports direct links, Google Drive, Dropbox)"""
try:
# Generate model name if not provided
if not model_name:
model_name = Path(url).stem or f"cloud_model_{len(self.custom_models)}"
# Handle different cloud services
final_url = url
# Google Drive handling
if 'drive.google.com' in url:
# Extract file ID from Google Drive URL
match = re.search(r'/d/([a-zA-Z0-9_-]+)', url)
if match:
file_id = match.group(1)
final_url = f'https://drive.google.com/uc?id={file_id}&export=download'
# Dropbox handling
elif 'dropbox.com' in url:
final_url = url.replace('?dl=0', '?dl=1') if '?dl=0' in url else url
# Download the file
response = requests.get(final_url, allow_redirects=True, timeout=30)
response.raise_for_status()
# Parse JSON content
data = response.json()
# Store in custom models
self.custom_models[model_name] = {
'data': data,
'metadata': data.get('metadata', {}),
'source': 'cloud',
'url': url
}
# Also add to main models dict
self.models[f"[Cloud] {model_name}"] = self.custom_models[model_name]
print(f"Successfully loaded cloud model: {model_name}")
return True
except Exception as e:
print(f"Error downloading model from {url}: {e}")
return False
def load_folder_from_url(self, folder_url: str) -> int:
"""Load multiple NAM models from a cloud folder"""
loaded_count = 0
try:
# For Google Drive folders
if 'drive.google.com/drive/folders' in folder_url:
# Note: Full folder downloading would require Google Drive API
# For now, users need to provide individual file links
return 0
# For GitHub repos/folders
elif 'github.com' in folder_url:
# Convert to raw GitHub URL if needed
if '/tree/' in folder_url:
folder_url = folder_url.replace('github.com', 'raw.githubusercontent.com')
folder_url = folder_url.replace('/tree/', '/')
# Try to fetch a manifest or list file
manifest_url = folder_url.rstrip('/') + '/manifest.json'
response = requests.get(manifest_url)
if response.status_code == 200:
manifest = response.json()
for model_file in manifest.get('models', []):
model_url = folder_url.rstrip('/') + '/' + model_file
if self.download_from_url(model_url, Path(model_file).stem):
loaded_count += 1
# For direct server folders with index
else:
response = requests.get(folder_url)
if response.status_code == 200:
# Look for .nam file links in the response
nam_links = re.findall(r'href="([^"]*\.nam)"', response.text)
for nam_file in nam_links:
full_url = folder_url.rstrip('/') + '/' + nam_file
if self.download_from_url(full_url, Path(nam_file).stem):
loaded_count += 1
except Exception as e:
print(f"Error loading folder from {folder_url}: {e}")
return loaded_count
def get_model_choices(self):
"""Get model choices for dropdown"""
if not self.models:
# Try loading again in case models weren't loaded
self.load_available_models()
choices = []
# Add local models first
local_models = [name for name, info in self.models.items() if info.get('source') == 'local']
if local_models:
choices.append("━━━ 🎸 Pre-loaded Models ━━━")
choices.extend(sorted(local_models))
# Add cloud models
cloud_models = [name for name, info in self.models.items() if info.get('source') == 'cloud']
if cloud_models:
if local_models: # Add separator only if there are local models
choices.append("━━━ ☁️ Cloud Models ━━━")
else:
choices.append("━━━ ☁️ Cloud Models ━━━")
choices.extend(sorted(cloud_models))
if self.loading_errors:
choices.append("━━━ ⚠️ Loading Errors ━━━")
for err in self.loading_errors:
# Truncate long error messages for display
choices.append(err[:100] + '...' if len(err) > 100 else err)
if not choices:
return ["No models found - Add cloud models below"]
return choices
def clear_custom_models(self):
"""Clear all custom cloud models"""
# Remove cloud models from main dict
self.models = {k: v for k, v in self.models.items() if v.get('source') != 'cloud'}
self.custom_models.clear()
print("Cleared all cloud models")
def load_model(self, model_name: str) -> bool:
"""Load a NAM model by name"""
if not model_name or model_name not in self.models:
return False
if model_name == self.current_model_name:
return True
try:
model_data = self.models[model_name]['data']
architecture = model_data.get('architecture', 'Linear')
config = model_data.get('config', {})
# Create model based on architecture
if architecture == 'Linear':
model = Linear(config.get('receptive_field', 32))
elif architecture == 'LSTM':
model = LSTM(hidden_size=config.get('hidden_size', 32))
elif architecture == 'WaveNet':
model = SimpleWaveNet(config.get('layers', None), config.get('head_scale', 0.02))
else:
return False
# Load weights if available
if 'weights' in model_data:
try:
weights = model_data['weights']
if isinstance(weights, list):
weights = torch.tensor(weights, dtype=torch.float32)
if architecture == 'Linear' and hasattr(model, 'fc'):
weight_size = model.fc.weight.numel()
bias_size = model.fc.bias.numel()
if len(weights) >= weight_size + bias_size:
model.fc.weight.data = weights[:weight_size].reshape(model.fc.weight.shape)
model.fc.bias.data = weights[weight_size:weight_size + bias_size]
except Exception as e:
print(f"Warning: Could not load weights: {e}")
self.current_model = model.to(device)
self.current_model_name = model_name
self.current_model.eval()
print(f"Loaded model: {model_name}")
return True
except Exception as e:
print(f"Error loading model {model_name}: {e}")
return False
def process_audio(self, audio_data, sr, input_gain, output_gain, mix):
"""Process audio through the current model"""
if self.current_model is None:
return None
try:
# Resample if needed
if sr != SAMPLE_RATE:
audio_tensor = torch.from_numpy(audio_data).unsqueeze(0)
resampler = torchaudio.transforms.Resample(sr, SAMPLE_RATE)
audio_tensor = resampler(audio_tensor)
audio_data = audio_tensor.squeeze(0).numpy()
# Apply input gain
if input_gain != 0:
gain_linear = 10 ** (input_gain / 20)
audio_data = audio_data * gain_linear
audio_data = np.tanh(audio_data)
# Store dry signal
dry_signal = audio_data.copy()
# Process through model
audio_tensor = torch.from_numpy(audio_data).float().to(device)
if audio_tensor.dim() == 1:
audio_tensor = audio_tensor.unsqueeze(0)
with torch.no_grad():
processed = self.current_model(audio_tensor)
if processed.dim() == 3:
processed = processed.squeeze(1)
if processed.dim() == 2:
processed = processed.squeeze(0)
processed_audio = processed.cpu().numpy()
# Apply mix (convert percentage to 0-1)
mix_ratio = mix / 100.0
processed_audio = dry_signal * (1 - mix_ratio) + processed_audio * mix_ratio
# Apply output gain
if output_gain != 0:
gain_linear = 10 ** (output_gain / 20)
processed_audio = processed_audio * gain_linear
# Clip to prevent distortion
processed_audio = np.clip(processed_audio, -1.0, 1.0)
return processed_audio
except Exception as e:
print(f"Processing error: {e}")
return None
# Initialize processor
processor = NAMProcessor()
print(f"\nβœ… Loaded {len([m for m in processor.models.values() if m.get('source') == 'local'])} pre-loaded NAM models")
print(f"Available models: {list(processor.models.keys())}\n")
# ========== CUSTOM CSS ==========
custom_css = """
/* Prevent body scrolling */
body {
overflow: hidden !important;
height: 100vh !important;
margin: 0 !important;
padding: 0 !important;
}
/* Main app container */
#component-0,
.gradio-container > div:first-child {
height: 100vh !important;
max-height: 100vh !important;
overflow: hidden !important;
}
/* Clean modern design - Dark teal/coral theme */
.gradio-container {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
background: linear-gradient(135deg, #0f2027 0%, #203a43 50%, #2c5364 100%);
color: #fff;
height: 100vh;
max-height: 100vh;
overflow: hidden;
padding: 0 !important;
}
/* Make columns full height */
.gradio-container .gr-column {
height: 100%;
display: flex;
flex-direction: column;
}
/* Main row container */
.gradio-container .gr-row {
height: 100vh;
max-height: 100vh;
overflow: hidden;
}
/* File columns specific styling */
.file-column {
height: 100vh !important;
max-height: 100vh !important;
overflow: hidden !important;
display: flex !important;
flex-direction: column !important;
}
/* File panel container fix */
.file-panel-container {
height: 100vh !important;
min-height: 100vh !important;
max-height: 100vh !important;
overflow: hidden;
display: flex;
flex-direction: column;
}
.file-panel-container > div {
height: 100% !important;
min-height: 100% !important;
display: flex;
flex-direction: column;
}
/* File panel styling */
.file-panel {
background: rgba(255, 255, 255, 0.1) !important;
backdrop-filter: blur(10px) !important;
border-right: 1px solid rgba(255, 255, 255, 0.2);
height: 100vh !important;
min-height: 100vh !important;
overflow: hidden;
display: flex;
flex-direction: column;
}
.file-panel:last-child {
border-right: none;
border-left: 1px solid rgba(255, 255, 255, 0.2);
}
/* File panel header */
.file-panel-header {
background: rgba(255, 255, 255, 0.15);
padding: 15px;
border-bottom: 2px solid #ff6b6b;
display: flex;
align-items: center;
gap: 10px;
}
/* File list styling */
.file-list {
flex: 1 1 auto;
min-height: 0;
overflow-y: auto;
padding: 10px;
height: calc(100vh - 80px);
}
.file-item {
background: rgba(255, 255, 255, 0.1);
margin: 5px 0;
padding: 12px;
border-radius: 6px;
cursor: pointer;
transition: all 0.2s;
display: flex;
align-items: center;
gap: 10px;
backdrop-filter: blur(5px);
}
.file-item:hover {
background: rgba(255, 107, 107, 0.3);
transform: translateX(5px);
}
.file-item.selected {
background: rgba(255, 107, 107, 0.4);
border-left: 4px solid #ff6b6b;
}
/* Main content area */
.main-content {
padding: 30px;
background: rgba(255, 255, 255, 0.05);
backdrop-filter: blur(20px);
height: 100vh;
overflow-y: auto;
overflow-x: hidden;
}
/* Control styling */
.control-group {
background: rgba(255, 255, 255, 0.1);
backdrop-filter: blur(10px);
padding: 20px;
border-radius: 10px;
margin: 20px 0;
border: 1px solid rgba(255, 255, 255, 0.2);
}
/* Button styling - Coral accent */
.gr-button {
background: linear-gradient(135deg, #ff6b6b 0%, #ee5a50 100%) !important;
border: none !important;
color: #fff !important;
padding: 10px 20px !important;
border-radius: 6px !important;
font-weight: 600 !important;
transition: all 0.2s !important;
}
.gr-button:hover {
background: linear-gradient(135deg, #ee5a50 0%, #e04545 100%) !important;
transform: translateY(-2px) !important;
box-shadow: 0 5px 15px rgba(255, 107, 107, 0.4) !important;
}
/* Slider styling */
.gr-slider input {
accent-color: #ff6b6b !important;
}
/* Center slider values */
.gr-slider .gr-slider-value,
.gr-slider span[data-testid="number"],
.gr-slider input[type="number"] {
text-align: center !important;
display: block !important;
width: 100% !important;
}
/* Progress bar */
.progress-bar {
position: relative;
background: rgba(255, 255, 255, 0.1);
backdrop-filter: blur(10px);
border: 1px solid rgba(255, 107, 107, 0.5);
border-radius: 6px;
padding: 0;
margin-top: 20px;
height: 44px;
overflow: hidden;
}
.progress-fill {
position: absolute;
top: 0;
left: 0;
height: 100%;
background: linear-gradient(90deg, #ff6b6b, #ee5a50);
width: 0%;
transition: width 0.3s ease;
}
.progress-fill.active {
animation: fillProgress 2s ease-out forwards;
}
@keyframes fillProgress {
from { width: 0%; }
to { width: 100%; }
}
.progress-text {
position: relative;
z-index: 1;
color: #fff;
font-size: 14px;
text-align: center;
line-height: 44px;
text-shadow: 0 0 4px rgba(0, 0, 0, 0.3);
font-weight: 500;
}
/* Dropdown styling */
.gr-dropdown {
background: rgba(255, 255, 255, 0.1) !important;
border: 1px solid rgba(255, 107, 107, 0.5) !important;
color: #fff !important;
user-select: none !important;
}
"""
# ========== FILE MANAGEMENT JAVASCRIPT ==========
file_management_js = """
<script>
let selectedInputFile = null;
let selectedProcessedFile = null;
let inputFiles = [];
let processedFiles = [];
// Wait for DOM to be ready
window.addEventListener('DOMContentLoaded', (event) => {
// Setup Import button click handler with multiple attempts
let attempts = 0;
const setupImportButton = () => {
const importBtn = document.getElementById('import-btn-header');
const fileInput = document.querySelector('input[type="file"][accept*="audio"]');
if (importBtn && fileInput) {
console.log('Setting up import button handler');
importBtn.addEventListener('click', (e) => {
e.preventDefault();
e.stopPropagation();
fileInput.click();
});
return true;
}
return false;
};
// Try immediately and then with delays
if (!setupImportButton() && attempts < 10) {
const interval = setInterval(() => {
attempts++;
if (setupImportButton() || attempts >= 10) {
clearInterval(interval);
}
}, 500);
}
});
function selectFile(type, index) {
if (type === 'input') {
document.querySelectorAll('#input-file-list .file-item').forEach(item => {
item.classList.remove('selected');
});
const selectedItem = document.getElementById(`input-file-${index}`);
if (selectedItem) {
selectedItem.classList.add('selected');
selectedInputFile = index;
// Update waveform
updateWaveformForFile(inputFiles[index]);
}
} else if (type === 'processed') {
document.querySelectorAll('#processed-file-list .file-item').forEach(item => {
item.classList.remove('selected');
});
const selectedItem = document.getElementById(`processed-file-${index}`);
if (selectedItem) {
selectedItem.classList.add('selected');
selectedProcessedFile = index;
}
}
}
function updateWaveformForFile(fileInfo) {
const canvas = document.getElementById('waveform');
if (canvas && fileInfo) {
const ctx = canvas.getContext('2d');
canvas.width = canvas.offsetWidth;
canvas.height = canvas.offsetHeight;
ctx.clearRect(0, 0, canvas.width, canvas.height);
ctx.strokeStyle = '#ff6b6b';
ctx.lineWidth = 2;
ctx.beginPath();
const width = canvas.width;
const height = canvas.height;
const centerY = height / 2;
// Generate waveform based on file
let seed = 0;
if (fileInfo && fileInfo.name) {
for (let i = 0; i < fileInfo.name.length; i++) {
seed += fileInfo.name.charCodeAt(i);
}
}
for (let x = 0; x < width; x++) {
const y = centerY + Math.sin(x * 0.02 + seed * 0.01) * 30 * Math.sin(x * 0.001);
if (x === 0) ctx.moveTo(x, y);
else ctx.lineTo(x, y);
}
ctx.stroke();
ctx.fillStyle = '#ff6b6b';
ctx.font = '12px monospace';
ctx.fillText(fileInfo ? `${fileInfo.name}` : 'No file selected', 10, 20);
}
}
function saveSelectedFile() {
if (selectedProcessedFile !== null && processedFiles[selectedProcessedFile]) {
console.log('Exporting:', processedFiles[selectedProcessedFile].name);
// Trigger the download button in Gradio
const downloadBtn = document.querySelector('#download-btn');
if (downloadBtn) {
downloadBtn.click();
}
} else {
alert('Please select a processed file to export');
}
}
function saveAllFiles() {
if (processedFiles.length > 0) {
console.log('Exporting all processed files');
// Trigger the download button in Gradio
const downloadBtn = document.querySelector('#download-btn');
if (downloadBtn) {
downloadBtn.click();
}
} else {
alert('No processed files to export');
}
}
// Progress animation
function startProcessing() {
const progressBar = document.querySelector('.progress-fill');
const progressText = document.querySelector('.progress-text');
if (progressBar && progressText) {
progressText.textContent = 'Processing...';
progressBar.classList.add('active');
progressBar.style.width = '0%';
setTimeout(() => {
progressBar.classList.remove('active');
}, 2000);
}
}
</script>
"""
# ========== UI HELPER FUNCTIONS ==========
def create_input_file_html(files):
"""Create HTML for input files panel"""
if not files:
file_items_html = """
<div style='text-align: center; padding: 40px 20px; color: rgba(255,255,255,0.5);'>
<div style='font-size: 48px; opacity: 0.3; margin-bottom: 10px;'>πŸ“</div>
<div>No files loaded</div>
<div style='font-size: 12px; margin-top: 10px;'>Import audio files to begin</div>
</div>
"""
file_data = []
else:
file_items_html = ""
file_data = []
for i, file in enumerate(files):
file_name = Path(file.name).name
file_size = os.path.getsize(file.name) / (1024 * 1024) # Convert to MB
file_info = {
'name': file_name,
'path': file.name,
'size': f"{file_size:.1f} MB"
}
file_data.append(file_info)
file_items_html += f"""
<div class='file-item' id='input-file-{i}' onclick='selectFile("input", {i})'>
<span style='font-size: 20px;'>🎡</span>
<div style='flex: 1;'>
<div style='font-weight: 500; color: #fff;'>{file_name}</div>
<div style='font-size: 11px; opacity: 0.7;'>{file_info['size']}</div>
</div>
</div>
"""
safe_file_data = json.dumps(file_data).replace("</", "<\/")
return f"""
<div class='file-panel'>
<div class='file-panel-header'>
<span style='font-size: 20px; font-weight: 600;'>Input Files</span>
<div style='margin-left: auto; display: flex; gap: 8px;'>
<button id='import-btn-header' style='
background: linear-gradient(135deg, #ff6b6b 0%, #ee5a50 100%);
border: none;
color: #fff;
padding: 8px 16px;
border-radius: 6px;
cursor: pointer;
font-weight: 500;
font-size: 13px;
min-width: 90px;
'>Import</button>
</div>
</div>
<div class='file-list' id='input-file-list'>
{file_items_html}
</div>
</div>
<script>
inputFiles = {safe_file_data};
// Re-attach the click handler after HTML update
setTimeout(() => {{
const importBtn = document.getElementById('import-btn-header');
if (importBtn) {{
importBtn.onclick = (e) => {{
e.preventDefault();
e.stopPropagation();
const fileInput = document.querySelector('input[type="file"][accept*="audio"]');
if (fileInput) {{
fileInput.click();
}} else {{
// Fallback: try other selectors
const altInput = document.querySelector('#file-upload input[type="file"]');
if (altInput) altInput.click();
}}
}};
}}
}}, 100);
</script>
"""
def create_processed_file_html(files):
"""Create HTML for processed files panel"""
if not files:
return """
<div class='file-panel'>
<div class='file-panel-header'>
<span style='font-size: 20px; font-weight: 600;'>Processed Files</span>
<div style='margin-left: auto; display: flex; gap: 8px;'>
<button onclick='saveSelectedFile()' style='
background: linear-gradient(135deg, #ff6b6b 0%, #ee5a50 100%);
border: none;
color: #fff;
padding: 8px 16px;
border-radius: 6px;
cursor: pointer;
font-weight: 500;
font-size: 13px;
min-width: 90px;
'>Export</button>
<button onclick='saveAllFiles()' style='
background: linear-gradient(135deg, #ff6b6b 0%, #ee5a50 100%);
border: none;
color: #fff;
padding: 8px 16px;
border-radius: 6px;
cursor: pointer;
font-weight: 500;
font-size: 13px;
min-width: 90px;
'>Export All</button>
</div>
</div>
<div class='file-list' id='processed-file-list'>
<div style='text-align: center; padding: 40px 20px; color: rgba(255,255,255,0.5);'>
<div style='font-size: 48px; opacity: 0.3; margin-bottom: 10px;'>πŸ’Ώ</div>
<div>No processed files</div>
<div style='font-size: 12px; margin-top: 10px;'>Process audio to see results here</div>
</div>
</div>
</div>
<script>processedFiles = [];</script>
"""
file_items_html = ""
file_data = []
for i, file_info in enumerate(files):
file_data.append(file_info)
file_items_html += f"""
<div class='file-item' id='processed-file-{i}' onclick='selectFile("processed", {i})'>
<span style='font-size: 20px;'>πŸ’Ώ</span>
<div style='flex: 1;'>
<div style='font-weight: 500; color: #fff;'>{file_info['name']}</div>
<div style='font-size: 11px; opacity: 0.7;'>{file_info.get('size', 'Unknown size')}</div>
</div>
</div>
"""
safe_file_data = json.dumps(file_data).replace("</", "<\/")
return f"""
<div class='file-panel'>
<div class='file-panel-header'>
<span style='font-size: 20px; font-weight: 600;'>Processed Files</span>
<div style='margin-left: auto; display: flex; gap: 8px;'>
<button onclick='saveSelectedFile()' style='
background: linear-gradient(135deg, #ff6b6b 0%, #ee5a50 100%);
border: none;
color: #fff;
padding: 8px 16px;
border-radius: 6px;
cursor: pointer;
font-weight: 500;
font-size: 13px;
min-width: 90px;
'>Export</button>
<button onclick='saveAllFiles()' style='
background: linear-gradient(135deg, #ff6b6b 0%, #ee5a50 100%);
border: none;
color: #fff;
padding: 8px 16px;
border-radius: 6px;
cursor: pointer;
font-weight: 500;
font-size: 13px;
min-width: 90px;
'>Export All</button>
</div>
</div>
<div class='file-list' id='processed-file-list'>
{file_items_html}
</div>
</div>
<script>processedFiles = {safe_file_data};</script>
"""
def update_status(message, processing=False):
"""Update status bar with optional progress animation"""
if processing:
return f"""
<div class='progress-bar'>
<div class='progress-fill active'></div>
<div class='progress-text'>{message}</div>
</div>
"""
else:
return f"""
<div class='progress-bar'>
<div class='progress-fill' style='width: 0%;'></div>
<div class='progress-text'>{message}</div>
</div>
"""
# ========== MAIN PROCESSING FUNCTIONS ==========
def process_audio_file(file, profile, input_gain, output_gain, mix):
"""Process a single audio file"""
if not file:
return None, "No file selected"
if not processor.load_model(profile):
return None, f"Failed to load model: {profile}"
try:
# Load audio
audio_data, sr = torchaudio.load(file.name)
audio_numpy = audio_data.numpy()
# Convert to mono if needed
if audio_numpy.shape[0] > 1:
audio_numpy = np.mean(audio_numpy, axis=0)
else:
audio_numpy = audio_numpy[0]
# Process
processed = processor.process_audio(audio_numpy, sr, input_gain, output_gain, mix)
if processed is None:
return None, "Processing failed"
# Save to temporary file
temp_path = tempfile.mktemp(suffix='.wav')
torchaudio.save(
temp_path,
torch.from_numpy(processed).unsqueeze(0),
SAMPLE_RATE
)
return temp_path, "Processing complete!"
except Exception as e:
return None, f"Error: {str(e)}"
def process_all_files(files, profile, input_gain, output_gain, mix):
"""Process all uploaded files"""
if not files:
return [], "No files to process"
processed = []
for file in files:
result, status = process_audio_file(file, profile, input_gain, output_gain, mix)
if result:
file_name = Path(file.name).stem + "_processed.wav"
file_size = os.path.getsize(result) / (1024 * 1024)
processed.append({
'name': file_name,
'path': result,
'size': f"{file_size:.1f} MB"
})
processor.processed_files = processed
return processed, f"Processed {len(processed)} files"
def download_processed_files():
"""Create download link for processed files"""
if not processor.processed_files:
return None
if len(processor.processed_files) == 1:
return processor.processed_files[0]['path']
# Create zip file for multiple files
temp_dir = tempfile.mkdtemp()
zip_path = Path(temp_dir) / "processed_audio.zip"
with zipfile.ZipFile(zip_path, 'w') as zipf:
for file_info in processor.processed_files:
zipf.write(file_info['path'], file_info['name'])
return str(zip_path)
# ========== GRADIO INTERFACE ==========
with gr.Blocks(css=custom_css, theme=gr.themes.Base()) as demo:
# Add JavaScript
gr.HTML(file_management_js)
# State management
uploaded_files = gr.State([])
processed_files = gr.State([])
with gr.Row():
# Left panel - Input files
with gr.Column(scale=1, min_width=250, elem_classes="file-column"):
# Input files display
input_file_panel = gr.HTML(
value=create_input_file_html([]),
elem_classes="file-panel-container"
)
# File upload (hidden but accessible)
file_upload = gr.File(
label="Upload Audio Files",
file_count="multiple",
file_types=["audio"],
visible=False,
elem_id="file-upload"
)
# Center - Main content
with gr.Column(scale=2, elem_classes="main-content"):
# Header
gr.Markdown("# 🎸 NAM Garden Audio Processor", elem_classes="main-title")
# Profile/IR Selection
with gr.Group(elem_classes="control-group"):
model_choices = processor.get_model_choices()
# Select first non-header model if available
default_model = None
for choice in model_choices:
if choice and "━━━" not in choice:
default_model = choice
break
model_dropdown = gr.Dropdown(
choices=model_choices,
value=default_model,
label="Select Profile/IR",
container=False
)
# Cloud model loading section (separate group)
with gr.Group(elem_classes="control-group"):
with gr.Accordion("☁️ Load Cloud Models", open=False):
gr.Markdown("""
**Add your own NAM models from cloud storage:**
- Direct `.nam` file URLs
- Google Drive links
- Dropbox links
- GitHub raw files
""")
cloud_url = gr.Textbox(
label="Model URL",
placeholder="https://drive.google.com/file/d/... or https://dropbox.com/...",
lines=1
)
with gr.Row():
load_cloud_btn = gr.Button("⬇ Load Model", size="sm")
load_folder_btn = gr.Button("πŸ“‚ Load Folder", size="sm")
clear_cloud_btn = gr.Button("πŸ—‘ Clear Cloud Models", size="sm")
cloud_status = gr.Textbox(
label="Cloud Status",
value="Ready to load cloud models",
interactive=False,
lines=1
)
# Audio Controls
with gr.Group(elem_classes="control-group"):
with gr.Row():
input_gain = gr.Slider(
minimum=-20,
maximum=20,
value=0,
step=1,
label="In (dB)",
container=True
)
output_gain = gr.Slider(
minimum=-20,
maximum=20,
value=0,
step=1,
label="Out (dB)",
container=True
)
mix_slider = gr.Slider(
minimum=0,
maximum=100,
value=100,
step=1,
label="Mix (%)",
container=True
)
# Waveform Display
with gr.Group(elem_classes="control-group"):
gr.HTML("""
<div style='
height: 200px;
background: rgba(255, 255, 255, 0.08);
backdrop-filter: blur(10px);
border-radius: 8px;
position: relative;
overflow: hidden;
'>
<canvas id='waveform' style='width: 100%; height: 100%;'></canvas>
</div>
""")
# Process Button
with gr.Row():
process_btn = gr.Button(
"⚑ Process Audio",
variant="primary",
size="lg"
)
# Status Bar
status_display = gr.HTML(
value=update_status("Ready"),
elem_id="status-display"
)
# Right panel - Processed files
with gr.Column(scale=1, min_width=250, elem_classes="file-column"):
processed_file_panel = gr.HTML(
value=create_processed_file_html([]),
elem_classes="file-panel-container"
)
# Download button
download_btn = gr.Button(
"πŸ’Ύ Download Processed",
visible=False,
elem_id="download-btn"
)
download_file = gr.File(
label="Download",
visible=False
)
# Event handlers
def handle_upload(files):
"""Handle file upload"""
if not files:
return create_input_file_html([]), [], update_status("No files uploaded")
return create_input_file_html(files), files, update_status(f"Loaded {len(files)} file(s)")
def handle_process(files, profile, in_gain, out_gain, mix):
"""Handle processing"""
if not files:
return create_processed_file_html([]), update_status("No files to process"), gr.update(visible=False)
# Skip if header is selected
if profile and "━━━" in profile:
return create_processed_file_html([]), update_status("Please select a valid model"), gr.update(visible=False)
# Process files
processed, status_msg = process_all_files(files, profile, in_gain, out_gain, mix)
if processed:
return (
create_processed_file_html(processed),
update_status(status_msg),
gr.update(visible=True)
)
else:
return (
create_processed_file_html([]),
update_status("Processing failed"),
gr.update(visible=False)
)
def handle_download():
"""Handle download"""
file_path = download_processed_files()
if file_path:
return gr.update(value=file_path, visible=True)
return gr.update(visible=False)
def handle_load_cloud_model(url):
"""Handle loading a single cloud model"""
if not url:
return gr.update(), "Please enter a URL"
success = processor.download_from_url(url)
if success:
return gr.update(choices=processor.get_model_choices()), f"βœ… Loaded model from cloud"
else:
return gr.update(), f"❌ Failed to load model from URL"
def handle_load_cloud_folder(url):
"""Handle loading a cloud folder"""
if not url:
return gr.update(), "Please enter a folder URL"
count = processor.load_folder_from_url(url)
if count > 0:
return gr.update(choices=processor.get_model_choices()), f"βœ… Loaded {count} models from folder"
else:
return gr.update(), f"❌ No models found in folder"
def handle_clear_cloud():
"""Handle clearing cloud models"""
processor.clear_custom_models()
return gr.update(choices=processor.get_model_choices()), "πŸ—‘ Cleared all cloud models"
# Connect events
file_upload.change(
fn=handle_upload,
inputs=[file_upload],
outputs=[input_file_panel, uploaded_files, status_display]
)
process_btn.click(
fn=handle_process,
inputs=[uploaded_files, model_dropdown, input_gain, output_gain, mix_slider],
outputs=[processed_file_panel, status_display, download_btn]
)
download_btn.click(
fn=handle_download,
outputs=download_file
)
# Cloud model loading events
load_cloud_btn.click(
fn=handle_load_cloud_model,
inputs=[cloud_url],
outputs=[model_dropdown, cloud_status]
)
load_folder_btn.click(
fn=handle_load_cloud_folder,
inputs=[cloud_url],
outputs=[model_dropdown, cloud_status]
)
clear_cloud_btn.click(
fn=handle_clear_cloud,
outputs=[model_dropdown, cloud_status]
)
# Launch the app
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)