NeMo / nemo /collections /asr /models /k2_sequence_models.py
camenduru's picture
thanks to NVIDIA ❤
7934b29
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional
from omegaconf import DictConfig
from pytorch_lightning import Trainer
from nemo.collections.asr.models.ctc_bpe_models import EncDecCTCModelBPE
from nemo.collections.asr.models.ctc_models import EncDecCTCModel
from nemo.collections.asr.models.rnnt_bpe_models import EncDecRNNTBPEModel
from nemo.collections.asr.models.rnnt_models import EncDecRNNTModel
from nemo.collections.asr.parts.k2.classes import ASRK2Mixin
from nemo.core.classes.common import PretrainedModelInfo, typecheck
from nemo.utils import logging
class EncDecK2SeqModel(EncDecCTCModel, ASRK2Mixin):
"""Encoder decoder models with various lattice losses."""
def __init__(self, cfg: DictConfig, trainer: Trainer = None):
loss_type = cfg.graph_module_cfg.get("loss_type", "ctc")
if loss_type != "ctc" and loss_type != "mmi":
raise ValueError(f"Class {self.__class__.__name__} does not support `loss_type`={loss_type}")
super().__init__(cfg=cfg, trainer=trainer)
self._init_k2()
@classmethod
def list_available_models(cls) -> Optional[List[PretrainedModelInfo]]:
"""
This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.
Returns:
List of available pre-trained models.
"""
pass
def change_vocabulary(self, new_vocabulary: List[str]):
"""
Changes vocabulary used during CTC decoding process. Use this method when fine-tuning on from pre-trained model.
This method changes only decoder and leaves encoder and pre-processing modules unchanged. For example, you would
use it if you want to use pretrained encoder when fine-tuning on a data in another language, or when you'd need
model to learn capitalization, punctuation and/or special characters.
If new_vocabulary == self.decoder.vocabulary then nothing will be changed.
Args:
new_vocabulary: list with new vocabulary. Must contain at least 2 elements. Typically, \
this is target alphabet.
Returns: None
"""
super().change_vocabulary(new_vocabulary)
if self.use_graph_lm:
self.token_lm = None
logging.warning(
f"""With .change_vocabulary() call for a model with criterion_type=`{self.loss.criterion_type}`,
a new token_lm has to be set manually: call .update_k2_modules(new_cfg)
or update .graph_module_cfg.backend_cfg.token_lm before calling this method."""
)
self.update_k2_modules(self.graph_module_cfg)
@typecheck()
def forward(
self, input_signal=None, input_signal_length=None, processed_signal=None, processed_signal_length=None,
):
"""
Forward pass of the model.
Args:
input_signal: Tensor that represents a batch of raw audio signals,
of shape [B, T]. T here represents timesteps, with 1 second of audio represented as
`self.sample_rate` number of floating point values.
input_signal_length: Vector of length B, that contains the individual lengths of the audio
sequences.
processed_signal: Tensor that represents a batch of processed audio signals,
of shape (B, D, T) that has undergone processing via some DALI preprocessor.
processed_signal_length: Vector of length B, that contains the individual lengths of the
processed audio sequences.
Returns:
A tuple of 3 elements -
1) The log probabilities tensor of shape [B, T, D].
2) The lengths of the acoustic sequence after propagation through the encoder, of shape [B].
3) The greedy token predictions of the model of shape [B, T] (via argmax)
"""
log_probs, encoded_len, greedy_predictions = super().forward(
input_signal=input_signal,
input_signal_length=input_signal_length,
processed_signal=processed_signal,
processed_signal_length=processed_signal_length,
)
return self._forward_k2_post_processing(
log_probs=log_probs, encoded_length=encoded_len, greedy_predictions=greedy_predictions
)
class EncDecK2SeqModelBPE(EncDecCTCModelBPE, ASRK2Mixin):
"""Encoder decoder models with Byte Pair Encoding and various lattice losses."""
def __init__(self, cfg: DictConfig, trainer: Trainer = None):
loss_type = cfg.graph_module_cfg.get("loss_type", "ctc")
if loss_type != "ctc" and loss_type != "mmi":
raise ValueError(f"Class {self.__class__.__name__} does not support `loss_type`={loss_type}")
super().__init__(cfg=cfg, trainer=trainer)
self._init_k2()
@classmethod
def list_available_models(cls) -> Optional[List[PretrainedModelInfo]]:
"""
This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.
Returns:
List of available pre-trained models.
"""
pass
def change_vocabulary(self, new_tokenizer_dir: str, new_tokenizer_type: str):
"""
Changes vocabulary of the tokenizer used during CTC decoding process.
Use this method when fine-tuning on from pre-trained model.
This method changes only decoder and leaves encoder and pre-processing modules unchanged. For example, you would
use it if you want to use pretrained encoder when fine-tuning on a data in another language, or when you'd need
model to learn capitalization, punctuation and/or special characters.
Args:
new_tokenizer_dir: Path to the new tokenizer directory.
new_tokenizer_type: Either `bpe` or `wpe`. `bpe` is used for SentencePiece tokenizers,
whereas `wpe` is used for `BertTokenizer`.
Returns: None
"""
super().change_vocabulary(new_tokenizer_dir, new_tokenizer_type)
if self.use_graph_lm:
self.token_lm = None
logging.warning(
f"""With .change_vocabulary() call for a model with criterion_type=`{self.loss.criterion_type}`,
a new token_lm has to be set manually: call .update_k2_modules(new_cfg)
or update .graph_module_cfg.backend_cfg.token_lm before calling this method."""
)
self.update_k2_modules(self.graph_module_cfg)
@typecheck()
def forward(
self, input_signal=None, input_signal_length=None, processed_signal=None, processed_signal_length=None,
):
"""
Forward pass of the model.
Args:
input_signal: Tensor that represents a batch of raw audio signals,
of shape [B, T]. T here represents timesteps, with 1 second of audio represented as
`self.sample_rate` number of floating point values.
input_signal_length: Vector of length B, that contains the individual lengths of the audio
sequences.
processed_signal: Tensor that represents a batch of processed audio signals,
of shape (B, D, T) that has undergone processing via some DALI preprocessor.
processed_signal_length: Vector of length B, that contains the individual lengths of the
processed audio sequences.
Returns:
A tuple of 3 elements -
1) The log probabilities tensor of shape [B, T, D].
2) The lengths of the acoustic sequence after propagation through the encoder, of shape [B].
3) The greedy token predictions of the model of shape [B, T] (via argmax)
"""
log_probs, encoded_len, greedy_predictions = super().forward(
input_signal=input_signal,
input_signal_length=input_signal_length,
processed_signal=processed_signal,
processed_signal_length=processed_signal_length,
)
return self._forward_k2_post_processing(
log_probs=log_probs, encoded_length=encoded_len, greedy_predictions=greedy_predictions
)
class EncDecK2RnntSeqModel(EncDecRNNTModel, ASRK2Mixin):
"""Encoder decoder models with various lattice losses."""
def __init__(self, cfg: DictConfig, trainer: Trainer = None):
loss_type = cfg.graph_module_cfg.get("loss_type", "rnnt")
criterion_type = cfg.graph_module_cfg.get("criterion_type", "ml")
if loss_type != "rnnt" or criterion_type != "ml":
raise ValueError(
f"""Class {self.__class__.__name__} does not support
`criterion_type`={criterion_type} with `loss_type`={loss_type}"""
)
super().__init__(cfg=cfg, trainer=trainer)
self._init_k2()
@classmethod
def list_available_models(cls) -> Optional[PretrainedModelInfo]:
"""
This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.
Returns:
List of available pre-trained models.
"""
pass
def change_vocabulary(self, new_vocabulary: List[str]):
"""
Changes vocabulary used during CTC decoding process. Use this method when fine-tuning on from pre-trained model.
This method changes only decoder and leaves encoder and pre-processing modules unchanged. For example, you would
use it if you want to use pretrained encoder when fine-tuning on a data in another language, or when you'd need
model to learn capitalization, punctuation and/or special characters.
If new_vocabulary == self.decoder.vocabulary then nothing will be changed.
Args:
new_vocabulary: list with new vocabulary. Must contain at least 2 elements. Typically, \
this is target alphabet.
Returns: None
"""
super().change_vocabulary(new_vocabulary)
if self.use_graph_lm:
self.token_lm = None
logging.warning(
f"""With .change_vocabulary() call for a model with criterion_type=`{self.loss.criterion_type}`,
a new token_lm has to be set manually: call .update_k2_modules(new_cfg)
or update .graph_module_cfg.backend_cfg.token_lm before calling this method."""
)
self.update_k2_modules(self.graph_module_cfg)
class EncDecK2RnntSeqModelBPE(EncDecRNNTBPEModel, ASRK2Mixin):
"""Encoder decoder models with Byte Pair Encoding and various lattice losses."""
def __init__(self, cfg: DictConfig, trainer: Trainer = None):
loss_type = cfg.graph_module_cfg.get("loss_type", "rnnt")
criterion_type = cfg.graph_module_cfg.get("criterion_type", "ml")
if loss_type != "rnnt" or criterion_type != "ml":
raise ValueError(
f"""Class {self.__class__.__name__} does not support
`criterion_type`={criterion_type} with `loss_type`={loss_type}"""
)
super().__init__(cfg=cfg, trainer=trainer)
self._init_k2()
@classmethod
def list_available_models(cls) -> Optional[PretrainedModelInfo]:
"""
This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.
Returns:
List of available pre-trained models.
"""
pass
def change_vocabulary(self, new_tokenizer_dir: str, new_tokenizer_type: str):
"""
Changes vocabulary of the tokenizer used during CTC decoding process.
Use this method when fine-tuning on from pre-trained model.
This method changes only decoder and leaves encoder and pre-processing modules unchanged. For example, you would
use it if you want to use pretrained encoder when fine-tuning on a data in another language, or when you'd need
model to learn capitalization, punctuation and/or special characters.
Args:
new_tokenizer_dir: Path to the new tokenizer directory.
new_tokenizer_type: Either `bpe` or `wpe`. `bpe` is used for SentencePiece tokenizers,
whereas `wpe` is used for `BertTokenizer`.
Returns: None
"""
super().change_vocabulary(new_tokenizer_dir, new_tokenizer_type)
if self.use_graph_lm:
self.token_lm = None
logging.warning(
f"""With .change_vocabulary() call for a model with criterion_type=`{self.loss.criterion_type}`,
a new token_lm has to be set manually: call .update_k2_modules(new_cfg)
or update .graph_module_cfg.backend_cfg.token_lm before calling this method."""
)
self.update_k2_modules(self.graph_module_cfg)