Spaces:
Sleeping
Sleeping
File size: 2,464 Bytes
b5a064f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
from collections import OrderedDict
from typing import *
import torch
import tqdm
def merge(
path_a: str,
path_b: str,
path_c: str,
alpha: float,
weights: Dict[str, float],
method: str,
):
def extract(ckpt: Dict[str, Any]):
a = ckpt["model"]
opt = OrderedDict()
opt["weight"] = {}
for key in a.keys():
if "enc_q" in key:
continue
opt["weight"][key] = a[key]
return opt
def load_weight(path: str):
print(f"Loading {path}...")
state_dict = torch.load(path, map_location="cpu")
if "model" in state_dict:
weight = extract(state_dict)
else:
weight = state_dict["weight"]
return weight, state_dict
def get_alpha(key: str):
try:
filtered = sorted(
[x for x in weights.keys() if key.startswith(x)], key=len, reverse=True
)
if len(filtered) < 1:
return alpha
return weights[filtered[0]]
except:
return alpha
weight_a, state_dict = load_weight(path_a)
weight_b, _ = load_weight(path_b)
if path_c is not None:
weight_c, _ = load_weight(path_c)
if sorted(list(weight_a.keys())) != sorted(list(weight_b.keys())):
raise RuntimeError("Failed to merge models.")
merged = OrderedDict()
merged["weight"] = {}
def merge_weight(a, b, c, alpha):
if method == "weight_sum":
return (1 - alpha) * a + alpha * b
elif method == "add_diff":
return a + (b - c) * alpha
for key in tqdm.tqdm(weight_a.keys()):
a = get_alpha(key)
if path_c is not None:
merged["weight"][key] = merge_weight(
weight_a[key], weight_b[key], weight_c[key], a
)
else:
merged["weight"][key] = merge_weight(weight_a[key], weight_b[key], None, a)
merged["config"] = state_dict["config"]
merged["params"] = state_dict["params"] if "params" in state_dict else None
merged["version"] = state_dict.get("version", "v1")
merged["sr"] = state_dict["sr"]
merged["f0"] = state_dict["f0"]
merged["info"] = state_dict["info"]
merged["embedder_name"] = (
state_dict["embedder_name"] if "embedder_name" in state_dict else None
)
merged["embedder_output_layer"] = state_dict.get("embedder_output_layer", "12")
return merged
|