Spaces:
Runtime error
Runtime error
| """Position feed-forward network from "Attention is All You Need".""" | |
| import torch.nn as nn | |
| class PositionwiseFeedForward(nn.Module): | |
| """ A two-layer Feed-Forward-Network with residual layer norm. | |
| Args: | |
| d_model (int): the size of input for the first-layer of the FFN. | |
| d_ff (int): the hidden layer size of the second-layer | |
| of the FNN. | |
| dropout (float): dropout probability in :math:`[0, 1)`. | |
| """ | |
| def __init__(self, d_model, d_ff, dropout=0.1): | |
| super(PositionwiseFeedForward, self).__init__() | |
| self.w_1 = nn.Linear(d_model, d_ff) | |
| self.w_2 = nn.Linear(d_ff, d_model) | |
| self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) | |
| self.dropout_1 = nn.Dropout(dropout) | |
| self.relu = nn.ReLU() | |
| self.dropout_2 = nn.Dropout(dropout) | |
| def forward(self, x): | |
| """Layer definition. | |
| Args: | |
| x: ``(batch_size, input_len, model_dim)`` | |
| Returns: | |
| (FloatTensor): Output ``(batch_size, input_len, model_dim)``. | |
| """ | |
| inter = self.dropout_1(self.relu(self.w_1(self.layer_norm(x)))) | |
| output = self.dropout_2(self.w_2(inter)) | |
| return output + x | |
| def update_dropout(self, dropout): | |
| self.dropout_1.p = dropout | |
| self.dropout_2.p = dropout | |