Xsmos commited on
Commit
37efc7f
·
verified ·
1 Parent(s): 2327293
Files changed (2) hide show
  1. context_unet.py +3 -1
  2. diffusion.ipynb +100 -57
context_unet.py CHANGED
@@ -327,8 +327,10 @@ class ContextUnet(nn.Module):
327
  channel_mult = (1, 1, 2, 3, 4)
328
  elif image_size == 64:
329
  channel_mult = (1, 1, 2, 2, 4, 4)#(1, 2, 3, 4)
 
 
330
  elif image_size == 28:
331
- channel_mult = (1, 2)#(1, 2, 3, 4)
332
  else:
333
  raise ValueError(f"unsupported image size: {image_size}")
334
  # else:
 
327
  channel_mult = (1, 1, 2, 3, 4)
328
  elif image_size == 64:
329
  channel_mult = (1, 1, 2, 2, 4, 4)#(1, 2, 3, 4)
330
+ elif image_size == 32:
331
+ channel_mult = (1, 2, 2, 4)
332
  elif image_size == 28:
333
+ channel_mult = (1, 2, 4)#(1, 2, 3, 4)
334
  else:
335
  raise ValueError(f"unsupported image size: {image_size}")
336
  # else:
diffusion.ipynb CHANGED
@@ -74,24 +74,9 @@
74
  "cell_type": "code",
75
  "execution_count": 2,
76
  "metadata": {},
77
- "outputs": [
78
- {
79
- "data": {
80
- "application/vnd.jupyter.widget-view+json": {
81
- "model_id": "9d207dc0a7a24bccb2f419954b64ddc0",
82
- "version_major": 2,
83
- "version_minor": 0
84
- },
85
- "text/plain": [
86
- "VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…"
87
- ]
88
- },
89
- "metadata": {},
90
- "output_type": "display_data"
91
- }
92
- ],
93
  "source": [
94
- "notebook_login()"
95
  ]
96
  },
97
  {
@@ -272,12 +257,12 @@
272
  "\n",
273
  " # dim = 2\n",
274
  " dim = 3\n",
275
- " stride = (2,2) if dim == 2 else (2,2,2)\n",
276
  " num_image = 2000#32000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560\n",
277
- " batch_size = 2#2#50#20#2#100 # 10\n",
278
  " n_epoch = 10#50#20#20#2#5#25 # 120\n",
279
- " HII_DIM = 64\n",
280
- " num_redshift = 128#64#512#256#256#64#512#128\n",
281
  " channel = 1\n",
282
  " img_shape = (channel, HII_DIM, num_redshift) if dim == 2 else (channel, HII_DIM, HII_DIM, num_redshift)\n",
283
  "\n",
@@ -579,17 +564,16 @@
579
  "name": "stdout",
580
  "output_type": "stream",
581
  "text": [
582
- "Number of parameters for nn_model: 306285057\n",
583
- "---------------- num_image = 1000 ----------------\n",
584
- "run_name = 0706-2157\n",
585
  "Launching training on one GPU.\n",
586
  "dataset content: <KeysViewHDF5 ['brightness_temp', 'density', 'kwargs', 'params', 'redshifts_distances', 'seeds', 'xH_box']>\n",
587
  "51200 images can be loaded\n",
588
  "field.shape = (64, 64, 514)\n",
589
  "params keys = [b'ION_Tvir_MIN', b'HII_EFF_FACTOR']\n",
590
- "loading 1000 images randomly\n",
591
- "images loaded: (1000, 1, 64, 64, 128)\n",
592
- "params loaded: (1000, 2)\n"
593
  ]
594
  },
595
  {
@@ -603,19 +587,20 @@
603
  "name": "stdout",
604
  "output_type": "stream",
605
  "text": [
606
- "images rescaled to [-1.0, 1.2339041233062744]\n",
607
- "params rescaled to [0.0006788479393025145, 0.9997530171043563]\n"
 
608
  ]
609
  },
610
  {
611
  "data": {
612
  "application/vnd.jupyter.widget-view+json": {
613
- "model_id": "1d069efc366d480ba2515135ccb18f6c",
614
  "version_major": 2,
615
  "version_minor": 0
616
  },
617
  "text/plain": [
618
- " 0%| | 0/500 [00:00<?, ?it/s]"
619
  ]
620
  },
621
  "metadata": {},
@@ -624,12 +609,12 @@
624
  {
625
  "data": {
626
  "application/vnd.jupyter.widget-view+json": {
627
- "model_id": "93abe7efc5044d3796f67a0d243058b2",
628
  "version_major": 2,
629
  "version_minor": 0
630
  },
631
  "text/plain": [
632
- " 0%| | 0/500 [00:00<?, ?it/s]"
633
  ]
634
  },
635
  "metadata": {},
@@ -638,12 +623,12 @@
638
  {
639
  "data": {
640
  "application/vnd.jupyter.widget-view+json": {
641
- "model_id": "75e22a8998244becb916714cf4b8053e",
642
  "version_major": 2,
643
  "version_minor": 0
644
  },
645
  "text/plain": [
646
- " 0%| | 0/500 [00:00<?, ?it/s]"
647
  ]
648
  },
649
  "metadata": {},
@@ -652,12 +637,12 @@
652
  {
653
  "data": {
654
  "application/vnd.jupyter.widget-view+json": {
655
- "model_id": "7364f0924d2246e49fa02a596e0afde8",
656
  "version_major": 2,
657
  "version_minor": 0
658
  },
659
  "text/plain": [
660
- " 0%| | 0/500 [00:00<?, ?it/s]"
661
  ]
662
  },
663
  "metadata": {},
@@ -666,12 +651,12 @@
666
  {
667
  "data": {
668
  "application/vnd.jupyter.widget-view+json": {
669
- "model_id": "a5c57b0059154c2695e7fc0d426bb0d9",
670
  "version_major": 2,
671
  "version_minor": 0
672
  },
673
  "text/plain": [
674
- " 0%| | 0/500 [00:00<?, ?it/s]"
675
  ]
676
  },
677
  "metadata": {},
@@ -680,12 +665,12 @@
680
  {
681
  "data": {
682
  "application/vnd.jupyter.widget-view+json": {
683
- "model_id": "b0c9ee9c876144eab582e196e71ed83c",
684
  "version_major": 2,
685
  "version_minor": 0
686
  },
687
  "text/plain": [
688
- " 0%| | 0/500 [00:00<?, ?it/s]"
689
  ]
690
  },
691
  "metadata": {},
@@ -694,12 +679,12 @@
694
  {
695
  "data": {
696
  "application/vnd.jupyter.widget-view+json": {
697
- "model_id": "d2956adb8f4c45938f2d8f93baca2998",
698
  "version_major": 2,
699
  "version_minor": 0
700
  },
701
  "text/plain": [
702
- " 0%| | 0/500 [00:00<?, ?it/s]"
703
  ]
704
  },
705
  "metadata": {},
@@ -708,12 +693,12 @@
708
  {
709
  "data": {
710
  "application/vnd.jupyter.widget-view+json": {
711
- "model_id": "1f3031b4cd1040829ce3b457ee30e464",
712
  "version_major": 2,
713
  "version_minor": 0
714
  },
715
  "text/plain": [
716
- " 0%| | 0/500 [00:00<?, ?it/s]"
717
  ]
718
  },
719
  "metadata": {},
@@ -722,12 +707,12 @@
722
  {
723
  "data": {
724
  "application/vnd.jupyter.widget-view+json": {
725
- "model_id": "b4dcf375da4c4c4492d308e4bdc19358",
726
  "version_major": 2,
727
  "version_minor": 0
728
  },
729
  "text/plain": [
730
- " 0%| | 0/500 [00:00<?, ?it/s]"
731
  ]
732
  },
733
  "metadata": {},
@@ -736,12 +721,12 @@
736
  {
737
  "data": {
738
  "application/vnd.jupyter.widget-view+json": {
739
- "model_id": "50c6936f333a451394d6761fa9a085be",
740
  "version_major": 2,
741
  "version_minor": 0
742
  },
743
  "text/plain": [
744
- " 0%| | 0/500 [00:00<?, ?it/s]"
745
  ]
746
  },
747
  "metadata": {},
@@ -749,7 +734,7 @@
749
  }
750
  ],
751
  "source": [
752
- "num_image_list = [1000]#[200]#[1600,3200,6400,12800,25600]\n",
753
  "if __name__ == \"__main__\":\n",
754
  " # torch.multiprocessing.set_start_method(\"spawn\")\n",
755
  " # args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)\n",
@@ -759,7 +744,9 @@
759
  " ddpm21cm = DDPM21CM(config)\n",
760
  " print(f\" num_image = {ddpm21cm.config.num_image} \".center(50, '-'))\n",
761
  " print(f\"run_name = {ddpm21cm.config.run_name}\")\n",
762
- " notebook_launcher(ddpm21cm.train, num_processes=1, mixed_precision='fp16')"
 
 
763
  ]
764
  },
765
  {
@@ -767,6 +754,37 @@
767
  "execution_count": null,
768
  "metadata": {},
769
  "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
770
  "source": [
771
  "if __name__ == \"__main__\":\n",
772
  " # num_image_list = [1600,3200,6400,12800,25600]\n",
@@ -792,22 +810,47 @@
792
  },
793
  {
794
  "cell_type": "code",
795
- "execution_count": null,
796
  "metadata": {},
797
- "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
798
  "source": [
799
  "ls -lth outputs | head"
800
  ]
801
  },
802
  {
803
  "cell_type": "code",
804
- "execution_count": null,
805
  "metadata": {},
806
- "outputs": [],
 
 
 
 
 
 
 
 
807
  "source": [
808
  "def plot_grid(samples, c=None, row=1, col=2):\n",
809
  " print(\"samples.shape =\", samples.shape)\n",
810
- " for j in range(samples.shape[2]):\n",
811
  " plt.figure(figsize = (12,6), dpi=400)\n",
812
  " for i in range(len(samples)):\n",
813
  " plt.subplot(row,col,i+1)\n",
@@ -819,7 +862,7 @@
819
  " # plt.suptitle('simulations')\n",
820
  " plt.tight_layout()\n",
821
  " plt.subplots_adjust(wspace=0, hspace=0)\n",
822
- " plt.savefig(f\"test3D-{j:02d}.png\")\n",
823
  " plt.close()\n",
824
  " # plt.show()\n",
825
  " \n",
 
74
  "cell_type": "code",
75
  "execution_count": 2,
76
  "metadata": {},
77
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  "source": [
79
+ "# notebook_login()"
80
  ]
81
  },
82
  {
 
257
  "\n",
258
  " # dim = 2\n",
259
  " dim = 3\n",
260
+ " stride = (2,2) if dim == 2 else (2,2,1)\n",
261
  " num_image = 2000#32000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560\n",
262
+ " batch_size = 2#50#20#2#100 # 10\n",
263
  " n_epoch = 10#50#20#20#2#5#25 # 120\n",
264
+ " HII_DIM = 32#64\n",
265
+ " num_redshift = 4#128#64#512#256#256#64#512#128\n",
266
  " channel = 1\n",
267
  " img_shape = (channel, HII_DIM, num_redshift) if dim == 2 else (channel, HII_DIM, HII_DIM, num_redshift)\n",
268
  "\n",
 
564
  "name": "stdout",
565
  "output_type": "stream",
566
  "text": [
567
+ "Number of parameters for nn_model: 190142209\n",
568
+ "---------------- num_image = 100 -----------------\n",
569
+ "run_name = 0708-1342\n",
570
  "Launching training on one GPU.\n",
571
  "dataset content: <KeysViewHDF5 ['brightness_temp', 'density', 'kwargs', 'params', 'redshifts_distances', 'seeds', 'xH_box']>\n",
572
  "51200 images can be loaded\n",
573
  "field.shape = (64, 64, 514)\n",
574
  "params keys = [b'ION_Tvir_MIN', b'HII_EFF_FACTOR']\n",
575
+ "loading 100 images randomly\n",
576
+ "images loaded: (100, 1, 32, 32, 4)\n"
 
577
  ]
578
  },
579
  {
 
587
  "name": "stdout",
588
  "output_type": "stream",
589
  "text": [
590
+ "params loaded: (100, 2)\n",
591
+ "images rescaled to [-1.0, 1.2789411544799805]\n",
592
+ "params rescaled to [0.004197723271926046, 0.9944779188934443]\n"
593
  ]
594
  },
595
  {
596
  "data": {
597
  "application/vnd.jupyter.widget-view+json": {
598
+ "model_id": "a9bfadac7d3841c9a5a8c3440649c4f0",
599
  "version_major": 2,
600
  "version_minor": 0
601
  },
602
  "text/plain": [
603
+ " 0%| | 0/50 [00:00<?, ?it/s]"
604
  ]
605
  },
606
  "metadata": {},
 
609
  {
610
  "data": {
611
  "application/vnd.jupyter.widget-view+json": {
612
+ "model_id": "9df4310f213742d9a7aae110fca32403",
613
  "version_major": 2,
614
  "version_minor": 0
615
  },
616
  "text/plain": [
617
+ " 0%| | 0/50 [00:00<?, ?it/s]"
618
  ]
619
  },
620
  "metadata": {},
 
623
  {
624
  "data": {
625
  "application/vnd.jupyter.widget-view+json": {
626
+ "model_id": "399101df4a5f4de8a4a3f155b3ade75b",
627
  "version_major": 2,
628
  "version_minor": 0
629
  },
630
  "text/plain": [
631
+ " 0%| | 0/50 [00:00<?, ?it/s]"
632
  ]
633
  },
634
  "metadata": {},
 
637
  {
638
  "data": {
639
  "application/vnd.jupyter.widget-view+json": {
640
+ "model_id": "ea4834c350594a9c9cbd87727a88a6b8",
641
  "version_major": 2,
642
  "version_minor": 0
643
  },
644
  "text/plain": [
645
+ " 0%| | 0/50 [00:00<?, ?it/s]"
646
  ]
647
  },
648
  "metadata": {},
 
651
  {
652
  "data": {
653
  "application/vnd.jupyter.widget-view+json": {
654
+ "model_id": "b7b0d9a8c2ad456387dc1b053550c702",
655
  "version_major": 2,
656
  "version_minor": 0
657
  },
658
  "text/plain": [
659
+ " 0%| | 0/50 [00:00<?, ?it/s]"
660
  ]
661
  },
662
  "metadata": {},
 
665
  {
666
  "data": {
667
  "application/vnd.jupyter.widget-view+json": {
668
+ "model_id": "17e73c3722d64ae895f337a7379b5225",
669
  "version_major": 2,
670
  "version_minor": 0
671
  },
672
  "text/plain": [
673
+ " 0%| | 0/50 [00:00<?, ?it/s]"
674
  ]
675
  },
676
  "metadata": {},
 
679
  {
680
  "data": {
681
  "application/vnd.jupyter.widget-view+json": {
682
+ "model_id": "0a743e1a2db2445d93533c9ec5ed921f",
683
  "version_major": 2,
684
  "version_minor": 0
685
  },
686
  "text/plain": [
687
+ " 0%| | 0/50 [00:00<?, ?it/s]"
688
  ]
689
  },
690
  "metadata": {},
 
693
  {
694
  "data": {
695
  "application/vnd.jupyter.widget-view+json": {
696
+ "model_id": "cfda06e79a0e4f8b8172e9314263fb5b",
697
  "version_major": 2,
698
  "version_minor": 0
699
  },
700
  "text/plain": [
701
+ " 0%| | 0/50 [00:00<?, ?it/s]"
702
  ]
703
  },
704
  "metadata": {},
 
707
  {
708
  "data": {
709
  "application/vnd.jupyter.widget-view+json": {
710
+ "model_id": "8cb50f206a844a9da91197c2a9ed715b",
711
  "version_major": 2,
712
  "version_minor": 0
713
  },
714
  "text/plain": [
715
+ " 0%| | 0/50 [00:00<?, ?it/s]"
716
  ]
717
  },
718
  "metadata": {},
 
721
  {
722
  "data": {
723
  "application/vnd.jupyter.widget-view+json": {
724
+ "model_id": "dbaed562dee44cddb3b0d17f439464b6",
725
  "version_major": 2,
726
  "version_minor": 0
727
  },
728
  "text/plain": [
729
+ " 0%| | 0/50 [00:00<?, ?it/s]"
730
  ]
731
  },
732
  "metadata": {},
 
734
  }
735
  ],
736
  "source": [
737
+ "num_image_list = [100]#[1000]#[200]#[1600,3200,6400,12800,25600]\n",
738
  "if __name__ == \"__main__\":\n",
739
  " # torch.multiprocessing.set_start_method(\"spawn\")\n",
740
  " # args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)\n",
 
744
  " ddpm21cm = DDPM21CM(config)\n",
745
  " print(f\" num_image = {ddpm21cm.config.num_image} \".center(50, '-'))\n",
746
  " print(f\"run_name = {ddpm21cm.config.run_name}\")\n",
747
+ " notebook_launcher(\n",
748
+ " ddpm21cm.train, num_processes=1#, mixed_precision='fp16'\n",
749
+ " )"
750
  ]
751
  },
752
  {
 
754
  "execution_count": null,
755
  "metadata": {},
756
  "outputs": [],
757
+ "source": []
758
+ },
759
+ {
760
+ "cell_type": "code",
761
+ "execution_count": 15,
762
+ "metadata": {},
763
+ "outputs": [
764
+ {
765
+ "name": "stdout",
766
+ "output_type": "stream",
767
+ "text": [
768
+ "Number of parameters for nn_model: 306285057\n",
769
+ "sampling 2 images with normalized params = tensor([[0.2000, 0.5056]])\n",
770
+ "nn_model resumed from ./outputs/model_state-N1000\n"
771
+ ]
772
+ },
773
+ {
774
+ "data": {
775
+ "application/vnd.jupyter.widget-view+json": {
776
+ "model_id": "69eb2e5d3375414cab966a9c8db91901",
777
+ "version_major": 2,
778
+ "version_minor": 0
779
+ },
780
+ "text/plain": [
781
+ " 0%| | 0/1000 [00:00<?, ?it/s]"
782
+ ]
783
+ },
784
+ "metadata": {},
785
+ "output_type": "display_data"
786
+ }
787
+ ],
788
  "source": [
789
  "if __name__ == \"__main__\":\n",
790
  " # num_image_list = [1600,3200,6400,12800,25600]\n",
 
810
  },
811
  {
812
  "cell_type": "code",
813
+ "execution_count": 19,
814
  "metadata": {},
815
+ "outputs": [
816
+ {
817
+ "name": "stdout",
818
+ "output_type": "stream",
819
+ "text": [
820
+ "total 13G\n",
821
+ "-rw-r--r-- 1 bxia34 pace-jw254 4.1M Jul 6 23:59 Tvir4.400000095367432-zeta131.34100341796875-N1000.npy\n",
822
+ "-rw-r--r-- 1 bxia34 pace-jw254 2.3G Jul 6 23:06 model_state-N1000\n",
823
+ "drwxr-xr-x 56 bxia34 pace-jw254 4.0K Jul 6 22:09 \u001b[0m\u001b[01;34mlogs\u001b[0m/\n",
824
+ "-rw-r--r-- 1 bxia34 pace-jw254 4.1M Jul 6 21:39 Tvir4.400000095367432-zeta131.34100341796875-N50.npy\n",
825
+ "-rw-r--r-- 1 bxia34 pace-jw254 2.3G Jul 6 21:25 model_state-N50\n",
826
+ "-rw-r--r-- 1 bxia34 pace-jw254 193K Jul 6 20:46 Tvir4.400000095367432-zeta131.34100341796875-N20.npy\n",
827
+ "-rw-r--r-- 1 bxia34 pace-jw254 848M Jul 6 20:45 model_state-N20\n",
828
+ "-rw-r--r-- 1 bxia34 pace-jw254 6.1M Jul 5 14:44 Tvir4.400000095367432-zeta131.34100341796875-N200.npy\n",
829
+ "-rw-r--r-- 1 bxia34 pace-jw254 2.3G Jul 5 12:20 model_state-N200\n"
830
+ ]
831
+ }
832
+ ],
833
  "source": [
834
  "ls -lth outputs | head"
835
  ]
836
  },
837
  {
838
  "cell_type": "code",
839
+ "execution_count": 21,
840
  "metadata": {},
841
+ "outputs": [
842
+ {
843
+ "name": "stdout",
844
+ "output_type": "stream",
845
+ "text": [
846
+ "samples.shape = (2, 1, 64, 64, 128)\n"
847
+ ]
848
+ }
849
+ ],
850
  "source": [
851
  "def plot_grid(samples, c=None, row=1, col=2):\n",
852
  " print(\"samples.shape =\", samples.shape)\n",
853
+ " for j in range(samples.shape[4]):\n",
854
  " plt.figure(figsize = (12,6), dpi=400)\n",
855
  " for i in range(len(samples)):\n",
856
  " plt.subplot(row,col,i+1)\n",
 
862
  " # plt.suptitle('simulations')\n",
863
  " plt.tight_layout()\n",
864
  " plt.subplots_adjust(wspace=0, hspace=0)\n",
865
+ " plt.savefig(f\"test3D-{j:03d}.png\")\n",
866
  " plt.close()\n",
867
  " # plt.show()\n",
868
  " \n",