from collections import OrderedDict from typing import * import torch import tqdm def merge( path_a: str, path_b: str, path_c: str, alpha: float, weights: Dict[str, float], method: str, ): def extract(ckpt: Dict[str, Any]): a = ckpt["model"] opt = OrderedDict() opt["weight"] = {} for key in a.keys(): if "enc_q" in key: continue opt["weight"][key] = a[key] return opt def load_weight(path: str): print(f"Loading {path}...") state_dict = torch.load(path, map_location="cpu") if "model" in state_dict: weight = extract(state_dict) else: weight = state_dict["weight"] return weight, state_dict def get_alpha(key: str): try: filtered = sorted( [x for x in weights.keys() if key.startswith(x)], key=len, reverse=True ) if len(filtered) < 1: return alpha return weights[filtered[0]] except: return alpha weight_a, state_dict = load_weight(path_a) weight_b, _ = load_weight(path_b) if path_c is not None: weight_c, _ = load_weight(path_c) if sorted(list(weight_a.keys())) != sorted(list(weight_b.keys())): raise RuntimeError("Failed to merge models.") merged = OrderedDict() merged["weight"] = {} def merge_weight(a, b, c, alpha): if method == "weight_sum": return (1 - alpha) * a + alpha * b elif method == "add_diff": return a + (b - c) * alpha for key in tqdm.tqdm(weight_a.keys()): a = get_alpha(key) if path_c is not None: merged["weight"][key] = merge_weight( weight_a[key], weight_b[key], weight_c[key], a ) else: merged["weight"][key] = merge_weight(weight_a[key], weight_b[key], None, a) merged["config"] = state_dict["config"] merged["params"] = state_dict["params"] if "params" in state_dict else None merged["version"] = state_dict.get("version", "v1") merged["sr"] = state_dict["sr"] merged["f0"] = state_dict["f0"] merged["info"] = state_dict["info"] merged["embedder_name"] = ( state_dict["embedder_name"] if "embedder_name" in state_dict else None ) merged["embedder_output_layer"] = state_dict.get("embedder_output_layer", "12") return merged