| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """Encoder definition.""" |
| from typing import Optional, Tuple |
|
|
| import torch |
|
|
| from wenet.utils.mask import make_pad_mask |
| from wenet.transformer.encoder import TransformerEncoder, ConformerEncoder |
|
|
|
|
| class DualTransformerEncoder(TransformerEncoder): |
| """Transformer encoder module.""" |
|
|
| def __init__( |
| self, |
| input_size: int, |
| output_size: int = 256, |
| attention_heads: int = 4, |
| linear_units: int = 2048, |
| num_blocks: int = 6, |
| dropout_rate: float = 0.1, |
| positional_dropout_rate: float = 0.1, |
| attention_dropout_rate: float = 0.0, |
| input_layer: str = "conv2d", |
| pos_enc_layer_type: str = "abs_pos", |
| normalize_before: bool = True, |
| static_chunk_size: int = 0, |
| use_dynamic_chunk: bool = False, |
| global_cmvn: torch.nn.Module = None, |
| use_dynamic_left_chunk: bool = False, |
| query_bias: bool = True, |
| key_bias: bool = True, |
| value_bias: bool = True, |
| activation_type: str = "relu", |
| gradient_checkpointing: bool = False, |
| use_sdpa: bool = False, |
| layer_norm_type: str = 'layer_norm', |
| norm_eps: float = 1e-5, |
| n_kv_head: Optional[int] = None, |
| head_dim: Optional[int] = None, |
| selfattention_layer_type: str = "selfattn", |
| mlp_type: str = 'position_wise_feed_forward', |
| mlp_bias: bool = True, |
| n_expert: int = 8, |
| n_expert_activated: int = 2, |
| ): |
| """ Construct DualTransformerEncoder |
| Support both the full context mode and the streaming mode separately |
| """ |
| super().__init__(input_size, output_size, attention_heads, |
| linear_units, num_blocks, dropout_rate, |
| positional_dropout_rate, attention_dropout_rate, |
| input_layer, pos_enc_layer_type, normalize_before, |
| static_chunk_size, use_dynamic_chunk, global_cmvn, |
| use_dynamic_left_chunk, query_bias, key_bias, |
| value_bias, activation_type, gradient_checkpointing, |
| use_sdpa, layer_norm_type, norm_eps, n_kv_head, |
| head_dim, selfattention_layer_type, mlp_type, |
| mlp_bias, n_expert, n_expert_activated) |
|
|
| def forward_full( |
| self, |
| xs: torch.Tensor, |
| xs_lens: torch.Tensor, |
| decoding_chunk_size: int = 0, |
| num_decoding_left_chunks: int = -1, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| T = xs.size(1) |
| masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) |
| if self.global_cmvn is not None: |
| xs = self.global_cmvn(xs) |
| xs, pos_emb, masks = self.embed(xs, masks) |
| mask_pad = masks |
| for layer in self.encoders: |
| xs, masks, _, _ = layer(xs, masks, pos_emb, mask_pad) |
| if self.normalize_before: |
| xs = self.after_norm(xs) |
| return xs, masks |
|
|
|
|
| class DualConformerEncoder(ConformerEncoder): |
| """Conformer encoder module.""" |
|
|
| def __init__( |
| self, |
| input_size: int, |
| output_size: int = 256, |
| attention_heads: int = 4, |
| linear_units: int = 2048, |
| num_blocks: int = 6, |
| dropout_rate: float = 0.1, |
| positional_dropout_rate: float = 0.1, |
| attention_dropout_rate: float = 0.0, |
| input_layer: str = "conv2d", |
| pos_enc_layer_type: str = "rel_pos", |
| normalize_before: bool = True, |
| static_chunk_size: int = 0, |
| use_dynamic_chunk: bool = False, |
| global_cmvn: torch.nn.Module = None, |
| use_dynamic_left_chunk: bool = False, |
| positionwise_conv_kernel_size: int = 1, |
| macaron_style: bool = True, |
| selfattention_layer_type: str = "rel_selfattn", |
| activation_type: str = "swish", |
| use_cnn_module: bool = True, |
| cnn_module_kernel: int = 15, |
| causal: bool = False, |
| cnn_module_norm: str = "batch_norm", |
| query_bias: bool = True, |
| key_bias: bool = True, |
| value_bias: bool = True, |
| conv_bias: bool = True, |
| gradient_checkpointing: bool = False, |
| use_sdpa: bool = False, |
| layer_norm_type: str = 'layer_norm', |
| norm_eps: float = 1e-5, |
| n_kv_head: Optional[int] = None, |
| head_dim: Optional[int] = None, |
| mlp_type: str = 'position_wise_feed_forward', |
| mlp_bias: bool = True, |
| n_expert: int = 8, |
| n_expert_activated: int = 2, |
| ): |
| """ Construct DualConformerEncoder |
| Support both the full context mode and the streaming mode separately |
| """ |
| super().__init__( |
| input_size, output_size, attention_heads, linear_units, num_blocks, |
| dropout_rate, positional_dropout_rate, attention_dropout_rate, |
| input_layer, pos_enc_layer_type, normalize_before, |
| static_chunk_size, use_dynamic_chunk, global_cmvn, |
| use_dynamic_left_chunk, positionwise_conv_kernel_size, |
| macaron_style, selfattention_layer_type, activation_type, |
| use_cnn_module, cnn_module_kernel, causal, cnn_module_norm, |
| query_bias, key_bias, value_bias, conv_bias, |
| gradient_checkpointing, use_sdpa, layer_norm_type, norm_eps, |
| n_kv_head, head_dim, mlp_type, mlp_bias, n_expert, |
| n_expert_activated) |
|
|
| def forward_full( |
| self, |
| xs: torch.Tensor, |
| xs_lens: torch.Tensor, |
| decoding_chunk_size: int = 0, |
| num_decoding_left_chunks: int = -1, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| T = xs.size(1) |
| masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) |
| if self.global_cmvn is not None: |
| xs = self.global_cmvn(xs) |
| xs, pos_emb, masks = self.embed(xs, masks) |
| mask_pad = masks |
| for layer in self.encoders: |
| xs, masks, _, _ = layer(xs, masks, pos_emb, mask_pad) |
| if self.normalize_before: |
| xs = self.after_norm(xs) |
| return xs, masks |
|
|