File size: 894 Bytes
3f42bd3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch
import numpy as np
from src.models.EA_LSTM.model import weightedLSTM
from src.datasets.dataloader import MyDataset, create_vocab


def test(args):
    vocab, poetrys = create_vocab(args.data)
    # 词汇表长度
    args.vocab_size = len(vocab)
    int2char = np.array(vocab)
    valid_dataset = MyDataset(vocab, poetrys, args, train=False)

    model = weightedLSTM(6110, 256, 128, 2, [1.0] * 80, False)
    model.load_state_dict(torch.load(args.save_path))

    input_example_batch, target_example_batch = valid_dataset[0]
    example_batch_predictions = model(input_example_batch)
    predicted_id = torch.distributions.Categorical(example_batch_predictions).sample()
    predicted_id = torch.squeeze(predicted_id, -1).numpy()
    print("Input: \n", repr("".join(int2char[input_example_batch])))
    print()
    print("Predictions: \n", repr("".join(int2char[predicted_id])))