| import json
|
| import torch
|
| from main_model import Seq2Seq , generate_answer
|
|
|
|
|
| with open("./config.json", "r") as f:
|
| config = json.load(f)
|
|
|
| vocab_size = config["vocab_size"]
|
| embedding_dim = config["embedding_dim"]
|
| hidden_dim = config["hidden_dim"]
|
| max_len = config["max_len"]
|
|
|
|
|
| model = Seq2Seq(vocab_size, embedding_dim, hidden_dim)
|
| model.load_state_dict(torch.load("./seq2seq_model.pth",weights_only=True))
|
| model.eval()
|
|
|
| with open("./ma_vocab.json", "r") as f:
|
| vocab = json.load(f)
|
|
|
|
|
| word2idx = vocab
|
| idx2word = {idx: word for word, idx in vocab.items()}
|
|
|
|
|
| question = "what is MA?"
|
| answer = generate_answer(model, question, vocab=word2idx)
|
| print("Answer:", answer) |