Spaces:
Sleeping
Sleeping
File size: 1,416 Bytes
37163a6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
import torch.nn as nn
from models import register
from .model import Encoder, Decoder
default_configs = {
'f8c4': dict(
double_z=False,
z_channels=64,
resolution=256,
in_channels=3,
out_ch=3,
ch=128,
ch_mult=[1, 2, 2, 4, 4, 4, 4, 8, 8],
num_res_blocks=2,
attn_resolutions=[],
dropout=0.0,
give_pre_end=True,
),
'f16c8': dict(
double_z=False,
z_channels=8,
resolution=256,
in_channels=3,
out_ch=3,
ch=128,
ch_mult=[1, 2, 4, 4, 4],
num_res_blocks=2,
attn_resolutions=[],
dropout=0.0,
give_pre_end=True,
),
}
@register('vqgan_encoder')
def make_vqgan_encoder(config_name, **kwargs):
encoder_kwargs = default_configs[config_name]
encoder_kwargs.update(kwargs)
enc_out_channels = encoder_kwargs['z_channels'] * (2 if encoder_kwargs['double_z'] else 1)
return nn.Sequential(
Encoder(**encoder_kwargs),
nn.Conv2d(enc_out_channels, enc_out_channels, 1),
)
@register('vqgan_decoder')
def make_vqgan_decoder(config_name, **kwargs):
decoder_kwargs = default_configs[config_name]
decoder_kwargs.update(kwargs)
dec_in_channels = decoder_kwargs['z_channels']
return nn.Sequential(
nn.Conv2d(dec_in_channels, dec_in_channels, 1),
Decoder(**decoder_kwargs),
)
|