| """检查点保存和加载"""
|
|
|
|
|
| import torch
|
| from pathlib import Path
|
|
|
|
|
| def save_checkpoint(model, optimizer, epoch, step, loss, checkpoint_dir, name='checkpoint'):
|
| """保存检查点"""
|
| checkpoint_dir = Path(checkpoint_dir)
|
| checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
| checkpoint_path = checkpoint_dir / f"{name}_step_{step}.pt"
|
|
|
| torch.save({
|
| 'epoch': epoch,
|
| 'step': step,
|
| 'model_state_dict': model.state_dict(),
|
| 'optimizer_state_dict': optimizer.state_dict(),
|
| 'loss': loss,
|
| }, checkpoint_path)
|
|
|
| return checkpoint_path
|
|
|
|
|
| def load_checkpoint(model, optimizer, checkpoint_path):
|
| """加载检查点(用于恢复训练)"""
|
| checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
| model.load_state_dict(checkpoint['model_state_dict'])
|
| optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| return checkpoint['epoch'], checkpoint['step'], checkpoint.get('loss', None)
|
|
|
|
|
| def load_model_only(model, checkpoint_path):
|
| """只加载模型权重(用于推理,不需要优化器)"""
|
| checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
| model.load_state_dict(checkpoint['model_state_dict'])
|
| return checkpoint.get('epoch', 0), checkpoint.get('step', 0), checkpoint.get('loss', None)
|
|
|