File size: 3,665 Bytes
6a51385 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 | import torch.nn as nn
import torch
from typing import Tuple
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
import os
def load_checkpoint(
model: nn.Module,
optimizer: optim.Optimizer,
scheduler,
checkpoint_path: str,
device: torch.device,
) -> Tuple[int, float]:
"""Load model checkpoint."""
checkpoint = torch.load(checkpoint_path, map_location=device)
# Load model state dict (handle DDP wrapper)
if isinstance(model, DDP):
model.module.load_state_dict(checkpoint['model_state_dict'])
else:
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
if scheduler is not None and 'scheduler_state_dict' in checkpoint:
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
epoch = checkpoint.get('epoch', 0)
best_loss = checkpoint.get('best_loss', float('inf'))
return epoch, best_loss
def save_checkpoint(
model: nn.Module,
optimizer: optim.Optimizer,
scheduler,
epoch: int,
best_loss: float,
checkpoint_dir: str,
args=None,
rank: int = 0,
) -> None:
"""Save model checkpoint with configuration."""
if rank != 0:
return
os.makedirs(checkpoint_dir, exist_ok=True)
# Get model state dict (handle DDP wrapper)
model_state = model.module.state_dict() if isinstance(model, DDP) else model.state_dict()
checkpoint = {
'epoch': epoch,
'model_state_dict': model_state,
'optimizer_state_dict': optimizer.state_dict(),
'best_loss': best_loss,
}
if scheduler is not None:
checkpoint['scheduler_state_dict'] = scheduler.state_dict()
# Save model configuration for downstream tasks
if args is not None:
checkpoint['config'] = {
'model_type': args.model_type,
'embed_dim': args.embed_dim,
'depth': args.depth,
'predictor_depth': args.predictor_depth,
'drop_path_rate': args.drop_path_rate,
'rms_norm': args.rms_norm,
'fused_add_norm': args.fused_add_norm,
'residual_in_fp32': args.residual_in_fp32,
'bimamba_type': args.bimamba_type,
'if_bimamba': args.if_bimamba,
'mixer_type': args.mixer_type,
'if_devide_out': args.if_devide_out,
'predictor_hidden': args.predictor_hidden,
'momentum': args.momentum,
'norm_target': args.norm_target,
'num_heads': args.num_heads,
'mlp_ratio': args.mlp_ratio,
}
# Save latest checkpoint
latest_path = os.path.join(checkpoint_dir, 'checkpoint_latest.pt')
torch.save(checkpoint, latest_path)
# Save best checkpoint
if best_loss is not None:
best_path = os.path.join(checkpoint_dir, 'checkpoint_best.pt')
torch.save(checkpoint, best_path)
def save_downstream_checkpoint(model, optimizer, scheduler, epoch, metrics, checkpoint_dir, rank=0):
"""Save downstream checkpoint."""
if rank == 0:
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint = {
'epoch': epoch,
'model': model.state_dict() if not isinstance(model, DDP) else model.module.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict(),
'metrics': metrics,
}
path = os.path.join(checkpoint_dir, f"downstream_epoch_{epoch:03d}.pt")
torch.save(checkpoint, path) |