Xsmos commited on
Commit
f49a1c6
·
verified ·
1 Parent(s): d08708f
Files changed (2) hide show
  1. context_unet.py +1 -1
  2. diffusion.py +2 -2
context_unet.py CHANGED
@@ -330,7 +330,7 @@ class ContextUnet(nn.Module):
330
  elif image_size == 128:
331
  channel_mult = (1, 1, 2, 3, 4)
332
  elif image_size == 64:
333
- channel_mult = (1, 2, 4, 4, 4)#(1, 2, 2, 4)#(1, 2, 8, 8, 8)#(1, 2, 4)#(1, 2, 2, 4)#(0.5,1,2,2,4,4)#(1, 1, 2, 2, 4, 4)#
334
  elif image_size == 32:
335
  channel_mult = (1, 2, 2, 4)
336
  elif image_size == 28:
 
330
  elif image_size == 128:
331
  channel_mult = (1, 1, 2, 3, 4)
332
  elif image_size == 64:
333
+ channel_mult = (1, 2, 4, 6, 8)#(1, 2, 2, 4)#(1, 2, 8, 8, 8)#(1, 2, 4)#(1, 2, 2, 4)#(0.5,1,2,2,4,4)#(1, 1, 2, 2, 4, 4)#
334
  elif image_size == 32:
335
  channel_mult = (1, 2, 2, 4)
336
  elif image_size == 28:
diffusion.py CHANGED
@@ -239,7 +239,7 @@ class TrainConfig:
239
  stride = (2,2) if dim == 2 else (2,2,4)
240
  num_image = 1000#2000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560
241
  batch_size = 1#2#50#20#2#100 # 10
242
- n_epoch = 8#4# 10#50#20#20#2#5#25 # 120
243
  HII_DIM = 64
244
  num_redshift = 512#128#64#512#256#256#64#512#128
245
  channel = 1
@@ -586,7 +586,7 @@ def train(rank, world_size):
586
 
587
  ddp_setup(rank, world_size)
588
 
589
- num_train_image_list = [200]#[3200]#[200]#[1600,3200,6400,12800,25600]
590
  for i, num_image in enumerate(num_train_image_list):
591
  config.num_image = num_image
592
  # config.world_size = world_size
 
239
  stride = (2,2) if dim == 2 else (2,2,4)
240
  num_image = 1000#2000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560
241
  batch_size = 1#2#50#20#2#100 # 10
242
+ n_epoch = 2#4# 10#50#20#20#2#5#25 # 120
243
  HII_DIM = 64
244
  num_redshift = 512#128#64#512#256#256#64#512#128
245
  channel = 1
 
586
 
587
  ddp_setup(rank, world_size)
588
 
589
+ num_train_image_list = [3200]#[3200]#[200]#[1600,3200,6400,12800,25600]
590
  for i, num_image in enumerate(num_train_image_list):
591
  config.num_image = num_image
592
  # config.world_size = world_size