Spaces:
Runtime error
Runtime error
| """Encoder self-attention layer definition.""" | |
| import math | |
| import pdb | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from vita.model.multimodal_encoder.whale.module.layer.attention import ( | |
| Conv1dLinear, | |
| MultiHeadedAttention, | |
| MultiLayeredConv1d, | |
| PositionalEncoding, | |
| PositionwiseFeedForward, | |
| RelPositionalEncoding, | |
| ) | |
| # from vita.model.multimodal_encoder.whale.module.component.utils import * | |
| from vita.model.multimodal_encoder.whale.utils import IGNORE_ID, add_optional_chunk_mask, strtobool | |
| def repeat(N, fn): | |
| """Repeat module N times. | |
| :param int N: repeat time | |
| :param function fn: function to generate module | |
| :return: repeated modules | |
| :rtype: MultiSequential | |
| """ | |
| return MultiSequential(*[fn(n) for n in range(N)]) | |
| class MultiSequential(torch.nn.Sequential): | |
| """Multi-input multi-output torch.nn.Sequential.""" | |
| def forward(self, x, masks, pos_emb): | |
| """Repeat.""" | |
| for m in self: | |
| x, masks, pos_emb = m(x, masks, pos_emb) | |
| return x, masks, pos_emb | |
| def infer(self, x, pos_emb, buffer, buffer_index, buffer_out): | |
| # type: (Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor] | |
| """Repeat.""" | |
| for m in self: | |
| x, pos_emb, buffer, buffer_index, buffer_out = m.infer( | |
| x, pos_emb, buffer, buffer_index, buffer_out | |
| ) | |
| return x, pos_emb, buffer, buffer_index, buffer_out | |
| def infer_hidden(self, x, pos_emb, buffer, buffer_index, buffer_out, hidden_out): | |
| # type: (Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor] | |
| """Repeat.""" | |
| for m in self: | |
| x, pos_emb, buffer, buffer_index, buffer_out = m.infer( | |
| x, pos_emb, buffer, buffer_index, buffer_out | |
| ) | |
| hidden_out.append(x) | |
| return x, pos_emb, buffer, buffer_index, buffer_out, hidden_out | |
| class TransformerLayer(nn.Module): | |
| """Transformer layer module. | |
| :param int size: input dim | |
| :param self_attn: self attention module | |
| :param feed_forward: feed forward module | |
| :param float dropout_rate: dropout rate | |
| :param bool normalize_before: whether to use layer_norm before the first block | |
| :param bool concat_after: whether to concat attention layer's input and output | |
| if True, additional linear will be applied. i.e. x -> x + linear(concat(x, att(x))) | |
| if False, no additional linear will be applied. i.e. x -> x + att(x) | |
| """ | |
| def __init__( | |
| self, size, self_attn, feed_forward, dropout_rate, normalize_before=True, concat_after=False | |
| ): | |
| """Construct an TransformerLayer object.""" | |
| super(TransformerLayer, self).__init__() | |
| self.self_attn = self_attn | |
| self.feed_forward = feed_forward | |
| self.norm1 = torch.nn.LayerNorm(size) | |
| self.norm2 = torch.nn.LayerNorm(size) | |
| self.dropout = nn.Dropout(dropout_rate) | |
| self.size = size | |
| self.normalize_before = normalize_before | |
| self.concat_after = concat_after | |
| if self.concat_after: | |
| self.concat_linear = nn.Linear(size + size, size) | |
| else: | |
| self.concat_linear = nn.Identity() | |
| def forward(self, x, mask, pos_emb): | |
| """Compute encoded features. | |
| :param torch.Tensor x: encoded source features (batch, max_time_in, size) | |
| :param torch.Tensor mask: mask for x (batch, max_time_in) | |
| :rtype: Tuple[torch.Tensor, torch.Tensor] | |
| """ | |
| residual = x | |
| if self.normalize_before: | |
| x = self.norm1(x) | |
| if self.concat_after: | |
| x_concat = torch.cat((x, self.self_attn(x, x, x, mask, pos_emb)), dim=-1) | |
| x = residual + self.concat_linear(x_concat) | |
| else: | |
| x = residual + self.dropout(self.self_attn(x, x, x, mask, pos_emb)) | |
| if not self.normalize_before: | |
| x = self.norm1(x) | |
| residual = x | |
| if self.normalize_before: | |
| x = self.norm2(x) | |
| x = residual + self.dropout(self.feed_forward(x)) | |
| if not self.normalize_before: | |
| x = self.norm2(x) | |
| return x, mask, pos_emb | |
| def infer(self, x, pos_emb, buffer, buffer_index, buffer_out): | |
| # type: (Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor] | |
| residual = x.clone() | |
| if self.normalize_before: | |
| x = self.norm1(x) | |
| if self.concat_after: | |
| x_att, buffer, buffer_index, buffer_out = self.self_attn.infer( | |
| x, x, x, pos_emb, buffer, buffer_index, buffer_out | |
| ) | |
| x_concat = torch.cat((x, x_att), dim=-1) | |
| x = residual + self.concat_linear(x_concat) | |
| else: | |
| x_att, buffer, buffer_index, buffer_out = self.self_attn.infer( | |
| x, x, x, pos_emb, buffer, buffer_index, buffer_out | |
| ) | |
| x = residual + x_att | |
| if not self.normalize_before: | |
| x = self.norm1(x) | |
| residual = x.clone() | |
| if self.normalize_before: | |
| x = self.norm2(x) | |
| x_feed, buffer, buffer_index, buffer_out = self.feed_forward.infer( | |
| x, buffer, buffer_index, buffer_out | |
| ) | |
| x = residual + x_feed | |
| if not self.normalize_before: | |
| x = self.norm2(x) | |
| return x, pos_emb, buffer, buffer_index, buffer_out | |
| class Transformer(torch.nn.Module): | |
| def add_arguments(group): | |
| """Add TDNN common arguments.""" | |
| group.add_argument( | |
| "--transformer-input-dim", default=256, type=int, help="Input dim of Transformer." | |
| ) | |
| group.add_argument( | |
| "--transformer-output-dim", default=4, type=int, help="Output dim of Transformer." | |
| ) | |
| group.add_argument( | |
| "--transformer-attention-dim", default=256, type=int, help="Dimention of attention." | |
| ) | |
| group.add_argument( | |
| "--transformer-attention-heads", | |
| default=4, | |
| type=int, | |
| help="The number of heads of multi head attention.", | |
| ) | |
| group.add_argument( | |
| "--transformer-linear-units", | |
| default=1024, | |
| type=int, | |
| help="The number of units of position-wise feed forward.", | |
| ) | |
| group.add_argument( | |
| "--transformer-num-blocks", default=6, type=int, help="The number of attention blocks." | |
| ) | |
| group.add_argument( | |
| "--transformer-dropout-rate", | |
| default=0.1, | |
| type=float, | |
| help="Dropout rate in Transformer.", | |
| ) | |
| group.add_argument( | |
| "--transformer-attention-dropout-rate", | |
| default=0.0, | |
| type=float, | |
| help="Dropout rate in attention.", | |
| ) | |
| group.add_argument( | |
| "--transformer-positional-dropout-rate", | |
| default=0.1, | |
| type=float, | |
| help="Dropout rate after adding positional encoding.", | |
| ) | |
| group.add_argument( | |
| "--transformer-input-layer", default="linear", type=str, help="Type of input layer" | |
| ) | |
| group.add_argument("--transformer-pos-enc-class", default="abs-enc", type=str, help="") | |
| group.add_argument( | |
| "--transformer-normalize-before", | |
| default=True, | |
| type=strtobool, | |
| help="Whether to use layer-norm before the first block.", | |
| ) | |
| group.add_argument( | |
| "--transformer-concat-after", | |
| default=False, | |
| type=strtobool, | |
| help="Whether to concat attention layer's input and output.", | |
| ) | |
| group.add_argument( | |
| "--transformer-positionwise-layer-type", | |
| default="linear", | |
| type=str, | |
| help="Linear of conv1d.", | |
| ) | |
| group.add_argument( | |
| "--transformer-positionwise-conv-kernel_size", | |
| default=1, | |
| type=int, | |
| help="Kernel size of positionwise conv1d layer.", | |
| ) | |
| group.add_argument("--transformer-chunk_size", default=-1, type=int, help="") | |
| group.add_argument("--transformer-left_chunks", default=-1, type=int, help="") | |
| group.add_argument("--transformer-dynamic-chunks", default=True, type=strtobool, help="") | |
| return group | |
| def __init__( | |
| self, | |
| args, | |
| input_dim=None, | |
| output_dim=None, | |
| attention_dim=None, | |
| attention_heads=None, | |
| linear_units=None, | |
| num_blocks=None, | |
| dropout_rate=None, | |
| positional_dropout_rate=None, | |
| attention_dropout_rate=None, | |
| input_layer=None, | |
| pos_enc_class=None, | |
| normalize_before=None, | |
| concat_after=None, | |
| positionwise_layer_type=None, | |
| positionwise_conv_kernel_size=None, | |
| chunk_size=None, | |
| left_chunks=None, | |
| ): | |
| """Construct an Encoder object.""" | |
| super(Transformer, self).__init__() | |
| if args is None: | |
| self.input_dim = input_dim | |
| self.output_dim = output_dim | |
| self.attention_dim = attention_dim | |
| self.attention_heads = attention_heads | |
| self.linear_units = linear_units | |
| self.num_blocks = num_blocks | |
| self.dropout_rate = dropout_rate | |
| self.positional_dropout_rate = positional_dropout_rate | |
| self.attention_dropout_rate = attention_dropout_rate | |
| self.input_layer = input_layer | |
| self.pos_enc_class = pos_enc_class | |
| self.normalize_before = normalize_before | |
| self.concat_after = concat_after | |
| self.positionwise_layer_type = positionwise_layer_type | |
| self.positionwise_conv_kernel_size = positionwise_conv_kernel_size | |
| self.chunk_size = chunk_size | |
| self.left_chunks = left_chunks | |
| else: | |
| self.input_dim = args.transformer_input_dim | |
| self.output_dim = args.transformer_output_dim | |
| self.attention_dim = args.transformer_attention_dim | |
| self.attention_heads = args.transformer_attention_heads | |
| self.linear_units = args.transformer_linear_units | |
| self.num_blocks = args.transformer_num_blocks | |
| self.dropout_rate = args.transformer_dropout_rate | |
| self.positional_dropout_rate = args.transformer_positional_dropout_rate | |
| self.attention_dropout_rate = args.transformer_attention_dropout_rate | |
| self.input_layer = args.transformer_input_layer | |
| self.pos_enc_class = args.transformer_pos_enc_class | |
| self.normalize_before = args.transformer_normalize_before | |
| self.concat_after = args.transformer_concat_after | |
| self.positionwise_layer_type = args.transformer_positionwise_layer_type | |
| self.positionwise_conv_kernel_size = args.transformer_positionwise_conv_kernel_size | |
| self.chunk_size = args.transformer_chunk_size | |
| self.left_chunks = args.transformer_left_chunks | |
| self.transformer_dynamic_chunks = args.transformer_dynamic_chunks | |
| if self.pos_enc_class == "abs-enc": | |
| pos_enc_args = (self.attention_dim, self.positional_dropout_rate) | |
| pos_enc_class = PositionalEncoding | |
| elif self.pos_enc_class == "rel-enc": | |
| pos_enc_args = ( | |
| self.attention_dim, | |
| self.positional_dropout_rate, | |
| self.chunk_size, | |
| self.left_chunks, | |
| ) | |
| pos_enc_class = RelPositionalEncoding | |
| if self.input_layer == "linear": | |
| self.embed = torch.nn.Sequential( | |
| torch.nn.Linear(self.input_dim, self.attention_dim), | |
| torch.nn.LayerNorm(self.attention_dim), | |
| torch.nn.Dropout(self.dropout_rate), | |
| torch.nn.ReLU(), | |
| ) | |
| elif self.input_layer == "embed": | |
| self.embed = torch.nn.Sequential( | |
| torch.nn.Embedding(self.input_dim, self.attention_dim, padding_idx=IGNORE_ID) | |
| ) | |
| elif self.input_layer == "none": | |
| self.embed = torch.nn.Sequential(torch.nn.Identity()) | |
| else: | |
| raise ValueError("unknown input_layer: " + self.input_layer) | |
| self.pe = pos_enc_class(*pos_enc_args) | |
| self.embed_layer_num = len(self.embed) | |
| if self.positionwise_layer_type == "linear": | |
| positionwise_layer = PositionwiseFeedForward | |
| positionwise_layer_args = (self.attention_dim, self.linear_units, self.dropout_rate) | |
| elif self.positionwise_layer_type == "conv1d": | |
| positionwise_layer = MultiLayeredConv1d | |
| positionwise_layer_args = ( | |
| self.attention_dim, | |
| self.linear_units, | |
| self.positionwise_conv_kernel_size, | |
| self.dropout_rate, | |
| ) | |
| elif self.positionwise_layer_type == "conv1d-linear": | |
| positionwise_layer = Conv1dLinear | |
| positionwise_layer_args = ( | |
| self.attention_dim, | |
| self.linear_units, | |
| self.positionwise_conv_kernel_size, | |
| self.dropout_rate, | |
| ) | |
| else: | |
| raise NotImplementedError("Support only linear or conv1d.") | |
| self.encoders = repeat( | |
| self.num_blocks, | |
| lambda lnum: TransformerLayer( | |
| self.attention_dim, | |
| MultiHeadedAttention( | |
| self.attention_heads, | |
| self.attention_dim, | |
| self.attention_dropout_rate, | |
| self.chunk_size, | |
| self.left_chunks, | |
| self.pos_enc_class, | |
| ), | |
| positionwise_layer(*positionwise_layer_args), | |
| self.dropout_rate, | |
| self.normalize_before, | |
| self.concat_after, | |
| ), | |
| ) | |
| if self.normalize_before: | |
| self.after_norm = torch.nn.LayerNorm(self.attention_dim) | |
| def forward(self, xs, ilens=None, masks=None): | |
| """Embed positions in tensor. | |
| :param torch.Tensor xs: input tensor | |
| :param torch.Tensor masks: input mask | |
| :return: position embedded tensor and mask | |
| :rtype Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| if self.transformer_dynamic_chunks == True: # and self.training: | |
| chunk_masks = add_optional_chunk_mask(xs, masks, True, True, 0, 0, -1) | |
| else: | |
| chunk_masks = add_optional_chunk_mask( | |
| xs, masks, False, False, self.chunk_size, self.chunk_size, self.left_chunks | |
| ).to(xs.device) | |
| xs = self.embed(xs) | |
| xs, pos_emb = self.pe(xs) | |
| xs, chunk_masks, pos_emb = self.encoders(xs, chunk_masks, pos_emb) | |
| if self.normalize_before: | |
| xs = self.after_norm(xs) | |
| return xs, ilens, masks | |
| def infer(self, xs, buffer, buffer_index, buffer_out): | |
| xs = self.embed(xs) | |
| # pe_index = buffer[buffer_index: buffer_index + 1].reshape([1]).to(torch.int64) | |
| # xs, pos_emb, pe_index[0] = self.pe.infer(xs, pe_index[0]) | |
| # buffer_out.append(pe_index.reshape(-1).to(torch.float32)) | |
| # buffer_index = buffer_index + 1 | |
| xs, pos_emb, _ = self.pe.infer(xs, 0) | |
| xs, pos_emb, buffer, buffer_index, buffer_out = self.encoders.infer( | |
| xs, pos_emb, buffer, buffer_index, buffer_out | |
| ) | |
| if self.normalize_before: | |
| xs = self.after_norm(xs) | |
| return xs, buffer, buffer_index, buffer_out | |
| def infer_hidden(self, xs, buffer, buffer_index, buffer_out, hidden_out): | |
| xs = self.embed(xs) | |
| # pe_index = buffer[buffer_index: buffer_index + 1].reshape([1]).to(torch.int64) | |
| # xs, pos_emb, pe_index[0] = self.pe.infer(xs, pe_index[0]) | |
| # buffer_out.append(pe_index.reshape(-1).to(torch.float32)) | |
| # buffer_index = buffer_index + 1 | |
| xs, pos_emb, _ = self.pe.infer(xs, 0) | |
| xs, pos_emb, buffer, buffer_index, buffer_out, hidden_out = self.encoders.infer_hidden( | |
| xs, pos_emb, buffer, buffer_index, buffer_out, hidden_out | |
| ) | |
| if self.normalize_before: | |
| xs = self.after_norm(xs) | |
| return xs, buffer, buffer_index, buffer_out, hidden_out | |