Xsmos commited on
Commit
1d94a19
·
verified ·
1 Parent(s): 615975e
Files changed (2) hide show
  1. diffusion.py +4 -2
  2. quantify_results.ipynb +0 -0
diffusion.py CHANGED
@@ -341,6 +341,8 @@ class DDPM21CM:
341
  # print(f"resumed nn_model from {config.resume}")
342
  self.nn_model.module.load_state_dict(torch.load(config.resume)['unet_state_dict'])
343
  print(f"device {torch.cuda.current_device()} resumed nn_model from {config.resume}")
 
 
344
 
345
  self.number_of_params = sum(x.numel() for x in self.nn_model.parameters())
346
  print(f" Number of parameters for nn_model: {self.number_of_params} ".center(100,'-'))
@@ -614,7 +616,7 @@ def train(rank, world_size):
614
 
615
 
616
  if __name__ == "__main__":
617
- world_size = torch.cuda.device_count()
618
  print(f" training, world_size = {world_size} ".center(100,'-'))
619
  # torch.multiprocessing.set_start_method("spawn")
620
  # args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)
@@ -691,7 +693,7 @@ if __name__ == "__main__":
691
 
692
  for num_image in num_train_image_list:
693
  config.num_image = num_image // world_size
694
- config.resume = f"./outputs/model_state-N8000-device_count4-epoch{config.n_epoch-1}"
695
 
696
  # print("ddpm21cm = DDPM21CM(config)")
697
  manager = mp.Manager()
 
341
  # print(f"resumed nn_model from {config.resume}")
342
  self.nn_model.module.load_state_dict(torch.load(config.resume)['unet_state_dict'])
343
  print(f"device {torch.cuda.current_device()} resumed nn_model from {config.resume}")
344
+ else:
345
+ print(f"device {torch.cuda.current_device()} initialized nn_model randomly")
346
 
347
  self.number_of_params = sum(x.numel() for x in self.nn_model.parameters())
348
  print(f" Number of parameters for nn_model: {self.number_of_params} ".center(100,'-'))
 
616
 
617
 
618
  if __name__ == "__main__":
619
+ world_size = 2#torch.cuda.device_count()
620
  print(f" training, world_size = {world_size} ".center(100,'-'))
621
  # torch.multiprocessing.set_start_method("spawn")
622
  # args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)
 
693
 
694
  for num_image in num_train_image_list:
695
  config.num_image = num_image // world_size
696
+ config.resume = f"./outputs/model_state-N4000-device_count2-epoch{config.n_epoch-1}"
697
 
698
  # print("ddpm21cm = DDPM21CM(config)")
699
  manager = mp.Manager()
quantify_results.ipynb CHANGED
The diff for this file is too large to render. See raw diff