File size: 1,416 Bytes
37163a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import torch.nn as nn


from models import register
from .model import Encoder, Decoder


default_configs = {
    'f8c4': dict(
        double_z=False,
        z_channels=64,
        resolution=256,
        in_channels=3,
        out_ch=3,
        ch=128,
        ch_mult=[1, 2, 2, 4, 4, 4, 4, 8, 8],
        num_res_blocks=2,
        attn_resolutions=[],
        dropout=0.0,
        give_pre_end=True,
    ),
    'f16c8': dict(
        double_z=False,
        z_channels=8,
        resolution=256,
        in_channels=3,
        out_ch=3,
        ch=128,
        ch_mult=[1, 2, 4, 4, 4],
        num_res_blocks=2,
        attn_resolutions=[],
        dropout=0.0,
        give_pre_end=True,
    ),
}


@register('vqgan_encoder')
def make_vqgan_encoder(config_name, **kwargs):
    encoder_kwargs = default_configs[config_name]
    encoder_kwargs.update(kwargs)
    enc_out_channels = encoder_kwargs['z_channels'] * (2 if encoder_kwargs['double_z'] else 1)
    return nn.Sequential(
        Encoder(**encoder_kwargs),
        nn.Conv2d(enc_out_channels, enc_out_channels, 1),
    )


@register('vqgan_decoder')
def make_vqgan_decoder(config_name, **kwargs):
    decoder_kwargs = default_configs[config_name]
    decoder_kwargs.update(kwargs)
    dec_in_channels = decoder_kwargs['z_channels']
    return nn.Sequential(
        nn.Conv2d(dec_in_channels, dec_in_channels, 1),
        Decoder(**decoder_kwargs),
    )