| | import torch
|
| | import sentencepiece as spm
|
| | import matplotlib.pyplot as plt
|
| | from model import BigramLM
|
| |
|
| | def load_model(model_path):
|
| | """加载已保存的模型"""
|
| | checkpoint = torch.load(model_path, map_location='cpu')
|
| |
|
| |
|
| | model = BigramLM(
|
| | vocab_size=checkpoint['vocab_size'],
|
| | d_model=checkpoint['d_model'],
|
| | block_size=checkpoint['block_size'],
|
| | h=checkpoint['h'],
|
| | Nx=checkpoint['Nx'],
|
| | dropout_rate=checkpoint['dropout_rate']
|
| | )
|
| |
|
| |
|
| | model.load_state_dict(checkpoint['model_state_dict'])
|
| |
|
| |
|
| | train_losses = checkpoint['train_losses']
|
| | valid_losses = checkpoint['valid_losses']
|
| | train_ppls = checkpoint['train_ppls']
|
| | valid_ppls = checkpoint['valid_ppls']
|
| | iteration = checkpoint['iteration']
|
| |
|
| | return model, train_losses, valid_losses, train_ppls, valid_ppls, iteration
|
| |
|
| | def plot_training_history(train_losses, valid_losses, train_ppls, valid_ppls, iteration):
|
| | """绘制训练历史曲线"""
|
| | eval_interval = 500
|
| | iterations = list(range(0, len(train_losses) * eval_interval, eval_interval))
|
| |
|
| | fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
|
| |
|
| |
|
| | ax1.plot(iterations, train_losses, label='Train Loss', color='blue', linewidth=2)
|
| | ax1.plot(iterations, valid_losses, label='Validation Loss', color='red', linewidth=2)
|
| | ax1.set_xlabel('Iterations')
|
| | ax1.set_ylabel('Loss')
|
| | ax1.set_title(f'Training History (Iteration: {iteration})')
|
| | ax1.legend()
|
| | ax1.grid(True, alpha=0.3)
|
| |
|
| |
|
| | ax2.plot(iterations, train_ppls, label='Train PPL', color='blue', linewidth=2)
|
| | ax2.plot(iterations, valid_ppls, label='Validation PPL', color='red', linewidth=2)
|
| | ax2.set_xlabel('Iterations')
|
| | ax2.set_ylabel('Perplexity (PPL)')
|
| | ax2.set_title(f'Perplexity History (Iteration: {iteration})')
|
| | ax2.legend()
|
| | ax2.grid(True, alpha=0.3)
|
| |
|
| | plt.tight_layout()
|
| | plt.show()
|
| |
|
| | def generate_text(model, sp, prompt="", max_new_tokens=200, device='cpu'):
|
| | """使用模型生成文本"""
|
| | model.eval()
|
| | model.to(device)
|
| |
|
| |
|
| | if prompt:
|
| | encoded_prompt = sp.encode(prompt, out_type=int)
|
| | context = torch.tensor([encoded_prompt], dtype=torch.long, device=device)
|
| | else:
|
| | context = torch.zeros((1, 1), dtype=torch.long, device=device)
|
| |
|
| |
|
| | with torch.no_grad():
|
| | generated_tokens = model.generate(context, max_new_tokens=max_new_tokens)[0].tolist()
|
| | generated_text = sp.decode(generated_tokens)
|
| |
|
| | return generated_text
|
| |
|
| | def compare_models(model_paths, sp, device='cpu'):
|
| | """比较多个模型的性能"""
|
| | print("模型比较结果:")
|
| | print("-" * 80)
|
| |
|
| | for model_path in model_paths:
|
| | try:
|
| | model, train_losses, valid_losses, train_ppls, valid_ppls, iteration = load_model(model_path)
|
| |
|
| | print(f"\n模型: {model_path}")
|
| | print(f"训练迭代次数: {iteration}")
|
| | print(f"最终训练损失: {train_losses[-1]:.4f}")
|
| | print(f"最终验证损失: {valid_losses[-1]:.4f}")
|
| | print(f"最终训练PPL: {train_ppls[-1]:.2f}")
|
| | print(f"最终验证PPL: {valid_ppls[-1]:.2f}")
|
| |
|
| |
|
| | generated_text = generate_text(model, sp, max_new_tokens=100, device=device)
|
| | print(f"生成文本示例: {generated_text[:100]}...")
|
| |
|
| | except Exception as e:
|
| | print(f"加载模型 {model_path} 失败: {e}")
|
| |
|
| | def main():
|
| | device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| | print(f"使用设备: {device}")
|
| |
|
| |
|
| | sp = spm.SentencePieceProcessor()
|
| | sp.load("lyric_tokenizer.model")
|
| |
|
| |
|
| | import os
|
| | model_files = [f for f in os.listdir("saved_models") if f.endswith(".pth")]
|
| |
|
| | if not model_files:
|
| | print("没有找到保存的模型文件")
|
| | return
|
| |
|
| |
|
| | model_files.sort(reverse=True)
|
| |
|
| | print("可用的模型文件(按时间倒序):")
|
| | for i, model_file in enumerate(model_files):
|
| |
|
| | if "final" in model_file:
|
| | file_type = "最终模型"
|
| | elif "checkpoint" in model_file:
|
| | file_type = "检查点"
|
| | else:
|
| | file_type = "模型"
|
| |
|
| |
|
| | import re
|
| | time_match = re.search(r'(\d{8}_\d{6})', model_file)
|
| | time_str = time_match.group(1) if time_match else "未知时间"
|
| |
|
| | print(f"{i+1}. {file_type} - {time_str} - {model_file}")
|
| |
|
| |
|
| | try:
|
| | choice = int(input("\n请选择要测试的模型编号: ")) - 1
|
| | selected_model = os.path.join("saved_models", model_files[choice])
|
| | except (ValueError, IndexError):
|
| | print("无效的选择,使用第一个模型")
|
| | selected_model = os.path.join("saved_models", model_files[0])
|
| |
|
| |
|
| | model, train_losses, valid_losses, train_ppls, valid_ppls, iteration = load_model(selected_model)
|
| | model.to(device)
|
| |
|
| |
|
| | print(f"\n加载模型: {selected_model}")
|
| | print(f"训练迭代次数: {iteration}")
|
| | if train_losses:
|
| | plot_training_history(train_losses, valid_losses, train_ppls, valid_ppls, iteration)
|
| | else:
|
| | print("该模型没有训练历史数据")
|
| |
|
| |
|
| | while True:
|
| | print("\n" + "="*50)
|
| | prompt = input("请输入提示文本(直接回车使用空提示,输入'quit'退出): ")
|
| |
|
| | if prompt.lower() == 'quit':
|
| | break
|
| |
|
| | max_tokens = input("请输入要生成的token数量(默认200): ")
|
| | try:
|
| | max_tokens = int(max_tokens) if max_tokens else 200
|
| | except ValueError:
|
| | max_tokens = 200
|
| |
|
| | generated_text = generate_text(model, sp, prompt, max_tokens, device)
|
| | print(f"\n生成的文本:")
|
| | print(generated_text)
|
| |
|
| | if __name__ == "__main__":
|
| | main()
|
| |
|