Xsmos commited on
Commit
10bd4ea
·
verified ·
1 Parent(s): 822afe0
Files changed (1) hide show
  1. diffusion.ipynb +192 -1219
diffusion.ipynb CHANGED
@@ -25,7 +25,10 @@
25
  "- 我統一了ddpm21cm這個module,能統一實現訓練和生成樣本,但目前有個bug, sample時總是會cuda out of memory,然而單獨resume model並sample就不會。\n",
26
  "- 解決了,問題出在我忘了寫with torch.no_grad():\n",
27
  "- 接下來就是生成800個lightcones,與此同時研究如何計算global signal以及power spectrum\n",
28
- "- 儅訓練圖片的數量達到5000時,生成的圖片與檢測數據的相似程度很高"
 
 
 
29
  ]
30
  },
31
  {
@@ -71,24 +74,9 @@
71
  "cell_type": "code",
72
  "execution_count": 2,
73
  "metadata": {},
74
- "outputs": [
75
- {
76
- "data": {
77
- "application/vnd.jupyter.widget-view+json": {
78
- "model_id": "930fd496e3474e459def52921323253c",
79
- "version_major": 2,
80
- "version_minor": 0
81
- },
82
- "text/plain": [
83
- "VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…"
84
- ]
85
- },
86
- "metadata": {},
87
- "output_type": "display_data"
88
- }
89
- ],
90
  "source": [
91
- "notebook_login()"
92
  ]
93
  },
94
  {
@@ -331,6 +319,20 @@
331
  "execution_count": 6,
332
  "metadata": {},
333
  "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334
  "source": [
335
  "# @dataclass\n",
336
  "class DDPM21CM:\n",
@@ -385,7 +387,7 @@
385
  " dataset = Dataset4h5(self.config.dataset_name, num_image=self.config.num_image, HII_DIM=self.config.HII_DIM, num_redshift=self.config.num_redshift, drop_prob=self.config.drop_prob, dim=self.config.dim, ranges_dict=self.ranges_dict)\n",
386
  " # self.shape_loaded = dataset.images.shape\n",
387
  " # print(\"shape_loaded =\", self.shape_loaded)\n",
388
- " self.dataloader = DataLoader(dataset, batch_size=self.config.batch_size, shuffle=True)\n",
389
  " # del dataset\n",
390
  " # self.accelerate(self.config)\n",
391
  " del dataset\n",
@@ -557,31 +559,21 @@
557
  "cell_type": "code",
558
  "execution_count": 8,
559
  "metadata": {},
560
- "outputs": [],
561
- "source": [
562
- "num_image_list = [200]#[1600,3200,6400,12800,25600]"
563
- ]
564
- },
565
- {
566
- "cell_type": "code",
567
- "execution_count": 9,
568
- "metadata": {},
569
  "outputs": [
570
  {
571
  "name": "stdout",
572
  "output_type": "stream",
573
  "text": [
574
  "Number of parameters for nn_model: 306285057\n",
575
- "---------------- num_image = 200 -----------------\n",
576
- "run_name = 0705-1109\n",
577
  "Launching training on one GPU.\n",
578
  "dataset content: <KeysViewHDF5 ['brightness_temp', 'density', 'kwargs', 'params', 'redshifts_distances', 'seeds', 'xH_box']>\n",
579
  "51200 images can be loaded\n",
580
  "field.shape = (64, 64, 514)\n",
581
  "params keys = [b'ION_Tvir_MIN', b'HII_EFF_FACTOR']\n",
582
- "loading 200 images randomly\n",
583
- "images loaded: (200, 1, 64, 64, 64)\n",
584
- "params loaded: (200, 2)\n"
585
  ]
586
  },
587
  {
@@ -595,271 +587,20 @@
595
  "name": "stdout",
596
  "output_type": "stream",
597
  "text": [
598
- "images rescaled to [-1.0, 1.2652204036712646]\n",
599
- "params rescaled to [0.006617276082765411, 0.9997543714597507]\n"
 
600
  ]
601
  },
602
  {
603
  "data": {
604
  "application/vnd.jupyter.widget-view+json": {
605
- "model_id": "35fc33a0ae44486a96d3c802f009c5d3",
606
- "version_major": 2,
607
- "version_minor": 0
608
- },
609
- "text/plain": [
610
- " 0%| | 0/100 [00:00<?, ?it/s]"
611
- ]
612
- },
613
- "metadata": {},
614
- "output_type": "display_data"
615
- },
616
- {
617
- "data": {
618
- "application/vnd.jupyter.widget-view+json": {
619
- "model_id": "2a519ac5315b49c688d3b88355547fb9",
620
- "version_major": 2,
621
- "version_minor": 0
622
- },
623
- "text/plain": [
624
- " 0%| | 0/100 [00:00<?, ?it/s]"
625
- ]
626
- },
627
- "metadata": {},
628
- "output_type": "display_data"
629
- },
630
- {
631
- "data": {
632
- "application/vnd.jupyter.widget-view+json": {
633
- "model_id": "69e2ce385d9c47c29d7adfc91c11f181",
634
- "version_major": 2,
635
- "version_minor": 0
636
- },
637
- "text/plain": [
638
- " 0%| | 0/100 [00:00<?, ?it/s]"
639
- ]
640
- },
641
- "metadata": {},
642
- "output_type": "display_data"
643
- },
644
- {
645
- "data": {
646
- "application/vnd.jupyter.widget-view+json": {
647
- "model_id": "ecd2c3325006430da0b26fc08fdcbd14",
648
- "version_major": 2,
649
- "version_minor": 0
650
- },
651
- "text/plain": [
652
- " 0%| | 0/100 [00:00<?, ?it/s]"
653
- ]
654
- },
655
- "metadata": {},
656
- "output_type": "display_data"
657
- },
658
- {
659
- "data": {
660
- "application/vnd.jupyter.widget-view+json": {
661
- "model_id": "dbfe0fd5713042cd8d39e3742b2b0120",
662
- "version_major": 2,
663
- "version_minor": 0
664
- },
665
- "text/plain": [
666
- " 0%| | 0/100 [00:00<?, ?it/s]"
667
- ]
668
- },
669
- "metadata": {},
670
- "output_type": "display_data"
671
- },
672
- {
673
- "data": {
674
- "application/vnd.jupyter.widget-view+json": {
675
- "model_id": "130fa2e6bac14e3a8c1dcf911f341f11",
676
- "version_major": 2,
677
- "version_minor": 0
678
- },
679
- "text/plain": [
680
- " 0%| | 0/100 [00:00<?, ?it/s]"
681
- ]
682
- },
683
- "metadata": {},
684
- "output_type": "display_data"
685
- },
686
- {
687
- "data": {
688
- "application/vnd.jupyter.widget-view+json": {
689
- "model_id": "41543c28d3e64dbaaae9b27e41f2a2f7",
690
- "version_major": 2,
691
- "version_minor": 0
692
- },
693
- "text/plain": [
694
- " 0%| | 0/100 [00:00<?, ?it/s]"
695
- ]
696
- },
697
- "metadata": {},
698
- "output_type": "display_data"
699
- },
700
- {
701
- "data": {
702
- "application/vnd.jupyter.widget-view+json": {
703
- "model_id": "9583c3c18dcf4171a5480587a318d6fa",
704
- "version_major": 2,
705
- "version_minor": 0
706
- },
707
- "text/plain": [
708
- " 0%| | 0/100 [00:00<?, ?it/s]"
709
- ]
710
- },
711
- "metadata": {},
712
- "output_type": "display_data"
713
- },
714
- {
715
- "data": {
716
- "application/vnd.jupyter.widget-view+json": {
717
- "model_id": "ee45ac51feca414aa424f86df6064549",
718
- "version_major": 2,
719
- "version_minor": 0
720
- },
721
- "text/plain": [
722
- " 0%| | 0/100 [00:00<?, ?it/s]"
723
- ]
724
- },
725
- "metadata": {},
726
- "output_type": "display_data"
727
- },
728
- {
729
- "data": {
730
- "application/vnd.jupyter.widget-view+json": {
731
- "model_id": "0b8da38f4cd146bfb43e9e298df4b0b1",
732
- "version_major": 2,
733
- "version_minor": 0
734
- },
735
- "text/plain": [
736
- " 0%| | 0/100 [00:00<?, ?it/s]"
737
- ]
738
- },
739
- "metadata": {},
740
- "output_type": "display_data"
741
- },
742
- {
743
- "data": {
744
- "application/vnd.jupyter.widget-view+json": {
745
- "model_id": "88673c003dee4d02bccdd96447933490",
746
- "version_major": 2,
747
- "version_minor": 0
748
- },
749
- "text/plain": [
750
- " 0%| | 0/100 [00:00<?, ?it/s]"
751
- ]
752
- },
753
- "metadata": {},
754
- "output_type": "display_data"
755
- },
756
- {
757
- "data": {
758
- "application/vnd.jupyter.widget-view+json": {
759
- "model_id": "e5c76ae1dc1d41e6a12727cfee5d5967",
760
- "version_major": 2,
761
- "version_minor": 0
762
- },
763
- "text/plain": [
764
- " 0%| | 0/100 [00:00<?, ?it/s]"
765
- ]
766
- },
767
- "metadata": {},
768
- "output_type": "display_data"
769
- },
770
- {
771
- "data": {
772
- "application/vnd.jupyter.widget-view+json": {
773
- "model_id": "d387481a4ec2451d8eb5490893f45731",
774
- "version_major": 2,
775
- "version_minor": 0
776
- },
777
- "text/plain": [
778
- " 0%| | 0/100 [00:00<?, ?it/s]"
779
- ]
780
- },
781
- "metadata": {},
782
- "output_type": "display_data"
783
- },
784
- {
785
- "data": {
786
- "application/vnd.jupyter.widget-view+json": {
787
- "model_id": "7b39605940254f9ca903db94e4f97a95",
788
- "version_major": 2,
789
- "version_minor": 0
790
- },
791
- "text/plain": [
792
- " 0%| | 0/100 [00:00<?, ?it/s]"
793
- ]
794
- },
795
- "metadata": {},
796
- "output_type": "display_data"
797
- },
798
- {
799
- "data": {
800
- "application/vnd.jupyter.widget-view+json": {
801
- "model_id": "c506de02c9764e7b81f2dded2708015e",
802
- "version_major": 2,
803
- "version_minor": 0
804
- },
805
- "text/plain": [
806
- " 0%| | 0/100 [00:00<?, ?it/s]"
807
- ]
808
- },
809
- "metadata": {},
810
- "output_type": "display_data"
811
- },
812
- {
813
- "data": {
814
- "application/vnd.jupyter.widget-view+json": {
815
- "model_id": "b96ad7ec574b456ca63bda889e2e44f1",
816
- "version_major": 2,
817
- "version_minor": 0
818
- },
819
- "text/plain": [
820
- " 0%| | 0/100 [00:00<?, ?it/s]"
821
- ]
822
- },
823
- "metadata": {},
824
- "output_type": "display_data"
825
- },
826
- {
827
- "data": {
828
- "application/vnd.jupyter.widget-view+json": {
829
- "model_id": "6b1779e89c93494d96a036a9b58f6efb",
830
- "version_major": 2,
831
- "version_minor": 0
832
- },
833
- "text/plain": [
834
- " 0%| | 0/100 [00:00<?, ?it/s]"
835
- ]
836
- },
837
- "metadata": {},
838
- "output_type": "display_data"
839
- },
840
- {
841
- "data": {
842
- "application/vnd.jupyter.widget-view+json": {
843
- "model_id": "33f102b456ba432f8d2fc0aff2c835ef",
844
- "version_major": 2,
845
- "version_minor": 0
846
- },
847
- "text/plain": [
848
- " 0%| | 0/100 [00:00<?, ?it/s]"
849
- ]
850
- },
851
- "metadata": {},
852
- "output_type": "display_data"
853
- },
854
- {
855
- "data": {
856
- "application/vnd.jupyter.widget-view+json": {
857
- "model_id": "14baff112df54367905c1650b8e8936a",
858
  "version_major": 2,
859
  "version_minor": 0
860
  },
861
  "text/plain": [
862
- " 0%| | 0/100 [00:00<?, ?it/s]"
863
  ]
864
  },
865
  "metadata": {},
@@ -868,12 +609,12 @@
868
  {
869
  "data": {
870
  "application/vnd.jupyter.widget-view+json": {
871
- "model_id": "cadc0d6348f743c39d866c6e2bd933c0",
872
  "version_major": 2,
873
  "version_minor": 0
874
  },
875
  "text/plain": [
876
- " 0%| | 0/100 [00:00<?, ?it/s]"
877
  ]
878
  },
879
  "metadata": {},
@@ -882,12 +623,12 @@
882
  {
883
  "data": {
884
  "application/vnd.jupyter.widget-view+json": {
885
- "model_id": "fa5216c49dbb4b6d86222165c04bdc82",
886
  "version_major": 2,
887
  "version_minor": 0
888
  },
889
  "text/plain": [
890
- " 0%| | 0/100 [00:00<?, ?it/s]"
891
  ]
892
  },
893
  "metadata": {},
@@ -896,12 +637,12 @@
896
  {
897
  "data": {
898
  "application/vnd.jupyter.widget-view+json": {
899
- "model_id": "2b8098f2826a4c36b0379a6db7e8a3cf",
900
  "version_major": 2,
901
  "version_minor": 0
902
  },
903
  "text/plain": [
904
- " 0%| | 0/100 [00:00<?, ?it/s]"
905
  ]
906
  },
907
  "metadata": {},
@@ -910,12 +651,12 @@
910
  {
911
  "data": {
912
  "application/vnd.jupyter.widget-view+json": {
913
- "model_id": "66afeea89e8343f0b23e2e16ed53824a",
914
  "version_major": 2,
915
  "version_minor": 0
916
  },
917
  "text/plain": [
918
- " 0%| | 0/100 [00:00<?, ?it/s]"
919
  ]
920
  },
921
  "metadata": {},
@@ -924,12 +665,12 @@
924
  {
925
  "data": {
926
  "application/vnd.jupyter.widget-view+json": {
927
- "model_id": "e927c234e312458f8f5a90c2c8037515",
928
  "version_major": 2,
929
  "version_minor": 0
930
  },
931
  "text/plain": [
932
- " 0%| | 0/100 [00:00<?, ?it/s]"
933
  ]
934
  },
935
  "metadata": {},
@@ -938,12 +679,12 @@
938
  {
939
  "data": {
940
  "application/vnd.jupyter.widget-view+json": {
941
- "model_id": "27917b8d625c40cfa4040a704c912d1d",
942
  "version_major": 2,
943
  "version_minor": 0
944
  },
945
  "text/plain": [
946
- " 0%| | 0/100 [00:00<?, ?it/s]"
947
  ]
948
  },
949
  "metadata": {},
@@ -952,12 +693,12 @@
952
  {
953
  "data": {
954
  "application/vnd.jupyter.widget-view+json": {
955
- "model_id": "87c9e518f4604fafabe20e79c1c55da2",
956
  "version_major": 2,
957
  "version_minor": 0
958
  },
959
  "text/plain": [
960
- " 0%| | 0/100 [00:00<?, ?it/s]"
961
  ]
962
  },
963
  "metadata": {},
@@ -966,12 +707,12 @@
966
  {
967
  "data": {
968
  "application/vnd.jupyter.widget-view+json": {
969
- "model_id": "88affae0951f433ebe0ea0976e98358d",
970
  "version_major": 2,
971
  "version_minor": 0
972
  },
973
  "text/plain": [
974
- " 0%| | 0/100 [00:00<?, ?it/s]"
975
  ]
976
  },
977
  "metadata": {},
@@ -980,873 +721,52 @@
980
  {
981
  "data": {
982
  "application/vnd.jupyter.widget-view+json": {
983
- "model_id": "400bfcd1fea94067b06afb47408a4b0a",
984
  "version_major": 2,
985
  "version_minor": 0
986
  },
987
  "text/plain": [
988
- " 0%| | 0/100 [00:00<?, ?it/s]"
989
- ]
990
- },
991
- "metadata": {},
992
- "output_type": "display_data"
993
- },
994
- {
995
- "data": {
996
- "application/vnd.jupyter.widget-view+json": {
997
- "model_id": "80c18ea4fb4f42aabb7b13c44e4e4ad6",
998
- "version_major": 2,
999
- "version_minor": 0
1000
- },
1001
- "text/plain": [
1002
- " 0%| | 0/100 [00:00<?, ?it/s]"
1003
- ]
1004
- },
1005
- "metadata": {},
1006
- "output_type": "display_data"
1007
- },
1008
- {
1009
- "data": {
1010
- "application/vnd.jupyter.widget-view+json": {
1011
- "model_id": "276810f258834fd2bc842ba4eaa90919",
1012
- "version_major": 2,
1013
- "version_minor": 0
1014
- },
1015
- "text/plain": [
1016
- " 0%| | 0/100 [00:00<?, ?it/s]"
1017
- ]
1018
- },
1019
- "metadata": {},
1020
- "output_type": "display_data"
1021
- },
1022
- {
1023
- "data": {
1024
- "application/vnd.jupyter.widget-view+json": {
1025
- "model_id": "ff346b75aced4a86876729fc89011894",
1026
- "version_major": 2,
1027
- "version_minor": 0
1028
- },
1029
- "text/plain": [
1030
- " 0%| | 0/100 [00:00<?, ?it/s]"
1031
- ]
1032
- },
1033
- "metadata": {},
1034
- "output_type": "display_data"
1035
- },
1036
- {
1037
- "data": {
1038
- "application/vnd.jupyter.widget-view+json": {
1039
- "model_id": "bdac0a472b7b42a49585d86910d0d5c0",
1040
- "version_major": 2,
1041
- "version_minor": 0
1042
- },
1043
- "text/plain": [
1044
- " 0%| | 0/100 [00:00<?, ?it/s]"
1045
- ]
1046
- },
1047
- "metadata": {},
1048
- "output_type": "display_data"
1049
- },
1050
- {
1051
- "data": {
1052
- "application/vnd.jupyter.widget-view+json": {
1053
- "model_id": "4339b5b8a7404965a06140a53e705937",
1054
- "version_major": 2,
1055
- "version_minor": 0
1056
- },
1057
- "text/plain": [
1058
- " 0%| | 0/100 [00:00<?, ?it/s]"
1059
- ]
1060
- },
1061
- "metadata": {},
1062
- "output_type": "display_data"
1063
- },
1064
- {
1065
- "data": {
1066
- "application/vnd.jupyter.widget-view+json": {
1067
- "model_id": "be80960e91844e72bee0958dd66a129f",
1068
- "version_major": 2,
1069
- "version_minor": 0
1070
- },
1071
- "text/plain": [
1072
- " 0%| | 0/100 [00:00<?, ?it/s]"
1073
- ]
1074
- },
1075
- "metadata": {},
1076
- "output_type": "display_data"
1077
- },
1078
- {
1079
- "data": {
1080
- "application/vnd.jupyter.widget-view+json": {
1081
- "model_id": "decf9ba7e48643a8a2231ef41d739331",
1082
- "version_major": 2,
1083
- "version_minor": 0
1084
- },
1085
- "text/plain": [
1086
- " 0%| | 0/100 [00:00<?, ?it/s]"
1087
- ]
1088
- },
1089
- "metadata": {},
1090
- "output_type": "display_data"
1091
- },
1092
- {
1093
- "data": {
1094
- "application/vnd.jupyter.widget-view+json": {
1095
- "model_id": "4890152cdd7f4a8d94d29b2fea5c2d59",
1096
- "version_major": 2,
1097
- "version_minor": 0
1098
- },
1099
- "text/plain": [
1100
- " 0%| | 0/100 [00:00<?, ?it/s]"
1101
- ]
1102
- },
1103
- "metadata": {},
1104
- "output_type": "display_data"
1105
- },
1106
- {
1107
- "data": {
1108
- "application/vnd.jupyter.widget-view+json": {
1109
- "model_id": "f004da39d4ff49f5a90098f0bc0568a7",
1110
- "version_major": 2,
1111
- "version_minor": 0
1112
- },
1113
- "text/plain": [
1114
- " 0%| | 0/100 [00:00<?, ?it/s]"
1115
- ]
1116
- },
1117
- "metadata": {},
1118
- "output_type": "display_data"
1119
- },
1120
- {
1121
- "data": {
1122
- "application/vnd.jupyter.widget-view+json": {
1123
- "model_id": "7c50fe95d26b48dab44797c66242f666",
1124
- "version_major": 2,
1125
- "version_minor": 0
1126
- },
1127
- "text/plain": [
1128
- " 0%| | 0/100 [00:00<?, ?it/s]"
1129
- ]
1130
- },
1131
- "metadata": {},
1132
- "output_type": "display_data"
1133
- },
1134
- {
1135
- "data": {
1136
- "application/vnd.jupyter.widget-view+json": {
1137
- "model_id": "280299f0f7b94483bb88cce5f659d6ad",
1138
- "version_major": 2,
1139
- "version_minor": 0
1140
- },
1141
- "text/plain": [
1142
- " 0%| | 0/100 [00:00<?, ?it/s]"
1143
- ]
1144
- },
1145
- "metadata": {},
1146
- "output_type": "display_data"
1147
- },
1148
- {
1149
- "data": {
1150
- "application/vnd.jupyter.widget-view+json": {
1151
- "model_id": "18a2c92fe5954b48b17b286bb45ab845",
1152
- "version_major": 2,
1153
- "version_minor": 0
1154
- },
1155
- "text/plain": [
1156
- " 0%| | 0/100 [00:00<?, ?it/s]"
1157
- ]
1158
- },
1159
- "metadata": {},
1160
- "output_type": "display_data"
1161
- },
1162
- {
1163
- "data": {
1164
- "application/vnd.jupyter.widget-view+json": {
1165
- "model_id": "825c8957956f4f0a8c25b967015e113a",
1166
- "version_major": 2,
1167
- "version_minor": 0
1168
- },
1169
- "text/plain": [
1170
- " 0%| | 0/100 [00:00<?, ?it/s]"
1171
- ]
1172
- },
1173
- "metadata": {},
1174
- "output_type": "display_data"
1175
- },
1176
- {
1177
- "data": {
1178
- "application/vnd.jupyter.widget-view+json": {
1179
- "model_id": "503a4bcabd714820bfbf705078db422a",
1180
- "version_major": 2,
1181
- "version_minor": 0
1182
- },
1183
- "text/plain": [
1184
- " 0%| | 0/100 [00:00<?, ?it/s]"
1185
- ]
1186
- },
1187
- "metadata": {},
1188
- "output_type": "display_data"
1189
- },
1190
- {
1191
- "data": {
1192
- "application/vnd.jupyter.widget-view+json": {
1193
- "model_id": "ed04ab8b9bd54ff092ca89c478a7e681",
1194
- "version_major": 2,
1195
- "version_minor": 0
1196
- },
1197
- "text/plain": [
1198
- " 0%| | 0/100 [00:00<?, ?it/s]"
1199
- ]
1200
- },
1201
- "metadata": {},
1202
- "output_type": "display_data"
1203
- },
1204
- {
1205
- "data": {
1206
- "application/vnd.jupyter.widget-view+json": {
1207
- "model_id": "0fa61bc1ac594a5180a56fe21dd0d3a5",
1208
- "version_major": 2,
1209
- "version_minor": 0
1210
- },
1211
- "text/plain": [
1212
- " 0%| | 0/100 [00:00<?, ?it/s]"
1213
- ]
1214
- },
1215
- "metadata": {},
1216
- "output_type": "display_data"
1217
- },
1218
- {
1219
- "data": {
1220
- "application/vnd.jupyter.widget-view+json": {
1221
- "model_id": "6d60e361bb274a09a0739ee6780c7dc7",
1222
- "version_major": 2,
1223
- "version_minor": 0
1224
- },
1225
- "text/plain": [
1226
- " 0%| | 0/100 [00:00<?, ?it/s]"
1227
- ]
1228
- },
1229
- "metadata": {},
1230
- "output_type": "display_data"
1231
- },
1232
- {
1233
- "data": {
1234
- "application/vnd.jupyter.widget-view+json": {
1235
- "model_id": "3172f8d978c54e68ac51e6eec23ae9cb",
1236
- "version_major": 2,
1237
- "version_minor": 0
1238
- },
1239
- "text/plain": [
1240
- " 0%| | 0/100 [00:00<?, ?it/s]"
1241
- ]
1242
- },
1243
- "metadata": {},
1244
- "output_type": "display_data"
1245
- },
1246
- {
1247
- "data": {
1248
- "application/vnd.jupyter.widget-view+json": {
1249
- "model_id": "acf9dabf57df402088d659c85164155a",
1250
- "version_major": 2,
1251
- "version_minor": 0
1252
- },
1253
- "text/plain": [
1254
- " 0%| | 0/100 [00:00<?, ?it/s]"
1255
- ]
1256
- },
1257
- "metadata": {},
1258
- "output_type": "display_data"
1259
- },
1260
- {
1261
- "data": {
1262
- "application/vnd.jupyter.widget-view+json": {
1263
- "model_id": "f1281bf575364662a9ee843a2336b94a",
1264
- "version_major": 2,
1265
- "version_minor": 0
1266
- },
1267
- "text/plain": [
1268
- " 0%| | 0/100 [00:00<?, ?it/s]"
1269
- ]
1270
- },
1271
- "metadata": {},
1272
- "output_type": "display_data"
1273
- },
1274
- {
1275
- "data": {
1276
- "application/vnd.jupyter.widget-view+json": {
1277
- "model_id": "5f3c79aad7644ae4916d8a2bc73c3b76",
1278
- "version_major": 2,
1279
- "version_minor": 0
1280
- },
1281
- "text/plain": [
1282
- " 0%| | 0/100 [00:00<?, ?it/s]"
1283
- ]
1284
- },
1285
- "metadata": {},
1286
- "output_type": "display_data"
1287
- },
1288
- {
1289
- "data": {
1290
- "application/vnd.jupyter.widget-view+json": {
1291
- "model_id": "bd4b87c28e024b58a2a8c8abdca92a75",
1292
- "version_major": 2,
1293
- "version_minor": 0
1294
- },
1295
- "text/plain": [
1296
- " 0%| | 0/100 [00:00<?, ?it/s]"
1297
- ]
1298
- },
1299
- "metadata": {},
1300
- "output_type": "display_data"
1301
- }
1302
- ],
1303
- "source": [
1304
- "if __name__ == \"__main__\":\n",
1305
- " # args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)\n",
1306
- " config = TrainConfig()\n",
1307
- " for i, num_image in enumerate(num_image_list):\n",
1308
- " config.num_image = num_image\n",
1309
- " ddpm21cm = DDPM21CM(config)\n",
1310
- " print(f\" num_image = {ddpm21cm.config.num_image} \".center(50, '-'))\n",
1311
- " print(f\"run_name = {ddpm21cm.config.run_name}\")\n",
1312
- " notebook_launcher(ddpm21cm.train, num_processes=1)"
1313
- ]
1314
- },
1315
- {
1316
- "cell_type": "code",
1317
- "execution_count": null,
1318
- "metadata": {},
1319
- "outputs": [
1320
- {
1321
- "name": "stdout",
1322
- "output_type": "stream",
1323
- "text": [
1324
- "total 4.4G\n",
1325
- "-rw-r--r-- 1 bxia34 13M Jul 2 21:45 Tvir4.800000190734863-zeta131.34100341796875-N1600.npy\n",
1326
- "-rw-r--r-- 1 bxia34 13M Jul 2 21:26 Tvir5.4770002365112305-zeta200.0-N1600.npy\n",
1327
- "-rw-r--r-- 1 bxia34 13M Jul 2 21:08 Tvir4.698999881744385-zeta30.0-N1600.npy\n",
1328
- "-rw-r--r-- 1 bxia34 13M Jul 2 20:49 Tvir5.599999904632568-zeta19.03700065612793-N1600.npy\n",
1329
- "-rw-r--r-- 1 bxia34 13M Jul 2 20:31 Tvir4.400000095367432-zeta131.34100341796875-N1600.npy\n",
1330
- "-rw-r--r-- 1 bxia34 848M Jul 2 20:13 model_state-N25600\n",
1331
- "drwxr-xr-x 15 bxia34 4.0K Jul 2 19:09 \u001b[0m\u001b[01;34mlogs\u001b[0m/\n",
1332
- "-rw-r--r-- 1 bxia34 848M Jul 2 18:45 model_state-N12800\n",
1333
- "-rw-r--r-- 1 bxia34 848M Jul 2 18:01 model_state-N6400\n",
1334
- "-rw-r--r-- 1 bxia34 848M Jul 2 17:37 model_state-N3200\n",
1335
- "-rw-r--r-- 1 bxia34 848M Jul 2 17:25 model_state-N1600\n",
1336
- "-rw-r--r-- 1 bxia34 3.1M Jul 1 12:31 Tvir4.800000190734863-zeta131.34100341796875-N2000.npy\n",
1337
- "-rw-r--r-- 1 bxia34 3.1M Jul 1 12:12 Tvir5.4770002365112305-zeta200.0-N2000.npy\n",
1338
- "-rw-r--r-- 1 bxia34 3.1M Jul 1 11:54 Tvir4.698999881744385-zeta30.0-N2000.npy\n",
1339
- "-rw-r--r-- 1 bxia34 3.1M Jul 1 11:35 Tvir5.599999904632568-zeta19.03700065612793-N2000.npy\n",
1340
- "-rw-r--r-- 1 bxia34 3.1M Jul 1 11:17 Tvir4.400000095367432-zeta131.34100341796875-N2000.npy\n",
1341
- "-rw-r--r-- 1 bxia34 3.1M Jul 1 10:37 Tvir4.800000190734863-zeta131.34100341796875-N32000.npy\n",
1342
- "-rw-r--r-- 1 bxia34 3.1M Jul 1 04:47 Tvir5.4770002365112305-zeta200.0-N32000.npy\n",
1343
- "-rw-r--r-- 1 bxia34 3.1M Jul 1 04:29 Tvir4.698999881744385-zeta30.0-N32000.npy\n",
1344
- "-rw-r--r-- 1 bxia34 3.1M Jul 1 04:11 Tvir5.599999904632568-zeta19.03700065612793-N32000.npy\n",
1345
- "-rw-r--r-- 1 bxia34 3.1M Jul 1 03:53 Tvir4.400000095367432-zeta131.34100341796875-N32000.npy\n",
1346
- "-rw-r--r-- 1 bxia34 3.1M Jul 1 00:23 Tvir4.800000190734863-zeta131.34100341796875-N20000.npy\n",
1347
- "-rw-r--r-- 1 bxia34 3.1M Jul 1 00:05 Tvir5.4770002365112305-zeta200.0-N20000.npy\n",
1348
- "-rw-r--r-- 1 bxia34 3.1M Jun 30 23:47 Tvir4.698999881744385-zeta30.0-N20000.npy\n",
1349
- "-rw-r--r-- 1 bxia34 3.1M Jun 30 23:29 Tvir5.599999904632568-zeta19.03700065612793-N20000.npy\n",
1350
- "-rw-r--r-- 1 bxia34 3.1M Jun 30 23:11 Tvir4.400000095367432-zeta131.34100341796875-N20000.npy\n",
1351
- "-rw-r--r-- 1 bxia34 3.1M Jun 30 20:08 Tvir4.800000190734863-zeta131.34100341796875-N15000.npy\n",
1352
- "-rw-r--r-- 1 bxia34 3.1M Jun 30 19:50 Tvir5.4770002365112305-zeta200.0-N15000.npy\n",
1353
- "-rw-r--r-- 1 bxia34 3.1M Jun 30 19:32 Tvir4.698999881744385-zeta30.0-N15000.npy\n",
1354
- "-rw-r--r-- 1 bxia34 3.1M Jun 30 19:14 Tvir5.599999904632568-zeta19.03700065612793-N15000.npy\n",
1355
- "-rw-r--r-- 1 bxia34 3.1M Jun 30 18:57 Tvir4.400000095367432-zeta131.34100341796875-N15000.npy\n",
1356
- "-rw-r--r-- 1 bxia34 3.1M Jun 29 12:41 Tvir4.800000190734863-zeta131.34100341796875-N7000.npy\n",
1357
- "-rw-r--r-- 1 bxia34 3.1M Jun 29 12:23 Tvir5.4770002365112305-zeta200.0-N7000.npy\n",
1358
- "-rw-r--r-- 1 bxia34 3.1M Jun 29 12:06 Tvir4.698999881744385-zeta30.0-N7000.npy\n",
1359
- "-rw-r--r-- 1 bxia34 3.1M Jun 29 11:48 Tvir5.599999904632568-zeta19.03700065612793-N7000.npy\n",
1360
- "-rw-r--r-- 1 bxia34 3.1M Jun 29 11:30 Tvir4.400000095367432-zeta131.34100341796875-N7000.npy\n",
1361
- "-rw-r--r-- 1 bxia34 3.1M Jun 29 04:56 Tvir4.800000190734863-zeta131.34100341796875-N25600.npy\n",
1362
- "-rw-r--r-- 1 bxia34 3.1M Jun 29 04:38 Tvir5.4770002365112305-zeta200.0-N25600.npy\n",
1363
- "-rw-r--r-- 1 bxia34 3.1M Jun 29 04:21 Tvir4.698999881744385-zeta30.0-N25600.npy\n",
1364
- "-rw-r--r-- 1 bxia34 3.1M Jun 29 04:03 Tvir5.599999904632568-zeta19.03700065612793-N25600.npy\n",
1365
- "-rw-r--r-- 1 bxia34 3.1M Jun 29 03:45 Tvir4.400000095367432-zeta131.34100341796875-N25600.npy\n",
1366
- "-rw-r--r-- 1 bxia34 3.1M Jun 29 00:35 Tvir4.800000190734863-zeta131.34100341796875-N3000.npy\n",
1367
- "-rw-r--r-- 1 bxia34 3.1M Jun 29 00:17 Tvir5.4770002365112305-zeta200.0-N3000.npy\n",
1368
- "-rw-r--r-- 1 bxia34 3.1M Jun 28 23:59 Tvir4.698999881744385-zeta30.0-N3000.npy\n",
1369
- "-rw-r--r-- 1 bxia34 3.1M Jun 28 23:42 Tvir5.599999904632568-zeta19.03700065612793-N3000.npy\n",
1370
- "-rw-r--r-- 1 bxia34 3.1M Jun 28 23:20 Tvir4.400000095367432-zeta131.34100341796875-N3000.npy\n",
1371
- "-rw-r--r-- 1 bxia34 3.1M Jun 28 21:06 Tvir4.800000190734863-zeta131.34100341796875-N10000.npy\n",
1372
- "-rw-r--r-- 1 bxia34 3.1M Jun 28 20:49 Tvir5.4770002365112305-zeta200.0-N10000.npy\n",
1373
- "-rw-r--r-- 1 bxia34 3.1M Jun 28 20:31 Tvir4.698999881744385-zeta30.0-N10000.npy\n",
1374
- "-rw-r--r-- 1 bxia34 3.1M Jun 28 20:13 Tvir5.599999904632568-zeta19.03700065612793-N10000.npy\n",
1375
- "-rw-r--r-- 1 bxia34 3.1M Jun 28 19:56 Tvir4.400000095367432-zeta131.34100341796875-N10000.npy\n",
1376
- "-rw-r--r-- 1 bxia34 3.1M Jun 28 18:30 Tvir4.800000190734863-zeta131.34100341796875-N1000.npy\n",
1377
- "-rw-r--r-- 1 bxia34 3.1M Jun 28 18:13 Tvir5.4770002365112305-zeta200.0-N1000.npy\n",
1378
- "-rw-r--r-- 1 bxia34 3.1M Jun 28 17:55 Tvir4.698999881744385-zeta30.0-N1000.npy\n",
1379
- "-rw-r--r-- 1 bxia34 3.1M Jun 28 17:37 Tvir5.599999904632568-zeta19.03700065612793-N1000.npy\n",
1380
- "-rw-r--r-- 1 bxia34 3.1M Jun 28 17:20 Tvir4.400000095367432-zeta131.34100341796875-N1000.npy\n",
1381
- "-rw-r--r-- 1 bxia34 3.1M Jun 28 14:03 Tvir4.400000095367432-zeta131.34100341796875-N5000.npy\n",
1382
- "-rw-r--r-- 1 bxia34 3.1M Jun 10 18:58 Tvir4.800000190734863-zeta131.34100341796875-N5000.npy\n",
1383
- "-rw-r--r-- 1 bxia34 3.1M Jun 10 18:40 Tvir5.4770002365112305-zeta200.0-N5000.npy\n",
1384
- "-rw-r--r-- 1 bxia34 3.1M Jun 10 18:22 Tvir4.698999881744385-zeta30.0-N5000.npy\n",
1385
- "-rw-r--r-- 1 bxia34 3.1M Jun 10 18:05 Tvir5.599999904632568-zeta19.03700065612793-N5000.npy\n"
1386
- ]
1387
- }
1388
- ],
1389
- "source": [
1390
- "# ll -lth outputs"
1391
- ]
1392
- },
1393
- {
1394
- "cell_type": "code",
1395
- "execution_count": null,
1396
- "metadata": {},
1397
- "outputs": [
1398
- {
1399
- "name": "stdout",
1400
- "output_type": "stream",
1401
- "text": [
1402
- "Number of parameters for nn_model: 111048705\n",
1403
- "sampling 800 images with normalized params = tensor([[0.2000, 0.5056]])\n",
1404
- "nn_model resumed from ./outputs/model_state-N3200\n"
1405
- ]
1406
- },
1407
- {
1408
- "data": {
1409
- "application/vnd.jupyter.widget-view+json": {
1410
- "model_id": "2d3427d677774c9785ee081b3b3b5542",
1411
- "version_major": 2,
1412
- "version_minor": 0
1413
- },
1414
- "text/plain": [
1415
- " 0%| | 0/1000 [00:00<?, ?it/s]"
1416
- ]
1417
- },
1418
- "metadata": {},
1419
- "output_type": "display_data"
1420
- },
1421
- {
1422
- "name": "stdout",
1423
- "output_type": "stream",
1424
- "text": [
1425
- "sampling 800 images with normalized params = tensor([[0.8000, 0.0377]])\n",
1426
- "nn_model resumed from ./outputs/model_state-N3200\n"
1427
- ]
1428
- },
1429
- {
1430
- "data": {
1431
- "application/vnd.jupyter.widget-view+json": {
1432
- "model_id": "8f30542543bf4d96ac6284da1d3e2d91",
1433
- "version_major": 2,
1434
- "version_minor": 0
1435
- },
1436
- "text/plain": [
1437
- " 0%| | 0/1000 [00:00<?, ?it/s]"
1438
- ]
1439
- },
1440
- "metadata": {},
1441
- "output_type": "display_data"
1442
- },
1443
- {
1444
- "name": "stdout",
1445
- "output_type": "stream",
1446
- "text": [
1447
- "sampling 800 images with normalized params = tensor([[0.3495, 0.0833]])\n",
1448
- "nn_model resumed from ./outputs/model_state-N3200\n"
1449
- ]
1450
- },
1451
- {
1452
- "data": {
1453
- "application/vnd.jupyter.widget-view+json": {
1454
- "model_id": "8b55d00d4ec74b1995fabcb27152a20c",
1455
- "version_major": 2,
1456
- "version_minor": 0
1457
- },
1458
- "text/plain": [
1459
- " 0%| | 0/1000 [00:00<?, ?it/s]"
1460
- ]
1461
- },
1462
- "metadata": {},
1463
- "output_type": "display_data"
1464
- },
1465
- {
1466
- "name": "stdout",
1467
- "output_type": "stream",
1468
- "text": [
1469
- "sampling 800 images with normalized params = tensor([[0.7385, 0.7917]])\n",
1470
- "nn_model resumed from ./outputs/model_state-N3200\n"
1471
- ]
1472
- },
1473
- {
1474
- "data": {
1475
- "application/vnd.jupyter.widget-view+json": {
1476
- "model_id": "3fbfb9641b7c4709ab24f843b9ef9a41",
1477
- "version_major": 2,
1478
- "version_minor": 0
1479
- },
1480
- "text/plain": [
1481
- " 0%| | 0/1000 [00:00<?, ?it/s]"
1482
- ]
1483
- },
1484
- "metadata": {},
1485
- "output_type": "display_data"
1486
- },
1487
- {
1488
- "name": "stdout",
1489
- "output_type": "stream",
1490
- "text": [
1491
- "sampling 800 images with normalized params = tensor([[0.4000, 0.5056]])\n",
1492
- "nn_model resumed from ./outputs/model_state-N3200\n"
1493
- ]
1494
- },
1495
- {
1496
- "data": {
1497
- "application/vnd.jupyter.widget-view+json": {
1498
- "model_id": "f146207ea2af4e2fbbe19a7255fe13bd",
1499
- "version_major": 2,
1500
- "version_minor": 0
1501
- },
1502
- "text/plain": [
1503
- " 0%| | 0/1000 [00:00<?, ?it/s]"
1504
- ]
1505
- },
1506
- "metadata": {},
1507
- "output_type": "display_data"
1508
- },
1509
- {
1510
- "name": "stdout",
1511
- "output_type": "stream",
1512
- "text": [
1513
- "Number of parameters for nn_model: 111048705\n",
1514
- "sampling 800 images with normalized params = tensor([[0.2000, 0.5056]])\n",
1515
- "nn_model resumed from ./outputs/model_state-N6400\n"
1516
- ]
1517
- },
1518
- {
1519
- "data": {
1520
- "application/vnd.jupyter.widget-view+json": {
1521
- "model_id": "932fde5fc32e46719e809370a0145171",
1522
- "version_major": 2,
1523
- "version_minor": 0
1524
- },
1525
- "text/plain": [
1526
- " 0%| | 0/1000 [00:00<?, ?it/s]"
1527
- ]
1528
- },
1529
- "metadata": {},
1530
- "output_type": "display_data"
1531
- },
1532
- {
1533
- "name": "stdout",
1534
- "output_type": "stream",
1535
- "text": [
1536
- "sampling 800 images with normalized params = tensor([[0.8000, 0.0377]])\n",
1537
- "nn_model resumed from ./outputs/model_state-N6400\n"
1538
- ]
1539
- },
1540
- {
1541
- "data": {
1542
- "application/vnd.jupyter.widget-view+json": {
1543
- "model_id": "ad63275c7ba7499aa34c7b7a1d600b01",
1544
- "version_major": 2,
1545
- "version_minor": 0
1546
- },
1547
- "text/plain": [
1548
- " 0%| | 0/1000 [00:00<?, ?it/s]"
1549
- ]
1550
- },
1551
- "metadata": {},
1552
- "output_type": "display_data"
1553
- },
1554
- {
1555
- "name": "stdout",
1556
- "output_type": "stream",
1557
- "text": [
1558
- "sampling 800 images with normalized params = tensor([[0.3495, 0.0833]])\n",
1559
- "nn_model resumed from ./outputs/model_state-N6400\n"
1560
- ]
1561
- },
1562
- {
1563
- "data": {
1564
- "application/vnd.jupyter.widget-view+json": {
1565
- "model_id": "b92c97678a144c7db39e7504d9b51fdd",
1566
- "version_major": 2,
1567
- "version_minor": 0
1568
- },
1569
- "text/plain": [
1570
- " 0%| | 0/1000 [00:00<?, ?it/s]"
1571
- ]
1572
- },
1573
- "metadata": {},
1574
- "output_type": "display_data"
1575
- },
1576
- {
1577
- "name": "stdout",
1578
- "output_type": "stream",
1579
- "text": [
1580
- "sampling 800 images with normalized params = tensor([[0.7385, 0.7917]])\n",
1581
- "nn_model resumed from ./outputs/model_state-N6400\n"
1582
- ]
1583
- },
1584
- {
1585
- "data": {
1586
- "application/vnd.jupyter.widget-view+json": {
1587
- "model_id": "ce95c55a33e94da39ad4761af0bca5da",
1588
- "version_major": 2,
1589
- "version_minor": 0
1590
- },
1591
- "text/plain": [
1592
- " 0%| | 0/1000 [00:00<?, ?it/s]"
1593
- ]
1594
- },
1595
- "metadata": {},
1596
- "output_type": "display_data"
1597
- },
1598
- {
1599
- "name": "stdout",
1600
- "output_type": "stream",
1601
- "text": [
1602
- "sampling 800 images with normalized params = tensor([[0.4000, 0.5056]])\n",
1603
- "nn_model resumed from ./outputs/model_state-N6400\n"
1604
- ]
1605
- },
1606
- {
1607
- "data": {
1608
- "application/vnd.jupyter.widget-view+json": {
1609
- "model_id": "c7d30d6d4e9c46d6930da9b6d10e53c1",
1610
- "version_major": 2,
1611
- "version_minor": 0
1612
- },
1613
- "text/plain": [
1614
- " 0%| | 0/1000 [00:00<?, ?it/s]"
1615
- ]
1616
- },
1617
- "metadata": {},
1618
- "output_type": "display_data"
1619
- },
1620
- {
1621
- "name": "stdout",
1622
- "output_type": "stream",
1623
- "text": [
1624
- "Number of parameters for nn_model: 111048705\n",
1625
- "sampling 800 images with normalized params = tensor([[0.2000, 0.5056]])\n",
1626
- "nn_model resumed from ./outputs/model_state-N12800\n"
1627
- ]
1628
- },
1629
- {
1630
- "data": {
1631
- "application/vnd.jupyter.widget-view+json": {
1632
- "model_id": "205b299c574343fca3f0f7bcce9c0fda",
1633
- "version_major": 2,
1634
- "version_minor": 0
1635
- },
1636
- "text/plain": [
1637
- " 0%| | 0/1000 [00:00<?, ?it/s]"
1638
- ]
1639
- },
1640
- "metadata": {},
1641
- "output_type": "display_data"
1642
- },
1643
- {
1644
- "name": "stdout",
1645
- "output_type": "stream",
1646
- "text": [
1647
- "sampling 800 images with normalized params = tensor([[0.8000, 0.0377]])\n",
1648
- "nn_model resumed from ./outputs/model_state-N12800\n"
1649
- ]
1650
- },
1651
- {
1652
- "data": {
1653
- "application/vnd.jupyter.widget-view+json": {
1654
- "model_id": "8e2fea1d73a54d87984bc430584ed7ee",
1655
- "version_major": 2,
1656
- "version_minor": 0
1657
- },
1658
- "text/plain": [
1659
- " 0%| | 0/1000 [00:00<?, ?it/s]"
1660
- ]
1661
- },
1662
- "metadata": {},
1663
- "output_type": "display_data"
1664
- },
1665
- {
1666
- "name": "stdout",
1667
- "output_type": "stream",
1668
- "text": [
1669
- "sampling 800 images with normalized params = tensor([[0.3495, 0.0833]])\n",
1670
- "nn_model resumed from ./outputs/model_state-N12800\n"
1671
- ]
1672
- },
1673
- {
1674
- "data": {
1675
- "application/vnd.jupyter.widget-view+json": {
1676
- "model_id": "8a4762d957a04a81b6283e303a0c1aa3",
1677
- "version_major": 2,
1678
- "version_minor": 0
1679
- },
1680
- "text/plain": [
1681
- " 0%| | 0/1000 [00:00<?, ?it/s]"
1682
- ]
1683
- },
1684
- "metadata": {},
1685
- "output_type": "display_data"
1686
- },
1687
- {
1688
- "name": "stdout",
1689
- "output_type": "stream",
1690
- "text": [
1691
- "sampling 800 images with normalized params = tensor([[0.7385, 0.7917]])\n",
1692
- "nn_model resumed from ./outputs/model_state-N12800\n"
1693
- ]
1694
- },
1695
- {
1696
- "data": {
1697
- "application/vnd.jupyter.widget-view+json": {
1698
- "model_id": "67dabb9f208d4977bdb8180afc4eec20",
1699
- "version_major": 2,
1700
- "version_minor": 0
1701
- },
1702
- "text/plain": [
1703
- " 0%| | 0/1000 [00:00<?, ?it/s]"
1704
- ]
1705
- },
1706
- "metadata": {},
1707
- "output_type": "display_data"
1708
- },
1709
- {
1710
- "name": "stdout",
1711
- "output_type": "stream",
1712
- "text": [
1713
- "sampling 800 images with normalized params = tensor([[0.4000, 0.5056]])\n",
1714
- "nn_model resumed from ./outputs/model_state-N12800\n"
1715
- ]
1716
- },
1717
- {
1718
- "data": {
1719
- "application/vnd.jupyter.widget-view+json": {
1720
- "model_id": "950ba8b752444e1e9a2d817968f5c3a3",
1721
- "version_major": 2,
1722
- "version_minor": 0
1723
- },
1724
- "text/plain": [
1725
- " 0%| | 0/1000 [00:00<?, ?it/s]"
1726
- ]
1727
- },
1728
- "metadata": {},
1729
- "output_type": "display_data"
1730
- },
1731
- {
1732
- "name": "stdout",
1733
- "output_type": "stream",
1734
- "text": [
1735
- "Number of parameters for nn_model: 111048705\n",
1736
- "sampling 800 images with normalized params = tensor([[0.2000, 0.5056]])\n",
1737
- "nn_model resumed from ./outputs/model_state-N25600\n"
1738
- ]
1739
- },
1740
- {
1741
- "data": {
1742
- "application/vnd.jupyter.widget-view+json": {
1743
- "model_id": "fd753c01957447398ec61ff383616825",
1744
- "version_major": 2,
1745
- "version_minor": 0
1746
- },
1747
- "text/plain": [
1748
- " 0%| | 0/1000 [00:00<?, ?it/s]"
1749
- ]
1750
- },
1751
- "metadata": {},
1752
- "output_type": "display_data"
1753
- },
1754
- {
1755
- "name": "stdout",
1756
- "output_type": "stream",
1757
- "text": [
1758
- "sampling 800 images with normalized params = tensor([[0.8000, 0.0377]])\n",
1759
- "nn_model resumed from ./outputs/model_state-N25600\n"
1760
- ]
1761
- },
1762
- {
1763
- "data": {
1764
- "application/vnd.jupyter.widget-view+json": {
1765
- "model_id": "a328cf05873a4370a645270ea917233d",
1766
- "version_major": 2,
1767
- "version_minor": 0
1768
- },
1769
- "text/plain": [
1770
- " 0%| | 0/1000 [00:00<?, ?it/s]"
1771
- ]
1772
- },
1773
- "metadata": {},
1774
- "output_type": "display_data"
1775
- },
1776
- {
1777
- "name": "stdout",
1778
- "output_type": "stream",
1779
- "text": [
1780
- "sampling 800 images with normalized params = tensor([[0.3495, 0.0833]])\n",
1781
- "nn_model resumed from ./outputs/model_state-N25600\n"
1782
- ]
1783
- },
1784
- {
1785
- "data": {
1786
- "application/vnd.jupyter.widget-view+json": {
1787
- "model_id": "3838c0ba114c406caa3e39668feaea38",
1788
- "version_major": 2,
1789
- "version_minor": 0
1790
- },
1791
- "text/plain": [
1792
- " 0%| | 0/1000 [00:00<?, ?it/s]"
1793
- ]
1794
- },
1795
- "metadata": {},
1796
- "output_type": "display_data"
1797
- },
1798
- {
1799
- "name": "stdout",
1800
- "output_type": "stream",
1801
- "text": [
1802
- "sampling 800 images with normalized params = tensor([[0.7385, 0.7917]])\n",
1803
- "nn_model resumed from ./outputs/model_state-N25600\n"
1804
- ]
1805
- },
1806
- {
1807
- "data": {
1808
- "application/vnd.jupyter.widget-view+json": {
1809
- "model_id": "a668e8fd474e4d659a8960e35598fb5f",
1810
- "version_major": 2,
1811
- "version_minor": 0
1812
- },
1813
- "text/plain": [
1814
- " 0%| | 0/1000 [00:00<?, ?it/s]"
1815
- ]
1816
- },
1817
- "metadata": {},
1818
- "output_type": "display_data"
1819
- },
1820
- {
1821
- "name": "stdout",
1822
- "output_type": "stream",
1823
- "text": [
1824
- "sampling 800 images with normalized params = tensor([[0.4000, 0.5056]])\n",
1825
- "nn_model resumed from ./outputs/model_state-N25600\n"
1826
- ]
1827
- },
1828
- {
1829
- "data": {
1830
- "application/vnd.jupyter.widget-view+json": {
1831
- "model_id": "028b3cc2a3214999b6a76700579c4263",
1832
- "version_major": 2,
1833
- "version_minor": 0
1834
- },
1835
- "text/plain": [
1836
- " 0%| | 0/1000 [00:00<?, ?it/s]"
1837
  ]
1838
  },
1839
  "metadata": {},
1840
  "output_type": "display_data"
1841
  }
1842
  ],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1843
  "source": [
1844
  "if __name__ == \"__main__\":\n",
1845
  " # num_image_list = [1600,3200,6400,12800,25600]\n",
1846
  " num_image_list = [200]\n",
1847
  " # num_image_list = [3200,6400,12800,25600]\n",
1848
  " # args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)\n",
1849
- " repeat = 2\n",
1850
  " config = TrainConfig()\n",
1851
  " for i, num_image in enumerate(num_image_list):\n",
1852
  " config.num_image = num_image\n",
@@ -1867,69 +787,38 @@
1867
  "cell_type": "code",
1868
  "execution_count": null,
1869
  "metadata": {},
1870
- "outputs": [
1871
- {
1872
- "name": "stdout",
1873
- "output_type": "stream",
1874
- "text": [
1875
- "total 995M\n",
1876
- "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jul 1 12:31 Tvir4.800000190734863-zeta131.34100341796875-N2000.npy\n",
1877
- "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jul 1 12:12 Tvir5.4770002365112305-zeta200.0-N2000.npy\n",
1878
- "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jul 1 11:54 Tvir4.698999881744385-zeta30.0-N2000.npy\n",
1879
- "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jul 1 11:35 Tvir5.599999904632568-zeta19.03700065612793-N2000.npy\n",
1880
- "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jul 1 11:17 Tvir4.400000095367432-zeta131.34100341796875-N2000.npy\n",
1881
- "-rw-r--r-- 1 bxia34 pace-jw254 848M Jul 1 10:59 model_state.pth\n",
1882
- "drwxr-xr-x 12 bxia34 pace-jw254 4.0K Jul 1 10:50 \u001b[0m\u001b[01;34mlogs\u001b[0m/\n",
1883
- "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jul 1 10:37 Tvir4.800000190734863-zeta131.34100341796875-N32000.npy\n",
1884
- "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jul 1 04:47 Tvir5.4770002365112305-zeta200.0-N32000.npy\n",
1885
- "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jul 1 04:29 Tvir4.698999881744385-zeta30.0-N32000.npy\n",
1886
- "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jul 1 04:11 Tvir5.599999904632568-zeta19.03700065612793-N32000.npy\n",
1887
- "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jul 1 03:53 Tvir4.400000095367432-zeta131.34100341796875-N32000.npy\n",
1888
- "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jul 1 00:23 Tvir4.800000190734863-zeta131.34100341796875-N20000.npy\n",
1889
- "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jul 1 00:05 Tvir5.4770002365112305-zeta200.0-N20000.npy\n",
1890
- "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jun 30 23:47 Tvir4.698999881744385-zeta30.0-N20000.npy\n",
1891
- "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jun 30 23:29 Tvir5.599999904632568-zeta19.03700065612793-N20000.npy\n",
1892
- "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jun 30 23:11 Tvir4.400000095367432-zeta131.34100341796875-N20000.npy\n",
1893
- "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jun 30 20:08 Tvir4.800000190734863-zeta131.34100341796875-N15000.npy\n",
1894
- "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jun 30 19:50 Tvir5.4770002365112305-zeta200.0-N15000.npy\n",
1895
- "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jun 30 19:32 Tvir4.698999881744385-zeta30.0-N15000.npy\n",
1896
- "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jun 30 19:14 Tvir5.599999904632568-zeta19.03700065612793-N15000.npy\n",
1897
- "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jun 30 18:57 Tvir4.400000095367432-zeta131.34100341796875-N15000.npy\n",
1898
- "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jun 29 12:41 Tvir4.800000190734863-zeta131.34100341796875-N7000.npy\n",
1899
- "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jun 29 12:23 Tvir5.4770002365112305-zeta200.0-N7000.npy\n",
1900
- "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jun 29 12:06 Tvir4.698999881744385-zeta30.0-N7000.npy\n",
1901
- "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jun 29 11:48 Tvir5.599999904632568-zeta19.03700065612793-N7000.npy\n",
1902
- "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jun 29 11:30 Tvir4.400000095367432-zeta131.34100341796875-N7000.npy\n",
1903
- "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jun 29 04:56 Tvir4.800000190734863-zeta131.34100341796875-N25600.npy\n",
1904
- "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jun 29 04:38 Tvir5.4770002365112305-zeta200.0-N25600.npy\n",
1905
- "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jun 29 04:21 Tvir4.698999881744385-zeta30.0-N25600.npy\n",
1906
- "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jun 29 04:03 Tvir5.599999904632568-zeta19.03700065612793-N25600.npy\n",
1907
- "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jun 29 03:45 Tvir4.400000095367432-zeta131.34100341796875-N25600.npy\n",
1908
- "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jun 29 00:35 Tvir4.800000190734863-zeta131.34100341796875-N3000.npy\n",
1909
- "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jun 29 00:17 Tvir5.4770002365112305-zeta200.0-N3000.npy\n",
1910
- "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jun 28 23:59 Tvir4.698999881744385-zeta30.0-N3000.npy\n",
1911
- "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jun 28 23:42 Tvir5.599999904632568-zeta19.03700065612793-N3000.npy\n",
1912
- "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jun 28 23:20 Tvir4.400000095367432-zeta131.34100341796875-N3000.npy\n",
1913
- "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jun 28 21:06 Tvir4.800000190734863-zeta131.34100341796875-N10000.npy\n",
1914
- "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jun 28 20:49 Tvir5.4770002365112305-zeta200.0-N10000.npy\n",
1915
- "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jun 28 20:31 Tvir4.698999881744385-zeta30.0-N10000.npy\n",
1916
- "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jun 28 20:13 Tvir5.599999904632568-zeta19.03700065612793-N10000.npy\n",
1917
- "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jun 28 19:56 Tvir4.400000095367432-zeta131.34100341796875-N10000.npy\n",
1918
- "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jun 28 18:30 Tvir4.800000190734863-zeta131.34100341796875-N1000.npy\n",
1919
- "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jun 28 18:13 Tvir5.4770002365112305-zeta200.0-N1000.npy\n",
1920
- "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jun 28 17:55 Tvir4.698999881744385-zeta30.0-N1000.npy\n",
1921
- "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jun 28 17:37 Tvir5.599999904632568-zeta19.03700065612793-N1000.npy\n",
1922
- "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jun 28 17:20 Tvir4.400000095367432-zeta131.34100341796875-N1000.npy\n",
1923
- "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jun 28 14:03 Tvir4.400000095367432-zeta131.34100341796875-N5000.npy\n",
1924
- "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jun 10 18:58 Tvir4.800000190734863-zeta131.34100341796875-N5000.npy\n",
1925
- "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jun 10 18:40 Tvir5.4770002365112305-zeta200.0-N5000.npy\n",
1926
- "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jun 10 18:22 Tvir4.698999881744385-zeta30.0-N5000.npy\n",
1927
- "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jun 10 18:05 Tvir5.599999904632568-zeta19.03700065612793-N5000.npy\n"
1928
- ]
1929
- }
1930
- ],
1931
  "source": [
1932
- "ls -lth outputs"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1933
  ]
1934
  },
1935
  {
@@ -1986,7 +875,91 @@
1986
  "execution_count": null,
1987
  "metadata": {},
1988
  "outputs": [],
1989
- "source": []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1990
  }
1991
  ],
1992
  "metadata": {
 
25
  "- 我統一了ddpm21cm這個module,能統一實現訓練和生成樣本,但目前有個bug, sample時總是會cuda out of memory,然而單獨resume model並sample就不會。\n",
26
  "- 解決了,問題出在我忘了寫with torch.no_grad():\n",
27
  "- 接下來就是生成800個lightcones,與此同時研究如何計算global signal以及power spectrum\n",
28
+ "- 儅訓練圖片的數量達到5000時,生成的圖片與檢測數據的相似程度很高\n",
29
+ "- it takes 62 mins to generated 8 images with shape of (64,64,64), which is even slower than simulation, which takes ~5 mins for each image. Besides, the batch_size during training and num of images to be generated are limited to be 2 and 8, respectively.\n",
30
+ "- the slowerness can be solved by using multi-GPUs, and the limited-num-of-images can be solved by multi-accuracy, multi-GPUs.\n",
31
+ "- In addtion, the performance of DDPM can looks better compared to computation-intensive simulations. "
32
  ]
33
  },
34
  {
 
74
  "cell_type": "code",
75
  "execution_count": 2,
76
  "metadata": {},
77
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  "source": [
79
+ "# notebook_login()"
80
  ]
81
  },
82
  {
 
319
  "execution_count": 6,
320
  "metadata": {},
321
  "outputs": [],
322
+ "source": [
323
+ "# import os\n",
324
+ "# print(os.cpu_count())\n",
325
+ "# print(len(os.sched_getaffinity(0)))\n",
326
+ "# import torch\n",
327
+ "# data = torch.randn((64,64))\n",
328
+ "# print(data.dtype)"
329
+ ]
330
+ },
331
+ {
332
+ "cell_type": "code",
333
+ "execution_count": 7,
334
+ "metadata": {},
335
+ "outputs": [],
336
  "source": [
337
  "# @dataclass\n",
338
  "class DDPM21CM:\n",
 
387
  " dataset = Dataset4h5(self.config.dataset_name, num_image=self.config.num_image, HII_DIM=self.config.HII_DIM, num_redshift=self.config.num_redshift, drop_prob=self.config.drop_prob, dim=self.config.dim, ranges_dict=self.ranges_dict)\n",
388
  " # self.shape_loaded = dataset.images.shape\n",
389
  " # print(\"shape_loaded =\", self.shape_loaded)\n",
390
+ " self.dataloader = DataLoader(dataset, batch_size=self.config.batch_size, shuffle=True, num_workers=len(os.sched_getaffinity(0)), pin_memory=True)\n",
391
  " # del dataset\n",
392
  " # self.accelerate(self.config)\n",
393
  " del dataset\n",
 
559
  "cell_type": "code",
560
  "execution_count": 8,
561
  "metadata": {},
 
 
 
 
 
 
 
 
 
562
  "outputs": [
563
  {
564
  "name": "stdout",
565
  "output_type": "stream",
566
  "text": [
567
  "Number of parameters for nn_model: 306285057\n",
568
+ "----------------- num_image = 20 -----------------\n",
569
+ "run_name = 0706-1527\n",
570
  "Launching training on one GPU.\n",
571
  "dataset content: <KeysViewHDF5 ['brightness_temp', 'density', 'kwargs', 'params', 'redshifts_distances', 'seeds', 'xH_box']>\n",
572
  "51200 images can be loaded\n",
573
  "field.shape = (64, 64, 514)\n",
574
  "params keys = [b'ION_Tvir_MIN', b'HII_EFF_FACTOR']\n",
575
+ "loading 20 images randomly\n",
576
+ "images loaded: (20, 1, 64, 64, 64)\n"
 
577
  ]
578
  },
579
  {
 
587
  "name": "stdout",
588
  "output_type": "stream",
589
  "text": [
590
+ "params loaded: (20, 2)\n",
591
+ "images rescaled to [-1.0, 0.9884977340698242]\n",
592
+ "params rescaled to [0.029776105270727538, 0.9947531424980958]\n"
593
  ]
594
  },
595
  {
596
  "data": {
597
  "application/vnd.jupyter.widget-view+json": {
598
+ "model_id": "c285cd667f5e47789fb7dc9483a8963d",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
599
  "version_major": 2,
600
  "version_minor": 0
601
  },
602
  "text/plain": [
603
+ " 0%| | 0/10 [00:00<?, ?it/s]"
604
  ]
605
  },
606
  "metadata": {},
 
609
  {
610
  "data": {
611
  "application/vnd.jupyter.widget-view+json": {
612
+ "model_id": "b25fd0ace52848c1a1f96ee0b8aa92e0",
613
  "version_major": 2,
614
  "version_minor": 0
615
  },
616
  "text/plain": [
617
+ " 0%| | 0/10 [00:00<?, ?it/s]"
618
  ]
619
  },
620
  "metadata": {},
 
623
  {
624
  "data": {
625
  "application/vnd.jupyter.widget-view+json": {
626
+ "model_id": "8677a0b1b1f94a7db2f634d4d006a479",
627
  "version_major": 2,
628
  "version_minor": 0
629
  },
630
  "text/plain": [
631
+ " 0%| | 0/10 [00:00<?, ?it/s]"
632
  ]
633
  },
634
  "metadata": {},
 
637
  {
638
  "data": {
639
  "application/vnd.jupyter.widget-view+json": {
640
+ "model_id": "0caaf98111d844e7b158eaf65b71623e",
641
  "version_major": 2,
642
  "version_minor": 0
643
  },
644
  "text/plain": [
645
+ " 0%| | 0/10 [00:00<?, ?it/s]"
646
  ]
647
  },
648
  "metadata": {},
 
651
  {
652
  "data": {
653
  "application/vnd.jupyter.widget-view+json": {
654
+ "model_id": "90616fd5aafd4414876ee54d7017752e",
655
  "version_major": 2,
656
  "version_minor": 0
657
  },
658
  "text/plain": [
659
+ " 0%| | 0/10 [00:00<?, ?it/s]"
660
  ]
661
  },
662
  "metadata": {},
 
665
  {
666
  "data": {
667
  "application/vnd.jupyter.widget-view+json": {
668
+ "model_id": "9ff3680616ff4dc695599b9b091478a2",
669
  "version_major": 2,
670
  "version_minor": 0
671
  },
672
  "text/plain": [
673
+ " 0%| | 0/10 [00:00<?, ?it/s]"
674
  ]
675
  },
676
  "metadata": {},
 
679
  {
680
  "data": {
681
  "application/vnd.jupyter.widget-view+json": {
682
+ "model_id": "5c6146ebc5dc4194977014f93cb49cff",
683
  "version_major": 2,
684
  "version_minor": 0
685
  },
686
  "text/plain": [
687
+ " 0%| | 0/10 [00:00<?, ?it/s]"
688
  ]
689
  },
690
  "metadata": {},
 
693
  {
694
  "data": {
695
  "application/vnd.jupyter.widget-view+json": {
696
+ "model_id": "1aca831fe268486f83d7bb52eb129239",
697
  "version_major": 2,
698
  "version_minor": 0
699
  },
700
  "text/plain": [
701
+ " 0%| | 0/10 [00:00<?, ?it/s]"
702
  ]
703
  },
704
  "metadata": {},
 
707
  {
708
  "data": {
709
  "application/vnd.jupyter.widget-view+json": {
710
+ "model_id": "34bc05d0bd124e278b4e55abf95cfd8b",
711
  "version_major": 2,
712
  "version_minor": 0
713
  },
714
  "text/plain": [
715
+ " 0%| | 0/10 [00:00<?, ?it/s]"
716
  ]
717
  },
718
  "metadata": {},
 
721
  {
722
  "data": {
723
  "application/vnd.jupyter.widget-view+json": {
724
+ "model_id": "ce0ebad0a27643fd995fade810769c06",
725
  "version_major": 2,
726
  "version_minor": 0
727
  },
728
  "text/plain": [
729
+ " 0%| | 0/10 [00:00<?, ?it/s]"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
730
  ]
731
  },
732
  "metadata": {},
733
  "output_type": "display_data"
734
  }
735
  ],
736
+ "source": [
737
+ "num_image_list = [20]#[200]#[1600,3200,6400,12800,25600]\n",
738
+ "if __name__ == \"__main__\":\n",
739
+ " # args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)\n",
740
+ " config = TrainConfig()\n",
741
+ " for i, num_image in enumerate(num_image_list):\n",
742
+ " config.num_image = num_image\n",
743
+ " ddpm21cm = DDPM21CM(config)\n",
744
+ " print(f\" num_image = {ddpm21cm.config.num_image} \".center(50, '-'))\n",
745
+ " print(f\"run_name = {ddpm21cm.config.run_name}\")\n",
746
+ " notebook_launcher(ddpm21cm.train, num_processes=1)"
747
+ ]
748
+ },
749
+ {
750
+ "cell_type": "code",
751
+ "execution_count": null,
752
+ "metadata": {},
753
+ "outputs": [],
754
+ "source": [
755
+ "# ll -lth outputs"
756
+ ]
757
+ },
758
+ {
759
+ "cell_type": "code",
760
+ "execution_count": null,
761
+ "metadata": {},
762
+ "outputs": [],
763
  "source": [
764
  "if __name__ == \"__main__\":\n",
765
  " # num_image_list = [1600,3200,6400,12800,25600]\n",
766
  " num_image_list = [200]\n",
767
  " # num_image_list = [3200,6400,12800,25600]\n",
768
  " # args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)\n",
769
+ " repeat = 6\n",
770
  " config = TrainConfig()\n",
771
  " for i, num_image in enumerate(num_image_list):\n",
772
  " config.num_image = num_image\n",
 
787
  "cell_type": "code",
788
  "execution_count": null,
789
  "metadata": {},
790
+ "outputs": [],
791
+ "source": [
792
+ "ls -lth outputs | head"
793
+ ]
794
+ },
795
+ {
796
+ "cell_type": "code",
797
+ "execution_count": null,
798
+ "metadata": {},
799
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
800
  "source": [
801
+ "def plot_grid(samples, c=None, row=2, col=3):\n",
802
+ " print(\"samples.shape =\", samples.shape)\n",
803
+ " for j in range(samples.shape[2]):\n",
804
+ " plt.figure(figsize = (9,6), dpi=400)\n",
805
+ " for i in range(len(samples)):\n",
806
+ " plt.subplot(row,col,i+1)\n",
807
+ " plt.imshow(samples[i,0,:,:,j], cmap='gray')#, vmin=-1, vmax=1)\n",
808
+ " plt.xticks([])\n",
809
+ " plt.yticks([])\n",
810
+ " # plt.suptitle(f\"ION_Tvir_MIN = {c[0][0]}, HII_EFF_FACTOR = {c[0][1]}\")\n",
811
+ " # plt.show()\n",
812
+ " # plt.suptitle('simulations')\n",
813
+ " plt.tight_layout()\n",
814
+ " plt.subplots_adjust(wspace=0, hspace=0)\n",
815
+ " plt.savefig(f\"test3D-{j:02d}.png\")\n",
816
+ " plt.close()\n",
817
+ " # plt.show()\n",
818
+ " \n",
819
+ "data = np.load(\"outputs/Tvir4.400000095367432-zeta131.34100341796875-N200.npy\")\n",
820
+ "# print(data.shape)\n",
821
+ "plot_grid(data)"
822
  ]
823
  },
824
  {
 
875
  "execution_count": null,
876
  "metadata": {},
877
  "outputs": [],
878
+ "source": [
879
+ "import torch\n",
880
+ "import torch.nn as nn\n",
881
+ "import time\n",
882
+ "\n",
883
+ "class MyModel(nn.Module):\n",
884
+ " def __init__(self):\n",
885
+ " super().__init__()\n",
886
+ " self.fc = nn.Linear(100,50)\n",
887
+ "\n",
888
+ " def forward(self, x):\n",
889
+ " return self.fc(x)\n",
890
+ "\n",
891
+ "model = MyModel()\n",
892
+ "\n",
893
+ "device_count = torch.cuda.device_count()\n",
894
+ "print(\"device_count =\", device_count)\n",
895
+ "\n",
896
+ "if device_count > 1:\n",
897
+ " print(f\"using {device_count} GPUs!\")\n",
898
+ " model = nn.DataParallel(model)\n",
899
+ "\n",
900
+ "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
901
+ "model.to(device)\n",
902
+ "\n",
903
+ "start_time = time.time()\n",
904
+ "for i in range(10):\n",
905
+ " myinput = torch.randn(10,10,32000,100).to(device)\n",
906
+ " output = model(myinput)\n",
907
+ " print(output.shape)\n",
908
+ "# plt.imshow(myinput.cpu()[0])\n",
909
+ "# plt.show()\n",
910
+ "# plt.imshow(output.detach().cpu().numpy()[0])\n",
911
+ "# plt.show()"
912
+ ]
913
+ },
914
+ {
915
+ "cell_type": "code",
916
+ "execution_count": null,
917
+ "metadata": {},
918
+ "outputs": [],
919
+ "source": [
920
+ "# import torch.distributed as dist\n",
921
+ "# dist.init_process_group(backend='nccl')"
922
+ ]
923
+ },
924
+ {
925
+ "cell_type": "code",
926
+ "execution_count": null,
927
+ "metadata": {},
928
+ "outputs": [],
929
+ "source": [
930
+ "import numpy as np\n",
931
+ "import torch\n",
932
+ "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
933
+ "\n",
934
+ "data = torch.randn((64,64,64))\n",
935
+ "\n",
936
+ "num_elements = data.numpy().size\n",
937
+ "element_size = data.numpy().itemsize\n",
938
+ "\n",
939
+ "print(data.dtype)\n",
940
+ "print(num_elements, element_size)\n",
941
+ "print(f\"total size = {num_elements*element_size/1024/1024} MB\")\n",
942
+ "\n",
943
+ "print(\"---\"*30)\n",
944
+ "data = data.to(torch.float64)\n",
945
+ "\n",
946
+ "num_elements = data.numpy().size\n",
947
+ "element_size = data.numpy().itemsize\n",
948
+ "\n",
949
+ "print(data.dtype)\n",
950
+ "print(num_elements, element_size)\n",
951
+ "print(f\"total size = {num_elements*element_size/1024/1024} MB\")\n",
952
+ "\n",
953
+ "print(\"---\"*30)\n",
954
+ "data = data.to(torch.float16)\n",
955
+ "\n",
956
+ "num_elements = data.numpy().size\n",
957
+ "element_size = data.numpy().itemsize\n",
958
+ "\n",
959
+ "print(data.dtype)\n",
960
+ "print(num_elements, element_size)\n",
961
+ "print(f\"total size = {num_elements*element_size/1024/1024} MB\")"
962
+ ]
963
  }
964
  ],
965
  "metadata": {