# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections import OrderedDict import torch import torch.distributed import torch.nn as nn from nemo.collections.asr.parts.submodules.subsampling import ConvSubsampling, StackingSubsampling from nemo.core.classes.common import typecheck from nemo.core.classes.exportable import Exportable from nemo.core.classes.module import NeuralModule from nemo.core.neural_types import AcousticEncodedRepresentation, LengthsType, NeuralType, SpectrogramType __all__ = ['RNNEncoder'] class RNNEncoder(NeuralModule, Exportable): """ The RNN-based encoder for ASR models. Followed the architecture suggested in the following paper: 'STREAMING END-TO-END SPEECH RECOGNITION FOR MOBILE DEVICES' by Yanzhang He et al. https://arxiv.org/pdf/1811.06621.pdf Args: feat_in (int): the size of feature channels n_layers (int): number of layers of RNN d_model (int): the hidden size of the model proj_size (int): the size of the output projection after each RNN layer rnn_type (str): the type of the RNN layers, choices=['lstm, 'gru', 'rnn'] bidirectional (float): specifies whether RNN layers should be bidirectional or not Defaults to True. feat_out (int): the size of the output features Defaults to -1 (means feat_out is d_model) subsampling (str): the method of subsampling, choices=['stacking, 'vggnet', 'striding'] Defaults to stacking. subsampling_factor (int): the subsampling factor Defaults to 4. subsampling_conv_channels (int): the size of the convolutions in the subsampling module for vggnet and striding Defaults to -1 which would set it to d_model. dropout (float): the dropout rate used between all layers Defaults to 0.2. """ def input_example(self): """ Generates input examples for tracing etc. Returns: A tuple of input examples. """ input_example = torch.randn(16, self._feat_in, 256).to(next(self.parameters()).device) input_example_length = torch.randint(0, 256, (16,)).to(next(self.parameters()).device) return tuple([input_example, input_example_length]) @property def input_types(self): """Returns definitions of module input ports. """ return OrderedDict( { "audio_signal": NeuralType(('B', 'D', 'T'), SpectrogramType()), "length": NeuralType(tuple('B'), LengthsType()), } ) @property def output_types(self): """Returns definitions of module output ports. """ return OrderedDict( { "outputs": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()), "encoded_lengths": NeuralType(tuple('B'), LengthsType()), } ) def __init__( self, feat_in: int, n_layers: int, d_model: int, proj_size: int = -1, rnn_type: str = 'lstm', bidirectional: bool = True, subsampling: str = 'striding', subsampling_factor: int = 4, subsampling_conv_channels: int = -1, dropout: float = 0.2, ): super().__init__() self.d_model = d_model self._feat_in = feat_in if subsampling_conv_channels == -1: subsampling_conv_channels = proj_size if subsampling and subsampling_factor > 1: if subsampling in ['stacking', 'stacking_norm']: self.pre_encode = StackingSubsampling( subsampling_factor=subsampling_factor, feat_in=feat_in, feat_out=proj_size, norm=True if 'norm' in subsampling else False, ) else: self.pre_encode = ConvSubsampling( subsampling=subsampling, subsampling_factor=subsampling_factor, feat_in=feat_in, feat_out=proj_size, conv_channels=subsampling_conv_channels, activation=nn.ReLU(), ) else: self.pre_encode = nn.Linear(feat_in, proj_size) self._feat_out = proj_size self.layers = nn.ModuleList() SUPPORTED_RNN = {"lstm": nn.LSTM, "gru": nn.GRU, "rnn": nn.RNN} if rnn_type not in SUPPORTED_RNN: raise ValueError(f"rnn_type can be one from the following:{SUPPORTED_RNN.keys()}") else: rnn_module = SUPPORTED_RNN[rnn_type] for i in range(n_layers): rnn_proj_size = proj_size // 2 if bidirectional else proj_size if rnn_type == "lstm": layer = rnn_module( input_size=self._feat_out, hidden_size=d_model, num_layers=1, batch_first=True, bidirectional=bidirectional, proj_size=rnn_proj_size, ) self.layers.append(layer) self.layers.append(nn.LayerNorm(proj_size)) self.layers.append(nn.Dropout(p=dropout)) self._feat_out = proj_size @typecheck() def forward(self, audio_signal, length=None): max_audio_length: int = audio_signal.size(-1) if length is None: length = audio_signal.new_full( audio_signal.size(0), max_audio_length, dtype=torch.int32, device=self.seq_range.device ) audio_signal = torch.transpose(audio_signal, 1, 2) if isinstance(self.pre_encode, nn.Linear): audio_signal = self.pre_encode(audio_signal) else: audio_signal, length = self.pre_encode(audio_signal, length) for lth, layer in enumerate(self.layers): audio_signal = layer(audio_signal) if isinstance(audio_signal, tuple): audio_signal, _ = audio_signal audio_signal = torch.transpose(audio_signal, 1, 2) return audio_signal, length