Xsmos commited on
Commit
1aea5e1
·
verified ·
1 Parent(s): 5ea9467

0716-2131

Browse files
Files changed (3) hide show
  1. context_unet.py +1 -1
  2. diffusion.py +82 -55
  3. quantify_results.ipynb +26 -7
context_unet.py CHANGED
@@ -330,7 +330,7 @@ class ContextUnet(nn.Module):
330
  elif image_size == 128:
331
  channel_mult = (1, 1, 2, 3, 4)
332
  elif image_size == 64:
333
- channel_mult = (1, 2, 2, 4, 4)#(1, 2, 2, 4)#(1, 2, 8, 8, 8)#(1, 2, 4)#(1, 2, 2, 4)#(0.5,1,2,2,4,4)#(1, 1, 2, 2, 4, 4)#
334
  elif image_size == 32:
335
  channel_mult = (1, 2, 2, 4)
336
  elif image_size == 28:
 
330
  elif image_size == 128:
331
  channel_mult = (1, 1, 2, 3, 4)
332
  elif image_size == 64:
333
+ channel_mult = (1, 2, 4, 4, 4)#(1, 2, 2, 4)#(1, 2, 8, 8, 8)#(1, 2, 4)#(1, 2, 2, 4)#(0.5,1,2,2,4,4)#(1, 1, 2, 2, 4, 4)#
334
  elif image_size == 32:
335
  channel_mult = (1, 2, 2, 4)
336
  elif image_size == 28:
diffusion.py CHANGED
@@ -239,7 +239,7 @@ class TrainConfig:
239
  stride = (2,2) if dim == 2 else (2,2,4)
240
  num_image = 1000#32000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560
241
  batch_size = 1#2#50#20#2#100 # 10
242
- n_epoch = 4# 10#50#20#20#2#5#25 # 120
243
  HII_DIM = 64
244
  num_redshift = 512#128#64#512#256#256#64#512#128
245
  channel = 1
@@ -508,7 +508,10 @@ class DDPM21CM:
508
  # for i, from_ranges in self.ranges_dict[type].items():
509
  # value[i] = (value[i] - from_ranges[0])/(from_ranges[1]-from_ranges[0]) # normalize
510
  # value[i] =
511
- def rescale(self, value, ranges, to: list):
 
 
 
512
  if value.ndim == 1:
513
  value = value.view(-1,len(value))
514
 
@@ -518,20 +521,21 @@ class DDPM21CM:
518
  value = value * (to[1]-to[0]) + to[0]
519
  return value
520
 
521
- def sample(self, params:torch.tensor=None, num_new_img=192, ema=False, entire=False, save=False):
522
  # n_sample = params.shape[0]
523
  # file = self.config.resume
524
 
 
525
  if params is None:
526
- params = torch.tensor([0.20000000000000018, 0.5055875000000001])
527
- params_backup = params.numpy().copy()
528
- else:
529
- params_backup = params.numpy().copy()
530
- params = self.rescale(params, self.ranges_dict['params'], to=[0,1])
531
-
532
- print(f"device {torch.cuda.current_device()} sampling {num_new_img} images with normalized params = {params}")
533
- params = params.repeat(num_new_img,1)
534
- assert params.dim() == 2, "params must be a 2D torch.tensor"
535
  # print("params =", params)
536
  # print("params =", params)
537
  # print("len(params) =", len(params))
@@ -557,18 +561,24 @@ class DDPM21CM:
557
  with torch.no_grad():
558
  x_last, x_entire = self.ddpm.sample(
559
  nn_model=self.nn_model,
560
- params=params.to(self.config.device),
561
  device=self.config.device,
562
  guide_w=self.config.guide_w
563
  )
564
 
565
  if save:
566
  # np.save(os.path.join(self.config.output_dir, f"{self.config.run_name}{'ema' if ema else ''}.npy"), x_last)
567
- np.save(os.path.join(self.config.output_dir, f"Tvir{params_backup[0]}-zeta{params_backup[1]}-N{self.config.num_image}{'ema' if ema else ''}.npy"), x_last)
 
 
 
 
568
  if entire:
569
- np.save(os.path.join(self.config.output_dir, f"Tvir{params_backup[0]}-zeta{params_backup[1]}-N{self.config.num_image}{'ema' if ema else ''}_entire.npy"), x_last)
570
- else:
571
- return x_last
 
 
572
  # %%
573
  def train(rank, world_size):
574
  config = TrainConfig()
@@ -576,8 +586,8 @@ def train(rank, world_size):
576
 
577
  ddp_setup(rank, world_size)
578
 
579
- num_image_list = [2000]#[200]#[1600,3200,6400,12800,25600]
580
- for i, num_image in enumerate(num_image_list):
581
  config.num_image = num_image
582
  # config.world_size = world_size
583
 
@@ -614,68 +624,85 @@ if __name__ == "__main__":
614
 
615
  # %%
616
 
617
- def generate_samples(ddpm21cm, num_new_img, max_num_img_per_gpu, rank, world_size):
618
- samples = []
619
- for _ in range(num_new_img // max_num_img_per_gpu):
620
- sample = ddpm21cm.sample(params=torch.tensor([4.4, 131.341]), num_new_img=max_num_img_per_gpu)
621
- samples.append(sample)
622
- # ddpm21cm.sample(params=torch.tensor((5.6, 19.037)), num_new_img=max_num_img_per_gpu)
623
- # ddpm21cm.sample(params=torch.tensor((4.699, 30)), num_new_img=max_num_img_per_gpu)
624
- # ddpm21cm.sample(params=torch.tensor((5.477, 200)), num_new_img=max_num_img_per_gpu)
625
- # ddpm21cm.sample(params=torch.tensor((4.8, 131.341)), num_new_img=max_num_img_per_gpu)
626
- samples = np.concatenate(samples, axis=0)
627
-
628
- samples_list = [np.empty_like(samples) for _ in range(world_size)]
629
- dist.all_gather_object(samples_list, samples)
630
-
631
- if rank == 0:
632
- all_samples = np.concatenate(samples_list, axis=0)
633
- return all_samples
634
- else:
635
- return None
636
-
637
- def sample(rank, world_size, config, num_new_img, max_num_img_per_gpu, return_dict):
 
 
 
 
 
 
638
  ddp_setup(rank, world_size)
639
  ddpm21cm = DDPM21CM(config)
640
 
641
- samples = generate_samples(ddpm21cm, num_new_img, max_num_img_per_gpu, rank, world_size)
 
 
 
 
 
 
 
 
 
642
 
643
  # print(f"device {torch.cuda.current_device()}, rank = {rank}, keys = {return_dict.keys()}, samples.shape = {np.shape(samples)}")
644
- if rank == 0:
645
- return_dict['samples'] = samples
646
  # print(f"device {torch.cuda.current_device()}, rank = {rank}, keys = {return_dict.keys()}")
647
 
648
  dist.destroy_process_group()
649
 
650
 
651
- if __name__ == False:#"__main__":
652
- print(" sampling ".center(100,'-'))
653
  world_size = torch.cuda.device_count()
654
- # num_image_list = [1600,3200,6400,12800,25600]
655
- num_image_list = [10]
656
- num_new_img = 4
657
- max_num_img_per_gpu = 2
 
 
 
658
 
659
  # print("config = TrainConfig()")
660
  config = TrainConfig()
661
  config.world_size = world_size
662
  # print("config.world_size = world_size")
663
 
664
- for num_image in num_image_list:
665
  config.num_image = num_image
666
- config.resume = f"./outputs/model_state-N{num_image}-epoch1-device0"
667
 
668
  # print("ddpm21cm = DDPM21CM(config)")
669
  manager = mp.Manager()
670
  return_dict = manager.dict()
671
 
672
- mp.spawn(sample, args=(world_size, config, num_new_img, max_num_img_per_gpu, return_dict), nprocs=world_size, join=True)
673
 
674
  # print("---"*30)
675
  # print(f"device {torch.cuda.current_device()}, keys = {return_dict.keys()}")
676
- if "samples" in return_dict:
677
- samples = return_dict["samples"]
678
- print(f"device {torch.cuda.current_device()} generated samples shape: {samples.shape}")
679
 
680
 
681
  # %%
 
239
  stride = (2,2) if dim == 2 else (2,2,4)
240
  num_image = 1000#32000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560
241
  batch_size = 1#2#50#20#2#100 # 10
242
+ n_epoch = 8#4# 10#50#20#20#2#5#25 # 120
243
  HII_DIM = 64
244
  num_redshift = 512#128#64#512#256#256#64#512#128
245
  channel = 1
 
508
  # for i, from_ranges in self.ranges_dict[type].items():
509
  # value[i] = (value[i] - from_ranges[0])/(from_ranges[1]-from_ranges[0]) # normalize
510
  # value[i] =
511
+ def rescale(self, params, ranges, to: list):
512
+ # value = np.array(params).copy()
513
+ value = params.clone()
514
+
515
  if value.ndim == 1:
516
  value = value.view(-1,len(value))
517
 
 
521
  value = value * (to[1]-to[0]) + to[0]
522
  return value
523
 
524
+ def sample(self, params:torch.tensor=None, num_new_img_per_gpu=192, ema=False, entire=False, save=True):
525
  # n_sample = params.shape[0]
526
  # file = self.config.resume
527
 
528
+ print(f"device {torch.cuda.current_device()}, sample, params = {params}")
529
  if params is None:
530
+ params = torch.tensor([4.4, 131.341])
531
+ # params_backup = params.numpy().copy()
532
+ # else:
533
+ params_backup = params.numpy().copy()
534
+ params_normalized = self.rescale(params, self.ranges_dict['params'], to=[0,1])
535
+
536
+ print(f"device {torch.cuda.current_device()} sampling {num_new_img_per_gpu} images with normalized params = {params_normalized}")
537
+ params_normalized = params_normalized.repeat(num_new_img_per_gpu,1)
538
+ assert params_normalized.dim() == 2, "params_normalized must be a 2D torch.tensor"
539
  # print("params =", params)
540
  # print("params =", params)
541
  # print("len(params) =", len(params))
 
561
  with torch.no_grad():
562
  x_last, x_entire = self.ddpm.sample(
563
  nn_model=self.nn_model,
564
+ params=params_normalized.to(self.config.device),
565
  device=self.config.device,
566
  guide_w=self.config.guide_w
567
  )
568
 
569
  if save:
570
  # np.save(os.path.join(self.config.output_dir, f"{self.config.run_name}{'ema' if ema else ''}.npy"), x_last)
571
+ savetime = datetime.datetime.now().strftime("%m%d-%H%M")
572
+ savename = os.path.join(self.config.output_dir, f"Tvir{params_backup[0]}-zeta{params_backup[1]}-N{self.config.num_image}-device{torch.cuda.current_device()}-{savetime}{'ema' if ema else ''}.npy")
573
+ print(f"saving {savename} ...")
574
+ np.save(savename, x_last)
575
+
576
  if entire:
577
+ savename = os.path.join(self.config.output_dir, f"Tvir{params_backup[0]}-zeta{params_backup[1]}-N{self.config.num_image}-device{torch.cuda.current_device()}-{savetime}{'ema' if ema else ''}_entire.npy")
578
+ print(f"saving {savename} ...")
579
+ np.save(savename, x_entire)
580
+ # else:
581
+ return x_last
582
  # %%
583
  def train(rank, world_size):
584
  config = TrainConfig()
 
586
 
587
  ddp_setup(rank, world_size)
588
 
589
+ num_train_image_list = [10]#[200]#[1600,3200,6400,12800,25600]
590
+ for i, num_image in enumerate(num_train_image_list):
591
  config.num_image = num_image
592
  # config.world_size = world_size
593
 
 
624
 
625
  # %%
626
 
627
+ # def generate_samples(ddpm21cm, num_new_img_per_gpu, max_num_img_per_gpu, rank, world_size, params):
628
+ # # samples = []
629
+ # for _ in range(num_new_img_per_gpu // max_num_img_per_gpu):
630
+ # sample = ddpm21cm.sample(
631
+ # params=params,
632
+ # num_new_img_per_gpu=max_num_img_per_gpu
633
+ # )
634
+
635
+ # print(f"device {torch.cuda.current_device()} generated sample of shape: {sample.shape}")
636
+
637
+ # # samples.append(sample)
638
+ # # ddpm21cm.sample(params=torch.tensor((5.6, 19.037)), num_new_img_per_gpu=max_num_img_per_gpu)
639
+ # # ddpm21cm.sample(params=torch.tensor((4.699, 30)), num_new_img_per_gpu=max_num_img_per_gpu)
640
+ # # ddpm21cm.sample(params=torch.tensor((5.477, 200)), num_new_img_per_gpu=max_num_img_per_gpu)
641
+ # # ddpm21cm.sample(params=torch.tensor((4.8, 131.341)), num_new_img_per_gpu=max_num_img_per_gpu)
642
+ # # samples = np.concatenate(samples, axis=0)
643
+
644
+ # # samples_list = [np.empty_like(samples) for _ in range(world_size)]
645
+ # # dist.all_gather_object(samples_list, samples)
646
+
647
+ # # if rank == 0:
648
+ # # all_samples = np.concatenate(samples_list, axis=0)
649
+ # # return all_samples
650
+ # # else:
651
+ # # return None
652
+
653
+ def generate_samples(rank, world_size, config, num_new_img_per_gpu, max_num_img_per_gpu, return_dict, params):
654
  ddp_setup(rank, world_size)
655
  ddpm21cm = DDPM21CM(config)
656
 
657
+ # generate_samples(ddpm21cm, num_new_img_per_gpu, max_num_img_per_gpu, rank, world_size, params)
658
+
659
+ # samples = []
660
+ for _ in range(num_new_img_per_gpu // max_num_img_per_gpu):
661
+ sample = ddpm21cm.sample(
662
+ params=params,
663
+ num_new_img_per_gpu=max_num_img_per_gpu
664
+ )
665
+
666
+ print(f"device {torch.cuda.current_device()} generated sample of shape: {sample.shape}")
667
 
668
  # print(f"device {torch.cuda.current_device()}, rank = {rank}, keys = {return_dict.keys()}, samples.shape = {np.shape(samples)}")
669
+ # if rank == 0:
670
+ # return_dict['samples'] = samples
671
  # print(f"device {torch.cuda.current_device()}, rank = {rank}, keys = {return_dict.keys()}")
672
 
673
  dist.destroy_process_group()
674
 
675
 
676
+ if __name__ == "__main__":
 
677
  world_size = torch.cuda.device_count()
678
+ print(f" sampling, world_size = {world_size} ".center(100,'-'))
679
+ # num_train_image_list = [1600,3200,6400,12800,25600]
680
+ num_train_image_list = [2000]
681
+ num_new_img_per_gpu = 8
682
+ max_num_img_per_gpu = 1
683
+
684
+ params = torch.tensor([4.4, 131.341])
685
 
686
  # print("config = TrainConfig()")
687
  config = TrainConfig()
688
  config.world_size = world_size
689
  # print("config.world_size = world_size")
690
 
691
+ for num_image in num_train_image_list:
692
  config.num_image = num_image
693
+ config.resume = f"./outputs/model_state-N{num_image}-epoch3-device0"
694
 
695
  # print("ddpm21cm = DDPM21CM(config)")
696
  manager = mp.Manager()
697
  return_dict = manager.dict()
698
 
699
+ mp.spawn(generate_samples, args=(world_size, config, num_new_img_per_gpu, max_num_img_per_gpu, return_dict, params), nprocs=world_size, join=True)
700
 
701
  # print("---"*30)
702
  # print(f"device {torch.cuda.current_device()}, keys = {return_dict.keys()}")
703
+ # if "samples" in return_dict:
704
+ # samples = return_dict["samples"]
705
+ # print(f"device {torch.cuda.current_device()} generated samples shape: {samples.shape}")
706
 
707
 
708
  # %%
quantify_results.ipynb CHANGED
@@ -1971,24 +1971,43 @@
1971
  },
1972
  {
1973
  "cell_type": "code",
1974
- "execution_count": null,
1975
  "metadata": {},
1976
- "outputs": [],
1977
- "source": []
 
 
 
 
 
 
 
 
 
 
 
 
1978
  },
1979
  {
1980
  "cell_type": "code",
1981
- "execution_count": null,
1982
  "metadata": {},
1983
  "outputs": [],
1984
- "source": []
 
 
1985
  },
1986
  {
1987
  "cell_type": "code",
1988
- "execution_count": null,
1989
  "metadata": {},
1990
  "outputs": [],
1991
- "source": []
 
 
 
 
 
1992
  },
1993
  {
1994
  "cell_type": "code",
 
1971
  },
1972
  {
1973
  "cell_type": "code",
1974
+ "execution_count": 6,
1975
  "metadata": {},
1976
+ "outputs": [
1977
+ {
1978
+ "name": "stdout",
1979
+ "output_type": "stream",
1980
+ "text": [
1981
+ "(1, 1, 64, 64, 512)\n"
1982
+ ]
1983
+ }
1984
+ ],
1985
+ "source": [
1986
+ "import numpy as np\n",
1987
+ "data = np.load('/storage/home/hcoda1/3/bxia34/p-jw254-0/ml21cm/outputs/Tvir4.400000095367432-zeta131.34100341796875-N2000-device0-0716-1726.npy')\n",
1988
+ "print(data.shape)"
1989
+ ]
1990
  },
1991
  {
1992
  "cell_type": "code",
1993
+ "execution_count": 7,
1994
  "metadata": {},
1995
  "outputs": [],
1996
+ "source": [
1997
+ "Tb = data[0,0]"
1998
+ ]
1999
  },
2000
  {
2001
  "cell_type": "code",
2002
+ "execution_count": 8,
2003
  "metadata": {},
2004
  "outputs": [],
2005
+ "source": [
2006
+ "import matplotlib.pyplot as plt\n",
2007
+ "for i in range(Tb.shape[-1]):\n",
2008
+ " plt.imshow(Tb[:,:,i])\n",
2009
+ " plt.savefig(f\"Tb{i:03d}.png\")"
2010
+ ]
2011
  },
2012
  {
2013
  "cell_type": "code",