Spaces:
Runtime error
Runtime error
| # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import paddle | |
| from paddle import nn | |
| import paddle.nn.functional as F | |
| from paddle.nn import Linear | |
| from paddle.nn.initializer import XavierUniform as xavier_uniform_ | |
| from paddle.nn.initializer import Constant as constant_ | |
| from paddle.nn.initializer import XavierNormal as xavier_normal_ | |
| zeros_ = constant_(value=0.) | |
| ones_ = constant_(value=1.) | |
| class MultiheadAttention(nn.Layer): | |
| """Allows the model to jointly attend to information | |
| from different representation subspaces. | |
| See reference: Attention Is All You Need | |
| .. math:: | |
| \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O | |
| \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) | |
| Args: | |
| embed_dim: total dimension of the model | |
| num_heads: parallel attention layers, or heads | |
| """ | |
| def __init__(self, | |
| embed_dim, | |
| num_heads, | |
| dropout=0., | |
| bias=True, | |
| add_bias_kv=False, | |
| add_zero_attn=False): | |
| super(MultiheadAttention, self).__init__() | |
| self.embed_dim = embed_dim | |
| self.num_heads = num_heads | |
| self.dropout = dropout | |
| self.head_dim = embed_dim // num_heads | |
| assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" | |
| self.scaling = self.head_dim**-0.5 | |
| self.out_proj = Linear(embed_dim, embed_dim, bias_attr=bias) | |
| self._reset_parameters() | |
| self.conv1 = paddle.nn.Conv2D( | |
| in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1)) | |
| self.conv2 = paddle.nn.Conv2D( | |
| in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1)) | |
| self.conv3 = paddle.nn.Conv2D( | |
| in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1)) | |
| def _reset_parameters(self): | |
| xavier_uniform_(self.out_proj.weight) | |
| def forward(self, | |
| query, | |
| key, | |
| value, | |
| key_padding_mask=None, | |
| incremental_state=None, | |
| attn_mask=None): | |
| """ | |
| Inputs of forward function | |
| query: [target length, batch size, embed dim] | |
| key: [sequence length, batch size, embed dim] | |
| value: [sequence length, batch size, embed dim] | |
| key_padding_mask: if True, mask padding based on batch size | |
| incremental_state: if provided, previous time steps are cashed | |
| need_weights: output attn_output_weights | |
| static_kv: key and value are static | |
| Outputs of forward function | |
| attn_output: [target length, batch size, embed dim] | |
| attn_output_weights: [batch size, target length, sequence length] | |
| """ | |
| q_shape = paddle.shape(query) | |
| src_shape = paddle.shape(key) | |
| q = self._in_proj_q(query) | |
| k = self._in_proj_k(key) | |
| v = self._in_proj_v(value) | |
| q *= self.scaling | |
| q = paddle.transpose( | |
| paddle.reshape( | |
| q, [q_shape[0], q_shape[1], self.num_heads, self.head_dim]), | |
| [1, 2, 0, 3]) | |
| k = paddle.transpose( | |
| paddle.reshape( | |
| k, [src_shape[0], q_shape[1], self.num_heads, self.head_dim]), | |
| [1, 2, 0, 3]) | |
| v = paddle.transpose( | |
| paddle.reshape( | |
| v, [src_shape[0], q_shape[1], self.num_heads, self.head_dim]), | |
| [1, 2, 0, 3]) | |
| if key_padding_mask is not None: | |
| assert key_padding_mask.shape[0] == q_shape[1] | |
| assert key_padding_mask.shape[1] == src_shape[0] | |
| attn_output_weights = paddle.matmul(q, | |
| paddle.transpose(k, [0, 1, 3, 2])) | |
| if attn_mask is not None: | |
| attn_mask = paddle.unsqueeze(paddle.unsqueeze(attn_mask, 0), 0) | |
| attn_output_weights += attn_mask | |
| if key_padding_mask is not None: | |
| attn_output_weights = paddle.reshape( | |
| attn_output_weights, | |
| [q_shape[1], self.num_heads, q_shape[0], src_shape[0]]) | |
| key = paddle.unsqueeze(paddle.unsqueeze(key_padding_mask, 1), 2) | |
| key = paddle.cast(key, 'float32') | |
| y = paddle.full( | |
| shape=paddle.shape(key), dtype='float32', fill_value='-inf') | |
| y = paddle.where(key == 0., key, y) | |
| attn_output_weights += y | |
| attn_output_weights = F.softmax( | |
| attn_output_weights.astype('float32'), | |
| axis=-1, | |
| dtype=paddle.float32 if attn_output_weights.dtype == paddle.float16 | |
| else attn_output_weights.dtype) | |
| attn_output_weights = F.dropout( | |
| attn_output_weights, p=self.dropout, training=self.training) | |
| attn_output = paddle.matmul(attn_output_weights, v) | |
| attn_output = paddle.reshape( | |
| paddle.transpose(attn_output, [2, 0, 1, 3]), | |
| [q_shape[0], q_shape[1], self.embed_dim]) | |
| attn_output = self.out_proj(attn_output) | |
| return attn_output | |
| def _in_proj_q(self, query): | |
| query = paddle.transpose(query, [1, 2, 0]) | |
| query = paddle.unsqueeze(query, axis=2) | |
| res = self.conv1(query) | |
| res = paddle.squeeze(res, axis=2) | |
| res = paddle.transpose(res, [2, 0, 1]) | |
| return res | |
| def _in_proj_k(self, key): | |
| key = paddle.transpose(key, [1, 2, 0]) | |
| key = paddle.unsqueeze(key, axis=2) | |
| res = self.conv2(key) | |
| res = paddle.squeeze(res, axis=2) | |
| res = paddle.transpose(res, [2, 0, 1]) | |
| return res | |
| def _in_proj_v(self, value): | |
| value = paddle.transpose(value, [1, 2, 0]) #(1, 2, 0) | |
| value = paddle.unsqueeze(value, axis=2) | |
| res = self.conv3(value) | |
| res = paddle.squeeze(res, axis=2) | |
| res = paddle.transpose(res, [2, 0, 1]) | |
| return res | |