nextShakespeare / llm /utils /checkpoint.py
LiManshu's picture
Add files using upload-large-folder tool
bf6be45 verified
"""检查点保存和加载"""
# 2026-01-23
import torch
from pathlib import Path
def save_checkpoint(model, optimizer, epoch, step, loss, checkpoint_dir, name='checkpoint'):
"""保存检查点"""
checkpoint_dir = Path(checkpoint_dir)
checkpoint_dir.mkdir(parents=True, exist_ok=True)
# 统一拼接 step,避免重复
checkpoint_path = checkpoint_dir / f"{name}_step_{step}.pt"
torch.save({
'epoch': epoch,
'step': step,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
}, checkpoint_path)
return checkpoint_path
def load_checkpoint(model, optimizer, checkpoint_path):
"""加载检查点(用于恢复训练)"""
checkpoint = torch.load(checkpoint_path, map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
return checkpoint['epoch'], checkpoint['step'], checkpoint.get('loss', None)
def load_model_only(model, checkpoint_path):
"""只加载模型权重(用于推理,不需要优化器)"""
checkpoint = torch.load(checkpoint_path, map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'])
return checkpoint.get('epoch', 0), checkpoint.get('step', 0), checkpoint.get('loss', None)