|
|
|
|
|
import torch |
|
|
from encoder_decoder_transformer import Transformer |
|
|
import json |
|
|
|
|
|
|
|
|
def load_model(model_path): |
|
|
checkpoint = torch.load(f"{model_path}/pytorch_model.bin", map_location="cpu") |
|
|
|
|
|
with open(f"{model_path}/config.json", "r") as f: |
|
|
config = json.load(f) |
|
|
|
|
|
|
|
|
model = Transformer( |
|
|
src_vocab_size=config["vocab_size"], |
|
|
tgt_vocab_size=config["vocab_size"], |
|
|
d_model=config["d_model"], |
|
|
n_heads=config["n_heads"], |
|
|
n_encoder_layers=config["n_encoder_layers"], |
|
|
n_decoder_layers=config["n_decoder_layers"], |
|
|
d_ff=config["d_ff"], |
|
|
dropout=config["dropout"], |
|
|
pad_idx=config["pad_token_id"] |
|
|
) |
|
|
|
|
|
|
|
|
model.load_state_dict(checkpoint["model_state_dict"]) |
|
|
model.eval() |
|
|
|
|
|
return model, checkpoint["vocab"], checkpoint["idx_to_vocab"] |
|
|
|
|
|
|
|
|
def generate_text(model, vocab, idx_to_vocab, input_text, max_length=50): |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
model.to(device) |
|
|
|
|
|
|
|
|
words = input_text.lower().split() |
|
|
tokens = [vocab.get("<BOS>", 2)] + [vocab.get(word, vocab.get("<UNK>", 1)) for word in words] + [vocab.get("<EOS>", 3)] |
|
|
|
|
|
|
|
|
max_len = 64 |
|
|
if len(tokens) < max_len: |
|
|
tokens += [vocab.get("<PAD>", 0)] * (max_len - len(tokens)) |
|
|
else: |
|
|
tokens = tokens[:max_len] |
|
|
|
|
|
src = torch.tensor(tokens).unsqueeze(0).to(device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
generated = model.generate(src, max_len=max_length, start_token=2, end_token=3) |
|
|
|
|
|
|
|
|
words = [] |
|
|
for token in generated[0]: |
|
|
word = idx_to_vocab.get(token.item(), "<UNK>") |
|
|
if word in ["<PAD>", "<BOS>", "<EOS>"]: |
|
|
if word == "<EOS>": |
|
|
break |
|
|
continue |
|
|
words.append(word) |
|
|
|
|
|
return " ".join(words) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
model, vocab, idx_to_vocab = load_model("./") |
|
|
|
|
|
input_texts = [ |
|
|
"To be or not to be", |
|
|
"What is your name", |
|
|
"The king is dead" |
|
|
] |
|
|
|
|
|
for text in input_texts: |
|
|
generated = generate_text(model, vocab, idx_to_vocab, text) |
|
|
print(f"輸入: {text}") |
|
|
print(f"生成: {generated}") |
|
|
print("-" * 50) |
|
|
|