Xsmos commited on
Commit
d7890be
·
verified ·
1 Parent(s): 2688c84
Files changed (3) hide show
  1. diffusion.py +5 -5
  2. load_h5.py +1 -1
  3. phoenix_diffusion.sbatch +1 -0
diffusion.py CHANGED
@@ -651,12 +651,12 @@ class DDPM21CM:
651
 
652
  #num_train_image_list = [6000]#[60]#[8000]#[1000]#[100]#
653
  def train(rank, world_size, local_world_size, master_addr, master_port):
654
- print("before ddp_setup")
655
  ddp_setup(rank, world_size, master_addr, master_port)
656
- print("after ddp_setup")
657
  local_rank = rank % local_world_size
658
  torch.cuda.set_device(local_rank)
659
- print("after set device")
660
 
661
  config = TrainConfig()
662
  config.device = f"cuda:{local_rank}"
@@ -668,11 +668,11 @@ def train(rank, world_size, local_world_size, master_addr, master_port):
668
  # config.world_size = world_size
669
  # print("ddpm21cm = DDPM21CM(config)")
670
  # print(f"config.device, torch.cuda.current_device() = {config.device}, {torch.cuda.current_device()}")
671
- print("before dppm21cm")
672
  ddpm21cm = DDPM21CM(config)
673
  # print(f" num_image = {ddpm21cm.config.num_image} ".center(50, '-'))
674
  # print(f"run_name = {ddpm21cm.config.run_name}")
675
- print(f"run_name {ddpm21cm.config.run_name}, global_rank {rank}, local_rank {local_rank}, current_device {torch.cuda.current_device()}, local_world_size {local_world_size}, world_size {world_size}")
676
  ddpm21cm.train()
677
  destroy_process_group()
678
  # %%
 
651
 
652
  #num_train_image_list = [6000]#[60]#[8000]#[1000]#[100]#
653
  def train(rank, world_size, local_world_size, master_addr, master_port):
654
+ #print("before ddp_setup")
655
  ddp_setup(rank, world_size, master_addr, master_port)
656
+ #print("after ddp_setup")
657
  local_rank = rank % local_world_size
658
  torch.cuda.set_device(local_rank)
659
+ #print("after set device")
660
 
661
  config = TrainConfig()
662
  config.device = f"cuda:{local_rank}"
 
668
  # config.world_size = world_size
669
  # print("ddpm21cm = DDPM21CM(config)")
670
  # print(f"config.device, torch.cuda.current_device() = {config.device}, {torch.cuda.current_device()}")
671
+ #print("before dppm21cm")
672
  ddpm21cm = DDPM21CM(config)
673
  # print(f" num_image = {ddpm21cm.config.num_image} ".center(50, '-'))
674
  # print(f"run_name = {ddpm21cm.config.run_name}")
675
+ #print(f"run_name {ddpm21cm.config.run_name}, global_rank {rank}, local_rank {local_rank}, current_device {torch.cuda.current_device()}, local_world_size {local_world_size}, world_size {world_size}")
676
  ddpm21cm.train()
677
  destroy_process_group()
678
  # %%
load_h5.py CHANGED
@@ -168,7 +168,7 @@ class Dataset4h5(Dataset):
168
  param_start = time()
169
  params = f['params']['values'][idx]
170
  param_end = time()
171
- print(f"ip_addr {socket.gethostbyname(socket.gethostname())}, cuda:{torch.cuda.current_device()}, CPU-pid {cpu_num}-{pid}: images {images.shape} & params {params.shape} loaded after {images_end-images_start:.3f}s & {param_end-param_start:.3f}s")
172
 
173
  return images, params
174
 
 
168
  param_start = time()
169
  params = f['params']['values'][idx]
170
  param_end = time()
171
+ print(f"{socket.gethostbyname(socket.gethostname())}, cuda:{torch.cuda.current_device()}, CPU-pid {cpu_num}-{pid}: images {images.shape} & params {params.shape} loaded after {images_end-images_start:.3f}s & {param_end-param_start:.3f}s")
172
 
173
  return images, params
174
 
phoenix_diffusion.sbatch CHANGED
@@ -3,6 +3,7 @@
3
  #SBATCH -A gts-jw254-coda20
4
  #SBATCH -qembers
5
  #SBATCH -N2 --gpus-per-node=RTX_6000:3 # -C A100-80GB # Number of nodes and cores per node required
 
6
  #SBATCH --mem-per-gpu=32G # Memory per core
7
  #SBATCH -t 10:00 # Duration of the job (Ex: 15 mins)
8
  #SBATCH -oReport-%j # Combined output and error messages file
 
3
  #SBATCH -A gts-jw254-coda20
4
  #SBATCH -qembers
5
  #SBATCH -N2 --gpus-per-node=RTX_6000:3 # -C A100-80GB # Number of nodes and cores per node required
6
+ #SBATCH --ntasks-per-gpu=1
7
  #SBATCH --mem-per-gpu=32G # Memory per core
8
  #SBATCH -t 10:00 # Duration of the job (Ex: 15 mins)
9
  #SBATCH -oReport-%j # Combined output and error messages file