zhangj726's picture
Upload 50 files
3f42bd3
import torch
import numpy as np
import torch.nn as nn
class Poetry_Model_lstm(nn.Module):
def __init__(self, hidden_num, word_size, embedding_num, Word2Vec):
super().__init__()
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.hidden_num = hidden_num
self.Word2Vec = Word2Vec
self.embedding = nn.Embedding(word_size, embedding_num)
self.lstm = nn.LSTM(input_size=embedding_num, hidden_size=hidden_num, batch_first=True, num_layers=2,
bidirectional=False)
self.dropout = nn.Dropout(0.3)
self.flatten = nn.Flatten(0, 1)
self.linear = nn.Linear(hidden_num, word_size)
self.cross_entropy = nn.CrossEntropyLoss()
def forward(self, xs_embedding, h_0=None, c_0=None):
# xs_embedding: [batch_size, max_seq_len, n_feature] n_feature=128
if h_0 == None or c_0 == None:
h_0 = torch.tensor(np.zeros((2, xs_embedding.shape[0], self.hidden_num), dtype=np.float32))
c_0 = torch.tensor(np.zeros((2, xs_embedding.shape[0], self.hidden_num), dtype=np.float32))
h_0 = h_0.to(self.device)
c_0 = c_0.to(self.device)
xs_embedding = xs_embedding.to(self.device)
if not self.Word2Vec:
xs_embedding = self.embedding(xs_embedding)
hidden, (h_0, c_0) = self.lstm(xs_embedding, (h_0, c_0))
hidden_drop = self.dropout(hidden)
hidden_flatten = self.flatten(hidden_drop)
pre = self.linear(hidden_flatten)
# pre:[batch_size*max_seq_len, vocab_size]
return pre, (h_0, c_0)