Xsmos commited on
Commit
113cc01
·
verified ·
1 Parent(s): d30bfa5
Files changed (1) hide show
  1. diffusion.py +7 -7
diffusion.py CHANGED
@@ -236,12 +236,12 @@ class TrainConfig:
236
 
237
  # dim = 2
238
  dim = 3
239
- stride = (2,2) if dim == 2 else (2,2,2)
240
  num_image = 1000#32000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560
241
  batch_size = 1#2#50#20#2#100 # 10
242
  n_epoch = 8#4# 10#50#20#20#2#5#25 # 120
243
  HII_DIM = 64
244
- num_redshift = 64#512#128#64#512#256#256#64#512#128
245
  channel = 1
246
  img_shape = (channel, HII_DIM, num_redshift) if dim == 2 else (channel, HII_DIM, HII_DIM, num_redshift)
247
 
@@ -586,7 +586,7 @@ def train(rank, world_size):
586
 
587
  ddp_setup(rank, world_size)
588
 
589
- num_train_image_list = [10]#[200]#[1600,3200,6400,12800,25600]
590
  for i, num_image in enumerate(num_train_image_list):
591
  config.num_image = num_image
592
  # config.world_size = world_size
@@ -677,9 +677,9 @@ if __name__ == "__main__":
677
  world_size = torch.cuda.device_count()
678
  print(f" sampling, world_size = {world_size} ".center(100,'-'))
679
  # num_train_image_list = [1600,3200,6400,12800,25600]
680
- num_train_image_list = [2000]
681
- num_new_img_per_gpu = 8
682
- max_num_img_per_gpu = 2
683
 
684
  params = torch.tensor([4.4, 131.341])
685
 
@@ -690,7 +690,7 @@ if __name__ == "__main__":
690
 
691
  for num_image in num_train_image_list:
692
  config.num_image = num_image
693
- config.resume = f"./outputs/model_state-N{num_image}-epoch3-device0"
694
 
695
  # print("ddpm21cm = DDPM21CM(config)")
696
  manager = mp.Manager()
 
236
 
237
  # dim = 2
238
  dim = 3
239
+ stride = (2,2) if dim == 2 else (2,2,4)
240
  num_image = 1000#32000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560
241
  batch_size = 1#2#50#20#2#100 # 10
242
  n_epoch = 8#4# 10#50#20#20#2#5#25 # 120
243
  HII_DIM = 64
244
+ num_redshift = 512#128#64#512#256#256#64#512#128
245
  channel = 1
246
  img_shape = (channel, HII_DIM, num_redshift) if dim == 2 else (channel, HII_DIM, HII_DIM, num_redshift)
247
 
 
586
 
587
  ddp_setup(rank, world_size)
588
 
589
+ num_train_image_list = [3200]#[200]#[1600,3200,6400,12800,25600]
590
  for i, num_image in enumerate(num_train_image_list):
591
  config.num_image = num_image
592
  # config.world_size = world_size
 
677
  world_size = torch.cuda.device_count()
678
  print(f" sampling, world_size = {world_size} ".center(100,'-'))
679
  # num_train_image_list = [1600,3200,6400,12800,25600]
680
+ num_train_image_list = [3200]
681
+ num_new_img_per_gpu = 9
682
+ max_num_img_per_gpu = 1
683
 
684
  params = torch.tensor([4.4, 131.341])
685
 
 
690
 
691
  for num_image in num_train_image_list:
692
  config.num_image = num_image
693
+ config.resume = f"./outputs/model_state-N{num_image}-epoch7-device0"
694
 
695
  # print("ddpm21cm = DDPM21CM(config)")
696
  manager = mp.Manager()