Xsmos commited on
Commit
0afd45d
·
verified ·
1 Parent(s): 4b76227
Files changed (3) hide show
  1. context_unet.py +1 -1
  2. diffusion.py +18 -14
  3. phoenix_diffusion.sbatch +2 -3
context_unet.py CHANGED
@@ -334,7 +334,7 @@ class ContextUnet(nn.Module):
334
  elif image_size == 128:
335
  channel_mult = (1, 1, 2, 3, 4)
336
  elif image_size == 64:
337
- channel_mult = (1,2,2,4,4)#(1,1,2,2,4)#(1,1,1,2,2)#(0.5,1,1,2,2)#(1,1,2)#(1,2)#(1,1,2,2)#(1,1,2,2,4)#(2,2,4,4,4)#(1, 2, 4)#(2,4,4,4,8)#(1, 2, 2, 4, 4)#(1, 2, 2, 4, 8)#(1, 1, 2, 2, 4, 4)#(1, 2, 4, 8, 16)#(1, 2, 3, 4)#(1, 2, 4, 6, 8)#(1, 2, 2, 4)#(1, 2, 8, 8, 8)#(1, 2, 4)#(1, 2, 2, 4)#(0.5,1,2,2,4,4)#(1, 1, 2, 2, 4, 4)#
338
  elif image_size == 32:
339
  channel_mult = (1, 2, 2, 4)
340
  elif image_size == 28:
 
334
  elif image_size == 128:
335
  channel_mult = (1, 1, 2, 3, 4)
336
  elif image_size == 64:
337
+ channel_mult = (1,2,2,2,4)#(1,1,2,2,4)#(1,1,1,2,2)#(0.5,1,1,2,2)#(1,1,2)#(1,2)#(1,1,2,2)#(1,1,2,2,4)#(2,2,4,4,4)#(1, 2, 4)#(2,4,4,4,8)#(1, 2, 2, 4, 4)#(1, 2, 2, 4, 8)#(1, 1, 2, 2, 4, 4)#(1, 2, 4, 8, 16)#(1, 2, 3, 4)#(1, 2, 4, 6, 8)#(1, 2, 2, 4)#(1, 2, 8, 8, 8)#(1, 2, 4)#(1, 2, 2, 4)#(0.5,1,2,2,4,4)#(1, 1, 2, 2, 4, 4)#
338
  elif image_size == 32:
339
  channel_mult = (1, 2, 2, 4)
340
  elif image_size == 28:
diffusion.py CHANGED
@@ -271,7 +271,7 @@ class TrainConfig:
271
  stride = (2,4) if dim == 2 else (2,2,2)
272
  num_image = 2000#480#1200#120#3000#300#3000#6000#30#60#6000#1000#2000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560
273
  batch_size = 1#1#10#50#10#50#20#50#1#2#50#20#2#100 # 10
274
- n_epoch = 15#1#50#10#1#50#1#50#5#50#5#50#100#50#100#30#120#5#4# 10#50#20#20#2#5#25 # 120
275
  HII_DIM = 64
276
  num_redshift = 64#256#512#256#512#256#512#256#512#64#512#64#512#64#256CUDAoom#128#64#512#128#64#512#256#256#64#512#128
277
  channel = 1
@@ -304,7 +304,7 @@ class TrainConfig:
304
  lrate = 1e-4
305
  lr_warmup_steps = 0#5#00
306
  output_dir = "./outputs/"
307
- save_name = os.path.join(output_dir, 'model_state')
308
  # save_period = 1 #10 # the period of saving model
309
  # cond = True # if training using the conditional information
310
  # lr_decay = False #True# if using the learning rate decay
@@ -378,7 +378,7 @@ class DDPM21CM:
378
  # )
379
  # config = TrainConfig()
380
  # date = datetime.datetime.now().strftime("%m%d-%H%M")
381
- config.run_name = datetime.datetime.now().strftime("%m%d-%H%M%S") # the unique name of each experiment
382
  self.config = config
383
  # dataset = Dataset4h5(config.dataset_name, num_image=config.num_image, HII_DIM=config.HII_DIM, num_redshift=config.num_redshift, drop_prob=config.drop_prob, dim=config.dim)
384
  # # self.shape_loaded = dataset.images.shape
@@ -537,7 +537,7 @@ class DDPM21CM:
537
  loss = F.mse_loss(noise, noise_pred)
538
  #print(f"loss.dtype =", loss.dtype)
539
  self.accelerator.backward(loss)
540
- self.accelerator.clip_grad_norm_(self.nn_model.parameters(), 1)
541
  self.optimizer.step()
542
  self.lr_scheduler.step()
543
  self.optimizer.zero_grad()
@@ -593,7 +593,7 @@ class DDPM21CM:
593
  'unet_state_dict': self.nn_model.module.state_dict(),
594
  # 'ema_unet_state_dict': self.ema_model.state_dict(),
595
  }
596
- save_name = self.config.save_name+f"-N{self.config.num_image}-device_count{self.config.world_size}-node{int(os.environ['SLURM_NNODES'])}-epoch{ep}-{socket.gethostbyname(socket.gethostname())}"
597
  torch.save(model_state, save_name)
598
  print(f'cuda:{torch.cuda.current_device()}/{self.config.global_rank} saved model at ' + save_name)
599
  # print('saved model at ' + config.save_dir + f"model_epoch_{ep}_test_{config.run_name}.pth")
@@ -676,13 +676,15 @@ class DDPM21CM:
676
 
677
  if save:
678
  # np.save(os.path.join(self.config.output_dir, f"{self.config.run_name}{'ema' if ema else ''}.npy"), x_last)
679
- savetime = datetime.datetime.now().strftime("%m%d-%H%M%S")
680
- savename = os.path.join(self.config.output_dir, f"Tvir{params_backup[0]}-zeta{params_backup[1]}-N{self.config.num_image}-device{self.config.global_rank}-{os.path.basename(self.config.resume)}-{savetime}{'ema' if ema else ''}.npy")
 
 
681
  np.save(savename, x_last)
682
  print(f"{socket.gethostbyname(socket.gethostname())} cuda:{torch.cuda.current_device()}/{self.config.global_rank} saved images of shape {x_last.shape} to {savename}")
683
 
684
  if entire:
685
- savename = os.path.join(self.config.output_dir, f"Tvir{params_backup[0]}-zeta{params_backup[1]}-N{self.config.num_image}-device{self.config.global_rank}-{os.path.basename(self.config.resume)}-{savetime}{'ema' if ema else ''}_entire.npy")
686
  np.save(savename, x_entire)
687
  print(f"{socket.gethostbyname(socket.gethostname())} cuda:{torch.cuda.current_device()}/{self.config.global_rank} saved images of shape {x_entire.shape} to {savename}")
688
  # else:
@@ -690,13 +692,13 @@ class DDPM21CM:
690
  # %%
691
 
692
  #num_train_image_list = [6000]#[60]#[8000]#[1000]#[100]#
693
- def train(rank, world_size, local_world_size, master_addr, master_port):
694
  global_rank = rank + local_world_size * int(os.environ["SLURM_NODEID"])
695
  ddp_setup(global_rank, world_size, master_addr, master_port)
696
  torch.cuda.set_device(rank)
697
  #print(f"rank = {rank}, global_rank = {global_rank}, world_size = {world_size}, local_world_size = {local_world_size}")
698
 
699
- config = TrainConfig()
700
  config.device = f"cuda:{rank}"
701
  config.world_size = local_world_size
702
  config.global_rank = global_rank
@@ -753,12 +755,14 @@ if __name__ == "__main__":
753
  total_nodes = int(os.environ["SLURM_NNODES"])
754
  world_size = local_world_size * total_nodes #6#int(os.environ["SLURM_NTASKS"])
755
 
 
 
756
  ############################ training ################################
757
  if args.train == 1:
758
  print(f" training, ip_addr = {socket.gethostbyname(socket.gethostname())}, master_addr = {master_addr}, local_world_size = {local_world_size}, world_size = {world_size} ".center(120,'-'))
759
  mp.spawn(
760
  train,
761
- args=(world_size, local_world_size, master_addr, master_port),
762
  nprocs=local_world_size,
763
  join=True,
764
  )
@@ -766,11 +770,11 @@ if __name__ == "__main__":
766
  if args.train == 0:
767
  num_new_img_per_gpu = args.num_new_img_per_gpu#200#4#200
768
  max_num_img_per_gpu = args.max_num_img_per_gpu#40#2#20
769
- config = TrainConfig()
770
  #config.world_size = world_size
771
- config.dtype = torch.float32
772
  config.resume = args.resume
773
- config.gradient_accumulation_steps = args.gradient_accumulation_steps
774
  # config.resume = f"./outputs/model_state-N30-device_count3-epoch4-172.27.149.181"
775
  # config.resume = f"./outputs/model_state-N{config.num_image}-device_count{world_size}-epoch{config.n_epoch-1}"
776
  # config.resume = f"./outputs/model_state-N{config.num_image}-device_count1-epoch{config.n_epoch-1}"
 
271
  stride = (2,4) if dim == 2 else (2,2,2)
272
  num_image = 2000#480#1200#120#3000#300#3000#6000#30#60#6000#1000#2000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560
273
  batch_size = 1#1#10#50#10#50#20#50#1#2#50#20#2#100 # 10
274
+ n_epoch = 20#1#50#10#1#50#1#50#5#50#5#50#100#50#100#30#120#5#4# 10#50#20#20#2#5#25 # 120
275
  HII_DIM = 64
276
  num_redshift = 64#256#512#256#512#256#512#256#512#64#512#64#512#64#256CUDAoom#128#64#512#128#64#512#256#256#64#512#128
277
  channel = 1
 
304
  lrate = 1e-4
305
  lr_warmup_steps = 0#5#00
306
  output_dir = "./outputs/"
307
+ save_name = os.path.join(output_dir, 'model')
308
  # save_period = 1 #10 # the period of saving model
309
  # cond = True # if training using the conditional information
310
  # lr_decay = False #True# if using the learning rate decay
 
378
  # )
379
  # config = TrainConfig()
380
  # date = datetime.datetime.now().strftime("%m%d-%H%M")
381
+ config.run_name = datetime.datetime.now().strftime("%d%H%M%S") # the unique name of each experiment
382
  self.config = config
383
  # dataset = Dataset4h5(config.dataset_name, num_image=config.num_image, HII_DIM=config.HII_DIM, num_redshift=config.num_redshift, drop_prob=config.drop_prob, dim=config.dim)
384
  # # self.shape_loaded = dataset.images.shape
 
537
  loss = F.mse_loss(noise, noise_pred)
538
  #print(f"loss.dtype =", loss.dtype)
539
  self.accelerator.backward(loss)
540
+ #self.accelerator.clip_grad_norm_(self.nn_model.parameters(), 1)
541
  self.optimizer.step()
542
  self.lr_scheduler.step()
543
  self.optimizer.zero_grad()
 
593
  'unet_state_dict': self.nn_model.module.state_dict(),
594
  # 'ema_unet_state_dict': self.ema_model.state_dict(),
595
  }
596
+ save_name = self.config.save_name+f"-N{self.config.num_image}-device_count{self.config.world_size}-node{int(os.environ['SLURM_NNODES'])}-epoch{ep}-{self.config.run_name}"
597
  torch.save(model_state, save_name)
598
  print(f'cuda:{torch.cuda.current_device()}/{self.config.global_rank} saved model at ' + save_name)
599
  # print('saved model at ' + config.save_dir + f"model_epoch_{ep}_test_{config.run_name}.pth")
 
676
 
677
  if save:
678
  # np.save(os.path.join(self.config.output_dir, f"{self.config.run_name}{'ema' if ema else ''}.npy"), x_last)
679
+ savetime = datetime.datetime.now().strftime("%d%H%M%S")
680
+ savename = os.path.join(self.config.output_dir, f"Tvir{params_backup[0]:.3f}-zeta{params_backup[1]:.3f}-N{self.config.num_image}-device{self.config.global_rank}-{os.path.basename(self.config.resume)}-{savetime}{'ema' if ema else ''}.npy")
681
+ if not os.path.exists(self.config.output_dir):
682
+ os.makedirs(self.config.output_dir)
683
  np.save(savename, x_last)
684
  print(f"{socket.gethostbyname(socket.gethostname())} cuda:{torch.cuda.current_device()}/{self.config.global_rank} saved images of shape {x_last.shape} to {savename}")
685
 
686
  if entire:
687
+ savename = os.path.join(self.config.output_dir, f"Tvir{params_backup[0]:.3f}-zeta{params_backup[1]:.3f}-N{self.config.num_image}-device{self.config.global_rank}-{os.path.basename(self.config.resume)}-{savetime}{'ema' if ema else ''}_entire.npy")
688
  np.save(savename, x_entire)
689
  print(f"{socket.gethostbyname(socket.gethostname())} cuda:{torch.cuda.current_device()}/{self.config.global_rank} saved images of shape {x_entire.shape} to {savename}")
690
  # else:
 
692
  # %%
693
 
694
  #num_train_image_list = [6000]#[60]#[8000]#[1000]#[100]#
695
+ def train(rank, world_size, local_world_size, master_addr, master_port, config):
696
  global_rank = rank + local_world_size * int(os.environ["SLURM_NODEID"])
697
  ddp_setup(global_rank, world_size, master_addr, master_port)
698
  torch.cuda.set_device(rank)
699
  #print(f"rank = {rank}, global_rank = {global_rank}, world_size = {world_size}, local_world_size = {local_world_size}")
700
 
701
+ #config = TrainConfig()
702
  config.device = f"cuda:{rank}"
703
  config.world_size = local_world_size
704
  config.global_rank = global_rank
 
755
  total_nodes = int(os.environ["SLURM_NNODES"])
756
  world_size = local_world_size * total_nodes #6#int(os.environ["SLURM_NTASKS"])
757
 
758
+ config = TrainConfig()
759
+ config.gradient_accumulation_steps = args.gradient_accumulation_steps
760
  ############################ training ################################
761
  if args.train == 1:
762
  print(f" training, ip_addr = {socket.gethostbyname(socket.gethostname())}, master_addr = {master_addr}, local_world_size = {local_world_size}, world_size = {world_size} ".center(120,'-'))
763
  mp.spawn(
764
  train,
765
+ args=(world_size, local_world_size, master_addr, master_port, config),
766
  nprocs=local_world_size,
767
  join=True,
768
  )
 
770
  if args.train == 0:
771
  num_new_img_per_gpu = args.num_new_img_per_gpu#200#4#200
772
  max_num_img_per_gpu = args.max_num_img_per_gpu#40#2#20
773
+ #config = TrainConfig()
774
  #config.world_size = world_size
775
+ #config.dtype = torch.float32
776
  config.resume = args.resume
777
+ #config.gradient_accumulation_steps = args.gradient_accumulation_steps
778
  # config.resume = f"./outputs/model_state-N30-device_count3-epoch4-172.27.149.181"
779
  # config.resume = f"./outputs/model_state-N{config.num_image}-device_count{world_size}-epoch{config.n_epoch-1}"
780
  # config.resume = f"./outputs/model_state-N{config.num_image}-device_count1-epoch{config.n_epoch-1}"
phoenix_diffusion.sbatch CHANGED
@@ -7,7 +7,6 @@
7
  #SBATCH --mem-per-gpu=16G # Memory per core
8
  #SBATCH -t 08:00:00 # Duration of the job (Ex: 15 mins)
9
  #SBATCH -oReport-%j # Combined output and error messages file
10
- #SBATCH --error=error-%j
11
  #SBATCH --mail-type=BEGIN,END,FAIL # Mail preferences
12
 
13
  #module load gcc/10.3.0-o57x6h
@@ -31,9 +30,9 @@ export MASTER_PORT=$MASTER_PORT
31
 
32
  srun python diffusion.py \
33
  --train 1 \
34
- --resume outputs/model_state-N3000-device_count1-node3-epoch39-172.27.144.175 \
35
  --num_new_img_per_gpu 50 \
36
  --max_num_img_per_gpu 2 \
37
- --gradient_accumulation_steps 1 \
38
  ######################################################################################
39
 
 
7
  #SBATCH --mem-per-gpu=16G # Memory per core
8
  #SBATCH -t 08:00:00 # Duration of the job (Ex: 15 mins)
9
  #SBATCH -oReport-%j # Combined output and error messages file
 
10
  #SBATCH --mail-type=BEGIN,END,FAIL # Mail preferences
11
 
12
  #module load gcc/10.3.0-o57x6h
 
30
 
31
  srun python diffusion.py \
32
  --train 1 \
33
+ --resume outputs/model_state-N2000-device_count1-node8-epoch19-172.27.144.173 \
34
  --num_new_img_per_gpu 50 \
35
  --max_num_img_per_gpu 2 \
36
+ --gradient_accumulation_steps 40 \
37
  ######################################################################################
38