File size: 2,364 Bytes
5c97468
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81

import torch
from encoder_decoder_transformer import Transformer
import json

# 載入模型和配置
def load_model(model_path):
    checkpoint = torch.load(f"{model_path}/pytorch_model.bin", map_location="cpu")
    
    with open(f"{model_path}/config.json", "r") as f:
        config = json.load(f)
    
    # 重建模型
    model = Transformer(
        src_vocab_size=config["vocab_size"],
        tgt_vocab_size=config["vocab_size"],
        d_model=config["d_model"],
        n_heads=config["n_heads"],
        n_encoder_layers=config["n_encoder_layers"],
        n_decoder_layers=config["n_decoder_layers"],
        d_ff=config["d_ff"],
        dropout=config["dropout"],
        pad_idx=config["pad_token_id"]
    )
    
    # 載入權重
    model.load_state_dict(checkpoint["model_state_dict"])
    model.eval()
    
    return model, checkpoint["vocab"], checkpoint["idx_to_vocab"]

# 使用示例
def generate_text(model, vocab, idx_to_vocab, input_text, max_length=50):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    # 簡單的tokenization
    words = input_text.lower().split()
    tokens = [vocab.get("<BOS>", 2)] + [vocab.get(word, vocab.get("<UNK>", 1)) for word in words] + [vocab.get("<EOS>", 3)]
    
    # 填充
    max_len = 64
    if len(tokens) < max_len:
        tokens += [vocab.get("<PAD>", 0)] * (max_len - len(tokens))
    else:
        tokens = tokens[:max_len]
    
    src = torch.tensor(tokens).unsqueeze(0).to(device)
    
    # 生成
    with torch.no_grad():
        generated = model.generate(src, max_len=max_length, start_token=2, end_token=3)
    
    # 轉換為文本
    words = []
    for token in generated[0]:
        word = idx_to_vocab.get(token.item(), "<UNK>")
        if word in ["<PAD>", "<BOS>", "<EOS>"]:
            if word == "<EOS>":
                break
            continue
        words.append(word)
    
    return " ".join(words)

# 示例使用
if __name__ == "__main__":
    model, vocab, idx_to_vocab = load_model("./")
    
    input_texts = [
        "To be or not to be",
        "What is your name",
        "The king is dead"
    ]
    
    for text in input_texts:
        generated = generate_text(model, vocab, idx_to_vocab, text)
        print(f"輸入: {text}")
        print(f"生成: {generated}")
        print("-" * 50)