File size: 5,445 Bytes
70d8fcf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import torch
import copy
from typing import List, Dict, Any


def average_checkpoints(checkpoint_paths: List[str], output_path: str = None):
    """
    Average the model and model_ema weights from multiple checkpoints

    Parameters:
    checkpoint_paths: List of checkpoint file paths
    output_path: Output path; if None, return the averaged checkpoint dictionary

    Returns:
    Averaged checkpoint dictionary
    """
    if not checkpoint_paths:
        raise ValueError("At least one checkpoint path is required")

    # Load the first checkpoint as the base
    print(f"Loading base checkpoint: {checkpoint_paths[0]}")
    avg_checkpoint = torch.load(checkpoint_paths[0], map_location="cpu")

    if len(checkpoint_paths) == 1:
        if output_path:
            torch.save(avg_checkpoint, output_path)
        return avg_checkpoint

    # Initialize accumulators
    avg_model_state = copy.deepcopy(avg_checkpoint["model"])
    avg_model_ema_state = None

    if "model_ema" in avg_checkpoint:
        avg_model_ema_state = copy.deepcopy(avg_checkpoint["model_ema"])

    # Accumulate the weights from the other checkpoints
    for i, ckpt_path in enumerate(checkpoint_paths[1:], 1):
        print(f"Processing checkpoint {i + 1}/{len(checkpoint_paths)}: {ckpt_path}")
        ckpt = torch.load(ckpt_path, map_location="cpu")

        # Accumulate model weights
        for key in avg_model_state.keys():
            if key in ckpt["model"]:
                avg_model_state[key] += ckpt["model"][key]

        # Accumulate model_ema weights (if available)
        if avg_model_ema_state is not None and "model_ema" in ckpt:
            for key in avg_model_ema_state.keys():
                if key in ckpt["model_ema"]:
                    avg_model_ema_state[key] += ckpt["model_ema"][key]

    # Compute the average
    num_checkpoints = len(checkpoint_paths)
    print(f"Averaging over {num_checkpoints} checkpoints...")

    for key in avg_model_state.keys():
        avg_model_state[key] = avg_model_state[key] / num_checkpoints

    if avg_model_ema_state is not None:
        for key in avg_model_ema_state.keys():
            avg_model_ema_state[key] = avg_model_ema_state[key] / num_checkpoints

    # Update the checkpoint dictionary
    avg_checkpoint["model"] = avg_model_state
    if avg_model_ema_state is not None:
        avg_checkpoint["model_ema"] = avg_model_ema_state

    # Save (if an output path is specified)
    if output_path:
        print(f"Saving averaged checkpoint to: {output_path}")
        torch.save(avg_checkpoint, output_path)

    return avg_checkpoint


def average_checkpoints_memory_efficient(
    checkpoint_paths: List[str], output_path: str = None
):
    """
    Memory efficient version: Load and process checkpoints one by one, suitable for large models
    """
    if not checkpoint_paths:
        raise ValueError("At least one checkpoint path is required")

    print(f"Loading base checkpoint: {checkpoint_paths[0]}")
    avg_checkpoint = torch.load(checkpoint_paths[0], map_location="cpu")

    if len(checkpoint_paths) == 1:
        if output_path:
            torch.save(avg_checkpoint, output_path)
        return avg_checkpoint

    # Convert to float32 for better precision
    for key in avg_checkpoint["model"].keys():
        avg_checkpoint["model"][key] = avg_checkpoint["model"][key].float()

    if "model_ema" in avg_checkpoint:
        for key in avg_checkpoint["model_ema"].keys():
            avg_checkpoint["model_ema"][key] = avg_checkpoint["model_ema"][key].float()

    # Load and accumulate checkpoints one by one
    for i, ckpt_path in enumerate(checkpoint_paths[1:], 1):
        print(f"Processing checkpoint {i + 1}/{len(checkpoint_paths)}: {ckpt_path}")
        ckpt = torch.load(ckpt_path, map_location="cpu")

        # Accumulate model weights
        for key in avg_checkpoint["model"].keys():
            if key in ckpt["model"]:
                avg_checkpoint["model"][key] += ckpt["model"][key].float()

        # Accumulate model_ema weights
        if "model_ema" in avg_checkpoint and "model_ema" in ckpt:
            for key in avg_checkpoint["model_ema"].keys():
                if key in ckpt["model_ema"]:
                    avg_checkpoint["model_ema"][key] += ckpt["model_ema"][key].float()

        # Free memory
        del ckpt
        torch.cuda.empty_cache()

    # Compute the average
    num_checkpoints = len(checkpoint_paths)
    print(f"Averaging over {num_checkpoints} checkpoints...")

    for key in avg_checkpoint["model"].keys():
        avg_checkpoint["model"][key] /= num_checkpoints

    if "model_ema" in avg_checkpoint:
        for key in avg_checkpoint["model_ema"].keys():
            avg_checkpoint["model_ema"][key] /= num_checkpoints

    if output_path:
        print(f"Saving averaged checkpoint to: {output_path}")
        torch.save(avg_checkpoint, output_path)

    return avg_checkpoint


# Example usage
if __name__ == "__main__":
    # Method 1: Simple usage
    checkpoint_paths = []

    # Average and save
    average_checkpoints(checkpoint_paths, "")

    # Method 2: Get the averaged checkpoint and further process it
    # avg_ckpt = average_checkpoints(checkpoint_paths)
    # print("Averaged checkpoint keys:", avg_ckpt.keys())

    # Method 3: Use memory-efficient version (suitable for large models)
    # average_checkpoints_memory_efficient(checkpoint_paths, 'averaged_checkpoint_efficient.pt')