Spaces:
Running
Running
File size: 5,445 Bytes
70d8fcf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 | import torch
import copy
from typing import List, Dict, Any
def average_checkpoints(checkpoint_paths: List[str], output_path: str = None):
"""
Average the model and model_ema weights from multiple checkpoints
Parameters:
checkpoint_paths: List of checkpoint file paths
output_path: Output path; if None, return the averaged checkpoint dictionary
Returns:
Averaged checkpoint dictionary
"""
if not checkpoint_paths:
raise ValueError("At least one checkpoint path is required")
# Load the first checkpoint as the base
print(f"Loading base checkpoint: {checkpoint_paths[0]}")
avg_checkpoint = torch.load(checkpoint_paths[0], map_location="cpu")
if len(checkpoint_paths) == 1:
if output_path:
torch.save(avg_checkpoint, output_path)
return avg_checkpoint
# Initialize accumulators
avg_model_state = copy.deepcopy(avg_checkpoint["model"])
avg_model_ema_state = None
if "model_ema" in avg_checkpoint:
avg_model_ema_state = copy.deepcopy(avg_checkpoint["model_ema"])
# Accumulate the weights from the other checkpoints
for i, ckpt_path in enumerate(checkpoint_paths[1:], 1):
print(f"Processing checkpoint {i + 1}/{len(checkpoint_paths)}: {ckpt_path}")
ckpt = torch.load(ckpt_path, map_location="cpu")
# Accumulate model weights
for key in avg_model_state.keys():
if key in ckpt["model"]:
avg_model_state[key] += ckpt["model"][key]
# Accumulate model_ema weights (if available)
if avg_model_ema_state is not None and "model_ema" in ckpt:
for key in avg_model_ema_state.keys():
if key in ckpt["model_ema"]:
avg_model_ema_state[key] += ckpt["model_ema"][key]
# Compute the average
num_checkpoints = len(checkpoint_paths)
print(f"Averaging over {num_checkpoints} checkpoints...")
for key in avg_model_state.keys():
avg_model_state[key] = avg_model_state[key] / num_checkpoints
if avg_model_ema_state is not None:
for key in avg_model_ema_state.keys():
avg_model_ema_state[key] = avg_model_ema_state[key] / num_checkpoints
# Update the checkpoint dictionary
avg_checkpoint["model"] = avg_model_state
if avg_model_ema_state is not None:
avg_checkpoint["model_ema"] = avg_model_ema_state
# Save (if an output path is specified)
if output_path:
print(f"Saving averaged checkpoint to: {output_path}")
torch.save(avg_checkpoint, output_path)
return avg_checkpoint
def average_checkpoints_memory_efficient(
checkpoint_paths: List[str], output_path: str = None
):
"""
Memory efficient version: Load and process checkpoints one by one, suitable for large models
"""
if not checkpoint_paths:
raise ValueError("At least one checkpoint path is required")
print(f"Loading base checkpoint: {checkpoint_paths[0]}")
avg_checkpoint = torch.load(checkpoint_paths[0], map_location="cpu")
if len(checkpoint_paths) == 1:
if output_path:
torch.save(avg_checkpoint, output_path)
return avg_checkpoint
# Convert to float32 for better precision
for key in avg_checkpoint["model"].keys():
avg_checkpoint["model"][key] = avg_checkpoint["model"][key].float()
if "model_ema" in avg_checkpoint:
for key in avg_checkpoint["model_ema"].keys():
avg_checkpoint["model_ema"][key] = avg_checkpoint["model_ema"][key].float()
# Load and accumulate checkpoints one by one
for i, ckpt_path in enumerate(checkpoint_paths[1:], 1):
print(f"Processing checkpoint {i + 1}/{len(checkpoint_paths)}: {ckpt_path}")
ckpt = torch.load(ckpt_path, map_location="cpu")
# Accumulate model weights
for key in avg_checkpoint["model"].keys():
if key in ckpt["model"]:
avg_checkpoint["model"][key] += ckpt["model"][key].float()
# Accumulate model_ema weights
if "model_ema" in avg_checkpoint and "model_ema" in ckpt:
for key in avg_checkpoint["model_ema"].keys():
if key in ckpt["model_ema"]:
avg_checkpoint["model_ema"][key] += ckpt["model_ema"][key].float()
# Free memory
del ckpt
torch.cuda.empty_cache()
# Compute the average
num_checkpoints = len(checkpoint_paths)
print(f"Averaging over {num_checkpoints} checkpoints...")
for key in avg_checkpoint["model"].keys():
avg_checkpoint["model"][key] /= num_checkpoints
if "model_ema" in avg_checkpoint:
for key in avg_checkpoint["model_ema"].keys():
avg_checkpoint["model_ema"][key] /= num_checkpoints
if output_path:
print(f"Saving averaged checkpoint to: {output_path}")
torch.save(avg_checkpoint, output_path)
return avg_checkpoint
# Example usage
if __name__ == "__main__":
# Method 1: Simple usage
checkpoint_paths = []
# Average and save
average_checkpoints(checkpoint_paths, "")
# Method 2: Get the averaged checkpoint and further process it
# avg_ckpt = average_checkpoints(checkpoint_paths)
# print("Averaged checkpoint keys:", avg_ckpt.keys())
# Method 3: Use memory-efficient version (suitable for large models)
# average_checkpoints_memory_efficient(checkpoint_paths, 'averaged_checkpoint_efficient.pt')
|