File size: 4,399 Bytes
b386992 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 | name: "predictive_model"
model:
type: predictive
sample_rate: 16000
skip_nan_grad: false
num_outputs: 1
normalize_input: true # normalize the input signal to 0dBFS
train_ds:
manifest_filepath: ???
input_key: noisy_filepath
target_key: clean_filepath
audio_duration: 2.04 # Number of STFT time frames = 1 + audio_duration // encoder.hop_length = 256
random_offset: true
normalization_signal: input_signal
batch_size: 8 # batch size may be increased based on the available memory
shuffle: true
num_workers: 8
pin_memory: true
validation_ds:
manifest_filepath: ???
input_key: noisy_filepath
target_key: clean_filepath
batch_size: 8
shuffle: false
num_workers: 4
pin_memory: true
encoder:
_target_: nemo.collections.audio.modules.transforms.AudioToSpectrogram
fft_length: 510 # Number of subbands in the STFT = fft_length // 2 + 1 = 256
hop_length: 128
magnitude_power: 0.5
scale: 0.33
decoder:
_target_: nemo.collections.audio.modules.transforms.SpectrogramToAudio
fft_length: ${model.encoder.fft_length}
hop_length: ${model.encoder.hop_length}
magnitude_power: ${model.encoder.magnitude_power}
scale: ${model.encoder.scale}
estimator:
_target_: nemo.collections.audio.parts.submodules.ncsnpp.SpectrogramNoiseConditionalScoreNetworkPlusPlus
in_channels: 1 # single-channel noisy input
out_channels: 1 # single-channel estimate
num_res_blocks: 3 # increased number of res blocks
pad_time_to: 64 # pad to 64 frames for the time dimension
pad_dimension_to: 0 # no padding in the frequency dimension
loss:
_target_: nemo.collections.audio.losses.MSELoss # computed in the time domain
metrics:
val:
sisdr: # output SI-SDR
_target_: torchmetrics.audio.ScaleInvariantSignalDistortionRatio
optim:
name: adam
lr: 1e-4
# optimizer arguments
betas: [0.9, 0.999]
weight_decay: 0.0
trainer:
devices: -1 # number of GPUs, -1 would use all available GPUs
num_nodes: 1
max_epochs: -1
max_steps: -1 # computed at runtime if not set
val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations
accelerator: auto
strategy: ddp
accumulate_grad_batches: 1
gradient_clip_val: null
precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP.
log_every_n_steps: 25 # Interval of logging.
enable_progress_bar: true
num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it
check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs
sync_batchnorm: true
enable_checkpointing: false # Provided by exp_manager
logger: false # Provided by exp_manager
exp_manager:
exp_dir: null
name: ${name}
# use exponential moving average for model parameters
ema:
enable: true
decay: 0.999 # decay rate
cpu_offload: false # offload EMA parameters to CPU to save GPU memory
every_n_steps: 1 # how often to update EMA weights
validate_original_weights: False # use original weights for validation calculation?
# logging
create_tensorboard_logger: true
# checkpointing
create_checkpoint_callback: true
checkpoint_callback_params:
# in case of multiple validation sets, first one is used
monitor: val_sisdr
mode: max
save_top_k: 5
always_save_nemo: true # saves the checkpoints as nemo files instead of PTL checkpoints
# early stopping
create_early_stopping_callback: true
early_stopping_callback_params:
monitor: val_sisdr
mode: max
min_delta: 0.0
patience: 20 # patience in terms of check_val_every_n_epoch
verbose: true
strict: false # Should be False to avoid a runtime error where EarlyStopping says monitor is unavailable, which sometimes happens with resumed training.
resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.
# you need to set these two to true to continue the training
resume_if_exists: false
resume_ignore_no_checkpoint: false
# You may use this section to create a W&B logger
create_wandb_logger: false
wandb_logger_kwargs:
name: null
project: null
|