File size: 3,665 Bytes
6a51385
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import torch.nn as nn
import torch
from typing import Tuple
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
import os

def load_checkpoint(

    model: nn.Module,

    optimizer: optim.Optimizer,

    scheduler,

    checkpoint_path: str,

    device: torch.device,

) -> Tuple[int, float]:
    """Load model checkpoint."""
    checkpoint = torch.load(checkpoint_path, map_location=device)

    # Load model state dict (handle DDP wrapper)
    if isinstance(model, DDP):
        model.module.load_state_dict(checkpoint['model_state_dict'])
    else:
        model.load_state_dict(checkpoint['model_state_dict'])

    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    if scheduler is not None and 'scheduler_state_dict' in checkpoint:
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

    epoch = checkpoint.get('epoch', 0)
    best_loss = checkpoint.get('best_loss', float('inf'))

    return epoch, best_loss


def save_checkpoint(

    model: nn.Module,

    optimizer: optim.Optimizer,

    scheduler,

    epoch: int,

    best_loss: float,

    checkpoint_dir: str,

    args=None,

    rank: int = 0,

) -> None:
    """Save model checkpoint with configuration."""
    if rank != 0:
        return

    os.makedirs(checkpoint_dir, exist_ok=True)

    # Get model state dict (handle DDP wrapper)
    model_state = model.module.state_dict() if isinstance(model, DDP) else model.state_dict()

    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model_state,
        'optimizer_state_dict': optimizer.state_dict(),
        'best_loss': best_loss,
    }

    if scheduler is not None:
        checkpoint['scheduler_state_dict'] = scheduler.state_dict()

    # Save model configuration for downstream tasks
    if args is not None:
        checkpoint['config'] = {
            'model_type': args.model_type,
            'embed_dim': args.embed_dim,
            'depth': args.depth,
            'predictor_depth': args.predictor_depth,
            'drop_path_rate': args.drop_path_rate,
            'rms_norm': args.rms_norm,
            'fused_add_norm': args.fused_add_norm,
            'residual_in_fp32': args.residual_in_fp32,
            'bimamba_type': args.bimamba_type,
            'if_bimamba': args.if_bimamba,
            'mixer_type': args.mixer_type,
            'if_devide_out': args.if_devide_out,
            'predictor_hidden': args.predictor_hidden,
            'momentum': args.momentum,
            'norm_target': args.norm_target,
            'num_heads': args.num_heads,
            'mlp_ratio': args.mlp_ratio,
        }

    # Save latest checkpoint
    latest_path = os.path.join(checkpoint_dir, 'checkpoint_latest.pt')
    torch.save(checkpoint, latest_path)

    # Save best checkpoint
    if best_loss is not None:
        best_path = os.path.join(checkpoint_dir, 'checkpoint_best.pt')
        torch.save(checkpoint, best_path)


def save_downstream_checkpoint(model, optimizer, scheduler, epoch, metrics, checkpoint_dir, rank=0):
    """Save downstream checkpoint."""
    if rank == 0:
        os.makedirs(checkpoint_dir, exist_ok=True)
        checkpoint = {
            'epoch': epoch,
            'model': model.state_dict() if not isinstance(model, DDP) else model.module.state_dict(),
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
            'metrics': metrics,
        }
        path = os.path.join(checkpoint_dir, f"downstream_epoch_{epoch:03d}.pt")
        torch.save(checkpoint, path)