|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
from collections import OrderedDict |
|
|
from dataclasses import dataclass |
|
|
from typing import List, Optional, Set |
|
|
|
|
|
import torch |
|
|
import torch.distributed |
|
|
import torch.nn as nn |
|
|
from omegaconf import DictConfig, ListConfig |
|
|
|
|
|
from nemo.collections.asr.models.configs import CacheAwareStreamingConfig |
|
|
from nemo.collections.asr.parts.mixins.streaming import StreamingEncoder |
|
|
from nemo.collections.asr.parts.submodules.causal_convs import CausalConv1D |
|
|
from nemo.collections.asr.parts.submodules.conformer_modules import ConformerLayer |
|
|
from nemo.collections.asr.parts.submodules.multi_head_attention import ( |
|
|
LocalAttRelPositionalEncoding, |
|
|
MultiHeadAttention, |
|
|
PositionalEncoding, |
|
|
RelPositionalEncoding, |
|
|
RelPositionMultiHeadAttention, |
|
|
RelPositionMultiHeadAttentionLongformer, |
|
|
) |
|
|
from nemo.collections.asr.parts.submodules.subsampling import ( |
|
|
ConvSubsampling, |
|
|
StackingSubsampling, |
|
|
SubsamplingReductionModule, |
|
|
) |
|
|
from nemo.collections.asr.parts.utils import adapter_utils |
|
|
from nemo.collections.asr.parts.utils.regularization_utils import compute_stochastic_depth_drop_probs |
|
|
from nemo.core.classes.common import typecheck |
|
|
from nemo.core.classes.exportable import Exportable |
|
|
from nemo.core.classes.mixins import AccessMixin, adapter_mixins |
|
|
from nemo.core.classes.module import NeuralModule |
|
|
from nemo.core.neural_types import AcousticEncodedRepresentation, ChannelType, LengthsType, NeuralType, SpectrogramType |
|
|
|
|
|
__all__ = ['ConformerEncoder'] |
|
|
|
|
|
|
|
|
class ConformerEncoder(NeuralModule, StreamingEncoder, Exportable, AccessMixin): |
|
|
""" |
|
|
The encoder for ASR model of Conformer. |
|
|
Based on this paper: |
|
|
'Conformer: Convolution-augmented Transformer for Speech Recognition' by Anmol Gulati et al. |
|
|
https://arxiv.org/abs/2005.08100 |
|
|
|
|
|
Args: |
|
|
feat_in (int): the size of feature channels |
|
|
n_layers (int): number of layers of ConformerBlock |
|
|
d_model (int): the hidden size of the model |
|
|
feat_out (int): the size of the output features |
|
|
Defaults to -1 (means feat_out is d_model) |
|
|
subsampling (str): the method of subsampling, choices=['vggnet', 'striding'] |
|
|
Defaults to striding. |
|
|
subsampling_factor (int): the subsampling factor which should be power of 2 |
|
|
Defaults to 4. |
|
|
subsampling_conv_channels (int): the size of the convolutions in the subsampling module |
|
|
Defaults to -1 which would set it to d_model. |
|
|
reduction (str, Optional): the method of reduction, choices=['pooling', 'striding']. If no value |
|
|
is passed, then no reduction is performed and the models runs with the original 4x subsampling. |
|
|
reduction_position (int, Optional): the index of the layer to apply reduction. If -1, apply reduction |
|
|
at the end. |
|
|
reduction_factor (int): the reduction factor which should be either 1 or a power of 2 |
|
|
Defaults to 1. |
|
|
ff_expansion_factor (int): the expansion factor in feed forward layers |
|
|
Defaults to 4. |
|
|
self_attention_model (str): type of the attention layer and positional encoding |
|
|
'rel_pos': relative positional embedding and Transformer-XL |
|
|
'rel_pos_local_attn': relative positional embedding and Transformer-XL with local attention using |
|
|
overlapping chunks. Attention context is determined by att_context_size parameter. |
|
|
'abs_pos': absolute positional embedding and Transformer |
|
|
Default is rel_pos. |
|
|
pos_emb_max_len (int): the maximum length of positional embeddings |
|
|
Defaults to 5000 |
|
|
n_heads (int): number of heads in multi-headed attention layers |
|
|
Defaults to 4. |
|
|
att_context_size (List[int]): List of 2 ints corresponding to left and right attention context sizes, |
|
|
or None for full context. |
|
|
Defaults to None. |
|
|
xscaling (bool): enables scaling the inputs to the multi-headed attention layers by sqrt(d_model) |
|
|
Defaults to True. |
|
|
untie_biases (bool): whether to not share (untie) the bias weights between layers of Transformer-XL |
|
|
Defaults to True. |
|
|
conv_kernel_size (int): the size of the convolutions in the convolutional modules |
|
|
Defaults to 31. |
|
|
conv_norm_type (str): the type of the normalization in the convolutional modules |
|
|
Defaults to 'batch_norm'. |
|
|
dropout (float): the dropout rate used in all layers except the attention layers |
|
|
Defaults to 0.1. |
|
|
dropout_pre_encoder (float): the dropout rate used before the encoder |
|
|
Defaults to 0.1. |
|
|
dropout_emb (float): the dropout rate used for the positional embeddings |
|
|
Defaults to 0.1. |
|
|
dropout_att (float): the dropout rate used for the attention layer |
|
|
Defaults to 0.0. |
|
|
stochastic_depth_drop_prob (float): if non-zero, will randomly drop |
|
|
layers during training. The higher this value, the more often layers |
|
|
are dropped. Defaults to 0.0. |
|
|
stochastic_depth_mode (str): can be either "linear" or "uniform". If |
|
|
set to "uniform", all layers have the same probability of drop. If |
|
|
set to "linear", the drop probability grows linearly from 0 for the |
|
|
first layer to the desired value for the final layer. Defaults to |
|
|
"linear". |
|
|
stochastic_depth_start_layer (int): starting layer for stochastic depth. |
|
|
All layers before this will never be dropped. Note that drop |
|
|
probability will be adjusted accordingly if mode is "linear" when |
|
|
start layer is > 1. Defaults to 1. |
|
|
""" |
|
|
|
|
|
def input_example(self, max_batch=1, max_dim=256): |
|
|
""" |
|
|
Generates input examples for tracing etc. |
|
|
Returns: |
|
|
A tuple of input examples. |
|
|
""" |
|
|
dev = next(self.parameters()).device |
|
|
if self.export_cache_support: |
|
|
window_size = max_dim |
|
|
if self.streaming_cfg is not None: |
|
|
if isinstance(self.streaming_cfg.chunk_size, list): |
|
|
chunk_size = self.streaming_cfg.chunk_size[1] |
|
|
else: |
|
|
chunk_size = self.streaming_cfg.chunk_size |
|
|
if isinstance(self.streaming_cfg.pre_encode_cache_size, list): |
|
|
pre_encode_cache_size = self.streaming_cfg.pre_encode_cache_size[1] |
|
|
else: |
|
|
pre_encode_cache_size = self.streaming_cfg.pre_encode_cache_size |
|
|
window_size = chunk_size + pre_encode_cache_size |
|
|
input_example = torch.randn(max_batch, self._feat_in, window_size, device=dev) |
|
|
input_example_length = torch.randint( |
|
|
window_size // 4, window_size, (max_batch,), device=dev, dtype=torch.int64 |
|
|
) |
|
|
cache_last_channel, cache_last_time, cache_last_channel_len = self.get_initial_cache_state( |
|
|
batch_size=max_batch, device=dev, max_dim=max_dim |
|
|
) |
|
|
all_input_example = tuple( |
|
|
[ |
|
|
input_example, |
|
|
input_example_length, |
|
|
cache_last_channel.transpose(0, 1), |
|
|
cache_last_time.transpose(0, 1), |
|
|
cache_last_channel_len, |
|
|
] |
|
|
) |
|
|
else: |
|
|
input_example = torch.randn(max_batch, self._feat_in, max_dim, device=dev) |
|
|
input_example_length = torch.randint(max_dim // 4, max_dim, (max_batch,), device=dev, dtype=torch.int64) |
|
|
all_input_example = tuple([input_example, input_example_length]) |
|
|
|
|
|
return all_input_example |
|
|
|
|
|
@property |
|
|
def input_types(self): |
|
|
"""Returns definitions of module input ports.""" |
|
|
return OrderedDict( |
|
|
{ |
|
|
"audio_signal": NeuralType(('B', 'D', 'T'), SpectrogramType()), |
|
|
"length": NeuralType(tuple('B'), LengthsType()), |
|
|
"cache_last_channel": NeuralType(('D', 'B', 'T', 'D'), ChannelType(), optional=True), |
|
|
"cache_last_time": NeuralType(('D', 'B', 'D', 'T'), ChannelType(), optional=True), |
|
|
"cache_last_channel_len": NeuralType(tuple('B'), LengthsType(), optional=True), |
|
|
} |
|
|
) |
|
|
|
|
|
@property |
|
|
def output_types(self): |
|
|
"""Returns definitions of module output ports.""" |
|
|
return OrderedDict( |
|
|
{ |
|
|
"outputs": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()), |
|
|
"encoded_lengths": NeuralType(tuple('B'), LengthsType()), |
|
|
"cache_last_channel_next": NeuralType(('D', 'B', 'T', 'D'), ChannelType(), optional=True), |
|
|
"cache_last_time_next": NeuralType(('D', 'B', 'D', 'T'), ChannelType(), optional=True), |
|
|
"cache_last_channel_next_len": NeuralType(tuple('B'), LengthsType(), optional=True), |
|
|
} |
|
|
) |
|
|
|
|
|
@property |
|
|
def disabled_deployment_input_names(self): |
|
|
if not self.export_cache_support: |
|
|
return set(["cache_last_channel", "cache_last_time", "cache_last_channel_len"]) |
|
|
else: |
|
|
return set() |
|
|
|
|
|
@property |
|
|
def disabled_deployment_output_names(self): |
|
|
if not self.export_cache_support: |
|
|
return set(["cache_last_channel_next", "cache_last_time_next", "cache_last_channel_next_len"]) |
|
|
else: |
|
|
return set() |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
feat_in, |
|
|
n_layers, |
|
|
d_model, |
|
|
feat_out=-1, |
|
|
causal_downsampling=False, |
|
|
subsampling='striding', |
|
|
subsampling_factor=4, |
|
|
subsampling_conv_channels=-1, |
|
|
reduction=None, |
|
|
reduction_position=None, |
|
|
reduction_factor=1, |
|
|
ff_expansion_factor=4, |
|
|
self_attention_model='rel_pos', |
|
|
n_heads=4, |
|
|
att_context_size=None, |
|
|
att_context_style='regular', |
|
|
xscaling=True, |
|
|
untie_biases=True, |
|
|
pos_emb_max_len=5000, |
|
|
conv_kernel_size=31, |
|
|
conv_norm_type='batch_norm', |
|
|
conv_context_size=None, |
|
|
dropout=0.1, |
|
|
dropout_pre_encoder=0.1, |
|
|
dropout_emb=0.1, |
|
|
dropout_att=0.0, |
|
|
stochastic_depth_drop_prob: float = 0.0, |
|
|
stochastic_depth_mode: str = "linear", |
|
|
stochastic_depth_start_layer: int = 1, |
|
|
): |
|
|
super().__init__() |
|
|
d_ff = d_model * ff_expansion_factor |
|
|
self.d_model = d_model |
|
|
self.n_layers = n_layers |
|
|
self._feat_in = feat_in |
|
|
self.scale = math.sqrt(self.d_model) |
|
|
self.att_context_style = att_context_style |
|
|
self.subsampling_factor = subsampling_factor |
|
|
self.self_attention_model = self_attention_model |
|
|
|
|
|
if att_context_size: |
|
|
self.att_context_size = list(att_context_size) |
|
|
else: |
|
|
self.att_context_size = [-1, -1] |
|
|
|
|
|
if isinstance(conv_context_size, ListConfig): |
|
|
conv_context_size = list(conv_context_size) |
|
|
|
|
|
if conv_context_size is not None: |
|
|
if ( |
|
|
not isinstance(conv_context_size, list) |
|
|
and not isinstance(conv_context_size, str) |
|
|
and not isinstance(conv_context_size, ListConfig) |
|
|
): |
|
|
raise ValueError( |
|
|
f"Invalid conv_context_size! It should be the string 'causal' or a list of two integers." |
|
|
) |
|
|
if conv_context_size == "causal": |
|
|
conv_context_size = [conv_kernel_size - 1, 0] |
|
|
else: |
|
|
if conv_context_size[0] + conv_context_size[1] + 1 != conv_kernel_size: |
|
|
raise ValueError(f"Invalid conv_context_size: {self.conv_context_size}!") |
|
|
else: |
|
|
conv_context_size = [(conv_kernel_size - 1) // 2, (conv_kernel_size - 1) // 2] |
|
|
self.conv_context_size = conv_context_size |
|
|
|
|
|
if att_context_style == "chunked_limited": |
|
|
|
|
|
|
|
|
if self.att_context_size[0] > 0 and self.att_context_size[0] % (self.att_context_size[1] + 1) > 0: |
|
|
raise ValueError("att_context_size[0] % (att_context_size[1] + 1) should be zero!") |
|
|
if self.att_context_size[1] < 0: |
|
|
raise ValueError("Right context can not be unlimited for chunked_limited style!") |
|
|
self.chunk_size = self.att_context_size[1] + 1 |
|
|
|
|
|
|
|
|
if self.att_context_size[0] >= 0: |
|
|
self.left_chunks_num = self.att_context_size[0] // self.chunk_size |
|
|
else: |
|
|
self.left_chunks_num = 100000 |
|
|
|
|
|
elif att_context_style == "regular": |
|
|
self.chunk_size = None |
|
|
else: |
|
|
raise ValueError("Invalid att_context_style!") |
|
|
|
|
|
if xscaling: |
|
|
self.xscale = math.sqrt(d_model) |
|
|
else: |
|
|
self.xscale = None |
|
|
|
|
|
|
|
|
if subsampling_conv_channels == -1: |
|
|
subsampling_conv_channels = d_model |
|
|
if subsampling and subsampling_factor > 1: |
|
|
if subsampling in ['stacking', 'stacking_norm']: |
|
|
|
|
|
self.pre_encode = StackingSubsampling( |
|
|
subsampling_factor=subsampling_factor, |
|
|
feat_in=feat_in, |
|
|
feat_out=d_model, |
|
|
norm=True if subsampling == 'stacking_norm' else False, |
|
|
) |
|
|
else: |
|
|
self.pre_encode = ConvSubsampling( |
|
|
subsampling=subsampling, |
|
|
subsampling_factor=subsampling_factor, |
|
|
feat_in=feat_in, |
|
|
feat_out=d_model, |
|
|
conv_channels=subsampling_conv_channels, |
|
|
activation=nn.ReLU(True), |
|
|
is_causal=causal_downsampling, |
|
|
) |
|
|
else: |
|
|
self.pre_encode = nn.Linear(feat_in, d_model) |
|
|
|
|
|
|
|
|
if reduction and reduction_factor > 1: |
|
|
assert reduction_position >= -1 and reduction_position < n_layers |
|
|
self.reduction_subsampling = SubsamplingReductionModule( |
|
|
reduction=reduction, d_model=d_model, reduction_factor=reduction_factor, |
|
|
) |
|
|
self.reduction_position = reduction_position |
|
|
else: |
|
|
self.reduction_subsampling = None |
|
|
self.reduction_position = None |
|
|
|
|
|
self._feat_out = d_model |
|
|
|
|
|
if not untie_biases and self_attention_model == "rel_pos": |
|
|
d_head = d_model // n_heads |
|
|
pos_bias_u = nn.Parameter(torch.Tensor(n_heads, d_head)) |
|
|
pos_bias_v = nn.Parameter(torch.Tensor(n_heads, d_head)) |
|
|
nn.init.zeros_(pos_bias_u) |
|
|
nn.init.zeros_(pos_bias_v) |
|
|
else: |
|
|
pos_bias_u = None |
|
|
pos_bias_v = None |
|
|
|
|
|
self.pos_emb_max_len = pos_emb_max_len |
|
|
self.att_mask = None |
|
|
if self_attention_model == "rel_pos": |
|
|
self.pos_enc = RelPositionalEncoding( |
|
|
d_model=d_model, |
|
|
dropout_rate=dropout_pre_encoder, |
|
|
max_len=pos_emb_max_len, |
|
|
xscale=self.xscale, |
|
|
dropout_rate_emb=dropout_emb, |
|
|
) |
|
|
elif self_attention_model == 'rel_pos_local_attn': |
|
|
if max(att_context_size) <= 0: |
|
|
raise ValueError("When using local attention, context size must be set > 0") |
|
|
self.pos_enc = LocalAttRelPositionalEncoding( |
|
|
att_context_size=att_context_size, |
|
|
d_model=d_model, |
|
|
dropout_rate=dropout, |
|
|
max_len=pos_emb_max_len, |
|
|
xscale=self.xscale, |
|
|
dropout_rate_emb=dropout_emb, |
|
|
) |
|
|
elif self_attention_model == "abs_pos": |
|
|
pos_bias_u = None |
|
|
pos_bias_v = None |
|
|
self.pos_enc = PositionalEncoding( |
|
|
d_model=d_model, dropout_rate=dropout_pre_encoder, max_len=pos_emb_max_len, xscale=self.xscale |
|
|
) |
|
|
else: |
|
|
raise ValueError(f"Not valid self_attention_model: '{self_attention_model}'!") |
|
|
|
|
|
self.layers = nn.ModuleList() |
|
|
for i in range(n_layers): |
|
|
layer = ConformerLayer( |
|
|
d_model=d_model, |
|
|
d_ff=d_ff, |
|
|
self_attention_model=self_attention_model, |
|
|
n_heads=n_heads, |
|
|
conv_kernel_size=conv_kernel_size, |
|
|
conv_norm_type=conv_norm_type, |
|
|
conv_context_size=self.conv_context_size, |
|
|
dropout=dropout, |
|
|
dropout_att=dropout_att, |
|
|
pos_bias_u=pos_bias_u, |
|
|
pos_bias_v=pos_bias_v, |
|
|
att_context_size=self.att_context_size, |
|
|
) |
|
|
self.layers.append(layer) |
|
|
|
|
|
if feat_out > 0 and feat_out != self._feat_out: |
|
|
self.out_proj = nn.Linear(self._feat_out, feat_out) |
|
|
self._feat_out = feat_out |
|
|
else: |
|
|
self.out_proj = None |
|
|
self._feat_out = d_model |
|
|
self.set_max_audio_length(self.pos_emb_max_len) |
|
|
self.use_pad_mask = True |
|
|
|
|
|
self.setup_streaming_params() |
|
|
self.export_cache_support = False |
|
|
|
|
|
self.layer_drop_probs = compute_stochastic_depth_drop_probs( |
|
|
len(self.layers), stochastic_depth_drop_prob, stochastic_depth_mode, stochastic_depth_start_layer |
|
|
) |
|
|
|
|
|
self.interctc_capture_at_layers = None |
|
|
|
|
|
def update_max_seq_length(self, seq_length: int, device): |
|
|
|
|
|
if torch.distributed.is_initialized(): |
|
|
global_max_len = torch.tensor([seq_length], dtype=torch.float32, device=device) |
|
|
|
|
|
|
|
|
torch.distributed.all_reduce(global_max_len, op=torch.distributed.ReduceOp.MAX) |
|
|
|
|
|
seq_length = global_max_len.to(torch.int64).item() |
|
|
|
|
|
if seq_length > self.max_audio_length: |
|
|
self.set_max_audio_length(seq_length) |
|
|
|
|
|
def set_max_audio_length(self, max_audio_length): |
|
|
""" |
|
|
Sets maximum input length. |
|
|
Pre-calculates internal seq_range mask. |
|
|
""" |
|
|
self.max_audio_length = max_audio_length |
|
|
device = next(self.parameters()).device |
|
|
self.pos_enc.extend_pe(max_audio_length, device) |
|
|
|
|
|
if self.self_attention_model != "rel_pos_local_attn": |
|
|
att_mask = torch.ones(1, max_audio_length, max_audio_length, dtype=torch.bool, device=device) |
|
|
if self.chunk_size is None: |
|
|
if self.att_context_size[0] >= 0: |
|
|
att_mask = att_mask.triu(diagonal=-self.att_context_size[0]) |
|
|
if self.att_context_size[1] >= 0: |
|
|
att_mask = att_mask.tril(diagonal=self.att_context_size[1]) |
|
|
else: |
|
|
chunk_idx = torch.arange(0, max_audio_length, dtype=torch.int64, device=att_mask.device) |
|
|
chunk_idx = torch.div(chunk_idx, self.chunk_size, rounding_mode="trunc") |
|
|
diff_chunks = chunk_idx.unsqueeze(1) - chunk_idx.unsqueeze(0) |
|
|
chunked_limited_mask = torch.logical_and( |
|
|
torch.le(diff_chunks, self.left_chunks_num), torch.ge(diff_chunks, 0) |
|
|
) |
|
|
att_mask = torch.logical_and(att_mask, chunked_limited_mask.unsqueeze(0)) |
|
|
|
|
|
if hasattr(self, 'att_mask'): |
|
|
self.att_mask = att_mask |
|
|
else: |
|
|
self.register_buffer('att_mask', att_mask, persistent=False) |
|
|
else: |
|
|
self.att_mask = None |
|
|
|
|
|
def forward_for_export( |
|
|
self, audio_signal, length, cache_last_channel=None, cache_last_time=None, cache_last_channel_len=None |
|
|
): |
|
|
if cache_last_channel is not None: |
|
|
cache_last_channel = cache_last_channel.transpose(0, 1) |
|
|
cache_last_time = cache_last_time.transpose(0, 1) |
|
|
|
|
|
rets = self.forward_internal( |
|
|
audio_signal, |
|
|
length, |
|
|
cache_last_channel=cache_last_channel, |
|
|
cache_last_time=cache_last_time, |
|
|
cache_last_channel_len=cache_last_channel_len, |
|
|
) |
|
|
rets = self.streaming_post_process(rets, keep_all_outputs=False) |
|
|
if len(rets) == 2: |
|
|
return rets |
|
|
else: |
|
|
return ( |
|
|
rets[0], |
|
|
rets[1], |
|
|
rets[2].transpose(0, 1), |
|
|
rets[3].transpose(0, 1), |
|
|
rets[4], |
|
|
) |
|
|
|
|
|
def streaming_post_process(self, rets, keep_all_outputs=True): |
|
|
if len(rets) == 2: |
|
|
return rets |
|
|
|
|
|
(encoded, encoded_len, cache_last_channel_next, cache_last_time_next, cache_last_channel_next_len) = rets |
|
|
|
|
|
if cache_last_channel_next is not None and self.streaming_cfg.last_channel_cache_size >= 0: |
|
|
if self.streaming_cfg.last_channel_cache_size > 0: |
|
|
cache_last_channel_next = cache_last_channel_next[ |
|
|
:, :, -self.streaming_cfg.last_channel_cache_size :, : |
|
|
] |
|
|
|
|
|
if self.streaming_cfg.valid_out_len > 0 and (not keep_all_outputs or self.att_context_style == "regular"): |
|
|
encoded = encoded[:, :, : self.streaming_cfg.valid_out_len] |
|
|
encoded_len = torch.clamp(encoded_len, max=self.streaming_cfg.valid_out_len) |
|
|
|
|
|
return (encoded, encoded_len, cache_last_channel_next, cache_last_time_next, cache_last_channel_next_len) |
|
|
|
|
|
@typecheck() |
|
|
def forward( |
|
|
self, audio_signal, length, cache_last_channel=None, cache_last_time=None, cache_last_channel_len=None |
|
|
): |
|
|
return self.forward_internal( |
|
|
audio_signal, |
|
|
length, |
|
|
cache_last_channel=cache_last_channel, |
|
|
cache_last_time=cache_last_time, |
|
|
cache_last_channel_len=cache_last_channel_len, |
|
|
) |
|
|
|
|
|
def forward_internal( |
|
|
self, audio_signal, length, cache_last_channel=None, cache_last_time=None, cache_last_channel_len=None |
|
|
): |
|
|
self.update_max_seq_length(seq_length=audio_signal.size(2), device=audio_signal.device) |
|
|
max_audio_length = audio_signal.size(-1) |
|
|
|
|
|
if length is None: |
|
|
length = audio_signal.new_full( |
|
|
(audio_signal.size(0),), max_audio_length, dtype=torch.int64, device=audio_signal.device |
|
|
) |
|
|
|
|
|
if cache_last_time is not None: |
|
|
cache_last_time_next = torch.zeros_like(cache_last_time) |
|
|
else: |
|
|
cache_last_time_next = None |
|
|
audio_signal = torch.transpose(audio_signal, 1, 2) |
|
|
|
|
|
if isinstance(self.pre_encode, nn.Linear): |
|
|
audio_signal = self.pre_encode(audio_signal) |
|
|
else: |
|
|
audio_signal, length = self.pre_encode(x=audio_signal, lengths=length) |
|
|
|
|
|
if self.streaming_cfg.drop_extra_pre_encoded > 0 and cache_last_channel is not None: |
|
|
audio_signal = audio_signal[:, self.streaming_cfg.drop_extra_pre_encoded :, :] |
|
|
length = (length - self.streaming_cfg.drop_extra_pre_encoded).clamp(min=0) |
|
|
|
|
|
max_audio_length = audio_signal.size(1) |
|
|
|
|
|
if self.reduction_position is not None and cache_last_channel is not None: |
|
|
raise ValueError("Caching with reduction feature is not supported yet!") |
|
|
|
|
|
if cache_last_channel is not None: |
|
|
cache_len = self.streaming_cfg.last_channel_cache_size |
|
|
cache_keep_size = max_audio_length - self.streaming_cfg.cache_drop_size |
|
|
cache_last_channel_next = torch.zeros_like(cache_last_channel) |
|
|
max_audio_length = max_audio_length + cache_len |
|
|
padding_length = length + cache_len |
|
|
offset = torch.neg(cache_last_channel_len) + cache_len |
|
|
else: |
|
|
padding_length = length |
|
|
cache_last_channel_next = None |
|
|
cache_len = 0 |
|
|
offset = None |
|
|
|
|
|
if self.self_attention_model == 'abs_pos': |
|
|
audio_signal, pos_emb = self.pos_enc(x=audio_signal) |
|
|
else: |
|
|
audio_signal, pos_emb = self.pos_enc(x=audio_signal, cache_len=cache_len) |
|
|
|
|
|
|
|
|
pad_mask, att_mask = self._create_masks(max_audio_length, padding_length, offset, audio_signal.device) |
|
|
|
|
|
if cache_last_channel is not None: |
|
|
pad_mask = pad_mask[:, cache_len:] |
|
|
if self.att_mask is not None: |
|
|
att_mask = att_mask[:, cache_len:] |
|
|
|
|
|
for lth, (drop_prob, layer) in enumerate(zip(self.layer_drop_probs, self.layers)): |
|
|
original_signal = audio_signal |
|
|
audio_signal = layer( |
|
|
x=audio_signal, |
|
|
att_mask=att_mask, |
|
|
pos_emb=pos_emb, |
|
|
pad_mask=pad_mask, |
|
|
cache_last_channel=cache_last_channel, |
|
|
cache_last_time=cache_last_time, |
|
|
cache_last_channel_next=cache_last_channel_next, |
|
|
cache_last_time_next=cache_last_time_next, |
|
|
) |
|
|
|
|
|
if self.training and drop_prob > 0.0: |
|
|
should_drop = torch.rand(1) < drop_prob |
|
|
|
|
|
if should_drop: |
|
|
|
|
|
|
|
|
|
|
|
audio_signal = audio_signal * 0.0 + original_signal |
|
|
else: |
|
|
|
|
|
audio_signal = (audio_signal - original_signal) / (1.0 - drop_prob) + original_signal |
|
|
|
|
|
if self.reduction_position == lth: |
|
|
audio_signal, length = self.reduction_subsampling(x=audio_signal, lengths=length) |
|
|
max_audio_length = audio_signal.size(1) |
|
|
|
|
|
|
|
|
_, pos_emb = self.pos_enc(x=audio_signal, cache_len=cache_len) |
|
|
pad_mask, att_mask = self._create_masks(max_audio_length, length, offset, audio_signal.device) |
|
|
|
|
|
|
|
|
if self.is_access_enabled(): |
|
|
if self.interctc_capture_at_layers is None: |
|
|
self.interctc_capture_at_layers = self.access_cfg.get('interctc', {}).get('capture_layers', []) |
|
|
if lth in self.interctc_capture_at_layers: |
|
|
lth_audio_signal = audio_signal |
|
|
if self.out_proj is not None: |
|
|
lth_audio_signal = self.out_proj(audio_signal) |
|
|
|
|
|
self.register_accessible_tensor( |
|
|
name=f'interctc/layer_output_{lth}', tensor=torch.transpose(lth_audio_signal, 1, 2) |
|
|
) |
|
|
self.register_accessible_tensor(name=f'interctc/layer_length_{lth}', tensor=length) |
|
|
|
|
|
if self.out_proj is not None: |
|
|
audio_signal = self.out_proj(audio_signal) |
|
|
|
|
|
|
|
|
if self.reduction_position == -1: |
|
|
audio_signal, length = self.reduction_subsampling(x=audio_signal, lengths=length) |
|
|
|
|
|
audio_signal = torch.transpose(audio_signal, 1, 2) |
|
|
length = length.to(dtype=torch.int64) |
|
|
|
|
|
if cache_last_channel is not None: |
|
|
return ( |
|
|
audio_signal, |
|
|
length, |
|
|
cache_last_channel_next, |
|
|
cache_last_time_next, |
|
|
torch.clamp(cache_last_channel_len + cache_keep_size, max=cache_len), |
|
|
) |
|
|
else: |
|
|
return audio_signal, length |
|
|
|
|
|
def _create_masks(self, max_audio_length, padding_length, offset, device): |
|
|
|
|
|
pad_mask = torch.arange(0, max_audio_length, device=device).expand( |
|
|
padding_length.size(0), -1 |
|
|
) < padding_length.unsqueeze(-1) |
|
|
|
|
|
if offset is not None: |
|
|
pad_mask_off = torch.arange(0, max_audio_length, device=device).expand( |
|
|
padding_length.size(0), -1 |
|
|
) >= offset.unsqueeze(-1) |
|
|
|
|
|
pad_mask = pad_mask_off.logical_and(pad_mask) |
|
|
|
|
|
if self.att_mask is not None: |
|
|
|
|
|
pad_mask_for_att_mask = pad_mask.unsqueeze(1).repeat([1, max_audio_length, 1]) |
|
|
pad_mask_for_att_mask = torch.logical_and(pad_mask_for_att_mask, pad_mask_for_att_mask.transpose(1, 2)) |
|
|
|
|
|
att_mask = self.att_mask[:, :max_audio_length, :max_audio_length] |
|
|
|
|
|
att_mask = torch.logical_and(pad_mask_for_att_mask, att_mask.to(pad_mask_for_att_mask.device)) |
|
|
|
|
|
att_mask = ~att_mask |
|
|
else: |
|
|
att_mask = None |
|
|
|
|
|
pad_mask = ~pad_mask |
|
|
|
|
|
return pad_mask, att_mask |
|
|
|
|
|
def enable_pad_mask(self, on=True): |
|
|
|
|
|
mask = self.use_pad_mask |
|
|
self.use_pad_mask = on |
|
|
return mask |
|
|
|
|
|
def setup_streaming_params( |
|
|
self, chunk_size: int = None, shift_size: int = None, left_chunks: int = None, max_context: int = 10000 |
|
|
): |
|
|
""" |
|
|
This function sets the needed values and parameters to perform streaming. The configuration would be stored in self.streaming_cfg. |
|
|
The streaming configuration is needed to simulate streaming inference. |
|
|
Args: |
|
|
chunk_size (int): overrides the chunk size |
|
|
shift_size (int): overrides the shift size for chunks |
|
|
left_chunks (int): overrides the number of left chunks visible to each chunk |
|
|
max_context (int): the value used for the cache size of last_channel layers if left context is set to infinity (-1) |
|
|
Defaults to -1 (means feat_out is d_model) |
|
|
""" |
|
|
streaming_cfg = CacheAwareStreamingConfig() |
|
|
if chunk_size is not None: |
|
|
if chunk_size < 1: |
|
|
raise ValueError("chunk_size needs to be a number larger or equal to one.") |
|
|
lookahead_steps = chunk_size - 1 |
|
|
streaming_cfg.cache_drop_size = chunk_size - shift_size |
|
|
elif self.att_context_style == "chunked_limited": |
|
|
lookahead_steps = self.att_context_size[1] |
|
|
streaming_cfg.cache_drop_size = 0 |
|
|
elif self.att_context_style == "regular": |
|
|
lookahead_steps = self.att_context_size[1] * self.n_layers + self.conv_context_size[1] * self.n_layers |
|
|
streaming_cfg.cache_drop_size = lookahead_steps |
|
|
else: |
|
|
streaming_cfg.cache_drop_size = 0 |
|
|
lookahead_steps = None |
|
|
|
|
|
if chunk_size is None: |
|
|
streaming_cfg.last_channel_cache_size = ( |
|
|
self.att_context_size[0] if self.att_context_size[0] >= 0 else max_context |
|
|
) |
|
|
else: |
|
|
if left_chunks is None: |
|
|
raise ValueError("left_chunks can not be None when chunk_size is set.") |
|
|
streaming_cfg.last_channel_cache_size = left_chunks * chunk_size |
|
|
|
|
|
if hasattr(self.pre_encode, "get_sampling_frames"): |
|
|
sampling_frames = self.pre_encode.get_sampling_frames() |
|
|
else: |
|
|
sampling_frames = 0 |
|
|
|
|
|
if isinstance(sampling_frames, list): |
|
|
streaming_cfg.chunk_size = [ |
|
|
sampling_frames[0] + self.subsampling_factor * lookahead_steps, |
|
|
sampling_frames[1] + self.subsampling_factor * lookahead_steps, |
|
|
] |
|
|
else: |
|
|
streaming_cfg.chunk_size = sampling_frames * (1 + lookahead_steps) |
|
|
|
|
|
if isinstance(sampling_frames, list): |
|
|
streaming_cfg.shift_size = [ |
|
|
sampling_frames[0] + sampling_frames[1] * (lookahead_steps - streaming_cfg.cache_drop_size), |
|
|
sampling_frames[1] + sampling_frames[1] * (lookahead_steps - streaming_cfg.cache_drop_size), |
|
|
] |
|
|
else: |
|
|
streaming_cfg.shift_size = sampling_frames * (1 + lookahead_steps - streaming_cfg.cache_drop_size) |
|
|
|
|
|
if isinstance(streaming_cfg.shift_size, list): |
|
|
streaming_cfg.valid_out_len = ( |
|
|
streaming_cfg.shift_size[1] - sampling_frames[1] |
|
|
) // self.subsampling_factor + 1 |
|
|
else: |
|
|
streaming_cfg.valid_out_len = streaming_cfg.shift_size // self.subsampling_factor |
|
|
|
|
|
if hasattr(self.pre_encode, "get_streaming_cache_size"): |
|
|
streaming_cfg.pre_encode_cache_size = self.pre_encode.get_streaming_cache_size() |
|
|
else: |
|
|
streaming_cfg.pre_encode_cache_size = 0 |
|
|
|
|
|
if isinstance(streaming_cfg.pre_encode_cache_size, list): |
|
|
if streaming_cfg.pre_encode_cache_size[1] >= 1: |
|
|
streaming_cfg.drop_extra_pre_encoded = ( |
|
|
1 + (streaming_cfg.pre_encode_cache_size[1] - 1) // self.subsampling_factor |
|
|
) |
|
|
else: |
|
|
streaming_cfg.drop_extra_pre_encoded = 0 |
|
|
else: |
|
|
streaming_cfg.drop_extra_pre_encoded = streaming_cfg.pre_encode_cache_size // self.subsampling_factor |
|
|
|
|
|
|
|
|
streaming_cfg.last_channel_num = 0 |
|
|
streaming_cfg.last_time_num = 0 |
|
|
for m in self.layers.modules(): |
|
|
if hasattr(m, "_max_cache_len"): |
|
|
if isinstance(m, MultiHeadAttention): |
|
|
m._cache_id = streaming_cfg.last_channel_num |
|
|
m.cache_drop_size = streaming_cfg.cache_drop_size |
|
|
streaming_cfg.last_channel_num += 1 |
|
|
|
|
|
if isinstance(m, CausalConv1D): |
|
|
m._cache_id = streaming_cfg.last_time_num |
|
|
m.cache_drop_size = streaming_cfg.cache_drop_size |
|
|
streaming_cfg.last_time_num += 1 |
|
|
|
|
|
self.streaming_cfg = streaming_cfg |
|
|
|
|
|
def get_initial_cache_state(self, batch_size=1, dtype=torch.float32, device=None, max_dim=0): |
|
|
if device is None: |
|
|
device = next(self.parameters()).device |
|
|
if max_dim > 0: |
|
|
create_tensor = torch.randn |
|
|
else: |
|
|
create_tensor = torch.zeros |
|
|
last_time_cache_size = self.conv_context_size[0] |
|
|
cache_last_channel = create_tensor( |
|
|
( |
|
|
self.streaming_cfg.last_channel_num, |
|
|
batch_size, |
|
|
self.streaming_cfg.last_channel_cache_size, |
|
|
self.d_model, |
|
|
), |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
) |
|
|
cache_last_time = create_tensor( |
|
|
(self.streaming_cfg.last_time_num, batch_size, self.d_model, last_time_cache_size), |
|
|
device=device, |
|
|
dtype=dtype, |
|
|
) |
|
|
if max_dim > 0: |
|
|
cache_last_channel_len = torch.randint( |
|
|
0, |
|
|
min(max_dim, self.streaming_cfg.last_channel_cache_size), |
|
|
(batch_size,), |
|
|
device=device, |
|
|
dtype=torch.int64, |
|
|
) |
|
|
for i in range(batch_size): |
|
|
cache_last_channel[:, i, cache_last_channel_len[i] :, :] = 0 |
|
|
|
|
|
if cache_last_channel_len[i] == 0: |
|
|
cache_last_time[:, i, :, :] = 0 |
|
|
else: |
|
|
cache_last_channel_len = torch.zeros(batch_size, device=device, dtype=torch.int64) |
|
|
return cache_last_channel, cache_last_time, cache_last_channel_len |
|
|
|
|
|
def change_attention_model( |
|
|
self, |
|
|
self_attention_model: str = None, |
|
|
att_context_size: List[int] = None, |
|
|
update_config: bool = True, |
|
|
device: torch.device = None, |
|
|
): |
|
|
|
|
|
""" |
|
|
Update the self_attention_model which changes the positional encoding and attention layers. |
|
|
|
|
|
Args: |
|
|
self_attention_model (str): type of the attention layer and positional encoding |
|
|
'rel_pos': relative positional embedding and Transformer-XL |
|
|
'rel_pos_local_attn': relative positional embedding and Transformer-XL with local attention using |
|
|
overlapping windows. Attention context is determined by att_context_size parameter. |
|
|
'abs_pos': absolute positional embedding and Transformer |
|
|
If None is provided, the self_attention_model isn't changed. Defauts to None. |
|
|
att_context_size (List[int]): List of 2 ints corresponding to left and right attention context sizes, |
|
|
or None to keep as it is. Defauts to None. |
|
|
update_config (bool): Whether to update the config or not with the new attention model. |
|
|
Defaults to True. |
|
|
device (torch.device): If provided, new layers will be moved to the device. |
|
|
Defaults to None. |
|
|
""" |
|
|
|
|
|
if att_context_size: |
|
|
att_context_size = list(att_context_size) |
|
|
else: |
|
|
att_context_size = self._cfg.att_context_size |
|
|
|
|
|
if self_attention_model is None: |
|
|
self_attention_model = self._cfg.self_attention_model |
|
|
|
|
|
if self_attention_model == 'rel_pos_local_attn' and max(att_context_size) <= 0: |
|
|
raise ValueError("When using local attention, context size must be set > 0") |
|
|
|
|
|
if self_attention_model == "rel_pos": |
|
|
self.att_mask = None |
|
|
new_pos_enc = RelPositionalEncoding( |
|
|
d_model=self._cfg.d_model, |
|
|
dropout_rate=self._cfg.dropout, |
|
|
max_len=self._cfg.pos_emb_max_len, |
|
|
xscale=self.xscale, |
|
|
dropout_rate_emb=self._cfg.dropout_emb, |
|
|
) |
|
|
elif self_attention_model == 'rel_pos_local_attn': |
|
|
new_pos_enc = LocalAttRelPositionalEncoding( |
|
|
att_context_size=att_context_size, |
|
|
d_model=self._cfg.d_model, |
|
|
dropout_rate=self._cfg.dropout, |
|
|
max_len=self._cfg.pos_emb_max_len, |
|
|
xscale=self.xscale, |
|
|
dropout_rate_emb=self._cfg.dropout_emb, |
|
|
) |
|
|
elif self_attention_model == "abs_pos": |
|
|
new_pos_enc = PositionalEncoding( |
|
|
d_model=self._cfg.d_model, |
|
|
dropout_rate=self._cfg.dropout, |
|
|
max_len=self._cfg.pos_emb_max_len, |
|
|
xscale=self.xscale, |
|
|
) |
|
|
else: |
|
|
raise ValueError(f"Not valid self_attention_model: '{self_attention_model}'!") |
|
|
|
|
|
if device is not None: |
|
|
new_pos_enc = new_pos_enc.to(device=device) |
|
|
del self.pos_enc |
|
|
self.pos_enc = new_pos_enc |
|
|
self.self_attention_model = self_attention_model |
|
|
self.att_context_size = att_context_size |
|
|
self.set_max_audio_length(self.pos_emb_max_len) |
|
|
|
|
|
for name, m in self.named_modules(): |
|
|
if type(m) == ConformerLayer: |
|
|
|
|
|
if self_attention_model == 'rel_pos': |
|
|
new_attn = RelPositionMultiHeadAttention( |
|
|
n_head=self._cfg.n_heads, |
|
|
n_feat=self._cfg.d_model, |
|
|
dropout_rate=self._cfg.dropout_att, |
|
|
max_cache_len=att_context_size[0], |
|
|
pos_bias_u=None, |
|
|
pos_bias_v=None, |
|
|
) |
|
|
elif self_attention_model == 'rel_pos_local_attn': |
|
|
new_attn = RelPositionMultiHeadAttentionLongformer( |
|
|
n_head=self._cfg.n_heads, |
|
|
n_feat=self._cfg.d_model, |
|
|
dropout_rate=self._cfg.dropout_att, |
|
|
max_cache_len=att_context_size[0], |
|
|
att_context_size=att_context_size, |
|
|
pos_bias_u=None, |
|
|
pos_bias_v=None, |
|
|
) |
|
|
elif self_attention_model == 'abs_pos': |
|
|
new_attn = MultiHeadAttention( |
|
|
n_head=self._cfg.n_heads, |
|
|
n_feat=self._cfg.d_model, |
|
|
dropout_rate=self._cfg.dropout_att, |
|
|
max_cache_len=att_context_size[0], |
|
|
) |
|
|
else: |
|
|
raise ValueError( |
|
|
f"'{self_attention_model}' is not not a valid value for 'self_attention_model', " |
|
|
f"valid values can be from ['rel_pos', 'rel_pos_local_attn', 'abs_pos']" |
|
|
) |
|
|
if device is not None: |
|
|
new_attn = new_attn.to(device=device) |
|
|
new_attn.load_state_dict(m.self_attn.state_dict(), strict=False) |
|
|
del m.self_attn |
|
|
m.self_attn = new_attn |
|
|
m.self_attention_model = self_attention_model |
|
|
|
|
|
if update_config: |
|
|
self._cfg.self_attention_model = self_attention_model |
|
|
self._cfg.att_context_size = att_context_size |
|
|
|
|
|
|
|
|
class ConformerEncoderAdapter(ConformerEncoder, adapter_mixins.AdapterModuleMixin): |
|
|
|
|
|
|
|
|
def add_adapter(self, name: str, cfg: dict): |
|
|
cfg = self._update_adapter_cfg_input_dim(cfg) |
|
|
for conformer_layer in self.layers: |
|
|
conformer_layer.add_adapter(name, cfg) |
|
|
|
|
|
def is_adapter_available(self) -> bool: |
|
|
return any([conformer_layer.is_adapter_available() for conformer_layer in self.layers]) |
|
|
|
|
|
def set_enabled_adapters(self, name: Optional[str] = None, enabled: bool = True): |
|
|
for conformer_layer in self.layers: |
|
|
conformer_layer.set_enabled_adapters(name=name, enabled=enabled) |
|
|
|
|
|
def get_enabled_adapters(self) -> List[str]: |
|
|
names = set([]) |
|
|
for conformer_layer in self.layers: |
|
|
names.update(conformer_layer.get_enabled_adapters()) |
|
|
|
|
|
names = sorted(list(names)) |
|
|
return names |
|
|
|
|
|
def _update_adapter_cfg_input_dim(self, cfg: DictConfig): |
|
|
cfg = adapter_utils.update_adapter_cfg_input_dim(self, cfg, module_dim=self.d_model) |
|
|
return cfg |
|
|
|
|
|
def get_accepted_adapter_types(self,) -> Set[type]: |
|
|
types = super().get_accepted_adapter_types() |
|
|
|
|
|
if len(types) == 0: |
|
|
self.set_accepted_adapter_types( |
|
|
[ |
|
|
adapter_utils.LINEAR_ADAPTER_CLASSPATH, |
|
|
adapter_utils.MHA_ADAPTER_CLASSPATH, |
|
|
adapter_utils.RELMHA_ADAPTER_CLASSPATH, |
|
|
] |
|
|
) |
|
|
types = self.get_accepted_adapter_types() |
|
|
return types |
|
|
|
|
|
|
|
|
""" |
|
|
Register any additional information |
|
|
""" |
|
|
if adapter_mixins.get_registered_adapter(ConformerEncoder) is None: |
|
|
adapter_mixins.register_adapter(base_class=ConformerEncoder, adapter_class=ConformerEncoderAdapter) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ConformerChangeConfig: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self_attention_model: Optional[str] = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
att_context_size: Optional[List[int]] = None |
|
|
|