checkpoint / mg.py
MoeZilla's picture
Update mg.py
a93476e verified
import os
import torch
from safetensors.torch import load_file, save_file # Updated import
def merge_models(model_path_1, model_path_2, ratio_1, ratio_2, output_path):
# Ensure the sum of ratios equals 1
total_ratio = ratio_1 + ratio_2
if total_ratio != 1:
ratio_1 /= total_ratio
ratio_2 /= total_ratio
print(f"Merging models with ratios: {ratio_1:.2f} (Model 1) and {ratio_2:.2f} (Model 2)")
# Check if model paths exist
if not (os.path.exists(model_path_1) and os.path.exists(model_path_2)):
raise FileNotFoundError("One or both model files do not exist.")
try:
# Load the models
tensors_1 = load_file(model_path_1) # Load model 1
tensors_2 = load_file(model_path_2) # Load model 2
# Find common keys between both models
common_keys = set(tensors_1.keys()).intersection(tensors_2.keys())
# Merging only the common tensors
merged_tensors = {}
for key in common_keys:
# Ensure both tensors have the same shape before merging
if tensors_1[key].shape == tensors_2[key].shape:
print(f"Merging tensor: {key} (from both models)")
merged_tensors[key] = tensors_1[key] * ratio_1 + tensors_2[key] * ratio_2
else:
print(f"Skipping tensor: {key} due to shape mismatch (Model 1: {tensors_1[key].shape}, Model 2: {tensors_2[key].shape})")
# Save the merged model using save_file from safetensors.torch
save_file(merged_tensors, output_path) # Updated method for saving models
print(f"Merged model saved to: {output_path}")
except Exception as e:
print(f"An error occurred during model merging: {e}")
# Example usage
merge_models(
'flowgram01.safetensors',
'diffusion_pytorch_model-00001-of-00003.safetensors',
0.6, # 60% for the first model
0.4, # 40% for the second model
'01.safetensors' # Output filename
)