asdfasdfdsafdsa's picture
Initial upload of PGPS demo with all dependencies
383bfb8 verified
raw
history blame contribute delete
565 Bytes
import torch.nn as nn
class LSTM(nn.Module):
def __init__(self, cfg):
super(LSTM, self).__init__()
self.lstm = nn.LSTM(
input_size=cfg.WORD_EMBED_SIZE,
hidden_size=cfg.HIDDEN_SIZE, # int(hidden_size / num_directions),
num_layers=cfg.NUM_LAYERS,
batch_first=cfg.BATCH_FIRST, # first dim is batch_size or not
bidirectional=cfg.BIDIRECTIONAL
)
def forward(self, input, h0, c0):
output, (hn, cn) = self.lstm(input, (h0, c0))
return output, hn, cn