File size: 9,276 Bytes
7fa3276
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
# Configuration for TransAudio model with ~30M parameters for 44.1kHz audio
# 
# This configuration uses STFT preprocessing with U-Net architecture for audio processing
# STFT parameters are optimized for 44.1kHz audio to balance frequency and temporal resolution
# UNet parameters are tuned to achieve approximately 30M total parameters

model:
  # STFT Configuration for 44.1kHz audio processing
  stft_config:
    n_fft: 16              # FFT size - provides good frequency resolution for 44.1kHz
    hop_length: 8          # Hop size - 1/4 of n_fft for good temporal resolution
    win_length: 16         # Window length - same as n_fft for standard hann window
    
  # UNet Configuration targeting ~30M parameters
  unet_config:
    in_ch: null              # Will be automatically calculated as n_audio_channels
    out_ch: null             # Will be automatically calculated as n_audio_channels
    base_ch: 64              # Base channel count to keep parameter count in check
    depth: 6                 # Depth of U-Net (with downsampling/upsampling)
    ch_mults: [2, 3, 4, 6, 6, 8]
    k: 8                     # Kernel size for convolutions
    decoder_k: null
    stride: 4
    norm: 'weight'               # Normalization type: group norm for stability
    act: 'snake'
    separable: false         # Use standard convolutions rather than depthwise separable
    use_deconv: true         # Use transposed convolutions for upsampling
    bottleneck_dilations: [1, 2, 4, 8]  # Dilated convolutions in bottleneck
    learnable_alpha: false    # Learnable residual scaling parameter
    alpha_init: 1.0          # Initial value for residual scaling
    use_lstm_bottleneck: true
    lstm_layers: 2
    skip_layer_indexes: [-1, -2]
    skip_residual_scales: [1.0, 0.1]
    # iSTFT output head configuration (iSTFTNet-style synthesis)
    use_istft_head: false     # Enable iSTFT output head instead of direct waveform
    istft_n_fft: 32         # FFT size for iSTFT synthesis
    istft_hop_length: 16     # Hop length for iSTFT
    istft_win_length: null    # Window length (null = same as n_fft)
    phase_eps: 1.0e-8         # Epsilon for safe atan2 phase recovery
    
    
# Dataset configuration
dataset:
  sample_rate: 48000         # Target sample rate for audio
  chunk_size: 16384          # Audio chunk size in samples (about 1.5 seconds at 44.1kHz)
  mono: true                 # Convert to mono
  normalize: true            # Normalize to [-1, 1]
  file_extensions: [".wav", ".mp3", ".flac", ".aac", ".m4a", ".ogg"]
  cache_dir: "./audio_cache" # Directory to cache resampled files
  min_samples: 12000         # Minimum number of samples required for a file to be included

# Training Configuration
training:
  # Basic training parameters
  batch_size: 16
  num_epochs: 9999
  learning_rate: 0.0001      # Keep at 1e-4 as in original
  discriminator_lr_multiplier: 1.0 # Discriminator LR is 4x of G.
  lr_warmup_steps: 2000      # Linear learning rate warmup over 7500 steps
  adam_b1: 0.8               # Adam beta1, slightly higher for stability
  adam_b2: 0.99              # Adam beta2
  lr_decay: 0.999            # Learning rate decay per epoch
  seed: 1234                 # Random seed for reproducibility
  fp16_run: false            # Use mixed precision training (FP16)
  bf16_run: true             # Use BF16 training (mutually exclusive with fp16_run)
  gradient_clip: 1.5         # Default gradient clipping value (0.0 to disable)
  generator_gradient_clip: 1.0      # Generator gradient clip (defaults to gradient_clip if not set)
  discriminator_gradient_clip: 4.0  # Discriminator gradient clip (defaults to gradient_clip if not set)
  disc_loss_type: "hinge"


  # Adversarial training parameters
  use_adversarial: true       # Enable adversarial training
  pretrain_steps: 10000      # Number of steps to pretrain generator before adversarial training
  pretrain_reset: true       # Reset generator optimizer when switching to adversarial training
  use_se_blocks: false        # Enable Squeeze-and-Excitation blocks in discriminators
  enable_mpd: true             # Enable Multi-Period Discriminator
  enable_msd: false             # Enable Multi-Scale Discriminator
  enable_mbsd: true            # Enable Multi-Band Spectral Discriminators
  feature_matching_weight: 1.5 # Weight for feature matching loss
  disc_instance_noise_std: 0 # Gaussian noise std added to D inputs (prevents D overpowering, 0 to disable)
  gen_s_weight: 1.0          # Weight for multi-scale generator loss
  gen_f_weight: 1.0          # Weight for multi-period generator loss
  disc_loss_weight: 1.0      # Weight for discriminator loss
  
  # MultiBandSpec Discriminator parameters (part of unified Discriminator)
  mbsd_window_lengths: [2048, 1024, 512]  # Window lengths for each MBSD instance
  mbsd_hop_factor: 0.25                    # Hop factor as fraction of window length


  # Audio loss parameters (adapted for 44.1kHz)
  sampling_rate: 48000        # Updated to 44.1kHz
  n_fft: 2048                # Increased for 44.1kHz audio
  win_size: 2048             # Window size matching n_fft
  hop_size: 512              # Hop size - 1/4 of window size
  num_mels: 80               # Number of mel bands
  fmin: 0.0                  # Minimum frequency for mel
  fmax_for_loss: 22050.0     # Maximum frequency for loss (half of sample rate)

  # Mel loss weight
  mel_loss_weight: 15.0

  # Multi-scale mel loss parameters
  use_multi_scale_mel_loss: true
  multi_scale_mel_win_lengths: [512, 1024, 2048]
  multi_scale_mel_n_mels:      [40,   80,   128]
  multi_scale_mel_hop_divisor: 4
  multi_scale_mel_loss_mode: "charbonnier"
  multi_scale_mel_log_eps: 0.00001
  multi_scale_mel_l2_weight: 1.0
  multi_scale_mel_charbonnier_eps: 0.000001
  multi_scale_mel_f_min: 0.0
  multi_scale_mel_f_max: null
  multi_scale_mel_power: 1.0
  multi_scale_mel_scale: "htk"
  multi_scale_mel_norm: null
  multi_scale_mel_clamp_min: null

  # MR-STFT loss parameters (updated for 44.1kHz)
  use_mr_stft_loss: false
  mr_stft_n_ffts: [1024, 512, 256, 128]  # Updated for 44.1kHz
  mr_stft_hop_sizes: [256, 128, 64, 32]   # Updated for 44.1kHz
  mr_stft_win_sizes: [1024, 512, 256, 128] # Updated for 44.1kHz
  mr_stft_use_charbonnier: true
  mr_stft_charbonnier_eps: 0.000001
  mr_stft_loss_weight: 1.0

  # Waveform-domain loss parameters
  use_waveform_loss: false               # Enable direct waveform loss
  waveform_loss_type: "mae"     # Loss type: "mse", "mae", or "charbonnier"
  waveform_loss_weight: 1.0             # Weight for waveform loss
  waveform_loss_charbonnier_eps: 0.000001  # Epsilon for Charbonnier loss

  # Pitch loss parameters
  use_pitch_loss: true
  pitch_loss_use_activation_loss: false
  pitch_loss_act_weight: 0.1
  pitch_loss_use_charbonnier: false
  pitch_loss_charbonnier_eps: 0.000001
  pitch_loss_tau: 0.7
  pitch_loss_wmin: 0.15
  pitch_loss_conf_clip_min: 0.05
  pitch_loss_conf_clip_max: 0.95
  pitch_loss_vuv_thresh: 0.5
  pitch_loss_weight: 2.0
  pitch_loss_model: "mir-1k_g7"
  pitch_loss_step_size: 20.0



  # Loss configuration (specific to STFT domain processing)
  loss:
    # Add different loss functions as needed
    mse_weight: 1.0
    recon_weight: 2.0
    log_mag_weight: 1.0
    cos_phase_weight: 0.5
    sin_phase_weight: 0.5
    # Could add other losses like:
    # stft_loss_weight: 0.5
    # perceptual_loss_weight: 0.1

  # Data loading
  num_workers: 4
  shuffle: true
  pin_memory: true

  # Checkpointing
  checkpoint_interval: 20000  # Save checkpoint every N steps
  validation_interval: 5000  # Run validation every N steps
  save_best_only: true       # Only save checkpoints when validation loss improves

  # Logging
  log_interval: 100          # Log training progress every N steps
  tensorboard_log_dir: "./logs/transaudio_44khz"

  # Early stopping
  early_stopping_patience: 10  # Stop if validation loss doesn't improve for N validations
  early_stopping_min_delta: 0.001  # Minimum change to qualify as improvement

# Hardware configuration
hardware:
  num_gpus: 1                # Number of GPUs to use (0 for CPU only)
  cuda_visible_devices: "0"  # Which GPUs to use (comma separated)

# Paths
paths:
  audio_dir: "./audio_files"      # Directory containing training audio
  checkpoint_dir: "./checkpoints/transaudio_44khz" # Directory to save checkpoints
  output_dir: "./outputs"         # Directory for output files
  log_dir: "./logs/transaudio_44khz"               # Directory for logs

# Validation configuration
validation:
  batch_size: 2               # Reduced for 44.1kHz audio processing
  num_workers: 2

# Inference configuration
inference:
  chunk_size: 16384          # Increased for 44.1kHz audio
  overlap: 8192              # Overlap between chunks for seamless reconstruction (1/8 of chunk)
  batch_size: 1              # Usually 1 for inference

# Architecture notes:
# Input channels after STFT: 3 * (2048//2 + 1) = 3 * 1025 = 3075
# With depth=5 and base_ch=96, the channel progression is:
# Encoder: 3075 -> 96 -> 192 -> 384 -> 768 -> 1536 -> 3072
# Bottleneck: 3072 -> 3072 with dilated convolutions
# Decoder: 3072 -> 1536 -> 768 -> 384 -> 192 -> 96 -> 3075
# This configuration should provide approximately 30M parameters while being more manageable