ZDisket's picture
Upload folder using huggingface_hub
7fa3276 verified
# Configuration for TransAudio model with ~30M parameters for 44.1kHz audio
#
# This configuration uses STFT preprocessing with U-Net architecture for audio processing
# STFT parameters are optimized for 44.1kHz audio to balance frequency and temporal resolution
# UNet parameters are tuned to achieve approximately 30M total parameters
model:
# STFT Configuration for 44.1kHz audio processing
stft_config:
n_fft: 16 # FFT size - provides good frequency resolution for 44.1kHz
hop_length: 8 # Hop size - 1/4 of n_fft for good temporal resolution
win_length: 16 # Window length - same as n_fft for standard hann window
# UNet Configuration targeting ~30M parameters
unet_config:
in_ch: null # Will be automatically calculated as n_audio_channels
out_ch: null # Will be automatically calculated as n_audio_channels
base_ch: 64 # Base channel count to keep parameter count in check
depth: 6 # Depth of U-Net (with downsampling/upsampling)
ch_mults: [2, 3, 4, 6, 6, 8]
k: 8 # Kernel size for convolutions
decoder_k: null
stride: 4
norm: 'weight' # Normalization type: group norm for stability
act: 'snake'
separable: false # Use standard convolutions rather than depthwise separable
use_deconv: true # Use transposed convolutions for upsampling
bottleneck_dilations: [1, 2, 4, 8] # Dilated convolutions in bottleneck
learnable_alpha: false # Learnable residual scaling parameter
alpha_init: 1.0 # Initial value for residual scaling
use_lstm_bottleneck: true
lstm_layers: 2
skip_layer_indexes: [-1, -2]
skip_residual_scales: [1.0, 0.1]
# iSTFT output head configuration (iSTFTNet-style synthesis)
use_istft_head: false # Enable iSTFT output head instead of direct waveform
istft_n_fft: 32 # FFT size for iSTFT synthesis
istft_hop_length: 16 # Hop length for iSTFT
istft_win_length: null # Window length (null = same as n_fft)
phase_eps: 1.0e-8 # Epsilon for safe atan2 phase recovery
# Dataset configuration
dataset:
sample_rate: 48000 # Target sample rate for audio
chunk_size: 16384 # Audio chunk size in samples (about 1.5 seconds at 44.1kHz)
mono: true # Convert to mono
normalize: true # Normalize to [-1, 1]
file_extensions: [".wav", ".mp3", ".flac", ".aac", ".m4a", ".ogg"]
cache_dir: "./audio_cache" # Directory to cache resampled files
min_samples: 12000 # Minimum number of samples required for a file to be included
# Training Configuration
training:
# Basic training parameters
batch_size: 16
num_epochs: 9999
learning_rate: 0.0001 # Keep at 1e-4 as in original
discriminator_lr_multiplier: 1.0 # Discriminator LR is 4x of G.
lr_warmup_steps: 2000 # Linear learning rate warmup over 7500 steps
adam_b1: 0.8 # Adam beta1, slightly higher for stability
adam_b2: 0.99 # Adam beta2
lr_decay: 0.999 # Learning rate decay per epoch
seed: 1234 # Random seed for reproducibility
fp16_run: false # Use mixed precision training (FP16)
bf16_run: true # Use BF16 training (mutually exclusive with fp16_run)
gradient_clip: 1.5 # Default gradient clipping value (0.0 to disable)
generator_gradient_clip: 1.0 # Generator gradient clip (defaults to gradient_clip if not set)
discriminator_gradient_clip: 4.0 # Discriminator gradient clip (defaults to gradient_clip if not set)
disc_loss_type: "hinge"
# Adversarial training parameters
use_adversarial: true # Enable adversarial training
pretrain_steps: 10000 # Number of steps to pretrain generator before adversarial training
pretrain_reset: true # Reset generator optimizer when switching to adversarial training
use_se_blocks: false # Enable Squeeze-and-Excitation blocks in discriminators
enable_mpd: true # Enable Multi-Period Discriminator
enable_msd: false # Enable Multi-Scale Discriminator
enable_mbsd: true # Enable Multi-Band Spectral Discriminators
feature_matching_weight: 1.5 # Weight for feature matching loss
disc_instance_noise_std: 0 # Gaussian noise std added to D inputs (prevents D overpowering, 0 to disable)
gen_s_weight: 1.0 # Weight for multi-scale generator loss
gen_f_weight: 1.0 # Weight for multi-period generator loss
disc_loss_weight: 1.0 # Weight for discriminator loss
# MultiBandSpec Discriminator parameters (part of unified Discriminator)
mbsd_window_lengths: [2048, 1024, 512] # Window lengths for each MBSD instance
mbsd_hop_factor: 0.25 # Hop factor as fraction of window length
# Audio loss parameters (adapted for 44.1kHz)
sampling_rate: 48000 # Updated to 44.1kHz
n_fft: 2048 # Increased for 44.1kHz audio
win_size: 2048 # Window size matching n_fft
hop_size: 512 # Hop size - 1/4 of window size
num_mels: 80 # Number of mel bands
fmin: 0.0 # Minimum frequency for mel
fmax_for_loss: 22050.0 # Maximum frequency for loss (half of sample rate)
# Mel loss weight
mel_loss_weight: 15.0
# Multi-scale mel loss parameters
use_multi_scale_mel_loss: true
multi_scale_mel_win_lengths: [512, 1024, 2048]
multi_scale_mel_n_mels: [40, 80, 128]
multi_scale_mel_hop_divisor: 4
multi_scale_mel_loss_mode: "charbonnier"
multi_scale_mel_log_eps: 0.00001
multi_scale_mel_l2_weight: 1.0
multi_scale_mel_charbonnier_eps: 0.000001
multi_scale_mel_f_min: 0.0
multi_scale_mel_f_max: null
multi_scale_mel_power: 1.0
multi_scale_mel_scale: "htk"
multi_scale_mel_norm: null
multi_scale_mel_clamp_min: null
# MR-STFT loss parameters (updated for 44.1kHz)
use_mr_stft_loss: false
mr_stft_n_ffts: [1024, 512, 256, 128] # Updated for 44.1kHz
mr_stft_hop_sizes: [256, 128, 64, 32] # Updated for 44.1kHz
mr_stft_win_sizes: [1024, 512, 256, 128] # Updated for 44.1kHz
mr_stft_use_charbonnier: true
mr_stft_charbonnier_eps: 0.000001
mr_stft_loss_weight: 1.0
# Waveform-domain loss parameters
use_waveform_loss: false # Enable direct waveform loss
waveform_loss_type: "mae" # Loss type: "mse", "mae", or "charbonnier"
waveform_loss_weight: 1.0 # Weight for waveform loss
waveform_loss_charbonnier_eps: 0.000001 # Epsilon for Charbonnier loss
# Pitch loss parameters
use_pitch_loss: true
pitch_loss_use_activation_loss: false
pitch_loss_act_weight: 0.1
pitch_loss_use_charbonnier: false
pitch_loss_charbonnier_eps: 0.000001
pitch_loss_tau: 0.7
pitch_loss_wmin: 0.15
pitch_loss_conf_clip_min: 0.05
pitch_loss_conf_clip_max: 0.95
pitch_loss_vuv_thresh: 0.5
pitch_loss_weight: 2.0
pitch_loss_model: "mir-1k_g7"
pitch_loss_step_size: 20.0
# Loss configuration (specific to STFT domain processing)
loss:
# Add different loss functions as needed
mse_weight: 1.0
recon_weight: 2.0
log_mag_weight: 1.0
cos_phase_weight: 0.5
sin_phase_weight: 0.5
# Could add other losses like:
# stft_loss_weight: 0.5
# perceptual_loss_weight: 0.1
# Data loading
num_workers: 4
shuffle: true
pin_memory: true
# Checkpointing
checkpoint_interval: 20000 # Save checkpoint every N steps
validation_interval: 5000 # Run validation every N steps
save_best_only: true # Only save checkpoints when validation loss improves
# Logging
log_interval: 100 # Log training progress every N steps
tensorboard_log_dir: "./logs/transaudio_44khz"
# Early stopping
early_stopping_patience: 10 # Stop if validation loss doesn't improve for N validations
early_stopping_min_delta: 0.001 # Minimum change to qualify as improvement
# Hardware configuration
hardware:
num_gpus: 1 # Number of GPUs to use (0 for CPU only)
cuda_visible_devices: "0" # Which GPUs to use (comma separated)
# Paths
paths:
audio_dir: "./audio_files" # Directory containing training audio
checkpoint_dir: "./checkpoints/transaudio_44khz" # Directory to save checkpoints
output_dir: "./outputs" # Directory for output files
log_dir: "./logs/transaudio_44khz" # Directory for logs
# Validation configuration
validation:
batch_size: 2 # Reduced for 44.1kHz audio processing
num_workers: 2
# Inference configuration
inference:
chunk_size: 16384 # Increased for 44.1kHz audio
overlap: 8192 # Overlap between chunks for seamless reconstruction (1/8 of chunk)
batch_size: 1 # Usually 1 for inference
# Architecture notes:
# Input channels after STFT: 3 * (2048//2 + 1) = 3 * 1025 = 3075
# With depth=5 and base_ch=96, the channel progression is:
# Encoder: 3075 -> 96 -> 192 -> 384 -> 768 -> 1536 -> 3072
# Bottleneck: 3072 -> 3072 with dilated convolutions
# Decoder: 3072 -> 1536 -> 768 -> 384 -> 192 -> 96 -> 3075
# This configuration should provide approximately 30M parameters while being more manageable