import torch import copy from typing import List, Dict, Any def average_checkpoints(checkpoint_paths: List[str], output_path: str = None): """ Average the model and model_ema weights from multiple checkpoints Parameters: checkpoint_paths: List of checkpoint file paths output_path: Output path; if None, return the averaged checkpoint dictionary Returns: Averaged checkpoint dictionary """ if not checkpoint_paths: raise ValueError("At least one checkpoint path is required") # Load the first checkpoint as the base print(f"Loading base checkpoint: {checkpoint_paths[0]}") avg_checkpoint = torch.load(checkpoint_paths[0], map_location="cpu") if len(checkpoint_paths) == 1: if output_path: torch.save(avg_checkpoint, output_path) return avg_checkpoint # Initialize accumulators avg_model_state = copy.deepcopy(avg_checkpoint["model"]) avg_model_ema_state = None if "model_ema" in avg_checkpoint: avg_model_ema_state = copy.deepcopy(avg_checkpoint["model_ema"]) # Accumulate the weights from the other checkpoints for i, ckpt_path in enumerate(checkpoint_paths[1:], 1): print(f"Processing checkpoint {i + 1}/{len(checkpoint_paths)}: {ckpt_path}") ckpt = torch.load(ckpt_path, map_location="cpu") # Accumulate model weights for key in avg_model_state.keys(): if key in ckpt["model"]: avg_model_state[key] += ckpt["model"][key] # Accumulate model_ema weights (if available) if avg_model_ema_state is not None and "model_ema" in ckpt: for key in avg_model_ema_state.keys(): if key in ckpt["model_ema"]: avg_model_ema_state[key] += ckpt["model_ema"][key] # Compute the average num_checkpoints = len(checkpoint_paths) print(f"Averaging over {num_checkpoints} checkpoints...") for key in avg_model_state.keys(): avg_model_state[key] = avg_model_state[key] / num_checkpoints if avg_model_ema_state is not None: for key in avg_model_ema_state.keys(): avg_model_ema_state[key] = avg_model_ema_state[key] / num_checkpoints # Update the checkpoint dictionary avg_checkpoint["model"] = avg_model_state if avg_model_ema_state is not None: avg_checkpoint["model_ema"] = avg_model_ema_state # Save (if an output path is specified) if output_path: print(f"Saving averaged checkpoint to: {output_path}") torch.save(avg_checkpoint, output_path) return avg_checkpoint def average_checkpoints_memory_efficient( checkpoint_paths: List[str], output_path: str = None ): """ Memory efficient version: Load and process checkpoints one by one, suitable for large models """ if not checkpoint_paths: raise ValueError("At least one checkpoint path is required") print(f"Loading base checkpoint: {checkpoint_paths[0]}") avg_checkpoint = torch.load(checkpoint_paths[0], map_location="cpu") if len(checkpoint_paths) == 1: if output_path: torch.save(avg_checkpoint, output_path) return avg_checkpoint # Convert to float32 for better precision for key in avg_checkpoint["model"].keys(): avg_checkpoint["model"][key] = avg_checkpoint["model"][key].float() if "model_ema" in avg_checkpoint: for key in avg_checkpoint["model_ema"].keys(): avg_checkpoint["model_ema"][key] = avg_checkpoint["model_ema"][key].float() # Load and accumulate checkpoints one by one for i, ckpt_path in enumerate(checkpoint_paths[1:], 1): print(f"Processing checkpoint {i + 1}/{len(checkpoint_paths)}: {ckpt_path}") ckpt = torch.load(ckpt_path, map_location="cpu") # Accumulate model weights for key in avg_checkpoint["model"].keys(): if key in ckpt["model"]: avg_checkpoint["model"][key] += ckpt["model"][key].float() # Accumulate model_ema weights if "model_ema" in avg_checkpoint and "model_ema" in ckpt: for key in avg_checkpoint["model_ema"].keys(): if key in ckpt["model_ema"]: avg_checkpoint["model_ema"][key] += ckpt["model_ema"][key].float() # Free memory del ckpt torch.cuda.empty_cache() # Compute the average num_checkpoints = len(checkpoint_paths) print(f"Averaging over {num_checkpoints} checkpoints...") for key in avg_checkpoint["model"].keys(): avg_checkpoint["model"][key] /= num_checkpoints if "model_ema" in avg_checkpoint: for key in avg_checkpoint["model_ema"].keys(): avg_checkpoint["model_ema"][key] /= num_checkpoints if output_path: print(f"Saving averaged checkpoint to: {output_path}") torch.save(avg_checkpoint, output_path) return avg_checkpoint # Example usage if __name__ == "__main__": # Method 1: Simple usage checkpoint_paths = [] # Average and save average_checkpoints(checkpoint_paths, "") # Method 2: Get the averaged checkpoint and further process it # avg_ckpt = average_checkpoints(checkpoint_paths) # print("Averaged checkpoint keys:", avg_ckpt.keys()) # Method 3: Use memory-efficient version (suitable for large models) # average_checkpoints_memory_efficient(checkpoint_paths, 'averaged_checkpoint_efficient.pt')