File size: 9,276 Bytes
7fa3276 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 |
# Configuration for TransAudio model with ~30M parameters for 44.1kHz audio
#
# This configuration uses STFT preprocessing with U-Net architecture for audio processing
# STFT parameters are optimized for 44.1kHz audio to balance frequency and temporal resolution
# UNet parameters are tuned to achieve approximately 30M total parameters
model:
# STFT Configuration for 44.1kHz audio processing
stft_config:
n_fft: 16 # FFT size - provides good frequency resolution for 44.1kHz
hop_length: 8 # Hop size - 1/4 of n_fft for good temporal resolution
win_length: 16 # Window length - same as n_fft for standard hann window
# UNet Configuration targeting ~30M parameters
unet_config:
in_ch: null # Will be automatically calculated as n_audio_channels
out_ch: null # Will be automatically calculated as n_audio_channels
base_ch: 64 # Base channel count to keep parameter count in check
depth: 6 # Depth of U-Net (with downsampling/upsampling)
ch_mults: [2, 3, 4, 6, 6, 8]
k: 8 # Kernel size for convolutions
decoder_k: null
stride: 4
norm: 'weight' # Normalization type: group norm for stability
act: 'snake'
separable: false # Use standard convolutions rather than depthwise separable
use_deconv: true # Use transposed convolutions for upsampling
bottleneck_dilations: [1, 2, 4, 8] # Dilated convolutions in bottleneck
learnable_alpha: false # Learnable residual scaling parameter
alpha_init: 1.0 # Initial value for residual scaling
use_lstm_bottleneck: true
lstm_layers: 2
skip_layer_indexes: [-1, -2]
skip_residual_scales: [1.0, 0.1]
# iSTFT output head configuration (iSTFTNet-style synthesis)
use_istft_head: false # Enable iSTFT output head instead of direct waveform
istft_n_fft: 32 # FFT size for iSTFT synthesis
istft_hop_length: 16 # Hop length for iSTFT
istft_win_length: null # Window length (null = same as n_fft)
phase_eps: 1.0e-8 # Epsilon for safe atan2 phase recovery
# Dataset configuration
dataset:
sample_rate: 48000 # Target sample rate for audio
chunk_size: 16384 # Audio chunk size in samples (about 1.5 seconds at 44.1kHz)
mono: true # Convert to mono
normalize: true # Normalize to [-1, 1]
file_extensions: [".wav", ".mp3", ".flac", ".aac", ".m4a", ".ogg"]
cache_dir: "./audio_cache" # Directory to cache resampled files
min_samples: 12000 # Minimum number of samples required for a file to be included
# Training Configuration
training:
# Basic training parameters
batch_size: 16
num_epochs: 9999
learning_rate: 0.0001 # Keep at 1e-4 as in original
discriminator_lr_multiplier: 1.0 # Discriminator LR is 4x of G.
lr_warmup_steps: 2000 # Linear learning rate warmup over 7500 steps
adam_b1: 0.8 # Adam beta1, slightly higher for stability
adam_b2: 0.99 # Adam beta2
lr_decay: 0.999 # Learning rate decay per epoch
seed: 1234 # Random seed for reproducibility
fp16_run: false # Use mixed precision training (FP16)
bf16_run: true # Use BF16 training (mutually exclusive with fp16_run)
gradient_clip: 1.5 # Default gradient clipping value (0.0 to disable)
generator_gradient_clip: 1.0 # Generator gradient clip (defaults to gradient_clip if not set)
discriminator_gradient_clip: 4.0 # Discriminator gradient clip (defaults to gradient_clip if not set)
disc_loss_type: "hinge"
# Adversarial training parameters
use_adversarial: true # Enable adversarial training
pretrain_steps: 10000 # Number of steps to pretrain generator before adversarial training
pretrain_reset: true # Reset generator optimizer when switching to adversarial training
use_se_blocks: false # Enable Squeeze-and-Excitation blocks in discriminators
enable_mpd: true # Enable Multi-Period Discriminator
enable_msd: false # Enable Multi-Scale Discriminator
enable_mbsd: true # Enable Multi-Band Spectral Discriminators
feature_matching_weight: 1.5 # Weight for feature matching loss
disc_instance_noise_std: 0 # Gaussian noise std added to D inputs (prevents D overpowering, 0 to disable)
gen_s_weight: 1.0 # Weight for multi-scale generator loss
gen_f_weight: 1.0 # Weight for multi-period generator loss
disc_loss_weight: 1.0 # Weight for discriminator loss
# MultiBandSpec Discriminator parameters (part of unified Discriminator)
mbsd_window_lengths: [2048, 1024, 512] # Window lengths for each MBSD instance
mbsd_hop_factor: 0.25 # Hop factor as fraction of window length
# Audio loss parameters (adapted for 44.1kHz)
sampling_rate: 48000 # Updated to 44.1kHz
n_fft: 2048 # Increased for 44.1kHz audio
win_size: 2048 # Window size matching n_fft
hop_size: 512 # Hop size - 1/4 of window size
num_mels: 80 # Number of mel bands
fmin: 0.0 # Minimum frequency for mel
fmax_for_loss: 22050.0 # Maximum frequency for loss (half of sample rate)
# Mel loss weight
mel_loss_weight: 15.0
# Multi-scale mel loss parameters
use_multi_scale_mel_loss: true
multi_scale_mel_win_lengths: [512, 1024, 2048]
multi_scale_mel_n_mels: [40, 80, 128]
multi_scale_mel_hop_divisor: 4
multi_scale_mel_loss_mode: "charbonnier"
multi_scale_mel_log_eps: 0.00001
multi_scale_mel_l2_weight: 1.0
multi_scale_mel_charbonnier_eps: 0.000001
multi_scale_mel_f_min: 0.0
multi_scale_mel_f_max: null
multi_scale_mel_power: 1.0
multi_scale_mel_scale: "htk"
multi_scale_mel_norm: null
multi_scale_mel_clamp_min: null
# MR-STFT loss parameters (updated for 44.1kHz)
use_mr_stft_loss: false
mr_stft_n_ffts: [1024, 512, 256, 128] # Updated for 44.1kHz
mr_stft_hop_sizes: [256, 128, 64, 32] # Updated for 44.1kHz
mr_stft_win_sizes: [1024, 512, 256, 128] # Updated for 44.1kHz
mr_stft_use_charbonnier: true
mr_stft_charbonnier_eps: 0.000001
mr_stft_loss_weight: 1.0
# Waveform-domain loss parameters
use_waveform_loss: false # Enable direct waveform loss
waveform_loss_type: "mae" # Loss type: "mse", "mae", or "charbonnier"
waveform_loss_weight: 1.0 # Weight for waveform loss
waveform_loss_charbonnier_eps: 0.000001 # Epsilon for Charbonnier loss
# Pitch loss parameters
use_pitch_loss: true
pitch_loss_use_activation_loss: false
pitch_loss_act_weight: 0.1
pitch_loss_use_charbonnier: false
pitch_loss_charbonnier_eps: 0.000001
pitch_loss_tau: 0.7
pitch_loss_wmin: 0.15
pitch_loss_conf_clip_min: 0.05
pitch_loss_conf_clip_max: 0.95
pitch_loss_vuv_thresh: 0.5
pitch_loss_weight: 2.0
pitch_loss_model: "mir-1k_g7"
pitch_loss_step_size: 20.0
# Loss configuration (specific to STFT domain processing)
loss:
# Add different loss functions as needed
mse_weight: 1.0
recon_weight: 2.0
log_mag_weight: 1.0
cos_phase_weight: 0.5
sin_phase_weight: 0.5
# Could add other losses like:
# stft_loss_weight: 0.5
# perceptual_loss_weight: 0.1
# Data loading
num_workers: 4
shuffle: true
pin_memory: true
# Checkpointing
checkpoint_interval: 20000 # Save checkpoint every N steps
validation_interval: 5000 # Run validation every N steps
save_best_only: true # Only save checkpoints when validation loss improves
# Logging
log_interval: 100 # Log training progress every N steps
tensorboard_log_dir: "./logs/transaudio_44khz"
# Early stopping
early_stopping_patience: 10 # Stop if validation loss doesn't improve for N validations
early_stopping_min_delta: 0.001 # Minimum change to qualify as improvement
# Hardware configuration
hardware:
num_gpus: 1 # Number of GPUs to use (0 for CPU only)
cuda_visible_devices: "0" # Which GPUs to use (comma separated)
# Paths
paths:
audio_dir: "./audio_files" # Directory containing training audio
checkpoint_dir: "./checkpoints/transaudio_44khz" # Directory to save checkpoints
output_dir: "./outputs" # Directory for output files
log_dir: "./logs/transaudio_44khz" # Directory for logs
# Validation configuration
validation:
batch_size: 2 # Reduced for 44.1kHz audio processing
num_workers: 2
# Inference configuration
inference:
chunk_size: 16384 # Increased for 44.1kHz audio
overlap: 8192 # Overlap between chunks for seamless reconstruction (1/8 of chunk)
batch_size: 1 # Usually 1 for inference
# Architecture notes:
# Input channels after STFT: 3 * (2048//2 + 1) = 3 * 1025 = 3075
# With depth=5 and base_ch=96, the channel progression is:
# Encoder: 3075 -> 96 -> 192 -> 384 -> 768 -> 1536 -> 3072
# Bottleneck: 3072 -> 3072 with dilated convolutions
# Decoder: 3072 -> 1536 -> 768 -> 384 -> 192 -> 96 -> 3075
# This configuration should provide approximately 30M parameters while being more manageable |