rvc_api / modules /merge.py
aryo100's picture
first commit
b5a064f
from collections import OrderedDict
from typing import *
import torch
import tqdm
def merge(
path_a: str,
path_b: str,
path_c: str,
alpha: float,
weights: Dict[str, float],
method: str,
):
def extract(ckpt: Dict[str, Any]):
a = ckpt["model"]
opt = OrderedDict()
opt["weight"] = {}
for key in a.keys():
if "enc_q" in key:
continue
opt["weight"][key] = a[key]
return opt
def load_weight(path: str):
print(f"Loading {path}...")
state_dict = torch.load(path, map_location="cpu")
if "model" in state_dict:
weight = extract(state_dict)
else:
weight = state_dict["weight"]
return weight, state_dict
def get_alpha(key: str):
try:
filtered = sorted(
[x for x in weights.keys() if key.startswith(x)], key=len, reverse=True
)
if len(filtered) < 1:
return alpha
return weights[filtered[0]]
except:
return alpha
weight_a, state_dict = load_weight(path_a)
weight_b, _ = load_weight(path_b)
if path_c is not None:
weight_c, _ = load_weight(path_c)
if sorted(list(weight_a.keys())) != sorted(list(weight_b.keys())):
raise RuntimeError("Failed to merge models.")
merged = OrderedDict()
merged["weight"] = {}
def merge_weight(a, b, c, alpha):
if method == "weight_sum":
return (1 - alpha) * a + alpha * b
elif method == "add_diff":
return a + (b - c) * alpha
for key in tqdm.tqdm(weight_a.keys()):
a = get_alpha(key)
if path_c is not None:
merged["weight"][key] = merge_weight(
weight_a[key], weight_b[key], weight_c[key], a
)
else:
merged["weight"][key] = merge_weight(weight_a[key], weight_b[key], None, a)
merged["config"] = state_dict["config"]
merged["params"] = state_dict["params"] if "params" in state_dict else None
merged["version"] = state_dict.get("version", "v1")
merged["sr"] = state_dict["sr"]
merged["f0"] = state_dict["f0"]
merged["info"] = state_dict["info"]
merged["embedder_name"] = (
state_dict["embedder_name"] if "embedder_name" in state_dict else None
)
merged["embedder_output_layer"] = state_dict.get("embedder_output_layer", "12")
return merged