| import os |
| import torch |
| from safetensors.torch import load_file, save_file |
|
|
| def merge_models(model_path_1, model_path_2, ratio_1, ratio_2, output_path): |
| |
| total_ratio = ratio_1 + ratio_2 |
| if total_ratio != 1: |
| ratio_1 /= total_ratio |
| ratio_2 /= total_ratio |
|
|
| print(f"Merging models with ratios: {ratio_1:.2f} (Model 1) and {ratio_2:.2f} (Model 2)") |
|
|
| |
| if not (os.path.exists(model_path_1) and os.path.exists(model_path_2)): |
| raise FileNotFoundError("One or both model files do not exist.") |
|
|
| try: |
| |
| tensors_1 = load_file(model_path_1) |
| tensors_2 = load_file(model_path_2) |
|
|
| |
| common_keys = set(tensors_1.keys()).intersection(tensors_2.keys()) |
|
|
| |
| merged_tensors = {} |
| for key in common_keys: |
| |
| if tensors_1[key].shape == tensors_2[key].shape: |
| print(f"Merging tensor: {key} (from both models)") |
| merged_tensors[key] = tensors_1[key] * ratio_1 + tensors_2[key] * ratio_2 |
| else: |
| print(f"Skipping tensor: {key} due to shape mismatch (Model 1: {tensors_1[key].shape}, Model 2: {tensors_2[key].shape})") |
|
|
| |
| save_file(merged_tensors, output_path) |
| print(f"Merged model saved to: {output_path}") |
|
|
| except Exception as e: |
| print(f"An error occurred during model merging: {e}") |
|
|
| |
| merge_models( |
| 'flowgram01.safetensors', |
| 'diffusion_pytorch_model-00001-of-00003.safetensors', |
| 0.6, |
| 0.4, |
| '01.safetensors' |
| ) |