| | |
| |
|
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import torch |
| | from modules.wenet_extractor.transducer.joint import TransducerJoint |
| | from modules.wenet_extractor.transducer.predictor import ( |
| | ConvPredictor, |
| | EmbeddingPredictor, |
| | RNNPredictor, |
| | ) |
| | from modules.wenet_extractor.transducer.transducer import Transducer |
| | from modules.wenet_extractor.transformer.asr_model import ASRModel |
| | from modules.wenet_extractor.transformer.cmvn import GlobalCMVN |
| | from modules.wenet_extractor.transformer.ctc import CTC |
| | from modules.wenet_extractor.transformer.decoder import ( |
| | BiTransformerDecoder, |
| | TransformerDecoder, |
| | ) |
| | from modules.wenet_extractor.transformer.encoder import ( |
| | ConformerEncoder, |
| | TransformerEncoder, |
| | ) |
| | from modules.wenet_extractor.squeezeformer.encoder import SqueezeformerEncoder |
| | from modules.wenet_extractor.efficient_conformer.encoder import ( |
| | EfficientConformerEncoder, |
| | ) |
| | from modules.wenet_extractor.paraformer.paraformer import Paraformer |
| | from modules.wenet_extractor.cif.predictor import Predictor |
| | from modules.wenet_extractor.utils.cmvn import load_cmvn |
| |
|
| |
|
| | def init_model(configs): |
| | if configs["cmvn_file"] is not None: |
| | mean, istd = load_cmvn(configs["cmvn_file"], configs["is_json_cmvn"]) |
| | global_cmvn = GlobalCMVN( |
| | torch.from_numpy(mean).float(), torch.from_numpy(istd).float() |
| | ) |
| | else: |
| | global_cmvn = None |
| |
|
| | input_dim = configs["input_dim"] |
| | vocab_size = configs["output_dim"] |
| |
|
| | encoder_type = configs.get("encoder", "conformer") |
| | decoder_type = configs.get("decoder", "bitransformer") |
| |
|
| | if encoder_type == "conformer": |
| | encoder = ConformerEncoder( |
| | input_dim, global_cmvn=global_cmvn, **configs["encoder_conf"] |
| | ) |
| | elif encoder_type == "squeezeformer": |
| | encoder = SqueezeformerEncoder( |
| | input_dim, global_cmvn=global_cmvn, **configs["encoder_conf"] |
| | ) |
| | elif encoder_type == "efficientConformer": |
| | encoder = EfficientConformerEncoder( |
| | input_dim, |
| | global_cmvn=global_cmvn, |
| | **configs["encoder_conf"], |
| | **configs["encoder_conf"]["efficient_conf"] |
| | if "efficient_conf" in configs["encoder_conf"] |
| | else {}, |
| | ) |
| | else: |
| | encoder = TransformerEncoder( |
| | input_dim, global_cmvn=global_cmvn, **configs["encoder_conf"] |
| | ) |
| | if decoder_type == "transformer": |
| | decoder = TransformerDecoder( |
| | vocab_size, encoder.output_size(), **configs["decoder_conf"] |
| | ) |
| | else: |
| | assert 0.0 < configs["model_conf"]["reverse_weight"] < 1.0 |
| | assert configs["decoder_conf"]["r_num_blocks"] > 0 |
| | decoder = BiTransformerDecoder( |
| | vocab_size, encoder.output_size(), **configs["decoder_conf"] |
| | ) |
| | ctc = CTC(vocab_size, encoder.output_size()) |
| |
|
| | |
| | if "predictor" in configs: |
| | predictor_type = configs.get("predictor", "rnn") |
| | if predictor_type == "rnn": |
| | predictor = RNNPredictor(vocab_size, **configs["predictor_conf"]) |
| | elif predictor_type == "embedding": |
| | predictor = EmbeddingPredictor(vocab_size, **configs["predictor_conf"]) |
| | configs["predictor_conf"]["output_size"] = configs["predictor_conf"][ |
| | "embed_size" |
| | ] |
| | elif predictor_type == "conv": |
| | predictor = ConvPredictor(vocab_size, **configs["predictor_conf"]) |
| | configs["predictor_conf"]["output_size"] = configs["predictor_conf"][ |
| | "embed_size" |
| | ] |
| | else: |
| | raise NotImplementedError("only rnn, embedding and conv type support now") |
| | configs["joint_conf"]["enc_output_size"] = configs["encoder_conf"][ |
| | "output_size" |
| | ] |
| | configs["joint_conf"]["pred_output_size"] = configs["predictor_conf"][ |
| | "output_size" |
| | ] |
| | joint = TransducerJoint(vocab_size, **configs["joint_conf"]) |
| | model = Transducer( |
| | vocab_size=vocab_size, |
| | blank=0, |
| | predictor=predictor, |
| | encoder=encoder, |
| | attention_decoder=decoder, |
| | joint=joint, |
| | ctc=ctc, |
| | **configs["model_conf"], |
| | ) |
| | elif "paraformer" in configs: |
| | predictor = Predictor(**configs["cif_predictor_conf"]) |
| | model = Paraformer( |
| | vocab_size=vocab_size, |
| | encoder=encoder, |
| | decoder=decoder, |
| | ctc=ctc, |
| | predictor=predictor, |
| | **configs["model_conf"], |
| | ) |
| | else: |
| | model = ASRModel( |
| | vocab_size=vocab_size, |
| | encoder=encoder, |
| | decoder=decoder, |
| | ctc=ctc, |
| | lfmmi_dir=configs.get("lfmmi_dir", ""), |
| | **configs["model_conf"], |
| | ) |
| | return model |
| |
|