| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from typing import List, Optional |
| |
|
| | from omegaconf import DictConfig |
| | from pytorch_lightning import Trainer |
| |
|
| | from nemo.collections.asr.models.ctc_bpe_models import EncDecCTCModelBPE |
| | from nemo.collections.asr.models.ctc_models import EncDecCTCModel |
| | from nemo.collections.asr.models.rnnt_bpe_models import EncDecRNNTBPEModel |
| | from nemo.collections.asr.models.rnnt_models import EncDecRNNTModel |
| | from nemo.collections.asr.parts.k2.classes import ASRK2Mixin |
| | from nemo.core.classes.common import PretrainedModelInfo, typecheck |
| | from nemo.utils import logging |
| |
|
| |
|
| | class EncDecK2SeqModel(EncDecCTCModel, ASRK2Mixin): |
| | """Encoder decoder models with various lattice losses.""" |
| |
|
| | def __init__(self, cfg: DictConfig, trainer: Trainer = None): |
| | loss_type = cfg.graph_module_cfg.get("loss_type", "ctc") |
| | if loss_type != "ctc" and loss_type != "mmi": |
| | raise ValueError(f"Class {self.__class__.__name__} does not support `loss_type`={loss_type}") |
| | super().__init__(cfg=cfg, trainer=trainer) |
| | self._init_k2() |
| |
|
| | @classmethod |
| | def list_available_models(cls) -> Optional[List[PretrainedModelInfo]]: |
| | """ |
| | This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud. |
| | |
| | Returns: |
| | List of available pre-trained models. |
| | """ |
| | pass |
| |
|
| | def change_vocabulary(self, new_vocabulary: List[str]): |
| | """ |
| | Changes vocabulary used during CTC decoding process. Use this method when fine-tuning on from pre-trained model. |
| | This method changes only decoder and leaves encoder and pre-processing modules unchanged. For example, you would |
| | use it if you want to use pretrained encoder when fine-tuning on a data in another language, or when you'd need |
| | model to learn capitalization, punctuation and/or special characters. |
| | |
| | If new_vocabulary == self.decoder.vocabulary then nothing will be changed. |
| | |
| | Args: |
| | new_vocabulary: list with new vocabulary. Must contain at least 2 elements. Typically, \ |
| | this is target alphabet. |
| | |
| | Returns: None |
| | |
| | """ |
| | super().change_vocabulary(new_vocabulary) |
| |
|
| | if self.use_graph_lm: |
| | self.token_lm = None |
| | logging.warning( |
| | f"""With .change_vocabulary() call for a model with criterion_type=`{self.loss.criterion_type}`, |
| | a new token_lm has to be set manually: call .update_k2_modules(new_cfg) |
| | or update .graph_module_cfg.backend_cfg.token_lm before calling this method.""" |
| | ) |
| |
|
| | self.update_k2_modules(self.graph_module_cfg) |
| |
|
| | @typecheck() |
| | def forward( |
| | self, input_signal=None, input_signal_length=None, processed_signal=None, processed_signal_length=None, |
| | ): |
| | """ |
| | Forward pass of the model. |
| | |
| | Args: |
| | input_signal: Tensor that represents a batch of raw audio signals, |
| | of shape [B, T]. T here represents timesteps, with 1 second of audio represented as |
| | `self.sample_rate` number of floating point values. |
| | input_signal_length: Vector of length B, that contains the individual lengths of the audio |
| | sequences. |
| | processed_signal: Tensor that represents a batch of processed audio signals, |
| | of shape (B, D, T) that has undergone processing via some DALI preprocessor. |
| | processed_signal_length: Vector of length B, that contains the individual lengths of the |
| | processed audio sequences. |
| | |
| | Returns: |
| | A tuple of 3 elements - |
| | 1) The log probabilities tensor of shape [B, T, D]. |
| | 2) The lengths of the acoustic sequence after propagation through the encoder, of shape [B]. |
| | 3) The greedy token predictions of the model of shape [B, T] (via argmax) |
| | """ |
| | log_probs, encoded_len, greedy_predictions = super().forward( |
| | input_signal=input_signal, |
| | input_signal_length=input_signal_length, |
| | processed_signal=processed_signal, |
| | processed_signal_length=processed_signal_length, |
| | ) |
| | return self._forward_k2_post_processing( |
| | log_probs=log_probs, encoded_length=encoded_len, greedy_predictions=greedy_predictions |
| | ) |
| |
|
| |
|
| | class EncDecK2SeqModelBPE(EncDecCTCModelBPE, ASRK2Mixin): |
| | """Encoder decoder models with Byte Pair Encoding and various lattice losses.""" |
| |
|
| | def __init__(self, cfg: DictConfig, trainer: Trainer = None): |
| | loss_type = cfg.graph_module_cfg.get("loss_type", "ctc") |
| | if loss_type != "ctc" and loss_type != "mmi": |
| | raise ValueError(f"Class {self.__class__.__name__} does not support `loss_type`={loss_type}") |
| | super().__init__(cfg=cfg, trainer=trainer) |
| | self._init_k2() |
| |
|
| | @classmethod |
| | def list_available_models(cls) -> Optional[List[PretrainedModelInfo]]: |
| | """ |
| | This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud. |
| | |
| | Returns: |
| | List of available pre-trained models. |
| | """ |
| | pass |
| |
|
| | def change_vocabulary(self, new_tokenizer_dir: str, new_tokenizer_type: str): |
| | """ |
| | Changes vocabulary of the tokenizer used during CTC decoding process. |
| | Use this method when fine-tuning on from pre-trained model. |
| | This method changes only decoder and leaves encoder and pre-processing modules unchanged. For example, you would |
| | use it if you want to use pretrained encoder when fine-tuning on a data in another language, or when you'd need |
| | model to learn capitalization, punctuation and/or special characters. |
| | |
| | Args: |
| | new_tokenizer_dir: Path to the new tokenizer directory. |
| | new_tokenizer_type: Either `bpe` or `wpe`. `bpe` is used for SentencePiece tokenizers, |
| | whereas `wpe` is used for `BertTokenizer`. |
| | |
| | Returns: None |
| | |
| | """ |
| | super().change_vocabulary(new_tokenizer_dir, new_tokenizer_type) |
| |
|
| | if self.use_graph_lm: |
| | self.token_lm = None |
| | logging.warning( |
| | f"""With .change_vocabulary() call for a model with criterion_type=`{self.loss.criterion_type}`, |
| | a new token_lm has to be set manually: call .update_k2_modules(new_cfg) |
| | or update .graph_module_cfg.backend_cfg.token_lm before calling this method.""" |
| | ) |
| |
|
| | self.update_k2_modules(self.graph_module_cfg) |
| |
|
| | @typecheck() |
| | def forward( |
| | self, input_signal=None, input_signal_length=None, processed_signal=None, processed_signal_length=None, |
| | ): |
| | """ |
| | Forward pass of the model. |
| | |
| | Args: |
| | input_signal: Tensor that represents a batch of raw audio signals, |
| | of shape [B, T]. T here represents timesteps, with 1 second of audio represented as |
| | `self.sample_rate` number of floating point values. |
| | input_signal_length: Vector of length B, that contains the individual lengths of the audio |
| | sequences. |
| | processed_signal: Tensor that represents a batch of processed audio signals, |
| | of shape (B, D, T) that has undergone processing via some DALI preprocessor. |
| | processed_signal_length: Vector of length B, that contains the individual lengths of the |
| | processed audio sequences. |
| | |
| | Returns: |
| | A tuple of 3 elements - |
| | 1) The log probabilities tensor of shape [B, T, D]. |
| | 2) The lengths of the acoustic sequence after propagation through the encoder, of shape [B]. |
| | 3) The greedy token predictions of the model of shape [B, T] (via argmax) |
| | """ |
| | log_probs, encoded_len, greedy_predictions = super().forward( |
| | input_signal=input_signal, |
| | input_signal_length=input_signal_length, |
| | processed_signal=processed_signal, |
| | processed_signal_length=processed_signal_length, |
| | ) |
| | return self._forward_k2_post_processing( |
| | log_probs=log_probs, encoded_length=encoded_len, greedy_predictions=greedy_predictions |
| | ) |
| |
|
| |
|
| | class EncDecK2RnntSeqModel(EncDecRNNTModel, ASRK2Mixin): |
| | """Encoder decoder models with various lattice losses.""" |
| |
|
| | def __init__(self, cfg: DictConfig, trainer: Trainer = None): |
| | loss_type = cfg.graph_module_cfg.get("loss_type", "rnnt") |
| | criterion_type = cfg.graph_module_cfg.get("criterion_type", "ml") |
| | if loss_type != "rnnt" or criterion_type != "ml": |
| | raise ValueError( |
| | f"""Class {self.__class__.__name__} does not support |
| | `criterion_type`={criterion_type} with `loss_type`={loss_type}""" |
| | ) |
| | super().__init__(cfg=cfg, trainer=trainer) |
| | self._init_k2() |
| |
|
| | @classmethod |
| | def list_available_models(cls) -> Optional[PretrainedModelInfo]: |
| | """ |
| | This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud. |
| | |
| | Returns: |
| | List of available pre-trained models. |
| | """ |
| | pass |
| |
|
| | def change_vocabulary(self, new_vocabulary: List[str]): |
| | """ |
| | Changes vocabulary used during CTC decoding process. Use this method when fine-tuning on from pre-trained model. |
| | This method changes only decoder and leaves encoder and pre-processing modules unchanged. For example, you would |
| | use it if you want to use pretrained encoder when fine-tuning on a data in another language, or when you'd need |
| | model to learn capitalization, punctuation and/or special characters. |
| | |
| | If new_vocabulary == self.decoder.vocabulary then nothing will be changed. |
| | |
| | Args: |
| | new_vocabulary: list with new vocabulary. Must contain at least 2 elements. Typically, \ |
| | this is target alphabet. |
| | |
| | Returns: None |
| | |
| | """ |
| | super().change_vocabulary(new_vocabulary) |
| |
|
| | if self.use_graph_lm: |
| | self.token_lm = None |
| | logging.warning( |
| | f"""With .change_vocabulary() call for a model with criterion_type=`{self.loss.criterion_type}`, |
| | a new token_lm has to be set manually: call .update_k2_modules(new_cfg) |
| | or update .graph_module_cfg.backend_cfg.token_lm before calling this method.""" |
| | ) |
| |
|
| | self.update_k2_modules(self.graph_module_cfg) |
| |
|
| |
|
| | class EncDecK2RnntSeqModelBPE(EncDecRNNTBPEModel, ASRK2Mixin): |
| | """Encoder decoder models with Byte Pair Encoding and various lattice losses.""" |
| |
|
| | def __init__(self, cfg: DictConfig, trainer: Trainer = None): |
| | loss_type = cfg.graph_module_cfg.get("loss_type", "rnnt") |
| | criterion_type = cfg.graph_module_cfg.get("criterion_type", "ml") |
| | if loss_type != "rnnt" or criterion_type != "ml": |
| | raise ValueError( |
| | f"""Class {self.__class__.__name__} does not support |
| | `criterion_type`={criterion_type} with `loss_type`={loss_type}""" |
| | ) |
| | super().__init__(cfg=cfg, trainer=trainer) |
| | self._init_k2() |
| |
|
| | @classmethod |
| | def list_available_models(cls) -> Optional[PretrainedModelInfo]: |
| | """ |
| | This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud. |
| | |
| | Returns: |
| | List of available pre-trained models. |
| | """ |
| | pass |
| |
|
| | def change_vocabulary(self, new_tokenizer_dir: str, new_tokenizer_type: str): |
| | """ |
| | Changes vocabulary of the tokenizer used during CTC decoding process. |
| | Use this method when fine-tuning on from pre-trained model. |
| | This method changes only decoder and leaves encoder and pre-processing modules unchanged. For example, you would |
| | use it if you want to use pretrained encoder when fine-tuning on a data in another language, or when you'd need |
| | model to learn capitalization, punctuation and/or special characters. |
| | |
| | Args: |
| | new_tokenizer_dir: Path to the new tokenizer directory. |
| | new_tokenizer_type: Either `bpe` or `wpe`. `bpe` is used for SentencePiece tokenizers, |
| | whereas `wpe` is used for `BertTokenizer`. |
| | |
| | Returns: None |
| | |
| | """ |
| | super().change_vocabulary(new_tokenizer_dir, new_tokenizer_type) |
| |
|
| | if self.use_graph_lm: |
| | self.token_lm = None |
| | logging.warning( |
| | f"""With .change_vocabulary() call for a model with criterion_type=`{self.loss.criterion_type}`, |
| | a new token_lm has to be set manually: call .update_k2_modules(new_cfg) |
| | or update .graph_module_cfg.backend_cfg.token_lm before calling this method.""" |
| | ) |
| |
|
| | self.update_k2_modules(self.graph_module_cfg) |
| |
|