File size: 6,471 Bytes
8b57151 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 | import torch
import sentencepiece as spm
import matplotlib.pyplot as plt
from model import BigramLM
def load_model(model_path):
"""加载已保存的模型"""
checkpoint = torch.load(model_path, map_location='cpu')
# 创建模型实例
model = BigramLM(
vocab_size=checkpoint['vocab_size'],
d_model=checkpoint['d_model'],
block_size=checkpoint['block_size'],
h=checkpoint['h'],
Nx=checkpoint['Nx'],
dropout_rate=checkpoint['dropout_rate']
)
# 加载模型权重
model.load_state_dict(checkpoint['model_state_dict'])
# 获取训练历史
train_losses = checkpoint['train_losses']
valid_losses = checkpoint['valid_losses']
train_ppls = checkpoint['train_ppls']
valid_ppls = checkpoint['valid_ppls']
iteration = checkpoint['iteration']
return model, train_losses, valid_losses, train_ppls, valid_ppls, iteration
def plot_training_history(train_losses, valid_losses, train_ppls, valid_ppls, iteration):
"""绘制训练历史曲线"""
eval_interval = 500 # 假设评估间隔为500
iterations = list(range(0, len(train_losses) * eval_interval, eval_interval))
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
# 绘制loss曲线
ax1.plot(iterations, train_losses, label='Train Loss', color='blue', linewidth=2)
ax1.plot(iterations, valid_losses, label='Validation Loss', color='red', linewidth=2)
ax1.set_xlabel('Iterations')
ax1.set_ylabel('Loss')
ax1.set_title(f'Training History (Iteration: {iteration})')
ax1.legend()
ax1.grid(True, alpha=0.3)
# 绘制PPL曲线
ax2.plot(iterations, train_ppls, label='Train PPL', color='blue', linewidth=2)
ax2.plot(iterations, valid_ppls, label='Validation PPL', color='red', linewidth=2)
ax2.set_xlabel('Iterations')
ax2.set_ylabel('Perplexity (PPL)')
ax2.set_title(f'Perplexity History (Iteration: {iteration})')
ax2.legend()
ax2.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
def generate_text(model, sp, prompt="", max_new_tokens=200, device='cpu'):
"""使用模型生成文本"""
model.eval()
model.to(device)
# 编码提示文本
if prompt:
encoded_prompt = sp.encode(prompt, out_type=int)
context = torch.tensor([encoded_prompt], dtype=torch.long, device=device)
else:
context = torch.zeros((1, 1), dtype=torch.long, device=device)
# 生成文本
with torch.no_grad():
generated_tokens = model.generate(context, max_new_tokens=max_new_tokens)[0].tolist()
generated_text = sp.decode(generated_tokens)
return generated_text
def compare_models(model_paths, sp, device='cpu'):
"""比较多个模型的性能"""
print("模型比较结果:")
print("-" * 80)
for model_path in model_paths:
try:
model, train_losses, valid_losses, train_ppls, valid_ppls, iteration = load_model(model_path)
print(f"\n模型: {model_path}")
print(f"训练迭代次数: {iteration}")
print(f"最终训练损失: {train_losses[-1]:.4f}")
print(f"最终验证损失: {valid_losses[-1]:.4f}")
print(f"最终训练PPL: {train_ppls[-1]:.2f}")
print(f"最终验证PPL: {valid_ppls[-1]:.2f}")
# 生成示例文本
generated_text = generate_text(model, sp, max_new_tokens=100, device=device)
print(f"生成文本示例: {generated_text[:100]}...")
except Exception as e:
print(f"加载模型 {model_path} 失败: {e}")
def main():
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"使用设备: {device}")
# 加载tokenizer
sp = spm.SentencePieceProcessor()
sp.load("lyric_tokenizer.model")
# 列出可用的模型
import os
model_files = [f for f in os.listdir("saved_models") if f.endswith(".pth")]
if not model_files:
print("没有找到保存的模型文件")
return
# 按时间排序(最新的在前面)
model_files.sort(reverse=True)
print("可用的模型文件(按时间倒序):")
for i, model_file in enumerate(model_files):
# 显示更友好的文件名
if "final" in model_file:
file_type = "最终模型"
elif "checkpoint" in model_file:
file_type = "检查点"
else:
file_type = "模型"
# 提取时间信息
import re
time_match = re.search(r'(\d{8}_\d{6})', model_file)
time_str = time_match.group(1) if time_match else "未知时间"
print(f"{i+1}. {file_type} - {time_str} - {model_file}")
# 选择要测试的模型
try:
choice = int(input("\n请选择要测试的模型编号: ")) - 1
selected_model = os.path.join("saved_models", model_files[choice])
except (ValueError, IndexError):
print("无效的选择,使用第一个模型")
selected_model = os.path.join("saved_models", model_files[0])
# 加载模型
model, train_losses, valid_losses, train_ppls, valid_ppls, iteration = load_model(selected_model)
model.to(device)
# 显示训练历史
print(f"\n加载模型: {selected_model}")
print(f"训练迭代次数: {iteration}")
if train_losses: # 如果有训练历史数据
plot_training_history(train_losses, valid_losses, train_ppls, valid_ppls, iteration)
else:
print("该模型没有训练历史数据")
# 交互式文本生成
while True:
print("\n" + "="*50)
prompt = input("请输入提示文本(直接回车使用空提示,输入'quit'退出): ")
if prompt.lower() == 'quit':
break
max_tokens = input("请输入要生成的token数量(默认200): ")
try:
max_tokens = int(max_tokens) if max_tokens else 200
except ValueError:
max_tokens = 200
generated_text = generate_text(model, sp, prompt, max_tokens, device)
print(f"\n生成的文本:")
print(generated_text)
if __name__ == "__main__":
main()
|