File size: 1,105 Bytes
6434535
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38

from modules.unet import UNetModel
from generative.networks.nets import VQVAE
from config import config

myUnet = UNetModel(
    image_size=config.image_size / config.r,
    model_channels=128,
    in_channels=8,
    out_channels=8,
    num_res_blocks=8,
    num_heads=8,
    attention_resolutions=(64, 32, 16, 8),
    num_heads_upsample=-1,
    num_head_channels=-1,
    resblock_updown=True,
    channel_mult=(1, 1, 2, 2, 4, 4),
    use_scale_shift_norm=True,
    use_new_attention_order=True
)

myVQGANModel = VQVAE(
    spatial_dims=2,
    in_channels=1,
    out_channels=1,
    num_channels=(128, 256, 512),
    num_res_channels=512,
    num_res_layers=2,
    downsample_parameters=((2, 4, 1, 1), (2, 4, 1, 1), (2, 4, 1, 1),),
    upsample_parameters=((2, 4, 1, 1, 0), (2, 4, 1, 1, 0), (2, 4, 1, 1, 0)),
    num_embeddings=1024,
    embedding_dim=4,
)

if __name__ == "__main__":
    print("Number of model parameters:", sum([p.numel() for p in myUnet.parameters()]))
    print("Number of model parameters:", sum([p.numel() for p in myVQGANModel.parameters()]))