File size: 6,471 Bytes
8b57151
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import torch
import sentencepiece as spm
import matplotlib.pyplot as plt
from model import BigramLM

def load_model(model_path):
    """加载已保存的模型"""
    checkpoint = torch.load(model_path, map_location='cpu')
    
    # 创建模型实例
    model = BigramLM(
        vocab_size=checkpoint['vocab_size'],
        d_model=checkpoint['d_model'],
        block_size=checkpoint['block_size'],
        h=checkpoint['h'],
        Nx=checkpoint['Nx'],
        dropout_rate=checkpoint['dropout_rate']
    )
    
    # 加载模型权重
    model.load_state_dict(checkpoint['model_state_dict'])
    
    # 获取训练历史
    train_losses = checkpoint['train_losses']
    valid_losses = checkpoint['valid_losses']
    train_ppls = checkpoint['train_ppls']
    valid_ppls = checkpoint['valid_ppls']
    iteration = checkpoint['iteration']
    
    return model, train_losses, valid_losses, train_ppls, valid_ppls, iteration

def plot_training_history(train_losses, valid_losses, train_ppls, valid_ppls, iteration):
    """绘制训练历史曲线"""
    eval_interval = 500  # 假设评估间隔为500
    iterations = list(range(0, len(train_losses) * eval_interval, eval_interval))
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    
    # 绘制loss曲线
    ax1.plot(iterations, train_losses, label='Train Loss', color='blue', linewidth=2)
    ax1.plot(iterations, valid_losses, label='Validation Loss', color='red', linewidth=2)
    ax1.set_xlabel('Iterations')
    ax1.set_ylabel('Loss')
    ax1.set_title(f'Training History (Iteration: {iteration})')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # 绘制PPL曲线
    ax2.plot(iterations, train_ppls, label='Train PPL', color='blue', linewidth=2)
    ax2.plot(iterations, valid_ppls, label='Validation PPL', color='red', linewidth=2)
    ax2.set_xlabel('Iterations')
    ax2.set_ylabel('Perplexity (PPL)')
    ax2.set_title(f'Perplexity History (Iteration: {iteration})')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

def generate_text(model, sp, prompt="", max_new_tokens=200, device='cpu'):
    """使用模型生成文本"""
    model.eval()
    model.to(device)
    
    # 编码提示文本
    if prompt:
        encoded_prompt = sp.encode(prompt, out_type=int)
        context = torch.tensor([encoded_prompt], dtype=torch.long, device=device)
    else:
        context = torch.zeros((1, 1), dtype=torch.long, device=device)
    
    # 生成文本
    with torch.no_grad():
        generated_tokens = model.generate(context, max_new_tokens=max_new_tokens)[0].tolist()
        generated_text = sp.decode(generated_tokens)
    
    return generated_text

def compare_models(model_paths, sp, device='cpu'):
    """比较多个模型的性能"""
    print("模型比较结果:")
    print("-" * 80)
    
    for model_path in model_paths:
        try:
            model, train_losses, valid_losses, train_ppls, valid_ppls, iteration = load_model(model_path)
            
            print(f"\n模型: {model_path}")
            print(f"训练迭代次数: {iteration}")
            print(f"最终训练损失: {train_losses[-1]:.4f}")
            print(f"最终验证损失: {valid_losses[-1]:.4f}")
            print(f"最终训练PPL: {train_ppls[-1]:.2f}")
            print(f"最终验证PPL: {valid_ppls[-1]:.2f}")
            
            # 生成示例文本
            generated_text = generate_text(model, sp, max_new_tokens=100, device=device)
            print(f"生成文本示例: {generated_text[:100]}...")
            
        except Exception as e:
            print(f"加载模型 {model_path} 失败: {e}")

def main():
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"使用设备: {device}")
    
    # 加载tokenizer
    sp = spm.SentencePieceProcessor()
    sp.load("lyric_tokenizer.model")
    
    # 列出可用的模型
    import os
    model_files = [f for f in os.listdir("saved_models") if f.endswith(".pth")]
    
    if not model_files:
        print("没有找到保存的模型文件")
        return
    
    # 按时间排序(最新的在前面)
    model_files.sort(reverse=True)
    
    print("可用的模型文件(按时间倒序):")
    for i, model_file in enumerate(model_files):
        # 显示更友好的文件名
        if "final" in model_file:
            file_type = "最终模型"
        elif "checkpoint" in model_file:
            file_type = "检查点"
        else:
            file_type = "模型"
        
        # 提取时间信息
        import re
        time_match = re.search(r'(\d{8}_\d{6})', model_file)
        time_str = time_match.group(1) if time_match else "未知时间"
        
        print(f"{i+1}. {file_type} - {time_str} - {model_file}")
    
    # 选择要测试的模型
    try:
        choice = int(input("\n请选择要测试的模型编号: ")) - 1
        selected_model = os.path.join("saved_models", model_files[choice])
    except (ValueError, IndexError):
        print("无效的选择,使用第一个模型")
        selected_model = os.path.join("saved_models", model_files[0])
    
    # 加载模型
    model, train_losses, valid_losses, train_ppls, valid_ppls, iteration = load_model(selected_model)
    model.to(device)
    
    # 显示训练历史
    print(f"\n加载模型: {selected_model}")
    print(f"训练迭代次数: {iteration}")
    if train_losses:  # 如果有训练历史数据
        plot_training_history(train_losses, valid_losses, train_ppls, valid_ppls, iteration)
    else:
        print("该模型没有训练历史数据")
    
    # 交互式文本生成
    while True:
        print("\n" + "="*50)
        prompt = input("请输入提示文本(直接回车使用空提示,输入'quit'退出): ")
        
        if prompt.lower() == 'quit':
            break
        
        max_tokens = input("请输入要生成的token数量(默认200): ")
        try:
            max_tokens = int(max_tokens) if max_tokens else 200
        except ValueError:
            max_tokens = 200
        
        generated_text = generate_text(model, sp, prompt, max_tokens, device)
        print(f"\n生成的文本:")
        print(generated_text)

if __name__ == "__main__":
    main()