File size: 2,364 Bytes
5c97468 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
import torch
from encoder_decoder_transformer import Transformer
import json
# 載入模型和配置
def load_model(model_path):
checkpoint = torch.load(f"{model_path}/pytorch_model.bin", map_location="cpu")
with open(f"{model_path}/config.json", "r") as f:
config = json.load(f)
# 重建模型
model = Transformer(
src_vocab_size=config["vocab_size"],
tgt_vocab_size=config["vocab_size"],
d_model=config["d_model"],
n_heads=config["n_heads"],
n_encoder_layers=config["n_encoder_layers"],
n_decoder_layers=config["n_decoder_layers"],
d_ff=config["d_ff"],
dropout=config["dropout"],
pad_idx=config["pad_token_id"]
)
# 載入權重
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
return model, checkpoint["vocab"], checkpoint["idx_to_vocab"]
# 使用示例
def generate_text(model, vocab, idx_to_vocab, input_text, max_length=50):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# 簡單的tokenization
words = input_text.lower().split()
tokens = [vocab.get("<BOS>", 2)] + [vocab.get(word, vocab.get("<UNK>", 1)) for word in words] + [vocab.get("<EOS>", 3)]
# 填充
max_len = 64
if len(tokens) < max_len:
tokens += [vocab.get("<PAD>", 0)] * (max_len - len(tokens))
else:
tokens = tokens[:max_len]
src = torch.tensor(tokens).unsqueeze(0).to(device)
# 生成
with torch.no_grad():
generated = model.generate(src, max_len=max_length, start_token=2, end_token=3)
# 轉換為文本
words = []
for token in generated[0]:
word = idx_to_vocab.get(token.item(), "<UNK>")
if word in ["<PAD>", "<BOS>", "<EOS>"]:
if word == "<EOS>":
break
continue
words.append(word)
return " ".join(words)
# 示例使用
if __name__ == "__main__":
model, vocab, idx_to_vocab = load_model("./")
input_texts = [
"To be or not to be",
"What is your name",
"The king is dead"
]
for text in input_texts:
generated = generate_text(model, vocab, idx_to_vocab, text)
print(f"輸入: {text}")
print(f"生成: {generated}")
print("-" * 50)
|