File size: 2,464 Bytes
b5a064f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from collections import OrderedDict
from typing import *

import torch
import tqdm


def merge(
    path_a: str,
    path_b: str,
    path_c: str,
    alpha: float,
    weights: Dict[str, float],
    method: str,
):
    def extract(ckpt: Dict[str, Any]):
        a = ckpt["model"]
        opt = OrderedDict()
        opt["weight"] = {}
        for key in a.keys():
            if "enc_q" in key:
                continue
            opt["weight"][key] = a[key]
        return opt

    def load_weight(path: str):
        print(f"Loading {path}...")
        state_dict = torch.load(path, map_location="cpu")
        if "model" in state_dict:
            weight = extract(state_dict)
        else:
            weight = state_dict["weight"]
        return weight, state_dict

    def get_alpha(key: str):
        try:
            filtered = sorted(
                [x for x in weights.keys() if key.startswith(x)], key=len, reverse=True
            )
            if len(filtered) < 1:
                return alpha
            return weights[filtered[0]]
        except:
            return alpha

    weight_a, state_dict = load_weight(path_a)
    weight_b, _ = load_weight(path_b)
    if path_c is not None:
        weight_c, _ = load_weight(path_c)

    if sorted(list(weight_a.keys())) != sorted(list(weight_b.keys())):
        raise RuntimeError("Failed to merge models.")

    merged = OrderedDict()
    merged["weight"] = {}

    def merge_weight(a, b, c, alpha):
        if method == "weight_sum":
            return (1 - alpha) * a + alpha * b
        elif method == "add_diff":
            return a + (b - c) * alpha

    for key in tqdm.tqdm(weight_a.keys()):
        a = get_alpha(key)
        if path_c is not None:
            merged["weight"][key] = merge_weight(
                weight_a[key], weight_b[key], weight_c[key], a
            )
        else:
            merged["weight"][key] = merge_weight(weight_a[key], weight_b[key], None, a)
    merged["config"] = state_dict["config"]
    merged["params"] = state_dict["params"] if "params" in state_dict else None
    merged["version"] = state_dict.get("version", "v1")
    merged["sr"] = state_dict["sr"]
    merged["f0"] = state_dict["f0"]
    merged["info"] = state_dict["info"]
    merged["embedder_name"] = (
        state_dict["embedder_name"] if "embedder_name" in state_dict else None
    )
    merged["embedder_output_layer"] = state_dict.get("embedder_output_layer", "12")
    return merged