File size: 6,450 Bytes
64ec292
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import json


def create_model_from_config(model_config):
    model_type = model_config.get("model_type", None)

    assert model_type is not None, "model_type must be specified in model config"

    if model_type == "autoencoder":
        from .autoencoders import create_autoencoder_from_config

        return create_autoencoder_from_config(model_config)
    elif model_type == "diffusion_uncond":
        from .diffusion import create_diffusion_uncond_from_config

        return create_diffusion_uncond_from_config(model_config)
    elif (
        model_type == "diffusion_cond"
        or model_type == "diffusion_cond_inpaint"
        or model_type == "diffusion_prior"
    ):
        from .diffusion import create_diffusion_cond_from_config

        return create_diffusion_cond_from_config(model_config)
    elif model_type == "diffusion_autoencoder":
        from .autoencoders import create_diffAE_from_config

        return create_diffAE_from_config(model_config)
    elif model_type == "lm":
        from .lm import create_audio_lm_from_config

        return create_audio_lm_from_config(model_config)
    else:
        raise NotImplementedError(f"Unknown model type: {model_type}")


def create_model_from_config_path(model_config_path):
    with open(model_config_path) as f:
        model_config = json.load(f)

    return create_model_from_config(model_config)


def create_pretransform_from_config(pretransform_config, sample_rate):
    pretransform_type = pretransform_config.get("type", None)

    assert pretransform_type is not None, (
        "type must be specified in pretransform config"
    )

    if pretransform_type == "autoencoder":
        from .autoencoders import create_autoencoder_from_config
        from .pretransforms import AutoencoderPretransform

        # Create fake top-level config to pass sample rate to autoencoder constructor
        # This is a bit of a hack but it keeps us from re-defining the sample rate in the config
        autoencoder_config = {
            "sample_rate": sample_rate,
            "model": pretransform_config["config"],
        }
        autoencoder = create_autoencoder_from_config(autoencoder_config)

        scale = pretransform_config.get("scale", 1.0)
        model_half = pretransform_config.get("model_half", False)
        iterate_batch = pretransform_config.get("iterate_batch", False)
        chunked = pretransform_config.get("chunked", False)

        pretransform = AutoencoderPretransform(
            autoencoder,
            scale=scale,
            model_half=model_half,
            iterate_batch=iterate_batch,
            chunked=chunked,
        )
    elif pretransform_type == "wavelet":
        from .pretransforms import WaveletPretransform

        wavelet_config = pretransform_config["config"]
        channels = wavelet_config["channels"]
        levels = wavelet_config["levels"]
        wavelet = wavelet_config["wavelet"]

        pretransform = WaveletPretransform(channels, levels, wavelet)
    elif pretransform_type == "pqmf":
        from .pretransforms import PQMFPretransform

        pqmf_config = pretransform_config["config"]
        pretransform = PQMFPretransform(**pqmf_config)
    elif pretransform_type == "dac_pretrained":
        from .pretransforms import PretrainedDACPretransform

        pretrained_dac_config = pretransform_config["config"]
        pretransform = PretrainedDACPretransform(**pretrained_dac_config)
    elif pretransform_type == "audiocraft_pretrained":
        from .pretransforms import AudiocraftCompressionPretransform

        audiocraft_config = pretransform_config["config"]
        pretransform = AudiocraftCompressionPretransform(**audiocraft_config)
    else:
        raise NotImplementedError(f"Unknown pretransform type: {pretransform_type}")

    enable_grad = pretransform_config.get("enable_grad", False)
    pretransform.enable_grad = enable_grad

    pretransform.eval().requires_grad_(pretransform.enable_grad)

    return pretransform


def create_bottleneck_from_config(bottleneck_config):
    bottleneck_type = bottleneck_config.get("type", None)

    assert bottleneck_type is not None, "type must be specified in bottleneck config"

    if bottleneck_type == "tanh":
        from .bottleneck import TanhBottleneck

        bottleneck = TanhBottleneck()
    elif bottleneck_type == "vae":
        from .bottleneck import VAEBottleneck

        bottleneck = VAEBottleneck()
    elif bottleneck_type == "rvq":
        from .bottleneck import RVQBottleneck

        quantizer_params = {
            "dim": 128,
            "codebook_size": 1024,
            "num_quantizers": 8,
            "decay": 0.99,
            "kmeans_init": True,
            "kmeans_iters": 50,
            "threshold_ema_dead_code": 2,
        }

        quantizer_params.update(bottleneck_config["config"])

        bottleneck = RVQBottleneck(**quantizer_params)
    elif bottleneck_type == "dac_rvq":
        from .bottleneck import DACRVQBottleneck

        bottleneck = DACRVQBottleneck(**bottleneck_config["config"])

    elif bottleneck_type == "rvq_vae":
        from .bottleneck import RVQVAEBottleneck

        quantizer_params = {
            "dim": 128,
            "codebook_size": 1024,
            "num_quantizers": 8,
            "decay": 0.99,
            "kmeans_init": True,
            "kmeans_iters": 50,
            "threshold_ema_dead_code": 2,
        }

        quantizer_params.update(bottleneck_config["config"])

        bottleneck = RVQVAEBottleneck(**quantizer_params)

    elif bottleneck_type == "dac_rvq_vae":
        from .bottleneck import DACRVQVAEBottleneck

        bottleneck = DACRVQVAEBottleneck(**bottleneck_config["config"])
    elif bottleneck_type == "l2_norm":
        from .bottleneck import L2Bottleneck

        bottleneck = L2Bottleneck()
    elif bottleneck_type == "wasserstein":
        from .bottleneck import WassersteinBottleneck

        bottleneck = WassersteinBottleneck(**bottleneck_config.get("config", {}))
    elif bottleneck_type == "fsq":
        from .bottleneck import FSQBottleneck

        bottleneck = FSQBottleneck(**bottleneck_config["config"])
    else:
        raise NotImplementedError(f"Unknown bottleneck type: {bottleneck_type}")

    requires_grad = bottleneck_config.get("requires_grad", True)
    if not requires_grad:
        for param in bottleneck.parameters():
            param.requires_grad = False

    return bottleneck