| # Configuration for TransAudio model with ~30M parameters for 44.1kHz audio | |
| # | |
| # This configuration uses STFT preprocessing with U-Net architecture for audio processing | |
| # STFT parameters are optimized for 44.1kHz audio to balance frequency and temporal resolution | |
| # UNet parameters are tuned to achieve approximately 30M total parameters | |
| model: | |
| # STFT Configuration for 44.1kHz audio processing | |
| stft_config: | |
| n_fft: 16 # FFT size - provides good frequency resolution for 44.1kHz | |
| hop_length: 8 # Hop size - 1/4 of n_fft for good temporal resolution | |
| win_length: 16 # Window length - same as n_fft for standard hann window | |
| # UNet Configuration targeting ~30M parameters | |
| unet_config: | |
| in_ch: null # Will be automatically calculated as n_audio_channels | |
| out_ch: null # Will be automatically calculated as n_audio_channels | |
| base_ch: 64 # Base channel count to keep parameter count in check | |
| depth: 6 # Depth of U-Net (with downsampling/upsampling) | |
| ch_mults: [2, 3, 4, 6, 6, 8] | |
| k: 8 # Kernel size for convolutions | |
| decoder_k: null | |
| stride: 4 | |
| norm: 'weight' # Normalization type: group norm for stability | |
| act: 'snake' | |
| separable: false # Use standard convolutions rather than depthwise separable | |
| use_deconv: true # Use transposed convolutions for upsampling | |
| bottleneck_dilations: [1, 2, 4, 8] # Dilated convolutions in bottleneck | |
| learnable_alpha: false # Learnable residual scaling parameter | |
| alpha_init: 1.0 # Initial value for residual scaling | |
| use_lstm_bottleneck: true | |
| lstm_layers: 2 | |
| skip_layer_indexes: [-1, -2] | |
| skip_residual_scales: [1.0, 0.1] | |
| # iSTFT output head configuration (iSTFTNet-style synthesis) | |
| use_istft_head: false # Enable iSTFT output head instead of direct waveform | |
| istft_n_fft: 32 # FFT size for iSTFT synthesis | |
| istft_hop_length: 16 # Hop length for iSTFT | |
| istft_win_length: null # Window length (null = same as n_fft) | |
| phase_eps: 1.0e-8 # Epsilon for safe atan2 phase recovery | |
| # Dataset configuration | |
| dataset: | |
| sample_rate: 48000 # Target sample rate for audio | |
| chunk_size: 16384 # Audio chunk size in samples (about 1.5 seconds at 44.1kHz) | |
| mono: true # Convert to mono | |
| normalize: true # Normalize to [-1, 1] | |
| file_extensions: [".wav", ".mp3", ".flac", ".aac", ".m4a", ".ogg"] | |
| cache_dir: "./audio_cache" # Directory to cache resampled files | |
| min_samples: 12000 # Minimum number of samples required for a file to be included | |
| # Training Configuration | |
| training: | |
| # Basic training parameters | |
| batch_size: 16 | |
| num_epochs: 9999 | |
| learning_rate: 0.0001 # Keep at 1e-4 as in original | |
| discriminator_lr_multiplier: 1.0 # Discriminator LR is 4x of G. | |
| lr_warmup_steps: 2000 # Linear learning rate warmup over 7500 steps | |
| adam_b1: 0.8 # Adam beta1, slightly higher for stability | |
| adam_b2: 0.99 # Adam beta2 | |
| lr_decay: 0.999 # Learning rate decay per epoch | |
| seed: 1234 # Random seed for reproducibility | |
| fp16_run: false # Use mixed precision training (FP16) | |
| bf16_run: true # Use BF16 training (mutually exclusive with fp16_run) | |
| gradient_clip: 1.5 # Default gradient clipping value (0.0 to disable) | |
| generator_gradient_clip: 1.0 # Generator gradient clip (defaults to gradient_clip if not set) | |
| discriminator_gradient_clip: 4.0 # Discriminator gradient clip (defaults to gradient_clip if not set) | |
| disc_loss_type: "hinge" | |
| # Adversarial training parameters | |
| use_adversarial: true # Enable adversarial training | |
| pretrain_steps: 10000 # Number of steps to pretrain generator before adversarial training | |
| pretrain_reset: true # Reset generator optimizer when switching to adversarial training | |
| use_se_blocks: false # Enable Squeeze-and-Excitation blocks in discriminators | |
| enable_mpd: true # Enable Multi-Period Discriminator | |
| enable_msd: false # Enable Multi-Scale Discriminator | |
| enable_mbsd: true # Enable Multi-Band Spectral Discriminators | |
| feature_matching_weight: 1.5 # Weight for feature matching loss | |
| disc_instance_noise_std: 0 # Gaussian noise std added to D inputs (prevents D overpowering, 0 to disable) | |
| gen_s_weight: 1.0 # Weight for multi-scale generator loss | |
| gen_f_weight: 1.0 # Weight for multi-period generator loss | |
| disc_loss_weight: 1.0 # Weight for discriminator loss | |
| # MultiBandSpec Discriminator parameters (part of unified Discriminator) | |
| mbsd_window_lengths: [2048, 1024, 512] # Window lengths for each MBSD instance | |
| mbsd_hop_factor: 0.25 # Hop factor as fraction of window length | |
| # Audio loss parameters (adapted for 44.1kHz) | |
| sampling_rate: 48000 # Updated to 44.1kHz | |
| n_fft: 2048 # Increased for 44.1kHz audio | |
| win_size: 2048 # Window size matching n_fft | |
| hop_size: 512 # Hop size - 1/4 of window size | |
| num_mels: 80 # Number of mel bands | |
| fmin: 0.0 # Minimum frequency for mel | |
| fmax_for_loss: 22050.0 # Maximum frequency for loss (half of sample rate) | |
| # Mel loss weight | |
| mel_loss_weight: 15.0 | |
| # Multi-scale mel loss parameters | |
| use_multi_scale_mel_loss: true | |
| multi_scale_mel_win_lengths: [512, 1024, 2048] | |
| multi_scale_mel_n_mels: [40, 80, 128] | |
| multi_scale_mel_hop_divisor: 4 | |
| multi_scale_mel_loss_mode: "charbonnier" | |
| multi_scale_mel_log_eps: 0.00001 | |
| multi_scale_mel_l2_weight: 1.0 | |
| multi_scale_mel_charbonnier_eps: 0.000001 | |
| multi_scale_mel_f_min: 0.0 | |
| multi_scale_mel_f_max: null | |
| multi_scale_mel_power: 1.0 | |
| multi_scale_mel_scale: "htk" | |
| multi_scale_mel_norm: null | |
| multi_scale_mel_clamp_min: null | |
| # MR-STFT loss parameters (updated for 44.1kHz) | |
| use_mr_stft_loss: false | |
| mr_stft_n_ffts: [1024, 512, 256, 128] # Updated for 44.1kHz | |
| mr_stft_hop_sizes: [256, 128, 64, 32] # Updated for 44.1kHz | |
| mr_stft_win_sizes: [1024, 512, 256, 128] # Updated for 44.1kHz | |
| mr_stft_use_charbonnier: true | |
| mr_stft_charbonnier_eps: 0.000001 | |
| mr_stft_loss_weight: 1.0 | |
| # Waveform-domain loss parameters | |
| use_waveform_loss: false # Enable direct waveform loss | |
| waveform_loss_type: "mae" # Loss type: "mse", "mae", or "charbonnier" | |
| waveform_loss_weight: 1.0 # Weight for waveform loss | |
| waveform_loss_charbonnier_eps: 0.000001 # Epsilon for Charbonnier loss | |
| # Pitch loss parameters | |
| use_pitch_loss: true | |
| pitch_loss_use_activation_loss: false | |
| pitch_loss_act_weight: 0.1 | |
| pitch_loss_use_charbonnier: false | |
| pitch_loss_charbonnier_eps: 0.000001 | |
| pitch_loss_tau: 0.7 | |
| pitch_loss_wmin: 0.15 | |
| pitch_loss_conf_clip_min: 0.05 | |
| pitch_loss_conf_clip_max: 0.95 | |
| pitch_loss_vuv_thresh: 0.5 | |
| pitch_loss_weight: 2.0 | |
| pitch_loss_model: "mir-1k_g7" | |
| pitch_loss_step_size: 20.0 | |
| # Loss configuration (specific to STFT domain processing) | |
| loss: | |
| # Add different loss functions as needed | |
| mse_weight: 1.0 | |
| recon_weight: 2.0 | |
| log_mag_weight: 1.0 | |
| cos_phase_weight: 0.5 | |
| sin_phase_weight: 0.5 | |
| # Could add other losses like: | |
| # stft_loss_weight: 0.5 | |
| # perceptual_loss_weight: 0.1 | |
| # Data loading | |
| num_workers: 4 | |
| shuffle: true | |
| pin_memory: true | |
| # Checkpointing | |
| checkpoint_interval: 20000 # Save checkpoint every N steps | |
| validation_interval: 5000 # Run validation every N steps | |
| save_best_only: true # Only save checkpoints when validation loss improves | |
| # Logging | |
| log_interval: 100 # Log training progress every N steps | |
| tensorboard_log_dir: "./logs/transaudio_44khz" | |
| # Early stopping | |
| early_stopping_patience: 10 # Stop if validation loss doesn't improve for N validations | |
| early_stopping_min_delta: 0.001 # Minimum change to qualify as improvement | |
| # Hardware configuration | |
| hardware: | |
| num_gpus: 1 # Number of GPUs to use (0 for CPU only) | |
| cuda_visible_devices: "0" # Which GPUs to use (comma separated) | |
| # Paths | |
| paths: | |
| audio_dir: "./audio_files" # Directory containing training audio | |
| checkpoint_dir: "./checkpoints/transaudio_44khz" # Directory to save checkpoints | |
| output_dir: "./outputs" # Directory for output files | |
| log_dir: "./logs/transaudio_44khz" # Directory for logs | |
| # Validation configuration | |
| validation: | |
| batch_size: 2 # Reduced for 44.1kHz audio processing | |
| num_workers: 2 | |
| # Inference configuration | |
| inference: | |
| chunk_size: 16384 # Increased for 44.1kHz audio | |
| overlap: 8192 # Overlap between chunks for seamless reconstruction (1/8 of chunk) | |
| batch_size: 1 # Usually 1 for inference | |
| # Architecture notes: | |
| # Input channels after STFT: 3 * (2048//2 + 1) = 3 * 1025 = 3075 | |
| # With depth=5 and base_ch=96, the channel progression is: | |
| # Encoder: 3075 -> 96 -> 192 -> 384 -> 768 -> 1536 -> 3072 | |
| # Bottleneck: 3072 -> 3072 with dilated convolutions | |
| # Decoder: 3072 -> 1536 -> 768 -> 384 -> 192 -> 96 -> 3075 | |
| # This configuration should provide approximately 30M parameters while being more manageable |