Xsmos commited on
Commit
6e6d516
·
verified ·
1 Parent(s): 6227394
Files changed (1) hide show
  1. diffusion.py +7 -5
diffusion.py CHANGED
@@ -361,7 +361,7 @@ class DDPM21CM:
361
  dataset = Dataset4h5(self.config.dataset_name, num_image=self.config.num_image, HII_DIM=self.config.HII_DIM, num_redshift=self.config.num_redshift, drop_prob=self.config.drop_prob, dim=self.config.dim, ranges_dict=self.ranges_dict)
362
  # self.shape_loaded = dataset.images.shape
363
  # print("shape_loaded =", self.shape_loaded)
364
- # print(f"load, current_device = {torch.cuda.current_device()}")
365
  self.dataloader = DataLoader(
366
  dataset=dataset,
367
  batch_size=self.config.batch_size,
@@ -394,7 +394,8 @@ class DDPM21CM:
394
  # distributed_type="MULTI_GPU",
395
  )
396
  # print("!!!!!!!!!!!!!!!!!!!self.accelerator.device:", self.accelerator.device)
397
- if self.accelerator.is_main_process:
 
398
  if self.config.output_dir is not None:
399
  os.makedirs(self.config.output_dir, exist_ok=True)
400
  if self.config.push_to_hub:
@@ -414,7 +415,7 @@ class DDPM21CM:
414
  self.nn_model, self.optimizer, self.lr_scheduler
415
  )
416
 
417
- print("!!!!!!!!!!!!!!!!, after prepare, self.dataloader.sampler =", self.dataloader.sampler)
418
  # print("!!!!!!!!!!!!!!!!, after prepare, self.dataloader.batch_sampler =", self.dataloader.batch_sampler)
419
  # print("!!!!!!!!!!!!!!!!, after prepare, self.dataloader.DistributedSampler =", self.dataloader.DistributedSampler)
420
 
@@ -468,7 +469,8 @@ class DDPM21CM:
468
 
469
  def save(self, ep):
470
  # save model
471
- if self.accelerator.is_main_process:
 
472
  if ep == self.config.n_epoch-1 or (ep+1)*self.config.save_freq==1:
473
  self.nn_model.eval()
474
  with torch.no_grad():
@@ -486,7 +488,7 @@ class DDPM21CM:
486
  # 'ema_unet_state_dict': self.ema_model.state_dict(),
487
  }
488
  torch.save(model_state, self.config.save_name+f"-N{self.config.num_image}")
489
- print('saved model at ' + self.config.save_name+f"-N{self.config.num_image}")
490
  # print('saved model at ' + config.save_dir + f"model_epoch_{ep}_test_{config.run_name}.pth")
491
 
492
  # def rescale(self, value, type='params', to_ranges=[0,1]):
 
361
  dataset = Dataset4h5(self.config.dataset_name, num_image=self.config.num_image, HII_DIM=self.config.HII_DIM, num_redshift=self.config.num_redshift, drop_prob=self.config.drop_prob, dim=self.config.dim, ranges_dict=self.ranges_dict)
362
  # self.shape_loaded = dataset.images.shape
363
  # print("shape_loaded =", self.shape_loaded)
364
+ # print(f"load, current_device() = {torch.cuda.current_device()}")
365
  self.dataloader = DataLoader(
366
  dataset=dataset,
367
  batch_size=self.config.batch_size,
 
394
  # distributed_type="MULTI_GPU",
395
  )
396
  # print("!!!!!!!!!!!!!!!!!!!self.accelerator.device:", self.accelerator.device)
397
+ # if self.accelerator.is_main_process:
398
+ if torch.cuda.current_device() == 0:
399
  if self.config.output_dir is not None:
400
  os.makedirs(self.config.output_dir, exist_ok=True)
401
  if self.config.push_to_hub:
 
415
  self.nn_model, self.optimizer, self.lr_scheduler
416
  )
417
 
418
+ # print("!!!!!!!!!!!!!!!!, after prepare, self.dataloader.sampler =", self.dataloader.sampler)
419
  # print("!!!!!!!!!!!!!!!!, after prepare, self.dataloader.batch_sampler =", self.dataloader.batch_sampler)
420
  # print("!!!!!!!!!!!!!!!!, after prepare, self.dataloader.DistributedSampler =", self.dataloader.DistributedSampler)
421
 
 
469
 
470
  def save(self, ep):
471
  # save model
472
+ # if self.accelerator.is_main_process:
473
+ if torch.cuda.current_device() == 0:
474
  if ep == self.config.n_epoch-1 or (ep+1)*self.config.save_freq==1:
475
  self.nn_model.eval()
476
  with torch.no_grad():
 
488
  # 'ema_unet_state_dict': self.ema_model.state_dict(),
489
  }
490
  torch.save(model_state, self.config.save_name+f"-N{self.config.num_image}")
491
+ print(f'device {torch.cuda.current_device()} saved model at ' + self.config.save_name+f"-N{self.config.num_image}")
492
  # print('saved model at ' + config.save_dir + f"model_epoch_{ep}_test_{config.run_name}.pth")
493
 
494
  # def rescale(self, value, type='params', to_ranges=[0,1]):