Xsmos commited on
Commit
7d468d2
·
verified ·
1 Parent(s): ed88e88

0702-1312

Browse files
Files changed (2) hide show
  1. diffusion.ipynb +144 -266
  2. quantify_results.ipynb +0 -0
diffusion.ipynb CHANGED
@@ -281,7 +281,7 @@
281
  " lrate = 1e-4\n",
282
  " lr_warmup_steps = 0#5#00\n",
283
  " output_dir = \"./outputs/\"\n",
284
- " save_name = os.path.join(output_dir, 'model_state.pth')\n",
285
  " # save_freq = 1 #10 # the period of saving model\n",
286
  " # cond = True # if training using the conditional information\n",
287
  " # lr_decay = False #True# if using the learning rate decay\n",
@@ -460,8 +460,8 @@
460
  " 'unet_state_dict': self.nn_model.state_dict(),\n",
461
  " 'ema_unet_state_dict': self.ema_model.state_dict(),\n",
462
  " }\n",
463
- " torch.save(model_state, self.config.save_name)\n",
464
- " print('saved model at ' + self.config.save_name)\n",
465
  " # print('saved model at ' + config.save_dir + f\"model_epoch_{ep}_test_{config.run_name}.pth\")\n",
466
  "\n",
467
  " # def rescale(self, value, type='params', to_ranges=[0,1]):\n",
@@ -537,7 +537,7 @@
537
  {
538
  "data": {
539
  "application/vnd.jupyter.widget-view+json": {
540
- "model_id": "b23ad327ed7d48b0856c0eb9a66be943",
541
  "version_major": 2,
542
  "version_minor": 0
543
  },
@@ -563,17 +563,15 @@
563
  "output_type": "stream",
564
  "text": [
565
  "-------------------- round 0 ---------------------\n",
566
- "resumed nn_model from ./outputs/model_state.pth\n",
567
  "Number of parameters for nn_model: 111048705\n",
568
- "resumed ema_model from ./outputs/model_state.pth\n",
569
- "run_name = 0701-1047\n",
570
  "Launching training on one GPU.\n",
571
  "dataset content: <KeysViewHDF5 ['brightness_temp', 'density', 'kwargs', 'params', 'redshifts_distances', 'seeds', 'xH_box']>\n",
572
  "51200 images can be loaded\n",
573
  "field.shape = (64, 64, 514)\n",
574
  "params keys = [b'ION_Tvir_MIN', b'HII_EFF_FACTOR']\n",
575
- "loading 2000 images randomly\n",
576
- "images loaded: (2000, 1, 64, 64)\n"
577
  ]
578
  },
579
  {
@@ -587,20 +585,20 @@
587
  "name": "stdout",
588
  "output_type": "stream",
589
  "text": [
590
- "params loaded: (2000, 2)\n",
591
  "images rescaled to [-1.0, 1.1335339546203613]\n",
592
- "params rescaled to [0.00010391510804133771, 0.9998883049763877]\n"
593
  ]
594
  },
595
  {
596
  "data": {
597
  "application/vnd.jupyter.widget-view+json": {
598
- "model_id": "6b39caf3d0c74a389e15275ad0422549",
599
  "version_major": 2,
600
  "version_minor": 0
601
  },
602
  "text/plain": [
603
- " 0%| | 0/40 [00:00<?, ?it/s]"
604
  ]
605
  },
606
  "metadata": {},
@@ -609,12 +607,12 @@
609
  {
610
  "data": {
611
  "application/vnd.jupyter.widget-view+json": {
612
- "model_id": "08c1cd82b84d46988cad13d1679f15dc",
613
  "version_major": 2,
614
  "version_minor": 0
615
  },
616
  "text/plain": [
617
- " 0%| | 0/40 [00:00<?, ?it/s]"
618
  ]
619
  },
620
  "metadata": {},
@@ -623,12 +621,12 @@
623
  {
624
  "data": {
625
  "application/vnd.jupyter.widget-view+json": {
626
- "model_id": "1dd978d29d0c41efa15af7d8b85c2554",
627
  "version_major": 2,
628
  "version_minor": 0
629
  },
630
  "text/plain": [
631
- " 0%| | 0/40 [00:00<?, ?it/s]"
632
  ]
633
  },
634
  "metadata": {},
@@ -637,12 +635,12 @@
637
  {
638
  "data": {
639
  "application/vnd.jupyter.widget-view+json": {
640
- "model_id": "a0c894fe81a54e2899a4efdb6cb838ab",
641
  "version_major": 2,
642
  "version_minor": 0
643
  },
644
  "text/plain": [
645
- " 0%| | 0/40 [00:00<?, ?it/s]"
646
  ]
647
  },
648
  "metadata": {},
@@ -651,12 +649,12 @@
651
  {
652
  "data": {
653
  "application/vnd.jupyter.widget-view+json": {
654
- "model_id": "cf595c2ffd3a4a74b0a22a546848f961",
655
  "version_major": 2,
656
  "version_minor": 0
657
  },
658
  "text/plain": [
659
- " 0%| | 0/40 [00:00<?, ?it/s]"
660
  ]
661
  },
662
  "metadata": {},
@@ -665,12 +663,12 @@
665
  {
666
  "data": {
667
  "application/vnd.jupyter.widget-view+json": {
668
- "model_id": "010dc738cc4040d9b5589ef1621b2db8",
669
  "version_major": 2,
670
  "version_minor": 0
671
  },
672
  "text/plain": [
673
- " 0%| | 0/40 [00:00<?, ?it/s]"
674
  ]
675
  },
676
  "metadata": {},
@@ -679,12 +677,12 @@
679
  {
680
  "data": {
681
  "application/vnd.jupyter.widget-view+json": {
682
- "model_id": "a880d0d06737431788b6e4dc937bf388",
683
  "version_major": 2,
684
  "version_minor": 0
685
  },
686
  "text/plain": [
687
- " 0%| | 0/40 [00:00<?, ?it/s]"
688
  ]
689
  },
690
  "metadata": {},
@@ -693,12 +691,12 @@
693
  {
694
  "data": {
695
  "application/vnd.jupyter.widget-view+json": {
696
- "model_id": "d65997de48844e35b530ce02adbe2e5b",
697
  "version_major": 2,
698
  "version_minor": 0
699
  },
700
  "text/plain": [
701
- " 0%| | 0/40 [00:00<?, ?it/s]"
702
  ]
703
  },
704
  "metadata": {},
@@ -707,12 +705,12 @@
707
  {
708
  "data": {
709
  "application/vnd.jupyter.widget-view+json": {
710
- "model_id": "f8b7d43ee1234eb8b9e238a52b6fe9ad",
711
  "version_major": 2,
712
  "version_minor": 0
713
  },
714
  "text/plain": [
715
- " 0%| | 0/40 [00:00<?, ?it/s]"
716
  ]
717
  },
718
  "metadata": {},
@@ -721,12 +719,12 @@
721
  {
722
  "data": {
723
  "application/vnd.jupyter.widget-view+json": {
724
- "model_id": "35ca65ab005e4c3db3163abe39e9fcf5",
725
  "version_major": 2,
726
  "version_minor": 0
727
  },
728
  "text/plain": [
729
- " 0%| | 0/40 [00:00<?, ?it/s]"
730
  ]
731
  },
732
  "metadata": {},
@@ -735,12 +733,12 @@
735
  {
736
  "data": {
737
  "application/vnd.jupyter.widget-view+json": {
738
- "model_id": "b27fac35150745f098f228025b95fc10",
739
  "version_major": 2,
740
  "version_minor": 0
741
  },
742
  "text/plain": [
743
- " 0%| | 0/40 [00:00<?, ?it/s]"
744
  ]
745
  },
746
  "metadata": {},
@@ -749,12 +747,12 @@
749
  {
750
  "data": {
751
  "application/vnd.jupyter.widget-view+json": {
752
- "model_id": "4f470fe28a074f0f9d7ee3c6163fcecf",
753
  "version_major": 2,
754
  "version_minor": 0
755
  },
756
  "text/plain": [
757
- " 0%| | 0/40 [00:00<?, ?it/s]"
758
  ]
759
  },
760
  "metadata": {},
@@ -763,12 +761,12 @@
763
  {
764
  "data": {
765
  "application/vnd.jupyter.widget-view+json": {
766
- "model_id": "ba8a1edd0738475183dbe5ec874d4eb7",
767
  "version_major": 2,
768
  "version_minor": 0
769
  },
770
  "text/plain": [
771
- " 0%| | 0/40 [00:00<?, ?it/s]"
772
  ]
773
  },
774
  "metadata": {},
@@ -777,12 +775,12 @@
777
  {
778
  "data": {
779
  "application/vnd.jupyter.widget-view+json": {
780
- "model_id": "a912fc9cfe86402cb4a51617d15815a6",
781
  "version_major": 2,
782
  "version_minor": 0
783
  },
784
  "text/plain": [
785
- " 0%| | 0/40 [00:00<?, ?it/s]"
786
  ]
787
  },
788
  "metadata": {},
@@ -791,12 +789,12 @@
791
  {
792
  "data": {
793
  "application/vnd.jupyter.widget-view+json": {
794
- "model_id": "151ce7bd0ab44372852a6ce202ace98c",
795
  "version_major": 2,
796
  "version_minor": 0
797
  },
798
  "text/plain": [
799
- " 0%| | 0/40 [00:00<?, ?it/s]"
800
  ]
801
  },
802
  "metadata": {},
@@ -805,12 +803,12 @@
805
  {
806
  "data": {
807
  "application/vnd.jupyter.widget-view+json": {
808
- "model_id": "f9748b330d324009ab5b7b8404d6849d",
809
  "version_major": 2,
810
  "version_minor": 0
811
  },
812
  "text/plain": [
813
- " 0%| | 0/40 [00:00<?, ?it/s]"
814
  ]
815
  },
816
  "metadata": {},
@@ -819,12 +817,12 @@
819
  {
820
  "data": {
821
  "application/vnd.jupyter.widget-view+json": {
822
- "model_id": "525e0eabf5354eafa6a85cf09e8c1d30",
823
  "version_major": 2,
824
  "version_minor": 0
825
  },
826
  "text/plain": [
827
- " 0%| | 0/40 [00:00<?, ?it/s]"
828
  ]
829
  },
830
  "metadata": {},
@@ -833,12 +831,12 @@
833
  {
834
  "data": {
835
  "application/vnd.jupyter.widget-view+json": {
836
- "model_id": "23d44749ed294e8fa5611ba5653af5c3",
837
  "version_major": 2,
838
  "version_minor": 0
839
  },
840
  "text/plain": [
841
- " 0%| | 0/40 [00:00<?, ?it/s]"
842
  ]
843
  },
844
  "metadata": {},
@@ -847,12 +845,12 @@
847
  {
848
  "data": {
849
  "application/vnd.jupyter.widget-view+json": {
850
- "model_id": "782ecf23b66b4a248d1ab3e27166617b",
851
  "version_major": 2,
852
  "version_minor": 0
853
  },
854
  "text/plain": [
855
- " 0%| | 0/40 [00:00<?, ?it/s]"
856
  ]
857
  },
858
  "metadata": {},
@@ -861,12 +859,12 @@
861
  {
862
  "data": {
863
  "application/vnd.jupyter.widget-view+json": {
864
- "model_id": "a03f4cf684fb4d67879e2e9bfd4b1cd8",
865
  "version_major": 2,
866
  "version_minor": 0
867
  },
868
  "text/plain": [
869
- " 0%| | 0/40 [00:00<?, ?it/s]"
870
  ]
871
  },
872
  "metadata": {},
@@ -875,12 +873,12 @@
875
  {
876
  "data": {
877
  "application/vnd.jupyter.widget-view+json": {
878
- "model_id": "02b85ce288ce4ab0a83ea0da7ccc21a6",
879
  "version_major": 2,
880
  "version_minor": 0
881
  },
882
  "text/plain": [
883
- " 0%| | 0/40 [00:00<?, ?it/s]"
884
  ]
885
  },
886
  "metadata": {},
@@ -889,12 +887,12 @@
889
  {
890
  "data": {
891
  "application/vnd.jupyter.widget-view+json": {
892
- "model_id": "1f603325944a4674bbbb846ada2a2984",
893
  "version_major": 2,
894
  "version_minor": 0
895
  },
896
  "text/plain": [
897
- " 0%| | 0/40 [00:00<?, ?it/s]"
898
  ]
899
  },
900
  "metadata": {},
@@ -903,12 +901,12 @@
903
  {
904
  "data": {
905
  "application/vnd.jupyter.widget-view+json": {
906
- "model_id": "8f0dc0491a3749c9b8284cab27686961",
907
  "version_major": 2,
908
  "version_minor": 0
909
  },
910
  "text/plain": [
911
- " 0%| | 0/40 [00:00<?, ?it/s]"
912
  ]
913
  },
914
  "metadata": {},
@@ -917,12 +915,12 @@
917
  {
918
  "data": {
919
  "application/vnd.jupyter.widget-view+json": {
920
- "model_id": "f545ac9f48ef4434b4c6537f1c77f3ab",
921
  "version_major": 2,
922
  "version_minor": 0
923
  },
924
  "text/plain": [
925
- " 0%| | 0/40 [00:00<?, ?it/s]"
926
  ]
927
  },
928
  "metadata": {},
@@ -931,12 +929,12 @@
931
  {
932
  "data": {
933
  "application/vnd.jupyter.widget-view+json": {
934
- "model_id": "ce605dce528c4ca69b8dea4858216c92",
935
  "version_major": 2,
936
  "version_minor": 0
937
  },
938
  "text/plain": [
939
- " 0%| | 0/40 [00:00<?, ?it/s]"
940
  ]
941
  },
942
  "metadata": {},
@@ -945,12 +943,12 @@
945
  {
946
  "data": {
947
  "application/vnd.jupyter.widget-view+json": {
948
- "model_id": "d907f2fefb974be38a050677785bd1a7",
949
  "version_major": 2,
950
  "version_minor": 0
951
  },
952
  "text/plain": [
953
- " 0%| | 0/40 [00:00<?, ?it/s]"
954
  ]
955
  },
956
  "metadata": {},
@@ -959,12 +957,12 @@
959
  {
960
  "data": {
961
  "application/vnd.jupyter.widget-view+json": {
962
- "model_id": "f041d4784e5740c881929426c477408a",
963
  "version_major": 2,
964
  "version_minor": 0
965
  },
966
  "text/plain": [
967
- " 0%| | 0/40 [00:00<?, ?it/s]"
968
  ]
969
  },
970
  "metadata": {},
@@ -973,12 +971,12 @@
973
  {
974
  "data": {
975
  "application/vnd.jupyter.widget-view+json": {
976
- "model_id": "0f4130e3aa104e578812e22b98778193",
977
  "version_major": 2,
978
  "version_minor": 0
979
  },
980
  "text/plain": [
981
- " 0%| | 0/40 [00:00<?, ?it/s]"
982
  ]
983
  },
984
  "metadata": {},
@@ -987,12 +985,12 @@
987
  {
988
  "data": {
989
  "application/vnd.jupyter.widget-view+json": {
990
- "model_id": "0bc61fe4d64b47abbe97bb358cce85eb",
991
  "version_major": 2,
992
  "version_minor": 0
993
  },
994
  "text/plain": [
995
- " 0%| | 0/40 [00:00<?, ?it/s]"
996
  ]
997
  },
998
  "metadata": {},
@@ -1001,12 +999,12 @@
1001
  {
1002
  "data": {
1003
  "application/vnd.jupyter.widget-view+json": {
1004
- "model_id": "1ae54adaa4434bfa93a748e52d259139",
1005
  "version_major": 2,
1006
  "version_minor": 0
1007
  },
1008
  "text/plain": [
1009
- " 0%| | 0/40 [00:00<?, ?it/s]"
1010
  ]
1011
  },
1012
  "metadata": {},
@@ -1015,12 +1013,12 @@
1015
  {
1016
  "data": {
1017
  "application/vnd.jupyter.widget-view+json": {
1018
- "model_id": "a032ece6fc2740cd8506608fbef80325",
1019
  "version_major": 2,
1020
  "version_minor": 0
1021
  },
1022
  "text/plain": [
1023
- " 0%| | 0/40 [00:00<?, ?it/s]"
1024
  ]
1025
  },
1026
  "metadata": {},
@@ -1029,12 +1027,12 @@
1029
  {
1030
  "data": {
1031
  "application/vnd.jupyter.widget-view+json": {
1032
- "model_id": "cc2a8c07ec6047f7b99466c1e5e4d310",
1033
  "version_major": 2,
1034
  "version_minor": 0
1035
  },
1036
  "text/plain": [
1037
- " 0%| | 0/40 [00:00<?, ?it/s]"
1038
  ]
1039
  },
1040
  "metadata": {},
@@ -1043,12 +1041,12 @@
1043
  {
1044
  "data": {
1045
  "application/vnd.jupyter.widget-view+json": {
1046
- "model_id": "5cf08a55fefd4859b2cf35cf859c4869",
1047
  "version_major": 2,
1048
  "version_minor": 0
1049
  },
1050
  "text/plain": [
1051
- " 0%| | 0/40 [00:00<?, ?it/s]"
1052
  ]
1053
  },
1054
  "metadata": {},
@@ -1057,12 +1055,12 @@
1057
  {
1058
  "data": {
1059
  "application/vnd.jupyter.widget-view+json": {
1060
- "model_id": "f336833fd44e40eb98eb9e42ed011cf6",
1061
  "version_major": 2,
1062
  "version_minor": 0
1063
  },
1064
  "text/plain": [
1065
- " 0%| | 0/40 [00:00<?, ?it/s]"
1066
  ]
1067
  },
1068
  "metadata": {},
@@ -1071,12 +1069,12 @@
1071
  {
1072
  "data": {
1073
  "application/vnd.jupyter.widget-view+json": {
1074
- "model_id": "e2843d610d50429a8cc65e9c4051a9c5",
1075
  "version_major": 2,
1076
  "version_minor": 0
1077
  },
1078
  "text/plain": [
1079
- " 0%| | 0/40 [00:00<?, ?it/s]"
1080
  ]
1081
  },
1082
  "metadata": {},
@@ -1085,12 +1083,12 @@
1085
  {
1086
  "data": {
1087
  "application/vnd.jupyter.widget-view+json": {
1088
- "model_id": "9302becd557a4d5faa8ccde0aacbb4bc",
1089
  "version_major": 2,
1090
  "version_minor": 0
1091
  },
1092
  "text/plain": [
1093
- " 0%| | 0/40 [00:00<?, ?it/s]"
1094
  ]
1095
  },
1096
  "metadata": {},
@@ -1099,12 +1097,12 @@
1099
  {
1100
  "data": {
1101
  "application/vnd.jupyter.widget-view+json": {
1102
- "model_id": "66c8271cfece47918ecab1522cb5e3e7",
1103
  "version_major": 2,
1104
  "version_minor": 0
1105
  },
1106
  "text/plain": [
1107
- " 0%| | 0/40 [00:00<?, ?it/s]"
1108
  ]
1109
  },
1110
  "metadata": {},
@@ -1113,12 +1111,12 @@
1113
  {
1114
  "data": {
1115
  "application/vnd.jupyter.widget-view+json": {
1116
- "model_id": "491cbe8db73d41bc8e1a980e6343086b",
1117
  "version_major": 2,
1118
  "version_minor": 0
1119
  },
1120
  "text/plain": [
1121
- " 0%| | 0/40 [00:00<?, ?it/s]"
1122
  ]
1123
  },
1124
  "metadata": {},
@@ -1127,12 +1125,12 @@
1127
  {
1128
  "data": {
1129
  "application/vnd.jupyter.widget-view+json": {
1130
- "model_id": "26623ba0a2424c448c91f7359f81c6cd",
1131
  "version_major": 2,
1132
  "version_minor": 0
1133
  },
1134
  "text/plain": [
1135
- " 0%| | 0/40 [00:00<?, ?it/s]"
1136
  ]
1137
  },
1138
  "metadata": {},
@@ -1141,12 +1139,12 @@
1141
  {
1142
  "data": {
1143
  "application/vnd.jupyter.widget-view+json": {
1144
- "model_id": "aea1c9aabd7848e19ee0dea60ba43966",
1145
  "version_major": 2,
1146
  "version_minor": 0
1147
  },
1148
  "text/plain": [
1149
- " 0%| | 0/40 [00:00<?, ?it/s]"
1150
  ]
1151
  },
1152
  "metadata": {},
@@ -1155,12 +1153,12 @@
1155
  {
1156
  "data": {
1157
  "application/vnd.jupyter.widget-view+json": {
1158
- "model_id": "7dd06fec865945aa8d8181312c2f08d8",
1159
  "version_major": 2,
1160
  "version_minor": 0
1161
  },
1162
  "text/plain": [
1163
- " 0%| | 0/40 [00:00<?, ?it/s]"
1164
  ]
1165
  },
1166
  "metadata": {},
@@ -1169,12 +1167,12 @@
1169
  {
1170
  "data": {
1171
  "application/vnd.jupyter.widget-view+json": {
1172
- "model_id": "645917cca25a40e8b311cf0fca4b5fac",
1173
  "version_major": 2,
1174
  "version_minor": 0
1175
  },
1176
  "text/plain": [
1177
- " 0%| | 0/40 [00:00<?, ?it/s]"
1178
  ]
1179
  },
1180
  "metadata": {},
@@ -1183,12 +1181,12 @@
1183
  {
1184
  "data": {
1185
  "application/vnd.jupyter.widget-view+json": {
1186
- "model_id": "dd7b77cb6200422f87c7ca0a3218b008",
1187
  "version_major": 2,
1188
  "version_minor": 0
1189
  },
1190
  "text/plain": [
1191
- " 0%| | 0/40 [00:00<?, ?it/s]"
1192
  ]
1193
  },
1194
  "metadata": {},
@@ -1197,12 +1195,12 @@
1197
  {
1198
  "data": {
1199
  "application/vnd.jupyter.widget-view+json": {
1200
- "model_id": "a2441972859a4739bf804ac7960c33cd",
1201
  "version_major": 2,
1202
  "version_minor": 0
1203
  },
1204
  "text/plain": [
1205
- " 0%| | 0/40 [00:00<?, ?it/s]"
1206
  ]
1207
  },
1208
  "metadata": {},
@@ -1211,12 +1209,12 @@
1211
  {
1212
  "data": {
1213
  "application/vnd.jupyter.widget-view+json": {
1214
- "model_id": "5852dd77bf6e409baa1f3d049e41fdb3",
1215
  "version_major": 2,
1216
  "version_minor": 0
1217
  },
1218
  "text/plain": [
1219
- " 0%| | 0/40 [00:00<?, ?it/s]"
1220
  ]
1221
  },
1222
  "metadata": {},
@@ -1225,12 +1223,12 @@
1225
  {
1226
  "data": {
1227
  "application/vnd.jupyter.widget-view+json": {
1228
- "model_id": "64d524b098b94eff8b8b715fad31e614",
1229
  "version_major": 2,
1230
  "version_minor": 0
1231
  },
1232
  "text/plain": [
1233
- " 0%| | 0/40 [00:00<?, ?it/s]"
1234
  ]
1235
  },
1236
  "metadata": {},
@@ -1239,12 +1237,12 @@
1239
  {
1240
  "data": {
1241
  "application/vnd.jupyter.widget-view+json": {
1242
- "model_id": "8592a5488c024a6d952ce2aae60dc5a5",
1243
  "version_major": 2,
1244
  "version_minor": 0
1245
  },
1246
  "text/plain": [
1247
- " 0%| | 0/40 [00:00<?, ?it/s]"
1248
  ]
1249
  },
1250
  "metadata": {},
@@ -1253,12 +1251,12 @@
1253
  {
1254
  "data": {
1255
  "application/vnd.jupyter.widget-view+json": {
1256
- "model_id": "55e79db64846496f937b8429ba61f01f",
1257
  "version_major": 2,
1258
  "version_minor": 0
1259
  },
1260
  "text/plain": [
1261
- " 0%| | 0/40 [00:00<?, ?it/s]"
1262
  ]
1263
  },
1264
  "metadata": {},
@@ -1267,12 +1265,12 @@
1267
  {
1268
  "data": {
1269
  "application/vnd.jupyter.widget-view+json": {
1270
- "model_id": "c17e75bcb5ea4cf99d8860ff01721cc1",
1271
  "version_major": 2,
1272
  "version_minor": 0
1273
  },
1274
  "text/plain": [
1275
- " 0%| | 0/40 [00:00<?, ?it/s]"
1276
  ]
1277
  },
1278
  "metadata": {},
@@ -1281,12 +1279,12 @@
1281
  {
1282
  "data": {
1283
  "application/vnd.jupyter.widget-view+json": {
1284
- "model_id": "adb7ec9dfd2c469680b49d2d5c337b71",
1285
  "version_major": 2,
1286
  "version_minor": 0
1287
  },
1288
  "text/plain": [
1289
- " 0%| | 0/40 [00:00<?, ?it/s]"
1290
  ]
1291
  },
1292
  "metadata": {},
@@ -1294,12 +1292,14 @@
1294
  }
1295
  ],
1296
  "source": [
 
 
1297
  "if __name__ == \"__main__\":\n",
1298
  " # args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)\n",
1299
- " num_round = 1\n",
1300
- " for i in range(num_round):\n",
1301
  " print(f\" round {i} \".center(50, '-'))\n",
1302
  " ddpm21cm = DDPM21CM()\n",
 
1303
  " print(f\"run_name = {ddpm21cm.config.run_name}\")\n",
1304
  " notebook_launcher(ddpm21cm.train, num_processes=1)"
1305
  ]
@@ -1313,13 +1313,14 @@
1313
  "name": "stdout",
1314
  "output_type": "stream",
1315
  "text": [
1316
- "total 980M\n",
 
 
 
1317
  "-rw-r--r-- 1 bxia34 3.1M Jul 1 04:47 Tvir5.4770002365112305-zeta200.0-N32000.npy\n",
1318
  "-rw-r--r-- 1 bxia34 3.1M Jul 1 04:29 Tvir4.698999881744385-zeta30.0-N32000.npy\n",
1319
  "-rw-r--r-- 1 bxia34 3.1M Jul 1 04:11 Tvir5.599999904632568-zeta19.03700065612793-N32000.npy\n",
1320
  "-rw-r--r-- 1 bxia34 3.1M Jul 1 03:53 Tvir4.400000095367432-zeta131.34100341796875-N32000.npy\n",
1321
- "-rw-r--r-- 1 bxia34 848M Jul 1 03:35 model_state.pth\n",
1322
- "drwxr-xr-x 11 bxia34 4.0K Jul 1 01:04 \u001b[0m\u001b[01;34mlogs\u001b[0m/\n",
1323
  "-rw-r--r-- 1 bxia34 3.1M Jul 1 00:23 Tvir4.800000190734863-zeta131.34100341796875-N20000.npy\n",
1324
  "-rw-r--r-- 1 bxia34 3.1M Jul 1 00:05 Tvir5.4770002365112305-zeta200.0-N20000.npy\n",
1325
  "-rw-r--r-- 1 bxia34 3.1M Jun 30 23:47 Tvir4.698999881744385-zeta30.0-N20000.npy\n",
@@ -1386,115 +1387,7 @@
1386
  {
1387
  "data": {
1388
  "application/vnd.jupyter.widget-view+json": {
1389
- "model_id": "dce3b8a47c404689b08fd2c9e811cd8f",
1390
- "version_major": 2,
1391
- "version_minor": 0
1392
- },
1393
- "text/plain": [
1394
- " 0%| | 0/1000 [00:00<?, ?it/s]"
1395
- ]
1396
- },
1397
- "metadata": {},
1398
- "output_type": "display_data"
1399
- }
1400
- ],
1401
- "source": [
1402
- "ddpm21cm = DDPM21CM()\n",
1403
- "ddpm21cm.sample(\"./outputs/model_state.pth\", params=torch.tensor([4.4, 131.341]))"
1404
- ]
1405
- },
1406
- {
1407
- "cell_type": "code",
1408
- "execution_count": null,
1409
- "metadata": {},
1410
- "outputs": [
1411
- {
1412
- "name": "stdout",
1413
- "output_type": "stream",
1414
- "text": [
1415
- "resumed nn_model from ./outputs/model_state.pth\n",
1416
- "Number of parameters for nn_model: 111048705\n",
1417
- "resumed ema_model from ./outputs/model_state.pth\n",
1418
- "sampling 192 images with normalized params = tensor([[0.8000, 0.0377]])\n",
1419
- "nn_model resumed from ./outputs/model_state.pth\n"
1420
- ]
1421
- },
1422
- {
1423
- "data": {
1424
- "application/vnd.jupyter.widget-view+json": {
1425
- "model_id": "390067ed3a1c463a9573d82a1ba66bcc",
1426
- "version_major": 2,
1427
- "version_minor": 0
1428
- },
1429
- "text/plain": [
1430
- " 0%| | 0/1000 [00:00<?, ?it/s]"
1431
- ]
1432
- },
1433
- "metadata": {},
1434
- "output_type": "display_data"
1435
- }
1436
- ],
1437
- "source": [
1438
- "ddpm21cm = DDPM21CM()\n",
1439
- "ddpm21cm.sample(\"./outputs/model_state.pth\", params=torch.tensor((5.6, 19.037)))"
1440
- ]
1441
- },
1442
- {
1443
- "cell_type": "code",
1444
- "execution_count": null,
1445
- "metadata": {},
1446
- "outputs": [
1447
- {
1448
- "name": "stdout",
1449
- "output_type": "stream",
1450
- "text": [
1451
- "resumed nn_model from ./outputs/model_state.pth\n",
1452
- "Number of parameters for nn_model: 111048705\n",
1453
- "resumed ema_model from ./outputs/model_state.pth\n",
1454
- "sampling 192 images with normalized params = tensor([[0.3495, 0.0833]])\n",
1455
- "nn_model resumed from ./outputs/model_state.pth\n"
1456
- ]
1457
- },
1458
- {
1459
- "data": {
1460
- "application/vnd.jupyter.widget-view+json": {
1461
- "model_id": "5ae0f37108724e6f90136b646bfd691b",
1462
- "version_major": 2,
1463
- "version_minor": 0
1464
- },
1465
- "text/plain": [
1466
- " 0%| | 0/1000 [00:00<?, ?it/s]"
1467
- ]
1468
- },
1469
- "metadata": {},
1470
- "output_type": "display_data"
1471
- }
1472
- ],
1473
- "source": [
1474
- "ddpm21cm = DDPM21CM()\n",
1475
- "ddpm21cm.sample(\"./outputs/model_state.pth\", params=torch.tensor((4.699, 30)))"
1476
- ]
1477
- },
1478
- {
1479
- "cell_type": "code",
1480
- "execution_count": null,
1481
- "metadata": {},
1482
- "outputs": [
1483
- {
1484
- "name": "stdout",
1485
- "output_type": "stream",
1486
- "text": [
1487
- "resumed nn_model from ./outputs/model_state.pth\n",
1488
- "Number of parameters for nn_model: 111048705\n",
1489
- "resumed ema_model from ./outputs/model_state.pth\n",
1490
- "sampling 192 images with normalized params = tensor([[0.7385, 0.7917]])\n",
1491
- "nn_model resumed from ./outputs/model_state.pth\n"
1492
- ]
1493
- },
1494
- {
1495
- "data": {
1496
- "application/vnd.jupyter.widget-view+json": {
1497
- "model_id": "f62cdd2196de4ad6aa237e22c7a45446",
1498
  "version_major": 2,
1499
  "version_minor": 0
1500
  },
@@ -1507,44 +1400,24 @@
1507
  }
1508
  ],
1509
  "source": [
1510
- "ddpm21cm = DDPM21CM()\n",
1511
- "ddpm21cm.sample(\"./outputs/model_state.pth\", params=torch.tensor((5.477, 200)))"
1512
- ]
1513
- },
1514
- {
1515
- "cell_type": "code",
1516
- "execution_count": null,
1517
- "metadata": {},
1518
- "outputs": [
1519
- {
1520
- "name": "stdout",
1521
- "output_type": "stream",
1522
- "text": [
1523
- "resumed nn_model from ./outputs/model_state.pth\n",
1524
- "Number of parameters for nn_model: 111048705\n",
1525
- "resumed ema_model from ./outputs/model_state.pth\n",
1526
- "sampling 192 images with normalized params = tensor([[0.4000, 0.5056]])\n",
1527
- "nn_model resumed from ./outputs/model_state.pth\n"
1528
- ]
1529
- },
1530
- {
1531
- "data": {
1532
- "application/vnd.jupyter.widget-view+json": {
1533
- "model_id": "2668c98390d6464fa9d913643fc3d998",
1534
- "version_major": 2,
1535
- "version_minor": 0
1536
- },
1537
- "text/plain": [
1538
- " 0%| | 0/1000 [00:00<?, ?it/s]"
1539
- ]
1540
- },
1541
- "metadata": {},
1542
- "output_type": "display_data"
1543
- }
1544
- ],
1545
- "source": [
1546
- "ddpm21cm = DDPM21CM()\n",
1547
- "ddpm21cm.sample(\"./outputs/model_state.pth\", params=torch.tensor((4.8, 131.341)))"
1548
  ]
1549
  },
1550
  {
@@ -1556,14 +1429,19 @@
1556
  "name": "stdout",
1557
  "output_type": "stream",
1558
  "text": [
1559
- "total 980M\n",
 
 
 
 
 
 
 
1560
  "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jul 1 10:37 Tvir4.800000190734863-zeta131.34100341796875-N32000.npy\n",
1561
  "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jul 1 04:47 Tvir5.4770002365112305-zeta200.0-N32000.npy\n",
1562
  "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jul 1 04:29 Tvir4.698999881744385-zeta30.0-N32000.npy\n",
1563
  "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jul 1 04:11 Tvir5.599999904632568-zeta19.03700065612793-N32000.npy\n",
1564
  "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jul 1 03:53 Tvir4.400000095367432-zeta131.34100341796875-N32000.npy\n",
1565
- "-rw-r--r-- 1 bxia34 pace-jw254 848M Jul 1 03:35 model_state.pth\n",
1566
- "drwxr-xr-x 11 bxia34 pace-jw254 4.0K Jul 1 01:04 \u001b[0m\u001b[01;34mlogs\u001b[0m/\n",
1567
  "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jul 1 00:23 Tvir4.800000190734863-zeta131.34100341796875-N20000.npy\n",
1568
  "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jul 1 00:05 Tvir5.4770002365112305-zeta200.0-N20000.npy\n",
1569
  "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jun 30 23:47 Tvir4.698999881744385-zeta30.0-N20000.npy\n",
 
281
  " lrate = 1e-4\n",
282
  " lr_warmup_steps = 0#5#00\n",
283
  " output_dir = \"./outputs/\"\n",
284
+ " save_name = os.path.join(output_dir, 'model_state')\n",
285
  " # save_freq = 1 #10 # the period of saving model\n",
286
  " # cond = True # if training using the conditional information\n",
287
  " # lr_decay = False #True# if using the learning rate decay\n",
 
460
  " 'unet_state_dict': self.nn_model.state_dict(),\n",
461
  " 'ema_unet_state_dict': self.ema_model.state_dict(),\n",
462
  " }\n",
463
+ " torch.save(model_state, self.config.save_name+f\"-N{self.config.num_image}\")\n",
464
+ " print('saved model at ' + self.config.save_name+f\"-N{self.config.num_image}\")\n",
465
  " # print('saved model at ' + config.save_dir + f\"model_epoch_{ep}_test_{config.run_name}.pth\")\n",
466
  "\n",
467
  " # def rescale(self, value, type='params', to_ranges=[0,1]):\n",
 
537
  {
538
  "data": {
539
  "application/vnd.jupyter.widget-view+json": {
540
+ "model_id": "6025148aa9024daaaa9f3ba7ea0c784b",
541
  "version_major": 2,
542
  "version_minor": 0
543
  },
 
563
  "output_type": "stream",
564
  "text": [
565
  "-------------------- round 0 ---------------------\n",
 
566
  "Number of parameters for nn_model: 111048705\n",
567
+ "run_name = 0702-1312\n",
 
568
  "Launching training on one GPU.\n",
569
  "dataset content: <KeysViewHDF5 ['brightness_temp', 'density', 'kwargs', 'params', 'redshifts_distances', 'seeds', 'xH_box']>\n",
570
  "51200 images can be loaded\n",
571
  "field.shape = (64, 64, 514)\n",
572
  "params keys = [b'ION_Tvir_MIN', b'HII_EFF_FACTOR']\n",
573
+ "loading 1600 images randomly\n",
574
+ "images loaded: (1600, 1, 64, 64)\n"
575
  ]
576
  },
577
  {
 
585
  "name": "stdout",
586
  "output_type": "stream",
587
  "text": [
588
+ "params loaded: (1600, 2)\n",
589
  "images rescaled to [-1.0, 1.1335339546203613]\n",
590
+ "params rescaled to [0.0001702067256199591, 0.9998215201715621]\n"
591
  ]
592
  },
593
  {
594
  "data": {
595
  "application/vnd.jupyter.widget-view+json": {
596
+ "model_id": "82818076b9634cf2821ddc51d00752b7",
597
  "version_major": 2,
598
  "version_minor": 0
599
  },
600
  "text/plain": [
601
+ " 0%| | 0/32 [00:00<?, ?it/s]"
602
  ]
603
  },
604
  "metadata": {},
 
607
  {
608
  "data": {
609
  "application/vnd.jupyter.widget-view+json": {
610
+ "model_id": "e439288486394473af821f74b9df8128",
611
  "version_major": 2,
612
  "version_minor": 0
613
  },
614
  "text/plain": [
615
+ " 0%| | 0/32 [00:00<?, ?it/s]"
616
  ]
617
  },
618
  "metadata": {},
 
621
  {
622
  "data": {
623
  "application/vnd.jupyter.widget-view+json": {
624
+ "model_id": "4e7f1080ae9c44c1bab77db59431976e",
625
  "version_major": 2,
626
  "version_minor": 0
627
  },
628
  "text/plain": [
629
+ " 0%| | 0/32 [00:00<?, ?it/s]"
630
  ]
631
  },
632
  "metadata": {},
 
635
  {
636
  "data": {
637
  "application/vnd.jupyter.widget-view+json": {
638
+ "model_id": "9332319e53044145914249fa79781b55",
639
  "version_major": 2,
640
  "version_minor": 0
641
  },
642
  "text/plain": [
643
+ " 0%| | 0/32 [00:00<?, ?it/s]"
644
  ]
645
  },
646
  "metadata": {},
 
649
  {
650
  "data": {
651
  "application/vnd.jupyter.widget-view+json": {
652
+ "model_id": "9427a193e4c24989a8c24202493991ee",
653
  "version_major": 2,
654
  "version_minor": 0
655
  },
656
  "text/plain": [
657
+ " 0%| | 0/32 [00:00<?, ?it/s]"
658
  ]
659
  },
660
  "metadata": {},
 
663
  {
664
  "data": {
665
  "application/vnd.jupyter.widget-view+json": {
666
+ "model_id": "f2100a02d6874dd39e74265651d68cfb",
667
  "version_major": 2,
668
  "version_minor": 0
669
  },
670
  "text/plain": [
671
+ " 0%| | 0/32 [00:00<?, ?it/s]"
672
  ]
673
  },
674
  "metadata": {},
 
677
  {
678
  "data": {
679
  "application/vnd.jupyter.widget-view+json": {
680
+ "model_id": "8ce81cfbb0044152be24b1c75bf8381a",
681
  "version_major": 2,
682
  "version_minor": 0
683
  },
684
  "text/plain": [
685
+ " 0%| | 0/32 [00:00<?, ?it/s]"
686
  ]
687
  },
688
  "metadata": {},
 
691
  {
692
  "data": {
693
  "application/vnd.jupyter.widget-view+json": {
694
+ "model_id": "015753cf1d6e4b0c92efb17cc4440e37",
695
  "version_major": 2,
696
  "version_minor": 0
697
  },
698
  "text/plain": [
699
+ " 0%| | 0/32 [00:00<?, ?it/s]"
700
  ]
701
  },
702
  "metadata": {},
 
705
  {
706
  "data": {
707
  "application/vnd.jupyter.widget-view+json": {
708
+ "model_id": "bdff45db0dff4c29bb8b7f4ee2f1e92b",
709
  "version_major": 2,
710
  "version_minor": 0
711
  },
712
  "text/plain": [
713
+ " 0%| | 0/32 [00:00<?, ?it/s]"
714
  ]
715
  },
716
  "metadata": {},
 
719
  {
720
  "data": {
721
  "application/vnd.jupyter.widget-view+json": {
722
+ "model_id": "7499ebc69b984c508a1da52e298c5485",
723
  "version_major": 2,
724
  "version_minor": 0
725
  },
726
  "text/plain": [
727
+ " 0%| | 0/32 [00:00<?, ?it/s]"
728
  ]
729
  },
730
  "metadata": {},
 
733
  {
734
  "data": {
735
  "application/vnd.jupyter.widget-view+json": {
736
+ "model_id": "1e5055f2bddd4d13aa386258f7ae9bd9",
737
  "version_major": 2,
738
  "version_minor": 0
739
  },
740
  "text/plain": [
741
+ " 0%| | 0/32 [00:00<?, ?it/s]"
742
  ]
743
  },
744
  "metadata": {},
 
747
  {
748
  "data": {
749
  "application/vnd.jupyter.widget-view+json": {
750
+ "model_id": "adb82dcd1b4a465f87d64c0fe99444e0",
751
  "version_major": 2,
752
  "version_minor": 0
753
  },
754
  "text/plain": [
755
+ " 0%| | 0/32 [00:00<?, ?it/s]"
756
  ]
757
  },
758
  "metadata": {},
 
761
  {
762
  "data": {
763
  "application/vnd.jupyter.widget-view+json": {
764
+ "model_id": "b4206deea8e249138628a01478e8994f",
765
  "version_major": 2,
766
  "version_minor": 0
767
  },
768
  "text/plain": [
769
+ " 0%| | 0/32 [00:00<?, ?it/s]"
770
  ]
771
  },
772
  "metadata": {},
 
775
  {
776
  "data": {
777
  "application/vnd.jupyter.widget-view+json": {
778
+ "model_id": "17c826779a944efd955cc516e9b04c92",
779
  "version_major": 2,
780
  "version_minor": 0
781
  },
782
  "text/plain": [
783
+ " 0%| | 0/32 [00:00<?, ?it/s]"
784
  ]
785
  },
786
  "metadata": {},
 
789
  {
790
  "data": {
791
  "application/vnd.jupyter.widget-view+json": {
792
+ "model_id": "1fed82eb826f433e9d5a85c51aea757a",
793
  "version_major": 2,
794
  "version_minor": 0
795
  },
796
  "text/plain": [
797
+ " 0%| | 0/32 [00:00<?, ?it/s]"
798
  ]
799
  },
800
  "metadata": {},
 
803
  {
804
  "data": {
805
  "application/vnd.jupyter.widget-view+json": {
806
+ "model_id": "01c2f68a24d84757904303fc35681cee",
807
  "version_major": 2,
808
  "version_minor": 0
809
  },
810
  "text/plain": [
811
+ " 0%| | 0/32 [00:00<?, ?it/s]"
812
  ]
813
  },
814
  "metadata": {},
 
817
  {
818
  "data": {
819
  "application/vnd.jupyter.widget-view+json": {
820
+ "model_id": "fbba18c4ea074bf0a7861e7a4a8a0c64",
821
  "version_major": 2,
822
  "version_minor": 0
823
  },
824
  "text/plain": [
825
+ " 0%| | 0/32 [00:00<?, ?it/s]"
826
  ]
827
  },
828
  "metadata": {},
 
831
  {
832
  "data": {
833
  "application/vnd.jupyter.widget-view+json": {
834
+ "model_id": "bd6bf27417c542d19a6e5ced172036db",
835
  "version_major": 2,
836
  "version_minor": 0
837
  },
838
  "text/plain": [
839
+ " 0%| | 0/32 [00:00<?, ?it/s]"
840
  ]
841
  },
842
  "metadata": {},
 
845
  {
846
  "data": {
847
  "application/vnd.jupyter.widget-view+json": {
848
+ "model_id": "79a49db719464cb18283c5225484e050",
849
  "version_major": 2,
850
  "version_minor": 0
851
  },
852
  "text/plain": [
853
+ " 0%| | 0/32 [00:00<?, ?it/s]"
854
  ]
855
  },
856
  "metadata": {},
 
859
  {
860
  "data": {
861
  "application/vnd.jupyter.widget-view+json": {
862
+ "model_id": "97919d69fc3a40cfbfc03286418c4016",
863
  "version_major": 2,
864
  "version_minor": 0
865
  },
866
  "text/plain": [
867
+ " 0%| | 0/32 [00:00<?, ?it/s]"
868
  ]
869
  },
870
  "metadata": {},
 
873
  {
874
  "data": {
875
  "application/vnd.jupyter.widget-view+json": {
876
+ "model_id": "9102d3133c1c49948273e2ae92ae30a5",
877
  "version_major": 2,
878
  "version_minor": 0
879
  },
880
  "text/plain": [
881
+ " 0%| | 0/32 [00:00<?, ?it/s]"
882
  ]
883
  },
884
  "metadata": {},
 
887
  {
888
  "data": {
889
  "application/vnd.jupyter.widget-view+json": {
890
+ "model_id": "5fb87cbf52374eae8f20528ea02ffbcd",
891
  "version_major": 2,
892
  "version_minor": 0
893
  },
894
  "text/plain": [
895
+ " 0%| | 0/32 [00:00<?, ?it/s]"
896
  ]
897
  },
898
  "metadata": {},
 
901
  {
902
  "data": {
903
  "application/vnd.jupyter.widget-view+json": {
904
+ "model_id": "c4317d2231484a68a00b4d458694f289",
905
  "version_major": 2,
906
  "version_minor": 0
907
  },
908
  "text/plain": [
909
+ " 0%| | 0/32 [00:00<?, ?it/s]"
910
  ]
911
  },
912
  "metadata": {},
 
915
  {
916
  "data": {
917
  "application/vnd.jupyter.widget-view+json": {
918
+ "model_id": "423e00e991d241f08d53ed1ab13c5ee1",
919
  "version_major": 2,
920
  "version_minor": 0
921
  },
922
  "text/plain": [
923
+ " 0%| | 0/32 [00:00<?, ?it/s]"
924
  ]
925
  },
926
  "metadata": {},
 
929
  {
930
  "data": {
931
  "application/vnd.jupyter.widget-view+json": {
932
+ "model_id": "d67d402de8fe4f0493527fc53106b1d8",
933
  "version_major": 2,
934
  "version_minor": 0
935
  },
936
  "text/plain": [
937
+ " 0%| | 0/32 [00:00<?, ?it/s]"
938
  ]
939
  },
940
  "metadata": {},
 
943
  {
944
  "data": {
945
  "application/vnd.jupyter.widget-view+json": {
946
+ "model_id": "4df6775e00264b53bd3945860c18dc22",
947
  "version_major": 2,
948
  "version_minor": 0
949
  },
950
  "text/plain": [
951
+ " 0%| | 0/32 [00:00<?, ?it/s]"
952
  ]
953
  },
954
  "metadata": {},
 
957
  {
958
  "data": {
959
  "application/vnd.jupyter.widget-view+json": {
960
+ "model_id": "0445848a37ac46a8b88b24ee34a4fb88",
961
  "version_major": 2,
962
  "version_minor": 0
963
  },
964
  "text/plain": [
965
+ " 0%| | 0/32 [00:00<?, ?it/s]"
966
  ]
967
  },
968
  "metadata": {},
 
971
  {
972
  "data": {
973
  "application/vnd.jupyter.widget-view+json": {
974
+ "model_id": "9d25fe13f99b411293ca0b62b16c9935",
975
  "version_major": 2,
976
  "version_minor": 0
977
  },
978
  "text/plain": [
979
+ " 0%| | 0/32 [00:00<?, ?it/s]"
980
  ]
981
  },
982
  "metadata": {},
 
985
  {
986
  "data": {
987
  "application/vnd.jupyter.widget-view+json": {
988
+ "model_id": "ed5fb6ff03694a199b6d62a66709a745",
989
  "version_major": 2,
990
  "version_minor": 0
991
  },
992
  "text/plain": [
993
+ " 0%| | 0/32 [00:00<?, ?it/s]"
994
  ]
995
  },
996
  "metadata": {},
 
999
  {
1000
  "data": {
1001
  "application/vnd.jupyter.widget-view+json": {
1002
+ "model_id": "e5fd2eb6029d47e89208a08f9df196b6",
1003
  "version_major": 2,
1004
  "version_minor": 0
1005
  },
1006
  "text/plain": [
1007
+ " 0%| | 0/32 [00:00<?, ?it/s]"
1008
  ]
1009
  },
1010
  "metadata": {},
 
1013
  {
1014
  "data": {
1015
  "application/vnd.jupyter.widget-view+json": {
1016
+ "model_id": "f472161eceff48bea9f8ed68c4fb67dc",
1017
  "version_major": 2,
1018
  "version_minor": 0
1019
  },
1020
  "text/plain": [
1021
+ " 0%| | 0/32 [00:00<?, ?it/s]"
1022
  ]
1023
  },
1024
  "metadata": {},
 
1027
  {
1028
  "data": {
1029
  "application/vnd.jupyter.widget-view+json": {
1030
+ "model_id": "67e983dac33841edbcdb91fe331933e8",
1031
  "version_major": 2,
1032
  "version_minor": 0
1033
  },
1034
  "text/plain": [
1035
+ " 0%| | 0/32 [00:00<?, ?it/s]"
1036
  ]
1037
  },
1038
  "metadata": {},
 
1041
  {
1042
  "data": {
1043
  "application/vnd.jupyter.widget-view+json": {
1044
+ "model_id": "61f079cd7523453abf4887be98b0a6a8",
1045
  "version_major": 2,
1046
  "version_minor": 0
1047
  },
1048
  "text/plain": [
1049
+ " 0%| | 0/32 [00:00<?, ?it/s]"
1050
  ]
1051
  },
1052
  "metadata": {},
 
1055
  {
1056
  "data": {
1057
  "application/vnd.jupyter.widget-view+json": {
1058
+ "model_id": "a7bdb93a4bc148dcbd367e05a2ac46d8",
1059
  "version_major": 2,
1060
  "version_minor": 0
1061
  },
1062
  "text/plain": [
1063
+ " 0%| | 0/32 [00:00<?, ?it/s]"
1064
  ]
1065
  },
1066
  "metadata": {},
 
1069
  {
1070
  "data": {
1071
  "application/vnd.jupyter.widget-view+json": {
1072
+ "model_id": "73f1305e213140859346ebb239f9d7bf",
1073
  "version_major": 2,
1074
  "version_minor": 0
1075
  },
1076
  "text/plain": [
1077
+ " 0%| | 0/32 [00:00<?, ?it/s]"
1078
  ]
1079
  },
1080
  "metadata": {},
 
1083
  {
1084
  "data": {
1085
  "application/vnd.jupyter.widget-view+json": {
1086
+ "model_id": "d142bdd2d1c64c4091207338267cea71",
1087
  "version_major": 2,
1088
  "version_minor": 0
1089
  },
1090
  "text/plain": [
1091
+ " 0%| | 0/32 [00:00<?, ?it/s]"
1092
  ]
1093
  },
1094
  "metadata": {},
 
1097
  {
1098
  "data": {
1099
  "application/vnd.jupyter.widget-view+json": {
1100
+ "model_id": "434b794f81984cb8a8ba540d05eb938a",
1101
  "version_major": 2,
1102
  "version_minor": 0
1103
  },
1104
  "text/plain": [
1105
+ " 0%| | 0/32 [00:00<?, ?it/s]"
1106
  ]
1107
  },
1108
  "metadata": {},
 
1111
  {
1112
  "data": {
1113
  "application/vnd.jupyter.widget-view+json": {
1114
+ "model_id": "0a33673984684e8ca65e8cc65d4cc1f5",
1115
  "version_major": 2,
1116
  "version_minor": 0
1117
  },
1118
  "text/plain": [
1119
+ " 0%| | 0/32 [00:00<?, ?it/s]"
1120
  ]
1121
  },
1122
  "metadata": {},
 
1125
  {
1126
  "data": {
1127
  "application/vnd.jupyter.widget-view+json": {
1128
+ "model_id": "b9d5166185a04053a06f369619a8654f",
1129
  "version_major": 2,
1130
  "version_minor": 0
1131
  },
1132
  "text/plain": [
1133
+ " 0%| | 0/32 [00:00<?, ?it/s]"
1134
  ]
1135
  },
1136
  "metadata": {},
 
1139
  {
1140
  "data": {
1141
  "application/vnd.jupyter.widget-view+json": {
1142
+ "model_id": "7fcbd0f7a09e474b8cd04904ae5fd4ab",
1143
  "version_major": 2,
1144
  "version_minor": 0
1145
  },
1146
  "text/plain": [
1147
+ " 0%| | 0/32 [00:00<?, ?it/s]"
1148
  ]
1149
  },
1150
  "metadata": {},
 
1153
  {
1154
  "data": {
1155
  "application/vnd.jupyter.widget-view+json": {
1156
+ "model_id": "a189ea99daa2489c81733a3d25ccadb8",
1157
  "version_major": 2,
1158
  "version_minor": 0
1159
  },
1160
  "text/plain": [
1161
+ " 0%| | 0/32 [00:00<?, ?it/s]"
1162
  ]
1163
  },
1164
  "metadata": {},
 
1167
  {
1168
  "data": {
1169
  "application/vnd.jupyter.widget-view+json": {
1170
+ "model_id": "87e779fce4fc47e580debfd737790b86",
1171
  "version_major": 2,
1172
  "version_minor": 0
1173
  },
1174
  "text/plain": [
1175
+ " 0%| | 0/32 [00:00<?, ?it/s]"
1176
  ]
1177
  },
1178
  "metadata": {},
 
1181
  {
1182
  "data": {
1183
  "application/vnd.jupyter.widget-view+json": {
1184
+ "model_id": "1e44cc8124a94f0dabb5489a06fe3ad3",
1185
  "version_major": 2,
1186
  "version_minor": 0
1187
  },
1188
  "text/plain": [
1189
+ " 0%| | 0/32 [00:00<?, ?it/s]"
1190
  ]
1191
  },
1192
  "metadata": {},
 
1195
  {
1196
  "data": {
1197
  "application/vnd.jupyter.widget-view+json": {
1198
+ "model_id": "97f9d2a71669403ab9672205739af7c8",
1199
  "version_major": 2,
1200
  "version_minor": 0
1201
  },
1202
  "text/plain": [
1203
+ " 0%| | 0/32 [00:00<?, ?it/s]"
1204
  ]
1205
  },
1206
  "metadata": {},
 
1209
  {
1210
  "data": {
1211
  "application/vnd.jupyter.widget-view+json": {
1212
+ "model_id": "be24e4fc75794090980b3501f99d3949",
1213
  "version_major": 2,
1214
  "version_minor": 0
1215
  },
1216
  "text/plain": [
1217
+ " 0%| | 0/32 [00:00<?, ?it/s]"
1218
  ]
1219
  },
1220
  "metadata": {},
 
1223
  {
1224
  "data": {
1225
  "application/vnd.jupyter.widget-view+json": {
1226
+ "model_id": "bcbecbae045746a883fb6cd1aaf0f8ff",
1227
  "version_major": 2,
1228
  "version_minor": 0
1229
  },
1230
  "text/plain": [
1231
+ " 0%| | 0/32 [00:00<?, ?it/s]"
1232
  ]
1233
  },
1234
  "metadata": {},
 
1237
  {
1238
  "data": {
1239
  "application/vnd.jupyter.widget-view+json": {
1240
+ "model_id": "9cb1a9c35dde4aebb1e4ade969e11c57",
1241
  "version_major": 2,
1242
  "version_minor": 0
1243
  },
1244
  "text/plain": [
1245
+ " 0%| | 0/32 [00:00<?, ?it/s]"
1246
  ]
1247
  },
1248
  "metadata": {},
 
1251
  {
1252
  "data": {
1253
  "application/vnd.jupyter.widget-view+json": {
1254
+ "model_id": "7b6fe2498b504053866cd20060e641ac",
1255
  "version_major": 2,
1256
  "version_minor": 0
1257
  },
1258
  "text/plain": [
1259
+ " 0%| | 0/32 [00:00<?, ?it/s]"
1260
  ]
1261
  },
1262
  "metadata": {},
 
1265
  {
1266
  "data": {
1267
  "application/vnd.jupyter.widget-view+json": {
1268
+ "model_id": "12eb61759b264bac902285108cad196a",
1269
  "version_major": 2,
1270
  "version_minor": 0
1271
  },
1272
  "text/plain": [
1273
+ " 0%| | 0/32 [00:00<?, ?it/s]"
1274
  ]
1275
  },
1276
  "metadata": {},
 
1279
  {
1280
  "data": {
1281
  "application/vnd.jupyter.widget-view+json": {
1282
+ "model_id": "378297e998ee40dd88cbff6f6ecdd849",
1283
  "version_major": 2,
1284
  "version_minor": 0
1285
  },
1286
  "text/plain": [
1287
+ " 0%| | 0/32 [00:00<?, ?it/s]"
1288
  ]
1289
  },
1290
  "metadata": {},
 
1292
  }
1293
  ],
1294
  "source": [
1295
+ "num_image_list = [1600,3200]#,6400,12800,25600]\n",
1296
+ "\n",
1297
  "if __name__ == \"__main__\":\n",
1298
  " # args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)\n",
1299
+ " for i, num_image in enumerate(num_image_list):\n",
 
1300
  " print(f\" round {i} \".center(50, '-'))\n",
1301
  " ddpm21cm = DDPM21CM()\n",
1302
+ " ddpm21cm.config.num_image = num_image\n",
1303
  " print(f\"run_name = {ddpm21cm.config.run_name}\")\n",
1304
  " notebook_launcher(ddpm21cm.train, num_processes=1)"
1305
  ]
 
1313
  "name": "stdout",
1314
  "output_type": "stream",
1315
  "text": [
1316
+ "total 968M\n",
1317
+ "-rw-r--r-- 1 bxia34 848M Jul 1 10:59 model_state.pth\n",
1318
+ "drwxr-xr-x 12 bxia34 4.0K Jul 1 10:50 \u001b[0m\u001b[01;34mlogs\u001b[0m/\n",
1319
+ "-rw-r--r-- 1 bxia34 3.1M Jul 1 10:37 Tvir4.800000190734863-zeta131.34100341796875-N32000.npy\n",
1320
  "-rw-r--r-- 1 bxia34 3.1M Jul 1 04:47 Tvir5.4770002365112305-zeta200.0-N32000.npy\n",
1321
  "-rw-r--r-- 1 bxia34 3.1M Jul 1 04:29 Tvir4.698999881744385-zeta30.0-N32000.npy\n",
1322
  "-rw-r--r-- 1 bxia34 3.1M Jul 1 04:11 Tvir5.599999904632568-zeta19.03700065612793-N32000.npy\n",
1323
  "-rw-r--r-- 1 bxia34 3.1M Jul 1 03:53 Tvir4.400000095367432-zeta131.34100341796875-N32000.npy\n",
 
 
1324
  "-rw-r--r-- 1 bxia34 3.1M Jul 1 00:23 Tvir4.800000190734863-zeta131.34100341796875-N20000.npy\n",
1325
  "-rw-r--r-- 1 bxia34 3.1M Jul 1 00:05 Tvir5.4770002365112305-zeta200.0-N20000.npy\n",
1326
  "-rw-r--r-- 1 bxia34 3.1M Jun 30 23:47 Tvir4.698999881744385-zeta30.0-N20000.npy\n",
 
1387
  {
1388
  "data": {
1389
  "application/vnd.jupyter.widget-view+json": {
1390
+ "model_id": "ebb1ef2f95274dbf8ccddd40b8ee931a",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1391
  "version_major": 2,
1392
  "version_minor": 0
1393
  },
 
1400
  }
1401
  ],
1402
  "source": [
1403
+ "if __name__ == \"__main__\":\n",
1404
+ " # args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)\n",
1405
+ " repeat = 800\n",
1406
+ " for i, num_image in enumerate(num_image_list):\n",
1407
+ " ddpm21cm = DDPM21CM()\n",
1408
+ " ddpm21cm.sample(f\"./outputs/model_state-N{num_image}\", params=torch.tensor([4.4, 131.341]), repeat=repeat)\n",
1409
+ "\n",
1410
+ " ddpm21cm = DDPM21CM()\n",
1411
+ " ddpm21cm.sample(f\"./outputs/model_state-N{num_image}\", params=torch.tensor((5.6, 19.037)), repeat=repeat)\n",
1412
+ "\n",
1413
+ " ddpm21cm = DDPM21CM()\n",
1414
+ " ddpm21cm.sample(f\"./outputs/model_state-N{num_image}\", params=torch.tensor((4.699, 30)), repeat=repeat)\n",
1415
+ "\n",
1416
+ " ddpm21cm = DDPM21CM()\n",
1417
+ " ddpm21cm.sample(f\"./outputs/model_state-N{num_image}\", params=torch.tensor((5.477, 200)), repeat=repeat)\n",
1418
+ "\n",
1419
+ " ddpm21cm = DDPM21CM()\n",
1420
+ " ddpm21cm.sample(f\"./outputs/model_state-N{num_image}\", params=torch.tensor((4.8, 131.341)), repeat=repeat)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1421
  ]
1422
  },
1423
  {
 
1429
  "name": "stdout",
1430
  "output_type": "stream",
1431
  "text": [
1432
+ "total 995M\n",
1433
+ "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jul 1 12:31 Tvir4.800000190734863-zeta131.34100341796875-N2000.npy\n",
1434
+ "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jul 1 12:12 Tvir5.4770002365112305-zeta200.0-N2000.npy\n",
1435
+ "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jul 1 11:54 Tvir4.698999881744385-zeta30.0-N2000.npy\n",
1436
+ "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jul 1 11:35 Tvir5.599999904632568-zeta19.03700065612793-N2000.npy\n",
1437
+ "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jul 1 11:17 Tvir4.400000095367432-zeta131.34100341796875-N2000.npy\n",
1438
+ "-rw-r--r-- 1 bxia34 pace-jw254 848M Jul 1 10:59 model_state.pth\n",
1439
+ "drwxr-xr-x 12 bxia34 pace-jw254 4.0K Jul 1 10:50 \u001b[0m\u001b[01;34mlogs\u001b[0m/\n",
1440
  "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jul 1 10:37 Tvir4.800000190734863-zeta131.34100341796875-N32000.npy\n",
1441
  "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jul 1 04:47 Tvir5.4770002365112305-zeta200.0-N32000.npy\n",
1442
  "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jul 1 04:29 Tvir4.698999881744385-zeta30.0-N32000.npy\n",
1443
  "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jul 1 04:11 Tvir5.599999904632568-zeta19.03700065612793-N32000.npy\n",
1444
  "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jul 1 03:53 Tvir4.400000095367432-zeta131.34100341796875-N32000.npy\n",
 
 
1445
  "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jul 1 00:23 Tvir4.800000190734863-zeta131.34100341796875-N20000.npy\n",
1446
  "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jul 1 00:05 Tvir5.4770002365112305-zeta200.0-N20000.npy\n",
1447
  "-rw-r--r-- 1 bxia34 pace-jw254 3.1M Jun 30 23:47 Tvir4.698999881744385-zeta30.0-N20000.npy\n",
quantify_results.ipynb CHANGED
The diff for this file is too large to render. See raw diff