PoetryGeneration / src /apis /evaluate.py
zhangj726's picture
Upload 50 files
3f42bd3
raw
history blame contribute delete
894 Bytes
import torch
import numpy as np
from src.models.EA_LSTM.model import weightedLSTM
from src.datasets.dataloader import MyDataset, create_vocab
def test(args):
vocab, poetrys = create_vocab(args.data)
# 词汇表长度
args.vocab_size = len(vocab)
int2char = np.array(vocab)
valid_dataset = MyDataset(vocab, poetrys, args, train=False)
model = weightedLSTM(6110, 256, 128, 2, [1.0] * 80, False)
model.load_state_dict(torch.load(args.save_path))
input_example_batch, target_example_batch = valid_dataset[0]
example_batch_predictions = model(input_example_batch)
predicted_id = torch.distributions.Categorical(example_batch_predictions).sample()
predicted_id = torch.squeeze(predicted_id, -1).numpy()
print("Input: \n", repr("".join(int2char[input_example_batch])))
print()
print("Predictions: \n", repr("".join(int2char[predicted_id])))