Xsmos commited on
Commit
bde1d63
·
verified ·
1 Parent(s): 8702435
Files changed (1) hide show
  1. diffusion.py +61 -50
diffusion.py CHANGED
@@ -65,17 +65,18 @@ from torch.nn.parallel import DistributedDataParallel as DDP
65
  from torch.distributed import init_process_group, destroy_process_group
66
  import torch.distributed as dist
67
 
 
 
68
  # %%
69
- def ddp_setup(rank: int, world_size: int):
70
  """
71
  Args:
72
  rank: Unique identifier of each process
73
  world_size: Total number of processes
74
  """
75
- os.environ["MASTER_ADDR"] = "localhost"
76
- os.environ["MASTER_PORT"] = "12355"
77
  # print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!ddp_setup, rank =", rank)
78
- torch.cuda.set_device(rank)
79
  init_process_group(backend="nccl", rank=rank, world_size=world_size)
80
 
81
  # %%
@@ -240,7 +241,7 @@ class TrainConfig:
240
  # dim = 2
241
  dim = 2
242
  stride = (2,4) if dim == 2 else (2,2,2)
243
- num_image = 1000#2000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560
244
  batch_size = 10#50#20#50#1#2#50#20#2#100 # 10
245
  n_epoch = 50#100#50#100#30#120#5#4# 10#50#20#20#2#5#25 # 120
246
  HII_DIM = 64
@@ -642,41 +643,30 @@ class DDPM21CM:
642
  return x_last
643
  # %%
644
 
645
- num_train_image_list = [6000]#[60]#[8000]#[1000]#[100]#
 
 
 
 
 
646
 
647
- def train(rank, world_size):
648
 
649
- # print("before ddp_setup")
650
- ddp_setup(rank, world_size)
651
- # print("after ddp_setup")
652
- # print("TrainConfig()")
653
  config = TrainConfig()
654
- config.device = f"cuda:{rank}"
655
- # print("torch.cuda.current_device(), config.device =", torch.cuda.current_device(), config.device)
656
  config.world_size = world_size
657
 
658
  #[3200]#[200]#[1600,3200,6400,12800,25600]
659
- for i, num_image in enumerate(num_train_image_list):
660
- config.num_image = num_image
661
  # config.world_size = world_size
662
  # print("ddpm21cm = DDPM21CM(config)")
663
  # print(f"config.device, torch.cuda.current_device() = {config.device}, {torch.cuda.current_device()}")
664
- ddpm21cm = DDPM21CM(config)
665
- # print(f" num_image = {ddpm21cm.config.num_image} ".center(50, '-'))
666
- print(f"run_name = {ddpm21cm.config.run_name}")
667
- ddpm21cm.train()
668
- destroy_process_group()
669
-
670
- if __name__ == "__main__":# and False:
671
- world_size = torch.cuda.device_count()
672
- print(f" training, world_size = {world_size} ".center(120,'-'))
673
- # torch.multiprocessing.set_start_method("spawn")
674
- # args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)
675
-
676
- mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)
677
- # notebook_launcher(ddpm21cm.train, num_processes=1, mixed_precision='fp16')
678
-
679
-
680
  # %%
681
 
682
  # def generate_samples(ddpm21cm, num_new_img_per_gpu, max_num_img_per_gpu, rank, world_size, params):
@@ -705,8 +695,8 @@ if __name__ == "__main__":# and False:
705
  # # else:
706
  # # return None
707
 
708
- def generate_samples(rank, world_size, config, num_new_img_per_gpu, max_num_img_per_gpu, return_dict, params):
709
- ddp_setup(rank, world_size)
710
  ddpm21cm = DDPM21CM(config)
711
 
712
  # generate_samples(ddpm21cm, num_new_img_per_gpu, max_num_img_per_gpu, rank, world_size, params)
@@ -729,28 +719,43 @@ def generate_samples(rank, world_size, config, num_new_img_per_gpu, max_num_img_
729
 
730
 
731
  if __name__ == "__main__":
732
- world_size = torch.cuda.device_count()
733
- # print(f" sampling, world_size = {world_size} ".center(120,'-'))
734
- # num_train_image_list = [1600,3200,6400,12800,25600]
735
- # num_train_image_list = [5000]
736
- num_new_img_per_gpu = 200
737
- max_num_img_per_gpu = 20
738
 
739
- # params = torch.tensor([4.4, 131.341])
 
 
 
740
 
741
- # print("config = TrainConfig()")
742
- config = TrainConfig()
743
- config.world_size = world_size
744
- # print("config.world_size = world_size")
 
 
 
 
 
 
745
 
746
- for num_image in num_train_image_list:
747
- config.num_image = num_image# // world_size
 
 
 
 
 
 
 
 
 
748
  config.resume = f"./outputs/model_state-N{config.num_image}-device_count{world_size}-epoch{config.n_epoch-1}"
749
  # config.resume = f"./outputs/model_state-N{config.num_image}-device_count1-epoch{config.n_epoch-1}"
750
 
751
- # print("ddpm21cm = DDPM21CM(config)")
752
- manager = mp.Manager()
753
- return_dict = manager.dict()
754
 
755
  params_pairs = [
756
  (4.4, 131.341),
@@ -759,9 +764,15 @@ if __name__ == "__main__":
759
  (5.477, 200),
760
  (4.8, 131.341),
761
  ]
 
762
  for params in params_pairs:
763
  print(f" sampling for {params}, world_size = {world_size} ".center(120,'-'))
764
- mp.spawn(generate_samples, args=(world_size, config, num_new_img_per_gpu, max_num_img_per_gpu, return_dict, torch.tensor(params)), nprocs=world_size, join=True)
 
 
 
 
 
765
 
766
  # print("---"*30)
767
  # print(f"cuda:{torch.cuda.current_device()}, keys = {return_dict.keys()}")
 
65
  from torch.distributed import init_process_group, destroy_process_group
66
  import torch.distributed as dist
67
 
68
+ import argparse
69
+
70
  # %%
71
+ def ddp_setup(rank: int, world_size: int, master_addr, master_port):
72
  """
73
  Args:
74
  rank: Unique identifier of each process
75
  world_size: Total number of processes
76
  """
77
+ os.environ["MASTER_ADDR"] = master_addr
78
+ os.environ["MASTER_PORT"] = master_port
79
  # print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!ddp_setup, rank =", rank)
 
80
  init_process_group(backend="nccl", rank=rank, world_size=world_size)
81
 
82
  # %%
 
241
  # dim = 2
242
  dim = 2
243
  stride = (2,4) if dim == 2 else (2,2,2)
244
+ num_image = 60#6000#1000#2000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560
245
  batch_size = 10#50#20#50#1#2#50#20#2#100 # 10
246
  n_epoch = 50#100#50#100#30#120#5#4# 10#50#20#20#2#5#25 # 120
247
  HII_DIM = 64
 
643
  return x_last
644
  # %%
645
 
646
+ #num_train_image_list = [6000]#[60]#[8000]#[1000]#[100]#
647
+ def train(rank, world_size, local_world_size, master_addr, master_port):
648
+ ddp_setup(rank, world_size, master_addr, master_port)
649
+
650
+ local_rank = rank % local_world_size
651
+ torch.cuda.set_device(local_rank)
652
 
653
+ print(f"Global rank {rank}, local rank {local_rank}, current_device {torch.cuda.current_device()}")
654
 
 
 
 
 
655
  config = TrainConfig()
656
+ config.device = f"cuda:{local_rank}"
 
657
  config.world_size = world_size
658
 
659
  #[3200]#[200]#[1600,3200,6400,12800,25600]
660
+ #for i, num_image in enumerate(num_train_image_list):
661
+ #config.num_image = num_image
662
  # config.world_size = world_size
663
  # print("ddpm21cm = DDPM21CM(config)")
664
  # print(f"config.device, torch.cuda.current_device() = {config.device}, {torch.cuda.current_device()}")
665
+ ddpm21cm = DDPM21CM(config)
666
+ # print(f" num_image = {ddpm21cm.config.num_image} ".center(50, '-'))
667
+ print(f"run_name = {ddpm21cm.config.run_name}")
668
+ ddpm21cm.train()
669
+ destroy_process_group()
 
 
 
 
 
 
 
 
 
 
 
670
  # %%
671
 
672
  # def generate_samples(ddpm21cm, num_new_img_per_gpu, max_num_img_per_gpu, rank, world_size, params):
 
695
  # # else:
696
  # # return None
697
 
698
+ def generate_samples(rank, world_size, local_world_size, master_addr, master_port, config, num_new_img_per_gpu, max_num_img_per_gpu, params):
699
+ ddp_setup(rank, world_size, master_addr, master_port)
700
  ddpm21cm = DDPM21CM(config)
701
 
702
  # generate_samples(ddpm21cm, num_new_img_per_gpu, max_num_img_per_gpu, rank, world_size, params)
 
719
 
720
 
721
  if __name__ == "__main__":
722
+ parser = argparse.ArgumentParser()
723
+ parser.add_argument("--train", type=int, required=False, help="whether to train the model", default=1)
724
+ parser.add_argument("--sample", type=int, required=False, help="whether to sample", default=0)
725
+ args = parser.parse_args()
 
 
726
 
727
+ master_addr = os.environ["SLURM_NODELIST"].split(",")[0]
728
+ master_port = "12355"
729
+ world_size = int(os.environ["SLURM_NTASKS"])
730
+ local_world_size = torch.cuda.device_count()
731
 
732
+ ############################ training ################################
733
+ world_size = torch.cuda.device_count()
734
+ if args.train:
735
+ print(f" training, world_size = {world_size} ".center(120,'-'))
736
+ mp.spawn(
737
+ train,
738
+ args=(world_size, local_world_size, master_addr, master_port),
739
+ nprocs=local_world_size,
740
+ join=True
741
+ )
742
 
743
+
744
+ ############################ sampling ################################
745
+ if args.sample:
746
+ num_new_img_per_gpu = 200
747
+ max_num_img_per_gpu = 20
748
+ config = TrainConfig()
749
+ config.world_size = world_size
750
+ # print("config.world_size = world_size")
751
+
752
+ #for num_image in num_train_image_list:
753
+ #config.num_image = num_image# // world_size
754
  config.resume = f"./outputs/model_state-N{config.num_image}-device_count{world_size}-epoch{config.n_epoch-1}"
755
  # config.resume = f"./outputs/model_state-N{config.num_image}-device_count1-epoch{config.n_epoch-1}"
756
 
757
+ # manager = mp.Manager()
758
+ # return_dict = manager.dict()
 
759
 
760
  params_pairs = [
761
  (4.4, 131.341),
 
764
  (5.477, 200),
765
  (4.8, 131.341),
766
  ]
767
+
768
  for params in params_pairs:
769
  print(f" sampling for {params}, world_size = {world_size} ".center(120,'-'))
770
+ mp.spawn(
771
+ generate_samples,
772
+ args=(world_size, local_world_size, master_addr, master_port, config, num_new_img_per_gpu, max_num_img_per_gpu, torch.tensor(params)),
773
+ nprocs=local_world_size,
774
+ join=True
775
+ )
776
 
777
  # print("---"*30)
778
  # print(f"cuda:{torch.cuda.current_device()}, keys = {return_dict.keys()}")