Xsmos commited on
Commit
a3983e0
·
verified ·
1 Parent(s): 5104d9c
Files changed (2) hide show
  1. diffusion.py +6 -6
  2. load_h5.py +2 -2
diffusion.py CHANGED
@@ -81,10 +81,10 @@ def ddp_setup(rank: int, world_size: int, local_world_size, master_addr, master_
81
  world_size: Total number of processes
82
  """
83
 
84
- print("inside ddp_setup")
85
  os.environ["MASTER_ADDR"] = master_addr
86
  os.environ["MASTER_PORT"] = master_port
87
- print("ddp_setup, rank =", rank)
88
  init_process_group(
89
  backend="nccl",
90
  init_method=f"tcp://{master_addr}:{master_port}",
@@ -574,7 +574,7 @@ class DDPM21CM:
574
  'unet_state_dict': self.nn_model.module.state_dict(),
575
  # 'ema_unet_state_dict': self.ema_model.state_dict(),
576
  }
577
- save_name = self.config.save_name+f"-N{self.config.num_image}-device_count{self.config.world_size}-epoch{ep}"
578
  torch.save(model_state, save_name)
579
  print(f'cuda:{torch.cuda.current_device()} saved model at ' + save_name)
580
  # print('saved model at ' + config.save_dir + f"model_epoch_{ep}_test_{config.run_name}.pth")
@@ -658,12 +658,12 @@ class DDPM21CM:
658
 
659
  #num_train_image_list = [6000]#[60]#[8000]#[1000]#[100]#
660
  def train(rank, world_size, local_world_size, master_addr, master_port):
661
- print("before ddp_setup")
662
  ddp_setup(rank, world_size, local_world_size, master_addr, master_port)
663
- print("after ddp_setup")
664
  local_rank = rank % local_world_size
665
  torch.cuda.set_device(local_rank)
666
- print("after set device")
667
  print(f"rank = {rank}, local_rank = {local_rank}, world_size = {world_size}, local_world_size = {local_world_size}")
668
 
669
  config = TrainConfig()
 
81
  world_size: Total number of processes
82
  """
83
 
84
+ #print("inside ddp_setup")
85
  os.environ["MASTER_ADDR"] = master_addr
86
  os.environ["MASTER_PORT"] = master_port
87
+ #print("ddp_setup, rank =", rank)
88
  init_process_group(
89
  backend="nccl",
90
  init_method=f"tcp://{master_addr}:{master_port}",
 
574
  'unet_state_dict': self.nn_model.module.state_dict(),
575
  # 'ema_unet_state_dict': self.ema_model.state_dict(),
576
  }
577
+ save_name = self.config.save_name+f"-N{self.config.num_image}-device_count{self.config.world_size}-epoch{ep}-{socket.gethostbyname(socket.gethostname())}"
578
  torch.save(model_state, save_name)
579
  print(f'cuda:{torch.cuda.current_device()} saved model at ' + save_name)
580
  # print('saved model at ' + config.save_dir + f"model_epoch_{ep}_test_{config.run_name}.pth")
 
658
 
659
  #num_train_image_list = [6000]#[60]#[8000]#[1000]#[100]#
660
  def train(rank, world_size, local_world_size, master_addr, master_port):
661
+ #print("before ddp_setup")
662
  ddp_setup(rank, world_size, local_world_size, master_addr, master_port)
663
+ #print("after ddp_setup")
664
  local_rank = rank % local_world_size
665
  torch.cuda.set_device(local_rank)
666
+ #print("after set device")
667
  print(f"rank = {rank}, local_rank = {local_rank}, world_size = {world_size}, local_world_size = {local_world_size}")
668
 
669
  config = TrainConfig()
load_h5.py CHANGED
@@ -96,10 +96,10 @@ class Dataset4h5(Dataset):
96
  print(f"dataset content: {f.keys()}")
97
  max_num_image = len(f['brightness_temp'])#.shape[0]
98
  field_shape = f['brightness_temp'].shape[1:]
99
- print(f"{max_num_image} images of shape {field_shape} can be loaded")
100
  #print(f"field.shape = {field_shape}")
101
  self.params_keys = list(f['params']['keys'])
102
- print(f"params keys = {self.params_keys}")
 
103
 
104
  # if self.idx is None:
105
  # if self.shuffle:
 
96
  print(f"dataset content: {f.keys()}")
97
  max_num_image = len(f['brightness_temp'])#.shape[0]
98
  field_shape = f['brightness_temp'].shape[1:]
 
99
  #print(f"field.shape = {field_shape}")
100
  self.params_keys = list(f['params']['keys'])
101
+ print(f"{max_num_image} images of shape {field_shape} can be loaded with different params.keys {self.params_keys}")
102
+ #print(f"params keys = {self.params_keys}")
103
 
104
  # if self.idx is None:
105
  # if self.shuffle: