File size: 3,614 Bytes
184401d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
# general settings
name: CodeFormer_stage3
model_type: CodeFormerJointModel
num_gpu: 8
manual_seed: 0

# dataset and data loader settings
datasets:
  train:
    name: FFHQ
    type: FFHQBlindJointDataset
    dataroot_gt: datasets/ffhq/ffhq_512
    filename_tmpl: '{}'
    io_backend:
      type: disk

    in_size: 512
    gt_size: 512
    mean: [0.5, 0.5, 0.5]
    std: [0.5, 0.5, 0.5]
    use_hflip: true
    use_corrupt: true

    blur_kernel_size: 41
    use_motion_kernel: false
    motion_kernel_prob: 0.001
    kernel_list: ['iso', 'aniso']
    kernel_prob: [0.5, 0.5]
    # small degradation in stageIII
    blur_sigma: [0.1, 10]
    downsample_range: [1, 12]
    noise_range: [0, 15]
    jpeg_range: [60, 100]
    # large degradation in stageII
    blur_sigma_large: [1, 15]
    downsample_range_large: [4, 30]
    noise_range_large: [0, 20]
    jpeg_range_large: [30, 80]

    latent_gt_path: ~ # without pre-calculated latent code
    # latent_gt_path: './experiments/pretrained_models/VQGAN/latent_gt_code1024.pth'

    # data loader
    num_worker_per_gpu: 1
    batch_size_per_gpu: 3
    dataset_enlarge_ratio: 100
    prefetch_mode: ~

  # val:
  #   name: CelebA-HQ-512
  #   type: PairedImageDataset
  #   dataroot_lq: datasets/faces/validation/lq
  #   dataroot_gt: datasets/faces/validation/gt
  #   io_backend:
  #     type: disk
  #   mean: [0.5, 0.5, 0.5]
  #   std: [0.5, 0.5, 0.5]
  #   scale: 1
    
# network structures
network_g:
  type: CodeFormer
  dim_embd: 512
  n_head: 8
  n_layers: 9
  codebook_size: 1024
  connect_list: ['32', '64', '128', '256']
  fix_modules: ['quantize','generator']

network_vqgan: # this config is needed if no pre-calculated latent
  type: VQAutoEncoder
  img_size: 512
  nf: 64
  ch_mult: [1, 2, 2, 4, 4, 8]
  quantizer: 'nearest'
  codebook_size: 1024

network_d:
  type: VQGANDiscriminator
  nc: 3
  ndf: 64
  n_layers: 4

# path
path:
  pretrain_network_g: './experiments/pretrained_models/CodeFormer_stage2/net_g_latest.pth' # pretrained G model in StageII 
  param_key_g: params_ema
  strict_load_g: false
  pretrain_network_d: './experiments/pretrained_models/CodeFormer_stage2/net_d_latest.pth' # pretrained D model in StageII
  resume_state: ~

# base_lr(4.5e-6)*bach_size(4)
train:
  use_hq_feat_loss: true
  feat_loss_weight: 1.0
  cross_entropy_loss: true
  entropy_loss_weight: 0.5
  scale_adaptive_gan_weight: 0.1

  optim_g:
    type: Adam
    lr: !!float 5e-5
    weight_decay: 0
    betas: [0.9, 0.99]
  optim_d:
    type: Adam
    lr: !!float 5e-5
    weight_decay: 0
    betas: [0.9, 0.99]

  scheduler:
    type: CosineAnnealingRestartLR
    periods: [150000]
    restart_weights: [1]
    eta_min: !!float 2e-5


  total_iter: 150000

  warmup_iter: -1  # no warm up
  ema_decay: 0.997

  pixel_opt:
    type: L1Loss
    loss_weight: 1.0
    reduction: mean

  perceptual_opt:
    type: LPIPSLoss
    loss_weight: 1.0
    use_input_norm: true
    range_norm: true

  gan_opt:
    type: GANLoss
    gan_type: hinge
    loss_weight: !!float 1.0 # adaptive_weighting

  use_adaptive_weight: true

  net_g_start_iter: 0
  net_d_iters: 1
  net_d_start_iter: 5001
  manual_seed: 0

# validation settings
val:
  val_freq: !!float 5e10 # no validation
  save_img: true

  metrics:
    psnr: # metric name, can be arbitrary
      type: calculate_psnr
      crop_border: 4
      test_y_channel: false

# logging settings
logger:
  print_freq: 100
  save_checkpoint_freq: !!float 5e3
  use_tb_logger: true
  wandb:
    project: ~
    resume_id: ~

# dist training settings
dist_params:
  backend: nccl
  port: 29413

find_unused_parameters: true