NeMo_Canary / examples /audio /conf /score_based_generative.yaml
Respair's picture
Upload folder using huggingface_hub
b386992 verified
name: score_based_generative_model
model:
type: score_based
sample_rate: 16000
skip_nan_grad: false
num_outputs: 1
normalize_input: true
max_utts_evaluation_metrics: 50 # metric calculation needs full inference and is slow, so we limit to first few files
train_ds:
manifest_filepath: ???
input_key: noisy_filepath
target_key: clean_filepath
audio_duration: 2.04 # Number of STFT time frames = 1 + audio_duration // encoder.hop_length = 256
random_offset: true
normalization_signal: input_signal
batch_size: 8 # batch size may be increased based on the available memory
shuffle: true
num_workers: 8
pin_memory: true
validation_ds:
manifest_filepath: ???
input_key: noisy_filepath
target_key: clean_filepath
normalize_input: false # load data as is for validation, the model will normalize it for inference
batch_size: 4
shuffle: false
num_workers: 4
pin_memory: true
encoder:
_target_: nemo.collections.audio.modules.transforms.AudioToSpectrogram
fft_length: 510 # Number of subbands in the STFT = fft_length // 2 + 1 = 256
hop_length: 128
magnitude_power: 0.5
scale: 0.33
decoder:
_target_: nemo.collections.audio.modules.transforms.SpectrogramToAudio
fft_length: ${model.encoder.fft_length}
hop_length: ${model.encoder.hop_length}
magnitude_power: ${model.encoder.magnitude_power}
scale: ${model.encoder.scale}
estimator:
_target_: nemo.collections.audio.parts.submodules.ncsnpp.SpectrogramNoiseConditionalScoreNetworkPlusPlus
in_channels: 2 # concatenation of single-channel perturbed and noisy
out_channels: 1 # single-channel score estimate
conditioned_on_time: true
num_res_blocks: 3 # increased number of res blocks
pad_time_to: 64 # pad to 64 frames for the time dimension
pad_dimension_to: 0 # no padding in the frequency dimension
sde:
_target_: nemo.collections.audio.parts.submodules.diffusion.OrnsteinUhlenbeckVarianceExplodingSDE
stiffness: 1.5
std_min: 0.05
std_max: 0.5
num_steps: 1000
sampler:
_target_: nemo.collections.audio.parts.submodules.diffusion.PredictorCorrectorSampler
predictor: reverse_diffusion
corrector: annealed_langevin_dynamics
num_steps: 50
num_corrector_steps: 1
snr: 0.5
loss:
_target_: nemo.collections.audio.losses.MSELoss
ndim: 4 # loss is calculated on the score in the encoded domain (batch, channel, dimension, time)
metrics:
val:
sisdr: # output SI-SDR
_target_: torchmetrics.audio.ScaleInvariantSignalDistortionRatio
optim:
name: adam
lr: 1e-4
# optimizer arguments
betas: [0.9, 0.999]
weight_decay: 0.0
trainer:
devices: -1 # number of GPUs, -1 would use all available GPUs
num_nodes: 1
max_epochs: -1
max_steps: -1 # computed at runtime if not set
val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations
accelerator: auto
strategy: ddp
accumulate_grad_batches: 1
gradient_clip_val: null
precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP.
log_every_n_steps: 25 # Interval of logging.
enable_progress_bar: true
num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it
check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs
sync_batchnorm: true
enable_checkpointing: false # Provided by exp_manager
logger: false # Provided by exp_manager
exp_manager:
exp_dir: null
name: ${name}
# use exponential moving average for model parameters
ema:
enable: true
decay: 0.999 # decay rate
cpu_offload: false # offload EMA parameters to CPU to save GPU memory
every_n_steps: 1 # how often to update EMA weights
validate_original_weights: false # use original weights for validation calculation?
# logging
create_tensorboard_logger: true
# checkpointing
create_checkpoint_callback: true
checkpoint_callback_params:
# in case of multiple validation sets, first one is used
monitor: val_sisdr
mode: max
save_top_k: 5
always_save_nemo: true # saves the checkpoints as nemo files instead of PTL checkpoints
# early stopping
create_early_stopping_callback: true
early_stopping_callback_params:
monitor: val_sisdr
mode: max
min_delta: 0.0
patience: 20 # patience in terms of check_val_every_n_epoch
verbose: true
strict: false # Should be False to avoid a runtime error where EarlyStopping says monitor is unavailable, which sometimes happens with resumed training.
resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.
# you need to set these two to true to continue the training
resume_if_exists: false
resume_ignore_no_checkpoint: false
# You may use this section to create a W&B logger
create_wandb_logger: false
wandb_logger_kwargs:
name: null
project: null