Xsmos commited on
Commit
39cf7eb
·
verified ·
1 Parent(s): 5cc510f
Files changed (1) hide show
  1. diffusion.py +10 -3
diffusion.py CHANGED
@@ -240,7 +240,7 @@ class TrainConfig:
240
  batch_size = 2#2#50#20#2#100 # 10
241
  n_epoch = 2#10#50#20#20#2#5#25 # 120
242
  HII_DIM = 28#64
243
- num_redshift = 4#128#64#512#256#256#64#512#128
244
  channel = 1
245
  img_shape = (channel, HII_DIM, num_redshift) if dim == 2 else (channel, HII_DIM, HII_DIM, num_redshift)
246
 
@@ -359,7 +359,14 @@ class DDPM21CM:
359
  dataset = Dataset4h5(self.config.dataset_name, num_image=self.config.num_image, HII_DIM=self.config.HII_DIM, num_redshift=self.config.num_redshift, drop_prob=self.config.drop_prob, dim=self.config.dim, ranges_dict=self.ranges_dict)
360
  # self.shape_loaded = dataset.images.shape
361
  # print("shape_loaded =", self.shape_loaded)
362
- self.dataloader = DataLoader(dataset, batch_size=self.config.batch_size, shuffle=True, num_workers=len(os.sched_getaffinity(0)), pin_memory=True)
 
 
 
 
 
 
 
363
  # del dataset
364
  # self.accelerate(self.config)
365
  del dataset
@@ -547,7 +554,7 @@ def single_main(rank, world_size):
547
  if __name__ == "__main__":
548
  # torch.multiprocessing.set_start_method("spawn")
549
  # args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)
550
- world_size = 1#torch.cuda.device_count()
551
 
552
  mp.spawn(single_main, args=(world_size,), nprocs=world_size)
553
  # notebook_launcher(ddpm21cm.train, num_processes=1, mixed_precision='fp16')
 
240
  batch_size = 2#2#50#20#2#100 # 10
241
  n_epoch = 2#10#50#20#20#2#5#25 # 120
242
  HII_DIM = 28#64
243
+ num_redshift = 2#128#64#512#256#256#64#512#128
244
  channel = 1
245
  img_shape = (channel, HII_DIM, num_redshift) if dim == 2 else (channel, HII_DIM, HII_DIM, num_redshift)
246
 
 
359
  dataset = Dataset4h5(self.config.dataset_name, num_image=self.config.num_image, HII_DIM=self.config.HII_DIM, num_redshift=self.config.num_redshift, drop_prob=self.config.drop_prob, dim=self.config.dim, ranges_dict=self.ranges_dict)
360
  # self.shape_loaded = dataset.images.shape
361
  # print("shape_loaded =", self.shape_loaded)
362
+ self.dataloader = DataLoader(
363
+ dataset,
364
+ batch_size=self.config.batch_size,
365
+ shuffle=True,
366
+ num_workers=1,#len(os.sched_getaffinity(0)),
367
+ pin_memory=True,
368
+ persistent_workers=True,
369
+ )
370
  # del dataset
371
  # self.accelerate(self.config)
372
  del dataset
 
554
  if __name__ == "__main__":
555
  # torch.multiprocessing.set_start_method("spawn")
556
  # args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)
557
+ world_size = 2#torch.cuda.device_count()
558
 
559
  mp.spawn(single_main, args=(world_size,), nprocs=world_size)
560
  # notebook_launcher(ddpm21cm.train, num_processes=1, mixed_precision='fp16')