Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| # -*- coding: utf-8 -*- | |
| # Copyright 2019 Shigeki Karita | |
| # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) | |
| """Decoder self-attention layer definition.""" | |
| import torch | |
| from torch import nn | |
| from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm | |
| class DecoderLayer(nn.Module): | |
| """Single decoder layer module. | |
| :param int size: input dim | |
| :param espnet.nets.pytorch_backend.transformer.attention.MultiHeadedAttention | |
| self_attn: self attention module | |
| :param espnet.nets.pytorch_backend.transformer.attention.MultiHeadedAttention | |
| src_attn: source attention module | |
| :param espnet.nets.pytorch_backend.transformer.positionwise_feed_forward. | |
| PositionwiseFeedForward feed_forward: feed forward layer module | |
| :param float dropout_rate: dropout rate | |
| :param bool normalize_before: whether to use layer_norm before the first block | |
| :param bool concat_after: whether to concat attention layer's input and output | |
| if True, additional linear will be applied. | |
| i.e. x -> x + linear(concat(x, att(x))) | |
| if False, no additional linear will be applied. i.e. x -> x + att(x) | |
| """ | |
| def __init__( | |
| self, | |
| size, | |
| self_attn, | |
| src_attn, | |
| feed_forward, | |
| dropout_rate, | |
| normalize_before=True, | |
| concat_after=False, | |
| ): | |
| """Construct an DecoderLayer object.""" | |
| super(DecoderLayer, self).__init__() | |
| self.size = size | |
| self.self_attn = self_attn | |
| self.src_attn = src_attn | |
| self.feed_forward = feed_forward | |
| self.norm1 = LayerNorm(size) | |
| self.norm2 = LayerNorm(size) | |
| self.norm3 = LayerNorm(size) | |
| self.dropout = nn.Dropout(dropout_rate) | |
| self.normalize_before = normalize_before | |
| self.concat_after = concat_after | |
| if self.concat_after: | |
| self.concat_linear1 = nn.Linear(size + size, size) | |
| self.concat_linear2 = nn.Linear(size + size, size) | |
| def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None): | |
| """Compute decoded features. | |
| Args: | |
| tgt (torch.Tensor): | |
| decoded previous target features (batch, max_time_out, size) | |
| tgt_mask (torch.Tensor): mask for x (batch, max_time_out) | |
| memory (torch.Tensor): encoded source features (batch, max_time_in, size) | |
| memory_mask (torch.Tensor): mask for memory (batch, max_time_in) | |
| cache (torch.Tensor): cached output (batch, max_time_out-1, size) | |
| """ | |
| residual = tgt | |
| if self.normalize_before: | |
| tgt = self.norm1(tgt) | |
| if cache is None: | |
| tgt_q = tgt | |
| tgt_q_mask = tgt_mask | |
| else: | |
| # compute only the last frame query keeping dim: max_time_out -> 1 | |
| assert cache.shape == ( | |
| tgt.shape[0], | |
| tgt.shape[1] - 1, | |
| self.size, | |
| ), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}" | |
| tgt_q = tgt[:, -1:, :] | |
| residual = residual[:, -1:, :] | |
| tgt_q_mask = None | |
| if tgt_mask is not None: | |
| tgt_q_mask = tgt_mask[:, -1:, :] | |
| if self.concat_after: | |
| tgt_concat = torch.cat( | |
| (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1 | |
| ) | |
| x = residual + self.concat_linear1(tgt_concat) | |
| else: | |
| x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)) | |
| if not self.normalize_before: | |
| x = self.norm1(x) | |
| residual = x | |
| if self.normalize_before: | |
| x = self.norm2(x) | |
| if self.concat_after: | |
| x_concat = torch.cat( | |
| (x, self.src_attn(x, memory, memory, memory_mask)), dim=-1 | |
| ) | |
| x = residual + self.concat_linear2(x_concat) | |
| else: | |
| x = residual + self.dropout(self.src_attn(x, memory, memory, memory_mask)) | |
| if not self.normalize_before: | |
| x = self.norm2(x) | |
| residual = x | |
| if self.normalize_before: | |
| x = self.norm3(x) | |
| x = residual + self.dropout(self.feed_forward(x)) | |
| if not self.normalize_before: | |
| x = self.norm3(x) | |
| if cache is not None: | |
| x = torch.cat([cache, x], dim=1) | |
| return x, tgt_mask, memory, memory_mask | |