Xsmos commited on
Commit
5ed3bf5
·
verified ·
1 Parent(s): 0190f2d
Files changed (4) hide show
  1. context_unet.py +1 -1
  2. diffusion.py +48 -35
  3. load_h5.py +16 -15
  4. quantify_results.ipynb +0 -0
context_unet.py CHANGED
@@ -330,7 +330,7 @@ class ContextUnet(nn.Module):
330
  elif image_size == 128:
331
  channel_mult = (1, 1, 2, 3, 4)
332
  elif image_size == 64:
333
- channel_mult = (2,4,4,4,8)#(1, 2, 2, 4, 4)#(1, 2, 2, 4, 8)#(1, 1, 2, 2, 4, 4)#(1, 2, 4, 8, 16)#(1, 2, 3, 4)#(1, 2, 4, 6, 8)#(1, 2, 2, 4)#(1, 2, 8, 8, 8)#(1, 2, 4)#(1, 2, 2, 4)#(0.5,1,2,2,4,4)#(1, 1, 2, 2, 4, 4)#
334
  elif image_size == 32:
335
  channel_mult = (1, 2, 2, 4)
336
  elif image_size == 28:
 
330
  elif image_size == 128:
331
  channel_mult = (1, 1, 2, 3, 4)
332
  elif image_size == 64:
333
+ channel_mult = (1,2,2,4,4)#(1, 2, 4)#(2,4,4,4,8)#(1, 2, 2, 4, 4)#(1, 2, 2, 4, 8)#(1, 1, 2, 2, 4, 4)#(1, 2, 4, 8, 16)#(1, 2, 3, 4)#(1, 2, 4, 6, 8)#(1, 2, 2, 4)#(1, 2, 8, 8, 8)#(1, 2, 4)#(1, 2, 2, 4)#(0.5,1,2,2,4,4)#(1, 1, 2, 2, 4, 4)#
334
  elif image_size == 32:
335
  channel_mult = (1, 2, 2, 4)
336
  elif image_size == 28:
diffusion.py CHANGED
@@ -154,7 +154,7 @@ class DDPMScheduler(nn.Module):
154
  # for i in range(self.num_timesteps, 0, -1):
155
  # print(f'sampling!!!')
156
  pbar_sample = tqdm(total=self.num_timesteps)
157
- pbar_sample.set_description(f"device {torch.cuda.current_device()} sampling")
158
  for i in reversed(range(0, self.num_timesteps)):
159
  # print(f'sampling timestep {i:4d}',end='\r')
160
  t_is = torch.tensor([i]).to(device)
@@ -232,16 +232,17 @@ class TrainConfig:
232
  hub_model_id = "Xsmos/ml21cm"
233
  hub_private_repo = False
234
  dataset_name = "/storage/home/hcoda1/3/bxia34/scratch/LEN128-DIM64-CUB8.h5"
235
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
236
- world_size = torch.cuda.device_count()
 
237
  # repeat = 2
238
 
239
  # dim = 2
240
  dim = 2
241
  stride = (2,4) if dim == 2 else (2,2,2)
242
  num_image = 1000#2000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560
243
- batch_size = 10#50#20#50#1#2#50#20#2#100 # 10
244
- n_epoch = 50#50#100#30#120#5#4# 10#50#20#20#2#5#25 # 120
245
  HII_DIM = 64
246
  num_redshift = 512#64#512#64#256CUDAoom#128#64#512#128#64#512#256#256#64#512#128
247
  channel = 1
@@ -268,7 +269,7 @@ class TrainConfig:
268
  # seed = 0
269
  # save_dir = './outputs/'
270
 
271
- save_period = 20#n_epoch // 2 #np.infty#.1 # the period of sampling
272
  # general parameters for the name and logger
273
  # device = "cuda" if torch.cuda.is_available() else "cpu"
274
  lrate = 1e-4
@@ -355,17 +356,21 @@ class DDPM21CM:
355
  # # print("shape_loaded =", self.shape_loaded)
356
  # self.dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True)
357
  # del dataset
 
358
  self.ddpm = DDPMScheduler(betas=(1e-4, 0.02), num_timesteps=config.num_timesteps, img_shape=config.img_shape, device=config.device, dtype=config.dtype)
359
 
 
360
  # initialize the unet
361
  self.nn_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM, dim=config.dim, stride=config.stride, dtype=config.dtype)
362
 
 
363
  # nn_model = ContextUnet(n_param=1, image_size=28)
364
  self.nn_model.train()
365
  # print("self.ddpm.device =", self.ddpm.device)
366
  self.nn_model.to(self.ddpm.device)
 
367
  self.nn_model = DDP(self.nn_model, device_ids=[self.ddpm.device])
368
- # print("nn_model.device =", ddpm.device)
369
  # number of parameters to be trained
370
 
371
  if config.resume and os.path.exists(config.resume):
@@ -373,12 +378,12 @@ class DDPM21CM:
373
  # self.nn_model.load_state_dict(torch.load(config.resume)['unet_state_dict'])
374
  # print(f"resumed nn_model from {config.resume}")
375
  self.nn_model.module.load_state_dict(torch.load(config.resume)['unet_state_dict'])
376
- print(f"device {torch.cuda.current_device()} resumed nn_model from {config.resume}")
377
  else:
378
- print(f"device {torch.cuda.current_device()} initialized nn_model randomly")
379
 
380
  self.number_of_params = sum(x.numel() for x in self.nn_model.parameters())
381
- print(f" Number of parameters for nn_model: {self.number_of_params} ".center(100,'-'))
382
 
383
  # whether to use ema
384
  if config.ema:
@@ -405,12 +410,13 @@ class DDPM21CM:
405
  dataset = Dataset4h5(
406
  self.config.dataset_name,
407
  num_image=self.config.num_image,
408
- idx = 'range',
409
  HII_DIM=self.config.HII_DIM,
410
  num_redshift=self.config.num_redshift,
411
  drop_prob=self.config.drop_prob,
412
  dim=self.config.dim,
413
- ranges_dict=self.ranges_dict
 
414
  )
415
  # self.shape_loaded = dataset.images.shape
416
  # print("shape_loaded =", self.shape_loaded)
@@ -419,7 +425,7 @@ class DDPM21CM:
419
  dataset=dataset,
420
  batch_size=self.config.batch_size,
421
  shuffle=True,#False,
422
- num_workers=len(os.sched_getaffinity(0)),
423
  pin_memory=True,
424
  persistent_workers=True,
425
  # sampler=DistributedSampler(dataset),
@@ -478,9 +484,9 @@ class DDPM21CM:
478
  # self.dataloader.sampler.set_epoch(ep)
479
 
480
  pbar_train = tqdm(total=len(self.dataloader), disable=not self.accelerator.is_local_main_process)
481
- pbar_train.set_description(f"device {torch.cuda.current_device()}, Epoch {ep}")
482
  for i, (x, c) in enumerate(self.dataloader):
483
- # print(f"device {torch.cuda.current_device()}, x[:,0,:2,0,0] =", x[:,0,:2,0,0])
484
  with self.accelerator.accumulate(self.nn_model):
485
  x = x.to(self.config.device)
486
  # print("x = x.to(self.config.device), x.dtype =", x.dtype)
@@ -556,7 +562,7 @@ class DDPM21CM:
556
  }
557
  save_name = self.config.save_name+f"-N{self.config.num_image}-device_count{self.config.world_size}-epoch{ep}"
558
  torch.save(model_state, save_name)
559
- print(f'device {torch.cuda.current_device()} saved model at ' + save_name)
560
  # print('saved model at ' + config.save_dir + f"model_epoch_{ep}_test_{config.run_name}.pth")
561
 
562
  # def rescale(self, value, type='params', to_ranges=[0,1]):
@@ -580,7 +586,7 @@ class DDPM21CM:
580
  # n_sample = params.shape[0]
581
  # file = self.config.resume
582
 
583
- # print(f"device {torch.cuda.current_device()}, sample, params = {params}")
584
  if params is None:
585
  params = torch.tensor([4.4, 131.341])
586
  # params_backup = params.numpy().copy()
@@ -588,7 +594,7 @@ class DDPM21CM:
588
  params_backup = params.numpy().copy()
589
  params_normalized = self.rescale(params, self.ranges_dict['params'], to=[0,1])
590
 
591
- print(f"device {torch.cuda.current_device()} sampling {num_new_img_per_gpu} images with normalized params = {params_normalized}")
592
  params_normalized = params_normalized.repeat(num_new_img_per_gpu,1)
593
  assert params_normalized.dim() == 2, "params_normalized must be a 2D torch.tensor"
594
  # print("params =", params)
@@ -603,7 +609,7 @@ class DDPM21CM:
603
  # self.nn_model.module.load_state_dict(torch.load(file)['ema_unet_state_dict'])
604
  # else:
605
  # self.nn_model.module.load_state_dict(torch.load(file)['unet_state_dict'])
606
- # print(f"device {torch.cuda.current_device()} resumed nn_model from {file}")
607
  # nn_model = ContextUnet(n_param=1, image_size=28)
608
  # nn_model.train()
609
  # self.nn_model.to(self.ddpm.device)
@@ -636,28 +642,34 @@ class DDPM21CM:
636
  return x_last
637
  # %%
638
 
639
- num_train_image_list = [6000]#[600]#[8000]#[1000]#[100]#
640
 
641
  def train(rank, world_size):
642
- config = TrainConfig()
643
- config.world_size = world_size
644
 
 
645
  ddp_setup(rank, world_size)
 
 
 
 
 
 
646
 
647
  #[3200]#[200]#[1600,3200,6400,12800,25600]
648
  for i, num_image in enumerate(num_train_image_list):
649
- config.num_image = num_image // world_size
650
  # config.world_size = world_size
651
-
 
652
  ddpm21cm = DDPM21CM(config)
653
  # print(f" num_image = {ddpm21cm.config.num_image} ".center(50, '-'))
654
  print(f"run_name = {ddpm21cm.config.run_name}")
655
  ddpm21cm.train()
656
  destroy_process_group()
657
 
658
- if __name__ == "__main__":
659
  world_size = torch.cuda.device_count()
660
- print(f" training, world_size = {world_size} ".center(100,'-'))
661
  # torch.multiprocessing.set_start_method("spawn")
662
  # args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)
663
 
@@ -675,7 +687,7 @@ if __name__ == "__main__":
675
  # num_new_img_per_gpu=max_num_img_per_gpu
676
  # )
677
 
678
- # print(f"device {torch.cuda.current_device()} generated sample of shape: {sample.shape}")
679
 
680
  # # samples.append(sample)
681
  # # ddpm21cm.sample(params=torch.tensor((5.6, 19.037)), num_new_img_per_gpu=max_num_img_per_gpu)
@@ -706,19 +718,19 @@ def generate_samples(rank, world_size, config, num_new_img_per_gpu, max_num_img_
706
  num_new_img_per_gpu=max_num_img_per_gpu
707
  )
708
 
709
- print(f"device {torch.cuda.current_device()} generated sample of shape: {sample.shape}")
710
 
711
- # print(f"device {torch.cuda.current_device()}, rank = {rank}, keys = {return_dict.keys()}, samples.shape = {np.shape(samples)}")
712
  # if rank == 0:
713
  # return_dict['samples'] = samples
714
- # print(f"device {torch.cuda.current_device()}, rank = {rank}, keys = {return_dict.keys()}")
715
 
716
  dist.destroy_process_group()
717
 
718
 
719
  if __name__ == "__main__":
720
  world_size = torch.cuda.device_count()
721
- # print(f" sampling, world_size = {world_size} ".center(100,'-'))
722
  # num_train_image_list = [1600,3200,6400,12800,25600]
723
  # num_train_image_list = [5000]
724
  num_new_img_per_gpu = 200
@@ -732,8 +744,9 @@ if __name__ == "__main__":
732
  # print("config.world_size = world_size")
733
 
734
  for num_image in num_train_image_list:
735
- config.num_image = num_image // world_size
736
  config.resume = f"./outputs/model_state-N{config.num_image}-device_count{world_size}-epoch{config.n_epoch-1}"
 
737
 
738
  # print("ddpm21cm = DDPM21CM(config)")
739
  manager = mp.Manager()
@@ -747,14 +760,14 @@ if __name__ == "__main__":
747
  (4.8, 131.341),
748
  ]
749
  for params in params_pairs:
750
- print(f" sampling for {params}, world_size = {world_size} ".center(100,'-'))
751
  mp.spawn(generate_samples, args=(world_size, config, num_new_img_per_gpu, max_num_img_per_gpu, return_dict, torch.tensor(params)), nprocs=world_size, join=True)
752
 
753
  # print("---"*30)
754
- # print(f"device {torch.cuda.current_device()}, keys = {return_dict.keys()}")
755
  # if "samples" in return_dict:
756
  # samples = return_dict["samples"]
757
- # print(f"device {torch.cuda.current_device()} generated samples shape: {samples.shape}")
758
 
759
 
760
  # %%
 
154
  # for i in range(self.num_timesteps, 0, -1):
155
  # print(f'sampling!!!')
156
  pbar_sample = tqdm(total=self.num_timesteps)
157
+ pbar_sample.set_description(f"cuda:{torch.cuda.current_device()} sampling")
158
  for i in reversed(range(0, self.num_timesteps)):
159
  # print(f'sampling timestep {i:4d}',end='\r')
160
  t_is = torch.tensor([i]).to(device)
 
232
  hub_model_id = "Xsmos/ml21cm"
233
  hub_private_repo = False
234
  dataset_name = "/storage/home/hcoda1/3/bxia34/scratch/LEN128-DIM64-CUB8.h5"
235
+ device = "cuda" if torch.cuda.is_available() else 'cpu'
236
+ # device = f"cuda:{torch.cuda.current_device()}" if torch.cuda.is_available() else 'cpu'
237
+ world_size = 1#torch.cuda.device_count()
238
  # repeat = 2
239
 
240
  # dim = 2
241
  dim = 2
242
  stride = (2,4) if dim == 2 else (2,2,2)
243
  num_image = 1000#2000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560
244
+ batch_size = 20#50#20#50#1#2#50#20#2#100 # 10
245
+ n_epoch = 100#50#100#30#120#5#4# 10#50#20#20#2#5#25 # 120
246
  HII_DIM = 64
247
  num_redshift = 512#64#512#64#256CUDAoom#128#64#512#128#64#512#256#256#64#512#128
248
  channel = 1
 
269
  # seed = 0
270
  # save_dir = './outputs/'
271
 
272
+ save_period = n_epoch // 3 #np.infty#.1 # the period of sampling
273
  # general parameters for the name and logger
274
  # device = "cuda" if torch.cuda.is_available() else "cpu"
275
  lrate = 1e-4
 
356
  # # print("shape_loaded =", self.shape_loaded)
357
  # self.dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True)
358
  # del dataset
359
+ # print("self.ddpm = DDPMScheduler")
360
  self.ddpm = DDPMScheduler(betas=(1e-4, 0.02), num_timesteps=config.num_timesteps, img_shape=config.img_shape, device=config.device, dtype=config.dtype)
361
 
362
+ # print("self.nn_model = ContextUnet")
363
  # initialize the unet
364
  self.nn_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM, dim=config.dim, stride=config.stride, dtype=config.dtype)
365
 
366
+ # print("self.nn_model.train()")
367
  # nn_model = ContextUnet(n_param=1, image_size=28)
368
  self.nn_model.train()
369
  # print("self.ddpm.device =", self.ddpm.device)
370
  self.nn_model.to(self.ddpm.device)
371
+ # print("before, nn_model.device =", self.ddpm.device)
372
  self.nn_model = DDP(self.nn_model, device_ids=[self.ddpm.device])
373
+ # print("after, nn_model.device =", self.ddpm.device)
374
  # number of parameters to be trained
375
 
376
  if config.resume and os.path.exists(config.resume):
 
378
  # self.nn_model.load_state_dict(torch.load(config.resume)['unet_state_dict'])
379
  # print(f"resumed nn_model from {config.resume}")
380
  self.nn_model.module.load_state_dict(torch.load(config.resume)['unet_state_dict'])
381
+ print(f"cuda:{torch.cuda.current_device()} resumed nn_model from {config.resume}")
382
  else:
383
+ print(f"cuda:{torch.cuda.current_device()} initialized nn_model randomly")
384
 
385
  self.number_of_params = sum(x.numel() for x in self.nn_model.parameters())
386
+ print(f" Number of parameters for nn_model: {self.number_of_params} ".center(120,'-'))
387
 
388
  # whether to use ema
389
  if config.ema:
 
410
  dataset = Dataset4h5(
411
  self.config.dataset_name,
412
  num_image=self.config.num_image,
413
+ idx = "random",#'range',
414
  HII_DIM=self.config.HII_DIM,
415
  num_redshift=self.config.num_redshift,
416
  drop_prob=self.config.drop_prob,
417
  dim=self.config.dim,
418
+ ranges_dict=self.ranges_dict,
419
+ num_workers=len(os.sched_getaffinity(0))//self.config.world_size,
420
  )
421
  # self.shape_loaded = dataset.images.shape
422
  # print("shape_loaded =", self.shape_loaded)
 
425
  dataset=dataset,
426
  batch_size=self.config.batch_size,
427
  shuffle=True,#False,
428
+ num_workers=len(os.sched_getaffinity(0))//self.config.world_size,
429
  pin_memory=True,
430
  persistent_workers=True,
431
  # sampler=DistributedSampler(dataset),
 
484
  # self.dataloader.sampler.set_epoch(ep)
485
 
486
  pbar_train = tqdm(total=len(self.dataloader), disable=not self.accelerator.is_local_main_process)
487
+ pbar_train.set_description(f"cuda:{torch.cuda.current_device()}, Epoch {ep}")
488
  for i, (x, c) in enumerate(self.dataloader):
489
+ # print(f"cuda:{torch.cuda.current_device()}, x[:,0,:2,0,0] =", x[:,0,:2,0,0])
490
  with self.accelerator.accumulate(self.nn_model):
491
  x = x.to(self.config.device)
492
  # print("x = x.to(self.config.device), x.dtype =", x.dtype)
 
562
  }
563
  save_name = self.config.save_name+f"-N{self.config.num_image}-device_count{self.config.world_size}-epoch{ep}"
564
  torch.save(model_state, save_name)
565
+ print(f'cuda:{torch.cuda.current_device()} saved model at ' + save_name)
566
  # print('saved model at ' + config.save_dir + f"model_epoch_{ep}_test_{config.run_name}.pth")
567
 
568
  # def rescale(self, value, type='params', to_ranges=[0,1]):
 
586
  # n_sample = params.shape[0]
587
  # file = self.config.resume
588
 
589
+ # print(f"cuda:{torch.cuda.current_device()}, sample, params = {params}")
590
  if params is None:
591
  params = torch.tensor([4.4, 131.341])
592
  # params_backup = params.numpy().copy()
 
594
  params_backup = params.numpy().copy()
595
  params_normalized = self.rescale(params, self.ranges_dict['params'], to=[0,1])
596
 
597
+ print(f"cuda:{torch.cuda.current_device()} sampling {num_new_img_per_gpu} images with normalized params = {params_normalized}")
598
  params_normalized = params_normalized.repeat(num_new_img_per_gpu,1)
599
  assert params_normalized.dim() == 2, "params_normalized must be a 2D torch.tensor"
600
  # print("params =", params)
 
609
  # self.nn_model.module.load_state_dict(torch.load(file)['ema_unet_state_dict'])
610
  # else:
611
  # self.nn_model.module.load_state_dict(torch.load(file)['unet_state_dict'])
612
+ # print(f"cuda:{torch.cuda.current_device()} resumed nn_model from {file}")
613
  # nn_model = ContextUnet(n_param=1, image_size=28)
614
  # nn_model.train()
615
  # self.nn_model.to(self.ddpm.device)
 
642
  return x_last
643
  # %%
644
 
645
+ num_train_image_list = [6000]#[60]#[8000]#[1000]#[100]#
646
 
647
  def train(rank, world_size):
 
 
648
 
649
+ # print("before ddp_setup")
650
  ddp_setup(rank, world_size)
651
+ # print("after ddp_setup")
652
+ # print("TrainConfig()")
653
+ config = TrainConfig()
654
+ config.device = f"cuda:{rank}"
655
+ # print("torch.cuda.current_device(), config.device =", torch.cuda.current_device(), config.device)
656
+ config.world_size = world_size
657
 
658
  #[3200]#[200]#[1600,3200,6400,12800,25600]
659
  for i, num_image in enumerate(num_train_image_list):
660
+ config.num_image = num_image
661
  # config.world_size = world_size
662
+ # print("ddpm21cm = DDPM21CM(config)")
663
+ # print(f"config.device, torch.cuda.current_device() = {config.device}, {torch.cuda.current_device()}")
664
  ddpm21cm = DDPM21CM(config)
665
  # print(f" num_image = {ddpm21cm.config.num_image} ".center(50, '-'))
666
  print(f"run_name = {ddpm21cm.config.run_name}")
667
  ddpm21cm.train()
668
  destroy_process_group()
669
 
670
+ if __name__ == "__main__":# and False:
671
  world_size = torch.cuda.device_count()
672
+ print(f" training, world_size = {world_size} ".center(120,'-'))
673
  # torch.multiprocessing.set_start_method("spawn")
674
  # args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)
675
 
 
687
  # num_new_img_per_gpu=max_num_img_per_gpu
688
  # )
689
 
690
+ # print(f"cuda:{torch.cuda.current_device()} generated sample of shape: {sample.shape}")
691
 
692
  # # samples.append(sample)
693
  # # ddpm21cm.sample(params=torch.tensor((5.6, 19.037)), num_new_img_per_gpu=max_num_img_per_gpu)
 
718
  num_new_img_per_gpu=max_num_img_per_gpu
719
  )
720
 
721
+ print(f"cuda:{torch.cuda.current_device()} generated sample of shape: {sample.shape}")
722
 
723
+ # print(f"cuda:{torch.cuda.current_device()}, rank = {rank}, keys = {return_dict.keys()}, samples.shape = {np.shape(samples)}")
724
  # if rank == 0:
725
  # return_dict['samples'] = samples
726
+ # print(f"cuda:{torch.cuda.current_device()}, rank = {rank}, keys = {return_dict.keys()}")
727
 
728
  dist.destroy_process_group()
729
 
730
 
731
  if __name__ == "__main__":
732
  world_size = torch.cuda.device_count()
733
+ # print(f" sampling, world_size = {world_size} ".center(120,'-'))
734
  # num_train_image_list = [1600,3200,6400,12800,25600]
735
  # num_train_image_list = [5000]
736
  num_new_img_per_gpu = 200
 
744
  # print("config.world_size = world_size")
745
 
746
  for num_image in num_train_image_list:
747
+ config.num_image = num_image# // world_size
748
  config.resume = f"./outputs/model_state-N{config.num_image}-device_count{world_size}-epoch{config.n_epoch-1}"
749
+ # config.resume = f"./outputs/model_state-N{config.num_image}-device_count1-epoch{config.n_epoch-1}"
750
 
751
  # print("ddpm21cm = DDPM21CM(config)")
752
  manager = mp.Manager()
 
760
  (4.8, 131.341),
761
  ]
762
  for params in params_pairs:
763
+ print(f" sampling for {params}, world_size = {world_size} ".center(120,'-'))
764
  mp.spawn(generate_samples, args=(world_size, config, num_new_img_per_gpu, max_num_img_per_gpu, return_dict, torch.tensor(params)), nprocs=world_size, join=True)
765
 
766
  # print("---"*30)
767
+ # print(f"cuda:{torch.cuda.current_device()}, keys = {return_dict.keys()}")
768
  # if "samples" in return_dict:
769
  # samples = return_dict["samples"]
770
+ # print(f"cuda:{torch.cuda.current_device()} generated samples shape: {samples.shape}")
771
 
772
 
773
  # %%
load_h5.py CHANGED
@@ -42,6 +42,7 @@ class Dataset4h5(Dataset):
42
  dim=2,
43
  transform=True,
44
  ranges_dict=None,
 
45
  # shuffle=False,
46
  ):
47
  super().__init__()
@@ -56,6 +57,7 @@ class Dataset4h5(Dataset):
56
  self.drop_prob = drop_prob
57
  self.dim = dim
58
  self.transform = transform
 
59
 
60
  # if ranges_dict == None:
61
  # ranges_dict = dict(
@@ -74,9 +76,8 @@ class Dataset4h5(Dataset):
74
  self.images = self.rescale(self.images, ranges=ranges_dict['images'], to=[-1,1])
75
  self.params = self.rescale(self.params, ranges=ranges_dict['params'], to=[0,1])
76
  rescale_end = time()
77
- print(f"rescaling costs {rescale_end-rescale_start:.3f} s")
78
- print(f"images rescaled to [{self.images.min()}, {self.images.max()}]")
79
- print(f"params rescaled to [{self.params.min()}, {self.params.max()}]")
80
 
81
  # from_numpy_start = time()
82
  self.len = len(self.params)
@@ -109,7 +110,7 @@ class Dataset4h5(Dataset):
109
  # print(f"loading {len(self.idx)} images with idx = {self.idx}")
110
  if self.idx == "random":
111
  self.idx = np.sort(random.sample(range(max_num_image), self.num_image))
112
- print(f"loading {self.num_image} images randomly")
113
  # print(self.idx)
114
  elif self.idx == "range":
115
  rank = torch.cuda.current_device()
@@ -123,12 +124,12 @@ class Dataset4h5(Dataset):
123
  concurrent_start = time()
124
  self.images = []
125
  self.params = []
126
- max_workers = len(os.sched_getaffinity(0))
127
- with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor:
128
- print(f"concurrently loading by {max_workers} max_workers...")
129
  futures = []
130
- for idx in np.array_split(self.idx, max_workers):
131
- futures.append(executor.submit(self.read_data_chunk, self.dir_name, idx))
132
  for future in concurrent.futures.as_completed(futures):
133
  images, params = future.result()
134
  self.images.append(images)
@@ -136,7 +137,7 @@ class Dataset4h5(Dataset):
136
  self.images = np.concatenate(self.images, axis=0)
137
  self.params = np.concatenate(self.params, axis=0)
138
  concurrent_end = time()
139
- print(f"images {self.images.shape} & params {self.params.shape} concurrently loaded after {concurrent_end-concurrent_start:.3f}s")
140
 
141
  transform_start = time()
142
  if self.transform:
@@ -145,14 +146,12 @@ class Dataset4h5(Dataset):
145
  transform_end = time()
146
  print(f"images transformed after {transform_end-transform_start:.3f}s")
147
 
148
- def read_data_chunk(self, f, idx):
149
- pid = os.getpid()
150
  # process = psutil.Process(pid)
151
  # cpu_affinity = process.cpu_affinity()
152
  # cpu_num = psutil.Process().cpu_num()
153
-
154
  # print(f"cpu_num = {cpu_num}")#, cpu_affinity = {cpu_affinity}")
155
-
156
  with h5py.File(self.dir_name, 'r') as f:
157
  images_start = time()
158
  if self.dim == 2:
@@ -162,11 +161,13 @@ class Dataset4h5(Dataset):
162
  images = f[self.field][idx,:self.HII_DIM,:self.HII_DIM,-self.num_redshift:][:,None]
163
  images_end = time()
164
  # print(f"pid {pid}: images of shape {images.shape} loaded after {load_end-load_start:.3f} s")
 
 
165
 
166
  param_start = time()
167
  params = f['params']['values'][idx]
168
  param_end = time()
169
- print(f"pid {pid}: images {images.shape} & params {params.shape} loaded after {images_end-images_start:.3f}s & {param_end-param_start:.3f}s")
170
 
171
  return images, params
172
 
 
42
  dim=2,
43
  transform=True,
44
  ranges_dict=None,
45
+ num_workers=len(os.sched_getaffinity(0))//torch.cuda.device_count(),
46
  # shuffle=False,
47
  ):
48
  super().__init__()
 
57
  self.drop_prob = drop_prob
58
  self.dim = dim
59
  self.transform = transform
60
+ self.num_workers = num_workers
61
 
62
  # if ranges_dict == None:
63
  # ranges_dict = dict(
 
76
  self.images = self.rescale(self.images, ranges=ranges_dict['images'], to=[-1,1])
77
  self.params = self.rescale(self.params, ranges=ranges_dict['params'], to=[0,1])
78
  rescale_end = time()
79
+ # print(f"rescaling costs {rescale_end-rescale_start:.3f} s")
80
+ print(f"images & params rescaled to [{self.images.min()}, {self.images.max()}] & [{self.params.min()}, {self.params.max()}] after {rescale_end-rescale_start:.3f} s")
 
81
 
82
  # from_numpy_start = time()
83
  self.len = len(self.params)
 
110
  # print(f"loading {len(self.idx)} images with idx = {self.idx}")
111
  if self.idx == "random":
112
  self.idx = np.sort(random.sample(range(max_num_image), self.num_image))
113
+ print(f"loading {self.num_image} images randomly with idx = {self.idx[:5]}...{self.idx[-5:]}")
114
  # print(self.idx)
115
  elif self.idx == "range":
116
  rank = torch.cuda.current_device()
 
124
  concurrent_start = time()
125
  self.images = []
126
  self.params = []
127
+ # self.num_workers = len(os.sched_getaffinity(0))//torch.cuda.device_count()
128
+ with concurrent.futures.ProcessPoolExecutor(max_workers=self.num_workers) as executor:
129
+ print(f" cuda:{torch.cuda.current_device()}, concurrently loading by {self.num_workers} workers ".center(120, '-'))
130
  futures = []
131
+ for idx in np.array_split(self.idx, self.num_workers):
132
+ futures.append(executor.submit(self.read_data_chunk, self.dir_name, idx, torch.cuda.current_device()))
133
  for future in concurrent.futures.as_completed(futures):
134
  images, params = future.result()
135
  self.images.append(images)
 
137
  self.images = np.concatenate(self.images, axis=0)
138
  self.params = np.concatenate(self.params, axis=0)
139
  concurrent_end = time()
140
+ print(f" cuda:{torch.cuda.current_device()}: images {self.images.shape} & params {self.params.shape} concurrently loaded after {concurrent_end-concurrent_start:.3f}s ".center(120, '-'))
141
 
142
  transform_start = time()
143
  if self.transform:
 
146
  transform_end = time()
147
  print(f"images transformed after {transform_end-transform_start:.3f}s")
148
 
149
+ def read_data_chunk(self, f, idx, device):
 
150
  # process = psutil.Process(pid)
151
  # cpu_affinity = process.cpu_affinity()
152
  # cpu_num = psutil.Process().cpu_num()
 
153
  # print(f"cpu_num = {cpu_num}")#, cpu_affinity = {cpu_affinity}")
154
+ torch.cuda.set_device(device)
155
  with h5py.File(self.dir_name, 'r') as f:
156
  images_start = time()
157
  if self.dim == 2:
 
161
  images = f[self.field][idx,:self.HII_DIM,:self.HII_DIM,-self.num_redshift:][:,None]
162
  images_end = time()
163
  # print(f"pid {pid}: images of shape {images.shape} loaded after {load_end-load_start:.3f} s")
164
+ pid = os.getpid()
165
+ cpu_num = psutil.Process(pid).cpu_num()
166
 
167
  param_start = time()
168
  params = f['params']['values'][idx]
169
  param_end = time()
170
+ print(f"cuda:{torch.cuda.current_device()}, CPU-pid {cpu_num}-{pid}: images {images.shape} & params {params.shape} loaded after {images_end-images_start:.3f}s & {param_end-param_start:.3f}s")
171
 
172
  return images, params
173
 
quantify_results.ipynb CHANGED
The diff for this file is too large to render. See raw diff