| import torch.nn as nn | |
| from torch.nn import functional as F | |
| __all__ = ['DSN'] | |
| class DSN(nn.Module): | |
| """Deep Summarization Network""" | |
| def __init__(self, in_dim=1024, hid_dim=256, num_layers=1, cell='lstm'): | |
| super(DSN, self).__init__() | |
| assert cell in ['lstm', 'gru'], "cell must be either 'lstm' or 'gru'" | |
| if cell == 'lstm': | |
| self.rnn = nn.LSTM(in_dim, hid_dim, num_layers=num_layers, bidirectional=True, batch_first=True) | |
| else: | |
| self.rnn = nn.GRU(in_dim, hid_dim, num_layers=num_layers, bidirectional=True, batch_first=True) | |
| self.fc = nn.Linear(hid_dim*2, 1) | |
| def forward(self, x): | |
| h, _ = self.rnn(x) | |
| p = F.sigmoid(self.fc(h)) | |
| return p | |