File size: 1,044 Bytes
37163a6
 
 
 
 
 
 
 
 
2279ae0
 
 
37163a6
 
 
2279ae0
 
 
 
37163a6
 
 
 
 
 
 
 
 
 
2279ae0
37163a6
 
2279ae0
37163a6
 
 
 
 
 
 
2279ae0
37163a6
2279ae0
37163a6
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import torch.nn as nn


from models import register
from .model import Encoder, Decoder, WNConv1d


default_configs = {
    'snake': dict(
        d_model=64,
        strides=[2, 4, 5, 8],
        d_latent=64,
        d_in=1,
        activation='snake',
    ),
    'snakebeta': dict(
        d_model=64,
        strides=[2, 4, 5, 8],
        d_latent=64,
        d_in=1,
        activation='snakebeta',
    ),
}


@register('dac_encoder')
def make_dac_encoder(config_name, **kwargs):
    encoder_kwargs = default_configs[config_name]
    encoder_kwargs.update(kwargs)
    d_model = encoder_kwargs['d_model']
    return nn.Sequential(
        Encoder(**encoder_kwargs),
        WNConv1d(d_model, d_model, kernel_size=1),
    )


@register('vqgan_decoder')
def make_vqgan_decoder(config_name, **kwargs):
    decoder_kwargs = default_configs[config_name]
    decoder_kwargs.update(kwargs)
    d_model = decoder_kwargs['d_model']
    return nn.Sequential(
        WNConv1d(d_model, d_model, kernel_size=1),
        Decoder(**decoder_kwargs),
    )