Spaces:
Runtime error
Runtime error
| """ | |
| Implementation of "Attention is All You Need" | |
| """ | |
| import torch.nn as nn | |
| from .encoder import EncoderBase | |
| from .multi_headed_attn import MultiHeadedAttention | |
| from .position_ffn import PositionwiseFeedForward | |
| from .misc import sequence_mask | |
| class TransformerEncoderLayer(nn.Module): | |
| """ | |
| A single layer of the transformer encoder. | |
| Args: | |
| d_model (int): the dimension of keys/values/queries in | |
| MultiHeadedAttention, also the input size of | |
| the first-layer of the PositionwiseFeedForward. | |
| heads (int): the number of head for MultiHeadedAttention. | |
| d_ff (int): the second-layer of the PositionwiseFeedForward. | |
| dropout (float): dropout probability(0-1.0). | |
| """ | |
| def __init__(self, d_model, heads, d_ff, dropout, attention_dropout, | |
| max_relative_positions=0): | |
| super(TransformerEncoderLayer, self).__init__() | |
| self.self_attn = MultiHeadedAttention( | |
| heads, d_model, dropout=attention_dropout, | |
| max_relative_positions=max_relative_positions) | |
| self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout) | |
| self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, inputs, mask): | |
| """ | |
| Args: | |
| inputs (FloatTensor): ``(batch_size, src_len, model_dim)`` | |
| mask (LongTensor): ``(batch_size, 1, src_len)`` | |
| Returns: | |
| (FloatTensor): | |
| * outputs ``(batch_size, src_len, model_dim)`` | |
| """ | |
| input_norm = self.layer_norm(inputs) | |
| context, _ = self.self_attn(input_norm, input_norm, input_norm, | |
| mask=mask, attn_type="self") | |
| out = self.dropout(context) + inputs | |
| return self.feed_forward(out) | |
| def update_dropout(self, dropout, attention_dropout): | |
| self.self_attn.update_dropout(attention_dropout) | |
| self.feed_forward.update_dropout(dropout) | |
| self.dropout.p = dropout | |
| class TransformerEncoder(EncoderBase): | |
| """The Transformer encoder from "Attention is All You Need" | |
| :cite:`DBLP:journals/corr/VaswaniSPUJGKP17` | |
| .. mermaid:: | |
| graph BT | |
| A[input] | |
| B[multi-head self-attn] | |
| C[feed forward] | |
| O[output] | |
| A --> B | |
| B --> C | |
| C --> O | |
| Args: | |
| num_layers (int): number of encoder layers | |
| d_model (int): size of the model | |
| heads (int): number of heads | |
| d_ff (int): size of the inner FF layer | |
| dropout (float): dropout parameters | |
| embeddings (onmt.modules.Embeddings): | |
| embeddings to use, should have positional encodings | |
| Returns: | |
| (torch.FloatTensor, torch.FloatTensor): | |
| * embeddings ``(src_len, batch_size, model_dim)`` | |
| * memory_bank ``(src_len, batch_size, model_dim)`` | |
| """ | |
| def __init__(self, num_layers, d_model, heads, d_ff, dropout, | |
| attention_dropout, embeddings, max_relative_positions): | |
| super(TransformerEncoder, self).__init__() | |
| self.embeddings = embeddings | |
| self.transformer = nn.ModuleList( | |
| [TransformerEncoderLayer( | |
| d_model, heads, d_ff, dropout, attention_dropout, | |
| max_relative_positions=max_relative_positions) | |
| for i in range(num_layers)]) | |
| self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) | |
| def from_opt(cls, opt, embeddings): | |
| """Alternate constructor.""" | |
| return cls( | |
| opt.enc_layers, | |
| opt.enc_rnn_size, | |
| opt.heads, | |
| opt.transformer_ff, | |
| opt.dropout[0] if type(opt.dropout) is list else opt.dropout, | |
| opt.attention_dropout[0] if type(opt.attention_dropout) | |
| is list else opt.attention_dropout, | |
| embeddings, | |
| opt.max_relative_positions) | |
| def forward(self, src, lengths=None): | |
| """See :func:`EncoderBase.forward()`""" | |
| self._check_args(src, lengths) | |
| emb = self.embeddings(src) | |
| out = emb.transpose(0, 1).contiguous() | |
| mask = ~sequence_mask(lengths).unsqueeze(1) | |
| # Run the forward pass of every layer of the tranformer. | |
| for layer in self.transformer: | |
| out = layer(out, mask) | |
| out = self.layer_norm(out) | |
| return emb, out.transpose(0, 1).contiguous(), lengths | |
| def update_dropout(self, dropout, attention_dropout): | |
| self.embeddings.update_dropout(dropout) | |
| for layer in self.transformer: | |
| layer.update_dropout(dropout, attention_dropout) | |