Spaces:
Sleeping
Sleeping
| import torch.nn as nn | |
| class LSTM(nn.Module): | |
| def __init__(self, cfg): | |
| super(LSTM, self).__init__() | |
| self.lstm = nn.LSTM( | |
| input_size=cfg.WORD_EMBED_SIZE, | |
| hidden_size=cfg.HIDDEN_SIZE, # int(hidden_size / num_directions), | |
| num_layers=cfg.NUM_LAYERS, | |
| batch_first=cfg.BATCH_FIRST, # first dim is batch_size or not | |
| bidirectional=cfg.BIDIRECTIONAL | |
| ) | |
| def forward(self, input, h0, c0): | |
| output, (hn, cn) = self.lstm(input, (h0, c0)) | |
| return output, hn, cn | |