# Configuration for TransAudio model with ~30M parameters for 44.1kHz audio # # This configuration uses STFT preprocessing with U-Net architecture for audio processing # STFT parameters are optimized for 44.1kHz audio to balance frequency and temporal resolution # UNet parameters are tuned to achieve approximately 30M total parameters model: # STFT Configuration for 44.1kHz audio processing stft_config: n_fft: 16 # FFT size - provides good frequency resolution for 44.1kHz hop_length: 8 # Hop size - 1/4 of n_fft for good temporal resolution win_length: 16 # Window length - same as n_fft for standard hann window # UNet Configuration targeting ~30M parameters unet_config: in_ch: null # Will be automatically calculated as n_audio_channels out_ch: null # Will be automatically calculated as n_audio_channels base_ch: 64 # Base channel count to keep parameter count in check depth: 6 # Depth of U-Net (with downsampling/upsampling) ch_mults: [2, 3, 4, 6, 6, 8] k: 8 # Kernel size for convolutions decoder_k: null stride: 4 norm: 'weight' # Normalization type: group norm for stability act: 'snake' separable: false # Use standard convolutions rather than depthwise separable use_deconv: true # Use transposed convolutions for upsampling bottleneck_dilations: [1, 2, 4, 8] # Dilated convolutions in bottleneck learnable_alpha: false # Learnable residual scaling parameter alpha_init: 1.0 # Initial value for residual scaling use_lstm_bottleneck: true lstm_layers: 2 skip_layer_indexes: [-1, -2] skip_residual_scales: [1.0, 0.1] # iSTFT output head configuration (iSTFTNet-style synthesis) use_istft_head: false # Enable iSTFT output head instead of direct waveform istft_n_fft: 32 # FFT size for iSTFT synthesis istft_hop_length: 16 # Hop length for iSTFT istft_win_length: null # Window length (null = same as n_fft) phase_eps: 1.0e-8 # Epsilon for safe atan2 phase recovery # Dataset configuration dataset: sample_rate: 48000 # Target sample rate for audio chunk_size: 16384 # Audio chunk size in samples (about 1.5 seconds at 44.1kHz) mono: true # Convert to mono normalize: true # Normalize to [-1, 1] file_extensions: [".wav", ".mp3", ".flac", ".aac", ".m4a", ".ogg"] cache_dir: "./audio_cache" # Directory to cache resampled files min_samples: 12000 # Minimum number of samples required for a file to be included # Training Configuration training: # Basic training parameters batch_size: 16 num_epochs: 9999 learning_rate: 0.0001 # Keep at 1e-4 as in original discriminator_lr_multiplier: 1.0 # Discriminator LR is 4x of G. lr_warmup_steps: 2000 # Linear learning rate warmup over 7500 steps adam_b1: 0.8 # Adam beta1, slightly higher for stability adam_b2: 0.99 # Adam beta2 lr_decay: 0.999 # Learning rate decay per epoch seed: 1234 # Random seed for reproducibility fp16_run: false # Use mixed precision training (FP16) bf16_run: true # Use BF16 training (mutually exclusive with fp16_run) gradient_clip: 1.5 # Default gradient clipping value (0.0 to disable) generator_gradient_clip: 1.0 # Generator gradient clip (defaults to gradient_clip if not set) discriminator_gradient_clip: 4.0 # Discriminator gradient clip (defaults to gradient_clip if not set) disc_loss_type: "hinge" # Adversarial training parameters use_adversarial: true # Enable adversarial training pretrain_steps: 10000 # Number of steps to pretrain generator before adversarial training pretrain_reset: true # Reset generator optimizer when switching to adversarial training use_se_blocks: false # Enable Squeeze-and-Excitation blocks in discriminators enable_mpd: true # Enable Multi-Period Discriminator enable_msd: false # Enable Multi-Scale Discriminator enable_mbsd: true # Enable Multi-Band Spectral Discriminators feature_matching_weight: 1.5 # Weight for feature matching loss disc_instance_noise_std: 0 # Gaussian noise std added to D inputs (prevents D overpowering, 0 to disable) gen_s_weight: 1.0 # Weight for multi-scale generator loss gen_f_weight: 1.0 # Weight for multi-period generator loss disc_loss_weight: 1.0 # Weight for discriminator loss # MultiBandSpec Discriminator parameters (part of unified Discriminator) mbsd_window_lengths: [2048, 1024, 512] # Window lengths for each MBSD instance mbsd_hop_factor: 0.25 # Hop factor as fraction of window length # Audio loss parameters (adapted for 44.1kHz) sampling_rate: 48000 # Updated to 44.1kHz n_fft: 2048 # Increased for 44.1kHz audio win_size: 2048 # Window size matching n_fft hop_size: 512 # Hop size - 1/4 of window size num_mels: 80 # Number of mel bands fmin: 0.0 # Minimum frequency for mel fmax_for_loss: 22050.0 # Maximum frequency for loss (half of sample rate) # Mel loss weight mel_loss_weight: 15.0 # Multi-scale mel loss parameters use_multi_scale_mel_loss: true multi_scale_mel_win_lengths: [512, 1024, 2048] multi_scale_mel_n_mels: [40, 80, 128] multi_scale_mel_hop_divisor: 4 multi_scale_mel_loss_mode: "charbonnier" multi_scale_mel_log_eps: 0.00001 multi_scale_mel_l2_weight: 1.0 multi_scale_mel_charbonnier_eps: 0.000001 multi_scale_mel_f_min: 0.0 multi_scale_mel_f_max: null multi_scale_mel_power: 1.0 multi_scale_mel_scale: "htk" multi_scale_mel_norm: null multi_scale_mel_clamp_min: null # MR-STFT loss parameters (updated for 44.1kHz) use_mr_stft_loss: false mr_stft_n_ffts: [1024, 512, 256, 128] # Updated for 44.1kHz mr_stft_hop_sizes: [256, 128, 64, 32] # Updated for 44.1kHz mr_stft_win_sizes: [1024, 512, 256, 128] # Updated for 44.1kHz mr_stft_use_charbonnier: true mr_stft_charbonnier_eps: 0.000001 mr_stft_loss_weight: 1.0 # Waveform-domain loss parameters use_waveform_loss: false # Enable direct waveform loss waveform_loss_type: "mae" # Loss type: "mse", "mae", or "charbonnier" waveform_loss_weight: 1.0 # Weight for waveform loss waveform_loss_charbonnier_eps: 0.000001 # Epsilon for Charbonnier loss # Pitch loss parameters use_pitch_loss: true pitch_loss_use_activation_loss: false pitch_loss_act_weight: 0.1 pitch_loss_use_charbonnier: false pitch_loss_charbonnier_eps: 0.000001 pitch_loss_tau: 0.7 pitch_loss_wmin: 0.15 pitch_loss_conf_clip_min: 0.05 pitch_loss_conf_clip_max: 0.95 pitch_loss_vuv_thresh: 0.5 pitch_loss_weight: 2.0 pitch_loss_model: "mir-1k_g7" pitch_loss_step_size: 20.0 # Loss configuration (specific to STFT domain processing) loss: # Add different loss functions as needed mse_weight: 1.0 recon_weight: 2.0 log_mag_weight: 1.0 cos_phase_weight: 0.5 sin_phase_weight: 0.5 # Could add other losses like: # stft_loss_weight: 0.5 # perceptual_loss_weight: 0.1 # Data loading num_workers: 4 shuffle: true pin_memory: true # Checkpointing checkpoint_interval: 20000 # Save checkpoint every N steps validation_interval: 5000 # Run validation every N steps save_best_only: true # Only save checkpoints when validation loss improves # Logging log_interval: 100 # Log training progress every N steps tensorboard_log_dir: "./logs/transaudio_44khz" # Early stopping early_stopping_patience: 10 # Stop if validation loss doesn't improve for N validations early_stopping_min_delta: 0.001 # Minimum change to qualify as improvement # Hardware configuration hardware: num_gpus: 1 # Number of GPUs to use (0 for CPU only) cuda_visible_devices: "0" # Which GPUs to use (comma separated) # Paths paths: audio_dir: "./audio_files" # Directory containing training audio checkpoint_dir: "./checkpoints/transaudio_44khz" # Directory to save checkpoints output_dir: "./outputs" # Directory for output files log_dir: "./logs/transaudio_44khz" # Directory for logs # Validation configuration validation: batch_size: 2 # Reduced for 44.1kHz audio processing num_workers: 2 # Inference configuration inference: chunk_size: 16384 # Increased for 44.1kHz audio overlap: 8192 # Overlap between chunks for seamless reconstruction (1/8 of chunk) batch_size: 1 # Usually 1 for inference # Architecture notes: # Input channels after STFT: 3 * (2048//2 + 1) = 3 * 1025 = 3075 # With depth=5 and base_ch=96, the channel progression is: # Encoder: 3075 -> 96 -> 192 -> 384 -> 768 -> 1536 -> 3072 # Bottleneck: 3072 -> 3072 with dilated convolutions # Decoder: 3072 -> 1536 -> 768 -> 384 -> 192 -> 96 -> 3075 # This configuration should provide approximately 30M parameters while being more manageable