own_gpt / test_model.py
AISkywalker's picture
Upload 21 files
8b57151 verified
import torch
import sentencepiece as spm
import matplotlib.pyplot as plt
from model import BigramLM
def load_model(model_path):
"""加载已保存的模型"""
checkpoint = torch.load(model_path, map_location='cpu')
# 创建模型实例
model = BigramLM(
vocab_size=checkpoint['vocab_size'],
d_model=checkpoint['d_model'],
block_size=checkpoint['block_size'],
h=checkpoint['h'],
Nx=checkpoint['Nx'],
dropout_rate=checkpoint['dropout_rate']
)
# 加载模型权重
model.load_state_dict(checkpoint['model_state_dict'])
# 获取训练历史
train_losses = checkpoint['train_losses']
valid_losses = checkpoint['valid_losses']
train_ppls = checkpoint['train_ppls']
valid_ppls = checkpoint['valid_ppls']
iteration = checkpoint['iteration']
return model, train_losses, valid_losses, train_ppls, valid_ppls, iteration
def plot_training_history(train_losses, valid_losses, train_ppls, valid_ppls, iteration):
"""绘制训练历史曲线"""
eval_interval = 500 # 假设评估间隔为500
iterations = list(range(0, len(train_losses) * eval_interval, eval_interval))
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
# 绘制loss曲线
ax1.plot(iterations, train_losses, label='Train Loss', color='blue', linewidth=2)
ax1.plot(iterations, valid_losses, label='Validation Loss', color='red', linewidth=2)
ax1.set_xlabel('Iterations')
ax1.set_ylabel('Loss')
ax1.set_title(f'Training History (Iteration: {iteration})')
ax1.legend()
ax1.grid(True, alpha=0.3)
# 绘制PPL曲线
ax2.plot(iterations, train_ppls, label='Train PPL', color='blue', linewidth=2)
ax2.plot(iterations, valid_ppls, label='Validation PPL', color='red', linewidth=2)
ax2.set_xlabel('Iterations')
ax2.set_ylabel('Perplexity (PPL)')
ax2.set_title(f'Perplexity History (Iteration: {iteration})')
ax2.legend()
ax2.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
def generate_text(model, sp, prompt="", max_new_tokens=200, device='cpu'):
"""使用模型生成文本"""
model.eval()
model.to(device)
# 编码提示文本
if prompt:
encoded_prompt = sp.encode(prompt, out_type=int)
context = torch.tensor([encoded_prompt], dtype=torch.long, device=device)
else:
context = torch.zeros((1, 1), dtype=torch.long, device=device)
# 生成文本
with torch.no_grad():
generated_tokens = model.generate(context, max_new_tokens=max_new_tokens)[0].tolist()
generated_text = sp.decode(generated_tokens)
return generated_text
def compare_models(model_paths, sp, device='cpu'):
"""比较多个模型的性能"""
print("模型比较结果:")
print("-" * 80)
for model_path in model_paths:
try:
model, train_losses, valid_losses, train_ppls, valid_ppls, iteration = load_model(model_path)
print(f"\n模型: {model_path}")
print(f"训练迭代次数: {iteration}")
print(f"最终训练损失: {train_losses[-1]:.4f}")
print(f"最终验证损失: {valid_losses[-1]:.4f}")
print(f"最终训练PPL: {train_ppls[-1]:.2f}")
print(f"最终验证PPL: {valid_ppls[-1]:.2f}")
# 生成示例文本
generated_text = generate_text(model, sp, max_new_tokens=100, device=device)
print(f"生成文本示例: {generated_text[:100]}...")
except Exception as e:
print(f"加载模型 {model_path} 失败: {e}")
def main():
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"使用设备: {device}")
# 加载tokenizer
sp = spm.SentencePieceProcessor()
sp.load("lyric_tokenizer.model")
# 列出可用的模型
import os
model_files = [f for f in os.listdir("saved_models") if f.endswith(".pth")]
if not model_files:
print("没有找到保存的模型文件")
return
# 按时间排序(最新的在前面)
model_files.sort(reverse=True)
print("可用的模型文件(按时间倒序):")
for i, model_file in enumerate(model_files):
# 显示更友好的文件名
if "final" in model_file:
file_type = "最终模型"
elif "checkpoint" in model_file:
file_type = "检查点"
else:
file_type = "模型"
# 提取时间信息
import re
time_match = re.search(r'(\d{8}_\d{6})', model_file)
time_str = time_match.group(1) if time_match else "未知时间"
print(f"{i+1}. {file_type} - {time_str} - {model_file}")
# 选择要测试的模型
try:
choice = int(input("\n请选择要测试的模型编号: ")) - 1
selected_model = os.path.join("saved_models", model_files[choice])
except (ValueError, IndexError):
print("无效的选择,使用第一个模型")
selected_model = os.path.join("saved_models", model_files[0])
# 加载模型
model, train_losses, valid_losses, train_ppls, valid_ppls, iteration = load_model(selected_model)
model.to(device)
# 显示训练历史
print(f"\n加载模型: {selected_model}")
print(f"训练迭代次数: {iteration}")
if train_losses: # 如果有训练历史数据
plot_training_history(train_losses, valid_losses, train_ppls, valid_ppls, iteration)
else:
print("该模型没有训练历史数据")
# 交互式文本生成
while True:
print("\n" + "="*50)
prompt = input("请输入提示文本(直接回车使用空提示,输入'quit'退出): ")
if prompt.lower() == 'quit':
break
max_tokens = input("请输入要生成的token数量(默认200): ")
try:
max_tokens = int(max_tokens) if max_tokens else 200
except ValueError:
max_tokens = 200
generated_text = generate_text(model, sp, prompt, max_tokens, device)
print(f"\n生成的文本:")
print(generated_text)
if __name__ == "__main__":
main()