# Hyperparameters toggles prefix: "" ## SSL features Selection pretrained_models_path: pretrained_models/ # pretrained_models: # { # "wav2vec2_base": "facebook/wav2vec2-base", # 768 # "hubert_base": "facebook/hubert-base-ls960", # 768 # "wavlm_base": "microsoft/wavlm-base", # 768 # "wavlm_base_plus": "microsoft/wavlm-base-plus", # 768 # "hubert_multilingual": "utter-project/mHuBERT-147", # 768 # "clap" : "laion/clap-htsat-fused", # 768 # "data2vec_base": "facebook/data2vec-audio-base", # 768 # "wav2vec2_large": "facebook/wav2vec2-large", # 1024 # "hubert_large": "facebook/hubert-large-ls960", # 1024 # "wavlm_large": "microsoft/wavlm-large-plus", # 1024 # "data2vec_large": "facebook/data2vec-audio-large", #1024 # "whisper_medium": "openai/whisper-medium", # 1024 # "whisper_large_v3_turbo": "openai/whisper-large-v3-turbo", # 1280 # } # select pretrained SSL models perceived_ssl_model: "wavlm_large" # in pretrained_models canonical_ssl_model: Null # # models hidden size, varies by model ENCODER_DIM: 1024 # # How to fuse the features feature_fusion: "mono" # Options: "mono" for single ssl, "dual_ssl_enc" for dual ssl encoder, "dual_loss" for single SSL dual ssl loss blend_alpha: 0.5 # If using "blend" fusion # Input files # Data files data_folder_save: "./data" train_annotation: !ref /train-train.json valid_annotation: !ref /train-dev.json test_annotation: !ref /test.json # Extra data train_annotation_extra: !ref /train-train_with_extra.json use_extra_train_data: False evaluate_key: "PER" # use "mpd_f1_seq" for Transformer decoder path best mpd f1 # "PER_seq" for Transformer decoder's best error rate # "PER" for ctc path best error rate # "mpd_f1" for ctc path best mpd f1 max_save_models: 3 # Maximum number of saved models for each metrics # generate training id for output folder # generate_training_id: !apply:trainer.generate_training_id.generate_training_id [!ref , !ref , !ref , !ref ] # output files output_folder: !ref exp_l2arctic/___ per_file: !ref /per.txt mpd_file: !ref /mpd.txt save_folder: !ref /save train_log: !ref /train_log.txt on_training_test_wer_folder: !ref /on_training_test_wer on_training_test_mpd_folder: !ref /on_training_test_mpd # Training Target training_target: "target" # "target": deduplicated canonical phoneme sequence; "target_with_repeats": with repeats # "canonical" # "perceived": deduplicated perceived phoneme sequence perceived_ssl: !new:speechbrain.lobes.models.huggingface_transformers.wavlm.WavLM source: "microsoft/wavlm-large" freeze: !ref freeze_feature_extractor: !ref save_path: !ref output_all_hiddens: False preceived_ssl_emb_layer: -1 enc: !new:torch.nn.Sequential - !new:speechbrain.lobes.models.VanillaNN.VanillaNN input_shape: [null, null, !ref ] activation: !ref dnn_blocks: !ref dnn_neurons: !ref - !new:torch.nn.LayerNorm normalized_shape: !ref ctc_lin: !new:speechbrain.nnet.linear.Linear input_size: !ref n_neurons: !ref # 40 phonemes + 1 blank + 1 err # Model parameters activation: !name:torch.nn.LeakyReLU dnn_layers: 2 dnn_neurons: 384 freeze_perceived_ssl: False freeze_canonical_ssl: False freeze_perceived_feature_extractor: True # freeze the CNN extractor in wav2vec freeze_canonical_feature_extractor: True # Freeze Whisper encoder? log_softmax: !new:speechbrain.nnet.activations.Softmax apply_log: True ctc_cost: !name:speechbrain.nnet.losses.ctc_loss blank_index: !ref ctc_cost_mispro: !name:speechbrain.nnet.losses.ctc_loss blank_index: !ref # Outputs output_neurons: 44 # l2arctic: 40phns(sil)+err+blank + eos + bos =44 blank_index: 0 model: !new:torch.nn.ModuleList - [!ref , !ref ] adam_opt_class: !name:torch.optim.Adam lr: !ref pretrained_opt_class: !name:torch.optim.Adam lr: !ref checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer checkpoints_dir: !ref recoverables: model: !ref perceived_ssl: !ref counter: !ref allow_partial_load: True # canonical_ssl: !ref augmentation: !new:speechbrain.augment.time_domain.SpeedPerturb orig_freq: !ref speeds: [95, 100, 105] epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter limit: !ref train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger save_file: !ref ctc_stats: !name:speechbrain.utils.metric_stats.MetricStats metric: !name:speechbrain.nnet.losses.ctc_loss blank_index: !ref reduction: batch per_stats: !name:speechbrain.utils.metric_stats.ErrorRateStats # # TIMIT # timit_local_data_folder: "/common/db/TIMIT" # Path to TIMIT datase seed: 3047 __set_seed: !apply:torch.manual_seed [!ref ] # training parameters number_of_epochs: 100 batch_size: 16 lr: 0.0003 sorting: ascending sample_rate: 16000 gradient_accumulation: 2 lr_pretrained: 0.00001 # Mix-Precision Training auto_mix_prec: true # or precision: fp16 # 支持 "fp32"、"fp16" 或 "bf16" eval_precision: fp32 # 推理同样切换到 FP16 # Dataloader options train_dataloader_opts: batch_size: !ref valid_dataloader_opts: batch_size: !ref test_dataloader_opts: batch_size: 1 pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer collect_in: !ref / loadables: perceived_ssl: !ref model: !ref tokenizer: !ref encoder: !new:speechbrain.nnet.containers.LengthsCapableSequential perceived_ssl: !ref enc: !ref ctc_lin: !ref log_softmax: !ref decoding_function: !name:speechbrain.decoders.ctc_greedy_decode blank_id: !ref tokenizer: !new:speechbrain.dataio.encoder.CTCTextEncoder modules: encoder: !ref