olivv-cs commited on
Commit
bd12953
·
verified ·
1 Parent(s): aebabbd

Create simple_alpha_model_merger.py

Browse files
Files changed (1) hide show
  1. simple_alpha_model_merger.py +68 -0
simple_alpha_model_merger.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # merge_two_fp8.py
2
+ # Merges two FP8 .safetensors files with weighted average
3
+ # Usage: python merge_two_fp8.py
4
+ # Or with args: python merge_two_fp8.py --model_a path/to/existing_fp8.safetensors --model_b path/to/new_model_fp8.safetensors --alpha 0.7 --output merged_result.safetensors
5
+
6
+ from safetensors.torch import load_file, save_file
7
+ import torch
8
+ import argparse
9
+ import os
10
+
11
+ def merge_weights(dict_a, dict_b, alpha):
12
+ """
13
+ alpha: weight for model_b (VBVR/new one). 0.0 = 100% model_a, 1.0 = 100% model_b
14
+ """
15
+ merged = {}
16
+ keys_a = set(dict_a.keys())
17
+ keys_b = set(dict_b.keys())
18
+
19
+ all_keys = keys_a.union(keys_b)
20
+
21
+ for key in all_keys:
22
+ if key in dict_a and key in dict_b:
23
+ # Weighted average for matching keys
24
+ merged[key] = (1 - alpha) * dict_a[key] + alpha * dict_b[key]
25
+ elif key in dict_a:
26
+ merged[key] = dict_a[key]
27
+ elif key in dict_b:
28
+ merged[key] = dict_b[key]
29
+ else:
30
+ continue # shouldn't happen
31
+
32
+ return merged
33
+
34
+ def main():
35
+ parser = argparse.ArgumentParser(description="Weighted merge two FP8 safetensors files (Wan2.2 style)")
36
+ parser.add_argument("--model_a", required=True, help="Path to existing merged FP8 model (base/reference)")
37
+ parser.add_argument("--model_b", required=True, help="Path to new VBVR FP8 model (the one to blend in)")
38
+ parser.add_argument("--alpha", type=float, default=0.7, help="Blend strength for model_b (0.0–1.0, default 0.7 = 70%% VBVR + 30%% existing)")
39
+ parser.add_argument("--output", default="vbvr_merged_with_existing_fp8.safetensors", help="Output filename")
40
+
41
+ args = parser.parse_args()
42
+
43
+ if not (0 <= args.alpha <= 1):
44
+ raise ValueError("Alpha must be between 0.0 and 1.0")
45
+
46
+ print(f"Merging:\n A: {args.model_a}\n B: {args.model_b}\n Alpha (for B): {args.alpha}\n Output: {args.output}")
47
+
48
+ print("Loading model A...")
49
+ dict_a = load_file(args.model_a, device="cpu")
50
+
51
+ print("Loading model B...")
52
+ dict_b = load_file(args.model_b, device="cpu")
53
+
54
+ print("Merging weights...")
55
+ merged_dict = merge_weights(dict_a, dict_b, args.alpha)
56
+
57
+ # Ensure output has .safetensors extension
58
+ if not args.output.endswith(".safetensors"):
59
+ args.output += ".safetensors"
60
+
61
+ print(f"Saving merged FP8 model...")
62
+ save_file(merged_dict, args.output)
63
+
64
+ print(f"Done! New file: {os.path.abspath(args.output)}")
65
+ print("Test it in ComfyUI or Diffusers — adjust alpha if motion/reasoning balance feels off.")
66
+
67
+ if __name__ == "__main__":
68
+ main()