Create simple_alpha_model_merger.py
Browse files- simple_alpha_model_merger.py +68 -0
simple_alpha_model_merger.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# merge_two_fp8.py
|
| 2 |
+
# Merges two FP8 .safetensors files with weighted average
|
| 3 |
+
# Usage: python merge_two_fp8.py
|
| 4 |
+
# Or with args: python merge_two_fp8.py --model_a path/to/existing_fp8.safetensors --model_b path/to/new_model_fp8.safetensors --alpha 0.7 --output merged_result.safetensors
|
| 5 |
+
|
| 6 |
+
from safetensors.torch import load_file, save_file
|
| 7 |
+
import torch
|
| 8 |
+
import argparse
|
| 9 |
+
import os
|
| 10 |
+
|
| 11 |
+
def merge_weights(dict_a, dict_b, alpha):
|
| 12 |
+
"""
|
| 13 |
+
alpha: weight for model_b (VBVR/new one). 0.0 = 100% model_a, 1.0 = 100% model_b
|
| 14 |
+
"""
|
| 15 |
+
merged = {}
|
| 16 |
+
keys_a = set(dict_a.keys())
|
| 17 |
+
keys_b = set(dict_b.keys())
|
| 18 |
+
|
| 19 |
+
all_keys = keys_a.union(keys_b)
|
| 20 |
+
|
| 21 |
+
for key in all_keys:
|
| 22 |
+
if key in dict_a and key in dict_b:
|
| 23 |
+
# Weighted average for matching keys
|
| 24 |
+
merged[key] = (1 - alpha) * dict_a[key] + alpha * dict_b[key]
|
| 25 |
+
elif key in dict_a:
|
| 26 |
+
merged[key] = dict_a[key]
|
| 27 |
+
elif key in dict_b:
|
| 28 |
+
merged[key] = dict_b[key]
|
| 29 |
+
else:
|
| 30 |
+
continue # shouldn't happen
|
| 31 |
+
|
| 32 |
+
return merged
|
| 33 |
+
|
| 34 |
+
def main():
|
| 35 |
+
parser = argparse.ArgumentParser(description="Weighted merge two FP8 safetensors files (Wan2.2 style)")
|
| 36 |
+
parser.add_argument("--model_a", required=True, help="Path to existing merged FP8 model (base/reference)")
|
| 37 |
+
parser.add_argument("--model_b", required=True, help="Path to new VBVR FP8 model (the one to blend in)")
|
| 38 |
+
parser.add_argument("--alpha", type=float, default=0.7, help="Blend strength for model_b (0.0–1.0, default 0.7 = 70%% VBVR + 30%% existing)")
|
| 39 |
+
parser.add_argument("--output", default="vbvr_merged_with_existing_fp8.safetensors", help="Output filename")
|
| 40 |
+
|
| 41 |
+
args = parser.parse_args()
|
| 42 |
+
|
| 43 |
+
if not (0 <= args.alpha <= 1):
|
| 44 |
+
raise ValueError("Alpha must be between 0.0 and 1.0")
|
| 45 |
+
|
| 46 |
+
print(f"Merging:\n A: {args.model_a}\n B: {args.model_b}\n Alpha (for B): {args.alpha}\n Output: {args.output}")
|
| 47 |
+
|
| 48 |
+
print("Loading model A...")
|
| 49 |
+
dict_a = load_file(args.model_a, device="cpu")
|
| 50 |
+
|
| 51 |
+
print("Loading model B...")
|
| 52 |
+
dict_b = load_file(args.model_b, device="cpu")
|
| 53 |
+
|
| 54 |
+
print("Merging weights...")
|
| 55 |
+
merged_dict = merge_weights(dict_a, dict_b, args.alpha)
|
| 56 |
+
|
| 57 |
+
# Ensure output has .safetensors extension
|
| 58 |
+
if not args.output.endswith(".safetensors"):
|
| 59 |
+
args.output += ".safetensors"
|
| 60 |
+
|
| 61 |
+
print(f"Saving merged FP8 model...")
|
| 62 |
+
save_file(merged_dict, args.output)
|
| 63 |
+
|
| 64 |
+
print(f"Done! New file: {os.path.abspath(args.output)}")
|
| 65 |
+
print("Test it in ComfyUI or Diffusers — adjust alpha if motion/reasoning balance feels off.")
|
| 66 |
+
|
| 67 |
+
if __name__ == "__main__":
|
| 68 |
+
main()
|