Xsmos commited on
Commit
7a9defa
·
verified ·
1 Parent(s): 37efc7f

0709-1331

Browse files
Files changed (1) hide show
  1. diffusion.ipynb +31 -193
diffusion.ipynb CHANGED
@@ -259,9 +259,9 @@
259
  " dim = 3\n",
260
  " stride = (2,2) if dim == 2 else (2,2,1)\n",
261
  " num_image = 2000#32000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560\n",
262
- " batch_size = 2#50#20#2#100 # 10\n",
263
  " n_epoch = 10#50#20#20#2#5#25 # 120\n",
264
- " HII_DIM = 32#64\n",
265
  " num_redshift = 4#128#64#512#256#256#64#512#128\n",
266
  " channel = 1\n",
267
  " img_shape = (channel, HII_DIM, num_redshift) if dim == 2 else (channel, HII_DIM, HII_DIM, num_redshift)\n",
@@ -564,16 +564,15 @@
564
  "name": "stdout",
565
  "output_type": "stream",
566
  "text": [
567
- "Number of parameters for nn_model: 190142209\n",
568
  "---------------- num_image = 100 -----------------\n",
569
- "run_name = 0708-1342\n",
570
- "Launching training on one GPU.\n",
571
  "dataset content: <KeysViewHDF5 ['brightness_temp', 'density', 'kwargs', 'params', 'redshifts_distances', 'seeds', 'xH_box']>\n",
572
  "51200 images can be loaded\n",
573
  "field.shape = (64, 64, 514)\n",
574
  "params keys = [b'ION_Tvir_MIN', b'HII_EFF_FACTOR']\n",
575
  "loading 100 images randomly\n",
576
- "images loaded: (100, 1, 32, 32, 4)\n"
577
  ]
578
  },
579
  {
@@ -588,14 +587,14 @@
588
  "output_type": "stream",
589
  "text": [
590
  "params loaded: (100, 2)\n",
591
- "images rescaled to [-1.0, 1.2789411544799805]\n",
592
- "params rescaled to [0.004197723271926046, 0.9944779188934443]\n"
593
  ]
594
  },
595
  {
596
  "data": {
597
  "application/vnd.jupyter.widget-view+json": {
598
- "model_id": "a9bfadac7d3841c9a5a8c3440649c4f0",
599
  "version_major": 2,
600
  "version_minor": 0
601
  },
@@ -609,7 +608,7 @@
609
  {
610
  "data": {
611
  "application/vnd.jupyter.widget-view+json": {
612
- "model_id": "9df4310f213742d9a7aae110fca32403",
613
  "version_major": 2,
614
  "version_minor": 0
615
  },
@@ -623,7 +622,7 @@
623
  {
624
  "data": {
625
  "application/vnd.jupyter.widget-view+json": {
626
- "model_id": "399101df4a5f4de8a4a3f155b3ade75b",
627
  "version_major": 2,
628
  "version_minor": 0
629
  },
@@ -637,7 +636,7 @@
637
  {
638
  "data": {
639
  "application/vnd.jupyter.widget-view+json": {
640
- "model_id": "ea4834c350594a9c9cbd87727a88a6b8",
641
  "version_major": 2,
642
  "version_minor": 0
643
  },
@@ -651,7 +650,7 @@
651
  {
652
  "data": {
653
  "application/vnd.jupyter.widget-view+json": {
654
- "model_id": "b7b0d9a8c2ad456387dc1b053550c702",
655
  "version_major": 2,
656
  "version_minor": 0
657
  },
@@ -665,7 +664,7 @@
665
  {
666
  "data": {
667
  "application/vnd.jupyter.widget-view+json": {
668
- "model_id": "17e73c3722d64ae895f337a7379b5225",
669
  "version_major": 2,
670
  "version_minor": 0
671
  },
@@ -679,7 +678,7 @@
679
  {
680
  "data": {
681
  "application/vnd.jupyter.widget-view+json": {
682
- "model_id": "0a743e1a2db2445d93533c9ec5ed921f",
683
  "version_major": 2,
684
  "version_minor": 0
685
  },
@@ -693,7 +692,7 @@
693
  {
694
  "data": {
695
  "application/vnd.jupyter.widget-view+json": {
696
- "model_id": "cfda06e79a0e4f8b8172e9314263fb5b",
697
  "version_major": 2,
698
  "version_minor": 0
699
  },
@@ -707,7 +706,7 @@
707
  {
708
  "data": {
709
  "application/vnd.jupyter.widget-view+json": {
710
- "model_id": "8cb50f206a844a9da91197c2a9ed715b",
711
  "version_major": 2,
712
  "version_minor": 0
713
  },
@@ -721,7 +720,7 @@
721
  {
722
  "data": {
723
  "application/vnd.jupyter.widget-view+json": {
724
- "model_id": "dbaed562dee44cddb3b0d17f439464b6",
725
  "version_major": 2,
726
  "version_minor": 0
727
  },
@@ -734,7 +733,7 @@
734
  }
735
  ],
736
  "source": [
737
- "num_image_list = [100]#[1000]#[200]#[1600,3200,6400,12800,25600]\n",
738
  "if __name__ == \"__main__\":\n",
739
  " # torch.multiprocessing.set_start_method(\"spawn\")\n",
740
  " # args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)\n",
@@ -744,47 +743,23 @@
744
  " ddpm21cm = DDPM21CM(config)\n",
745
  " print(f\" num_image = {ddpm21cm.config.num_image} \".center(50, '-'))\n",
746
  " print(f\"run_name = {ddpm21cm.config.run_name}\")\n",
747
- " notebook_launcher(\n",
748
- " ddpm21cm.train, num_processes=1#, mixed_precision='fp16'\n",
749
- " )"
750
  ]
751
  },
752
  {
753
- "cell_type": "code",
754
- "execution_count": null,
755
  "metadata": {},
756
- "outputs": [],
757
- "source": []
 
758
  },
759
  {
760
  "cell_type": "code",
761
- "execution_count": 15,
762
  "metadata": {},
763
- "outputs": [
764
- {
765
- "name": "stdout",
766
- "output_type": "stream",
767
- "text": [
768
- "Number of parameters for nn_model: 306285057\n",
769
- "sampling 2 images with normalized params = tensor([[0.2000, 0.5056]])\n",
770
- "nn_model resumed from ./outputs/model_state-N1000\n"
771
- ]
772
- },
773
- {
774
- "data": {
775
- "application/vnd.jupyter.widget-view+json": {
776
- "model_id": "69eb2e5d3375414cab966a9c8db91901",
777
- "version_major": 2,
778
- "version_minor": 0
779
- },
780
- "text/plain": [
781
- " 0%| | 0/1000 [00:00<?, ?it/s]"
782
- ]
783
- },
784
- "metadata": {},
785
- "output_type": "display_data"
786
- }
787
- ],
788
  "source": [
789
  "if __name__ == \"__main__\":\n",
790
  " # num_image_list = [1600,3200,6400,12800,25600]\n",
@@ -810,43 +785,18 @@
810
  },
811
  {
812
  "cell_type": "code",
813
- "execution_count": 19,
814
  "metadata": {},
815
- "outputs": [
816
- {
817
- "name": "stdout",
818
- "output_type": "stream",
819
- "text": [
820
- "total 13G\n",
821
- "-rw-r--r-- 1 bxia34 pace-jw254 4.1M Jul 6 23:59 Tvir4.400000095367432-zeta131.34100341796875-N1000.npy\n",
822
- "-rw-r--r-- 1 bxia34 pace-jw254 2.3G Jul 6 23:06 model_state-N1000\n",
823
- "drwxr-xr-x 56 bxia34 pace-jw254 4.0K Jul 6 22:09 \u001b[0m\u001b[01;34mlogs\u001b[0m/\n",
824
- "-rw-r--r-- 1 bxia34 pace-jw254 4.1M Jul 6 21:39 Tvir4.400000095367432-zeta131.34100341796875-N50.npy\n",
825
- "-rw-r--r-- 1 bxia34 pace-jw254 2.3G Jul 6 21:25 model_state-N50\n",
826
- "-rw-r--r-- 1 bxia34 pace-jw254 193K Jul 6 20:46 Tvir4.400000095367432-zeta131.34100341796875-N20.npy\n",
827
- "-rw-r--r-- 1 bxia34 pace-jw254 848M Jul 6 20:45 model_state-N20\n",
828
- "-rw-r--r-- 1 bxia34 pace-jw254 6.1M Jul 5 14:44 Tvir4.400000095367432-zeta131.34100341796875-N200.npy\n",
829
- "-rw-r--r-- 1 bxia34 pace-jw254 2.3G Jul 5 12:20 model_state-N200\n"
830
- ]
831
- }
832
- ],
833
  "source": [
834
  "ls -lth outputs | head"
835
  ]
836
  },
837
  {
838
  "cell_type": "code",
839
- "execution_count": 21,
840
  "metadata": {},
841
- "outputs": [
842
- {
843
- "name": "stdout",
844
- "output_type": "stream",
845
- "text": [
846
- "samples.shape = (2, 1, 64, 64, 128)\n"
847
- ]
848
- }
849
- ],
850
  "source": [
851
  "def plot_grid(samples, c=None, row=1, col=2):\n",
852
  " print(\"samples.shape =\", samples.shape)\n",
@@ -899,118 +849,6 @@
899
  "# # plt.imshow(images[0,0])\n",
900
  "# # plt.show()"
901
  ]
902
- },
903
- {
904
- "cell_type": "code",
905
- "execution_count": null,
906
- "metadata": {},
907
- "outputs": [],
908
- "source": [
909
- "# plot(\"outputs/0528-1433.npy\")\n",
910
- "# plot(\"outputs/0520-2323.npy\")\n",
911
- "# plot(\"outputs/0604-2353.npy\")"
912
- ]
913
- },
914
- {
915
- "cell_type": "code",
916
- "execution_count": null,
917
- "metadata": {},
918
- "outputs": [],
919
- "source": [
920
- "# x = np.load(\"outputs/0528-1433.npy\")\n",
921
- "# print(x.shape)"
922
- ]
923
- },
924
- {
925
- "cell_type": "code",
926
- "execution_count": null,
927
- "metadata": {},
928
- "outputs": [],
929
- "source": [
930
- "import torch\n",
931
- "import torch.nn as nn\n",
932
- "import time\n",
933
- "\n",
934
- "class MyModel(nn.Module):\n",
935
- " def __init__(self):\n",
936
- " super().__init__()\n",
937
- " self.fc = nn.Linear(100,50)\n",
938
- "\n",
939
- " def forward(self, x):\n",
940
- " return self.fc(x)\n",
941
- "\n",
942
- "model = MyModel()\n",
943
- "\n",
944
- "device_count = torch.cuda.device_count()\n",
945
- "print(\"device_count =\", device_count)\n",
946
- "\n",
947
- "if device_count > 1:\n",
948
- " print(f\"using {device_count} GPUs!\")\n",
949
- " model = nn.DataParallel(model)\n",
950
- "\n",
951
- "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
952
- "model.to(device)\n",
953
- "\n",
954
- "start_time = time.time()\n",
955
- "for i in range(10):\n",
956
- " myinput = torch.randn(10,10,32000,100).to(device)\n",
957
- " output = model(myinput)\n",
958
- " print(output.shape)\n",
959
- "# plt.imshow(myinput.cpu()[0])\n",
960
- "# plt.show()\n",
961
- "# plt.imshow(output.detach().cpu().numpy()[0])\n",
962
- "# plt.show()"
963
- ]
964
- },
965
- {
966
- "cell_type": "code",
967
- "execution_count": null,
968
- "metadata": {},
969
- "outputs": [],
970
- "source": [
971
- "# import torch.distributed as dist\n",
972
- "# dist.init_process_group(backend='nccl')"
973
- ]
974
- },
975
- {
976
- "cell_type": "code",
977
- "execution_count": null,
978
- "metadata": {},
979
- "outputs": [],
980
- "source": [
981
- "import numpy as np\n",
982
- "import torch\n",
983
- "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
984
- "\n",
985
- "data = torch.randn((64,64,64))\n",
986
- "\n",
987
- "num_elements = data.numpy().size\n",
988
- "element_size = data.numpy().itemsize\n",
989
- "\n",
990
- "print(data.dtype)\n",
991
- "print(num_elements, element_size)\n",
992
- "print(f\"total size = {num_elements*element_size/1024/1024} MB\")\n",
993
- "\n",
994
- "print(\"---\"*30)\n",
995
- "data = data.to(torch.float64)\n",
996
- "\n",
997
- "num_elements = data.numpy().size\n",
998
- "element_size = data.numpy().itemsize\n",
999
- "\n",
1000
- "print(data.dtype)\n",
1001
- "print(num_elements, element_size)\n",
1002
- "print(f\"total size = {num_elements*element_size/1024/1024} MB\")\n",
1003
- "\n",
1004
- "print(\"---\"*30)\n",
1005
- "data = data.to(torch.float16)\n",
1006
- "\n",
1007
- "num_elements = data.numpy().size\n",
1008
- "element_size = data.numpy().itemsize\n",
1009
- "\n",
1010
- "print(data.dtype)\n",
1011
- "print(num_elements, element_size)\n",
1012
- "print(f\"total size = {num_elements*element_size/1024/1024} MB\")"
1013
- ]
1014
  }
1015
  ],
1016
  "metadata": {
 
259
  " dim = 3\n",
260
  " stride = (2,2) if dim == 2 else (2,2,1)\n",
261
  " num_image = 2000#32000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560\n",
262
+ " batch_size = 2#2#50#20#2#100 # 10\n",
263
  " n_epoch = 10#50#20#20#2#5#25 # 120\n",
264
+ " HII_DIM = 28#64\n",
265
  " num_redshift = 4#128#64#512#256#256#64#512#128\n",
266
  " channel = 1\n",
267
  " img_shape = (channel, HII_DIM, num_redshift) if dim == 2 else (channel, HII_DIM, HII_DIM, num_redshift)\n",
 
564
  "name": "stdout",
565
  "output_type": "stream",
566
  "text": [
567
+ "Number of parameters for nn_model: 160234497\n",
568
  "---------------- num_image = 100 -----------------\n",
569
+ "run_name = 0709-1331\n",
 
570
  "dataset content: <KeysViewHDF5 ['brightness_temp', 'density', 'kwargs', 'params', 'redshifts_distances', 'seeds', 'xH_box']>\n",
571
  "51200 images can be loaded\n",
572
  "field.shape = (64, 64, 514)\n",
573
  "params keys = [b'ION_Tvir_MIN', b'HII_EFF_FACTOR']\n",
574
  "loading 100 images randomly\n",
575
+ "images loaded: (100, 1, 28, 28, 4)\n"
576
  ]
577
  },
578
  {
 
587
  "output_type": "stream",
588
  "text": [
589
  "params loaded: (100, 2)\n",
590
+ "images rescaled to [-1.0, 1.1254141330718994]\n",
591
+ "params rescaled to [0.0022036265313531977, 0.9978807793709957]\n"
592
  ]
593
  },
594
  {
595
  "data": {
596
  "application/vnd.jupyter.widget-view+json": {
597
+ "model_id": "ae9f12def1154f6cb1eb0fc8d1e1871c",
598
  "version_major": 2,
599
  "version_minor": 0
600
  },
 
608
  {
609
  "data": {
610
  "application/vnd.jupyter.widget-view+json": {
611
+ "model_id": "cae19ac5ef7a4c34b6b57a6478dc159d",
612
  "version_major": 2,
613
  "version_minor": 0
614
  },
 
622
  {
623
  "data": {
624
  "application/vnd.jupyter.widget-view+json": {
625
+ "model_id": "b448b32948894b3c8e8780f1b6e6bf58",
626
  "version_major": 2,
627
  "version_minor": 0
628
  },
 
636
  {
637
  "data": {
638
  "application/vnd.jupyter.widget-view+json": {
639
+ "model_id": "0271ab7c081a43ebb830dcfd3db145c1",
640
  "version_major": 2,
641
  "version_minor": 0
642
  },
 
650
  {
651
  "data": {
652
  "application/vnd.jupyter.widget-view+json": {
653
+ "model_id": "1bdf4c28272840f496288545bfdbdb96",
654
  "version_major": 2,
655
  "version_minor": 0
656
  },
 
664
  {
665
  "data": {
666
  "application/vnd.jupyter.widget-view+json": {
667
+ "model_id": "01b50340758b4b05891e51a616660eb8",
668
  "version_major": 2,
669
  "version_minor": 0
670
  },
 
678
  {
679
  "data": {
680
  "application/vnd.jupyter.widget-view+json": {
681
+ "model_id": "20ba64496e5e4467b97428cf6dcdbeb5",
682
  "version_major": 2,
683
  "version_minor": 0
684
  },
 
692
  {
693
  "data": {
694
  "application/vnd.jupyter.widget-view+json": {
695
+ "model_id": "8023e0bde0c3438fb218126c58fee954",
696
  "version_major": 2,
697
  "version_minor": 0
698
  },
 
706
  {
707
  "data": {
708
  "application/vnd.jupyter.widget-view+json": {
709
+ "model_id": "22881fbe8aff4ac1a3bde8735ab4fd24",
710
  "version_major": 2,
711
  "version_minor": 0
712
  },
 
720
  {
721
  "data": {
722
  "application/vnd.jupyter.widget-view+json": {
723
+ "model_id": "3bf15c9ba5144fa69fdddae271c7dbee",
724
  "version_major": 2,
725
  "version_minor": 0
726
  },
 
733
  }
734
  ],
735
  "source": [
736
+ "num_image_list = [100]#[200]#[1600,3200,6400,12800,25600]\n",
737
  "if __name__ == \"__main__\":\n",
738
  " # torch.multiprocessing.set_start_method(\"spawn\")\n",
739
  " # args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)\n",
 
743
  " ddpm21cm = DDPM21CM(config)\n",
744
  " print(f\" num_image = {ddpm21cm.config.num_image} \".center(50, '-'))\n",
745
  " print(f\"run_name = {ddpm21cm.config.run_name}\")\n",
746
+ " ddpm21cm.train()\n",
747
+ " # notebook_launcher(ddpm21cm.train, num_processes=1, mixed_precision='fp16')"
 
748
  ]
749
  },
750
  {
751
+ "attachments": {},
752
+ "cell_type": "markdown",
753
  "metadata": {},
754
+ "source": [
755
+ "# Sampling"
756
+ ]
757
  },
758
  {
759
  "cell_type": "code",
760
+ "execution_count": null,
761
  "metadata": {},
762
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
763
  "source": [
764
  "if __name__ == \"__main__\":\n",
765
  " # num_image_list = [1600,3200,6400,12800,25600]\n",
 
785
  },
786
  {
787
  "cell_type": "code",
788
+ "execution_count": null,
789
  "metadata": {},
790
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
791
  "source": [
792
  "ls -lth outputs | head"
793
  ]
794
  },
795
  {
796
  "cell_type": "code",
797
+ "execution_count": null,
798
  "metadata": {},
799
+ "outputs": [],
 
 
 
 
 
 
 
 
800
  "source": [
801
  "def plot_grid(samples, c=None, row=1, col=2):\n",
802
  " print(\"samples.shape =\", samples.shape)\n",
 
849
  "# # plt.imshow(images[0,0])\n",
850
  "# # plt.show()"
851
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
852
  }
853
  ],
854
  "metadata": {