Spaces:
Runtime error
Runtime error
File size: 894 Bytes
3f42bd3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
import torch
import numpy as np
from src.models.EA_LSTM.model import weightedLSTM
from src.datasets.dataloader import MyDataset, create_vocab
def test(args):
vocab, poetrys = create_vocab(args.data)
# 词汇表长度
args.vocab_size = len(vocab)
int2char = np.array(vocab)
valid_dataset = MyDataset(vocab, poetrys, args, train=False)
model = weightedLSTM(6110, 256, 128, 2, [1.0] * 80, False)
model.load_state_dict(torch.load(args.save_path))
input_example_batch, target_example_batch = valid_dataset[0]
example_batch_predictions = model(input_example_batch)
predicted_id = torch.distributions.Categorical(example_batch_predictions).sample()
predicted_id = torch.squeeze(predicted_id, -1).numpy()
print("Input: \n", repr("".join(int2char[input_example_batch])))
print()
print("Predictions: \n", repr("".join(int2char[predicted_id])))
|