Xsmos commited on
Commit
531f5d7
·
verified ·
1 Parent(s): 59fef97

0706-2100

Browse files
Files changed (1) hide show
  1. diffusion.ipynb +90 -40
diffusion.ipynb CHANGED
@@ -259,7 +259,7 @@
259
  " dim = 3\n",
260
  " stride = (2,2) if dim == 2 else (2,2,2)\n",
261
  " num_image = 2000#32000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560\n",
262
- " batch_size = 2#2#50#20#2#100 # 10\n",
263
  " n_epoch = 10#50#20#20#2#5#25 # 120\n",
264
  " HII_DIM = 64\n",
265
  " num_redshift = 128#64#512#256#256#64#512#128\n",
@@ -564,15 +564,16 @@
564
  "name": "stdout",
565
  "output_type": "stream",
566
  "text": [
567
- "Number of parameters for nn_model: 111048705\n",
568
- "----------------- num_image = 20 -----------------\n",
569
- "run_name = 0706-2044\n",
570
  "Launching training on one GPU.\n",
571
  "dataset content: <KeysViewHDF5 ['brightness_temp', 'density', 'kwargs', 'params', 'redshifts_distances', 'seeds', 'xH_box']>\n",
572
  "51200 images can be loaded\n",
573
  "field.shape = (64, 64, 514)\n",
574
  "params keys = [b'ION_Tvir_MIN', b'HII_EFF_FACTOR']\n",
575
- "loading 20 images randomly\n"
 
576
  ]
577
  },
578
  {
@@ -586,21 +587,20 @@
586
  "name": "stdout",
587
  "output_type": "stream",
588
  "text": [
589
- "images loaded: (20, 1, 64, 128)\n",
590
- "params loaded: (20, 2)\n",
591
- "images rescaled to [-1.0, 1.0278661251068115]\n",
592
- "params rescaled to [0.005739769005289105, 0.972333144312969]\n"
593
  ]
594
  },
595
  {
596
  "data": {
597
  "application/vnd.jupyter.widget-view+json": {
598
- "model_id": "432face5514f46dbab6a0bd6752ee9dd",
599
  "version_major": 2,
600
  "version_minor": 0
601
  },
602
  "text/plain": [
603
- " 0%| | 0/10 [00:00<?, ?it/s]"
604
  ]
605
  },
606
  "metadata": {},
@@ -609,12 +609,12 @@
609
  {
610
  "data": {
611
  "application/vnd.jupyter.widget-view+json": {
612
- "model_id": "43bc633a08524f1ea75546311244f8aa",
613
  "version_major": 2,
614
  "version_minor": 0
615
  },
616
  "text/plain": [
617
- " 0%| | 0/10 [00:00<?, ?it/s]"
618
  ]
619
  },
620
  "metadata": {},
@@ -623,12 +623,12 @@
623
  {
624
  "data": {
625
  "application/vnd.jupyter.widget-view+json": {
626
- "model_id": "884b273aa5d9468ba1facd835f4de2df",
627
  "version_major": 2,
628
  "version_minor": 0
629
  },
630
  "text/plain": [
631
- " 0%| | 0/10 [00:00<?, ?it/s]"
632
  ]
633
  },
634
  "metadata": {},
@@ -637,12 +637,12 @@
637
  {
638
  "data": {
639
  "application/vnd.jupyter.widget-view+json": {
640
- "model_id": "7eabba7c6741415ea5e6fd097e33dbb2",
641
  "version_major": 2,
642
  "version_minor": 0
643
  },
644
  "text/plain": [
645
- " 0%| | 0/10 [00:00<?, ?it/s]"
646
  ]
647
  },
648
  "metadata": {},
@@ -651,12 +651,12 @@
651
  {
652
  "data": {
653
  "application/vnd.jupyter.widget-view+json": {
654
- "model_id": "8cfc863995a74070b23bfcd73d314cca",
655
  "version_major": 2,
656
  "version_minor": 0
657
  },
658
  "text/plain": [
659
- " 0%| | 0/10 [00:00<?, ?it/s]"
660
  ]
661
  },
662
  "metadata": {},
@@ -665,12 +665,12 @@
665
  {
666
  "data": {
667
  "application/vnd.jupyter.widget-view+json": {
668
- "model_id": "7c53f6b374644535aa8e2bce7e961d80",
669
  "version_major": 2,
670
  "version_minor": 0
671
  },
672
  "text/plain": [
673
- " 0%| | 0/10 [00:00<?, ?it/s]"
674
  ]
675
  },
676
  "metadata": {},
@@ -679,12 +679,12 @@
679
  {
680
  "data": {
681
  "application/vnd.jupyter.widget-view+json": {
682
- "model_id": "2494d9b658f7493088b64eab1006f2c4",
683
  "version_major": 2,
684
  "version_minor": 0
685
  },
686
  "text/plain": [
687
- " 0%| | 0/10 [00:00<?, ?it/s]"
688
  ]
689
  },
690
  "metadata": {},
@@ -693,12 +693,12 @@
693
  {
694
  "data": {
695
  "application/vnd.jupyter.widget-view+json": {
696
- "model_id": "070d1da1b37043f8aac851d3bf5c22f3",
697
  "version_major": 2,
698
  "version_minor": 0
699
  },
700
  "text/plain": [
701
- " 0%| | 0/10 [00:00<?, ?it/s]"
702
  ]
703
  },
704
  "metadata": {},
@@ -707,12 +707,12 @@
707
  {
708
  "data": {
709
  "application/vnd.jupyter.widget-view+json": {
710
- "model_id": "063fc075914d4e7bb5ffaf9f0ac14285",
711
  "version_major": 2,
712
  "version_minor": 0
713
  },
714
  "text/plain": [
715
- " 0%| | 0/10 [00:00<?, ?it/s]"
716
  ]
717
  },
718
  "metadata": {},
@@ -721,12 +721,12 @@
721
  {
722
  "data": {
723
  "application/vnd.jupyter.widget-view+json": {
724
- "model_id": "01c9c1d56a394f3caf0c6fec1b072bcd",
725
  "version_major": 2,
726
  "version_minor": 0
727
  },
728
  "text/plain": [
729
- " 0%| | 0/10 [00:00<?, ?it/s]"
730
  ]
731
  },
732
  "metadata": {},
@@ -734,7 +734,7 @@
734
  }
735
  ],
736
  "source": [
737
- "num_image_list = [20]#[200]#[1600,3200,6400,12800,25600]\n",
738
  "if __name__ == \"__main__\":\n",
739
  " # torch.multiprocessing.set_start_method(\"spawn\")\n",
740
  " # args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)\n",
@@ -749,7 +749,7 @@
749
  },
750
  {
751
  "cell_type": "code",
752
- "execution_count": 9,
753
  "metadata": {},
754
  "outputs": [],
755
  "source": [
@@ -758,13 +758,37 @@
758
  },
759
  {
760
  "cell_type": "code",
761
- "execution_count": null,
762
  "metadata": {},
763
- "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
764
  "source": [
765
  "if __name__ == \"__main__\":\n",
766
  " # num_image_list = [1600,3200,6400,12800,25600]\n",
767
- " num_image_list = [200]\n",
768
  " # num_image_list = [3200,6400,12800,25600]\n",
769
  " # args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)\n",
770
  " repeat = 6\n",
@@ -786,18 +810,43 @@
786
  },
787
  {
788
  "cell_type": "code",
789
- "execution_count": null,
790
  "metadata": {},
791
- "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
792
  "source": [
793
  "ls -lth outputs | head"
794
  ]
795
  },
796
  {
797
  "cell_type": "code",
798
- "execution_count": null,
799
  "metadata": {},
800
- "outputs": [],
 
 
 
 
 
 
 
 
801
  "source": [
802
  "def plot_grid(samples, c=None, row=2, col=3):\n",
803
  " print(\"samples.shape =\", samples.shape)\n",
@@ -817,9 +866,10 @@
817
  " plt.close()\n",
818
  " # plt.show()\n",
819
  " \n",
820
- "data = np.load(\"outputs/Tvir4.400000095367432-zeta131.34100341796875-N200.npy\")\n",
821
  "# print(data.shape)\n",
822
- "plot_grid(data)"
 
823
  ]
824
  },
825
  {
 
259
  " dim = 3\n",
260
  " stride = (2,2) if dim == 2 else (2,2,2)\n",
261
  " num_image = 2000#32000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560\n",
262
+ " batch_size = 1#2#50#20#2#100 # 10\n",
263
  " n_epoch = 10#50#20#20#2#5#25 # 120\n",
264
  " HII_DIM = 64\n",
265
  " num_redshift = 128#64#512#256#256#64#512#128\n",
 
564
  "name": "stdout",
565
  "output_type": "stream",
566
  "text": [
567
+ "Number of parameters for nn_model: 306285057\n",
568
+ "----------------- num_image = 50 -----------------\n",
569
+ "run_name = 0706-2100\n",
570
  "Launching training on one GPU.\n",
571
  "dataset content: <KeysViewHDF5 ['brightness_temp', 'density', 'kwargs', 'params', 'redshifts_distances', 'seeds', 'xH_box']>\n",
572
  "51200 images can be loaded\n",
573
  "field.shape = (64, 64, 514)\n",
574
  "params keys = [b'ION_Tvir_MIN', b'HII_EFF_FACTOR']\n",
575
+ "loading 50 images randomly\n",
576
+ "images loaded: (50, 1, 64, 64, 128)\n"
577
  ]
578
  },
579
  {
 
587
  "name": "stdout",
588
  "output_type": "stream",
589
  "text": [
590
+ "params loaded: (50, 2)\n",
591
+ "images rescaled to [-1.0, 1.121453046798706]\n",
592
+ "params rescaled to [0.02178423211262565, 0.9987535256930432]\n"
 
593
  ]
594
  },
595
  {
596
  "data": {
597
  "application/vnd.jupyter.widget-view+json": {
598
+ "model_id": "aeff9b53f53c454b8d03f7b86f096a66",
599
  "version_major": 2,
600
  "version_minor": 0
601
  },
602
  "text/plain": [
603
+ " 0%| | 0/50 [00:00<?, ?it/s]"
604
  ]
605
  },
606
  "metadata": {},
 
609
  {
610
  "data": {
611
  "application/vnd.jupyter.widget-view+json": {
612
+ "model_id": "c741dbbfd1d14e5d92a0e34acce9ab29",
613
  "version_major": 2,
614
  "version_minor": 0
615
  },
616
  "text/plain": [
617
+ " 0%| | 0/50 [00:00<?, ?it/s]"
618
  ]
619
  },
620
  "metadata": {},
 
623
  {
624
  "data": {
625
  "application/vnd.jupyter.widget-view+json": {
626
+ "model_id": "7878b2bc18bb4765bdd0278201499c1b",
627
  "version_major": 2,
628
  "version_minor": 0
629
  },
630
  "text/plain": [
631
+ " 0%| | 0/50 [00:00<?, ?it/s]"
632
  ]
633
  },
634
  "metadata": {},
 
637
  {
638
  "data": {
639
  "application/vnd.jupyter.widget-view+json": {
640
+ "model_id": "e4a4b9dc4a3a4d0fafabf0f61636e039",
641
  "version_major": 2,
642
  "version_minor": 0
643
  },
644
  "text/plain": [
645
+ " 0%| | 0/50 [00:00<?, ?it/s]"
646
  ]
647
  },
648
  "metadata": {},
 
651
  {
652
  "data": {
653
  "application/vnd.jupyter.widget-view+json": {
654
+ "model_id": "b13fb9930f4d475ea76bc4a70d790fb6",
655
  "version_major": 2,
656
  "version_minor": 0
657
  },
658
  "text/plain": [
659
+ " 0%| | 0/50 [00:00<?, ?it/s]"
660
  ]
661
  },
662
  "metadata": {},
 
665
  {
666
  "data": {
667
  "application/vnd.jupyter.widget-view+json": {
668
+ "model_id": "19337b95fd3f485fa8f88a3440416cd7",
669
  "version_major": 2,
670
  "version_minor": 0
671
  },
672
  "text/plain": [
673
+ " 0%| | 0/50 [00:00<?, ?it/s]"
674
  ]
675
  },
676
  "metadata": {},
 
679
  {
680
  "data": {
681
  "application/vnd.jupyter.widget-view+json": {
682
+ "model_id": "bdf02f0c587d4837b134027ecad6065f",
683
  "version_major": 2,
684
  "version_minor": 0
685
  },
686
  "text/plain": [
687
+ " 0%| | 0/50 [00:00<?, ?it/s]"
688
  ]
689
  },
690
  "metadata": {},
 
693
  {
694
  "data": {
695
  "application/vnd.jupyter.widget-view+json": {
696
+ "model_id": "8066317a80394a3e9bce8ff7fe86e582",
697
  "version_major": 2,
698
  "version_minor": 0
699
  },
700
  "text/plain": [
701
+ " 0%| | 0/50 [00:00<?, ?it/s]"
702
  ]
703
  },
704
  "metadata": {},
 
707
  {
708
  "data": {
709
  "application/vnd.jupyter.widget-view+json": {
710
+ "model_id": "ac4c9f0b5ad149e0a7e1bf8df834c0f3",
711
  "version_major": 2,
712
  "version_minor": 0
713
  },
714
  "text/plain": [
715
+ " 0%| | 0/50 [00:00<?, ?it/s]"
716
  ]
717
  },
718
  "metadata": {},
 
721
  {
722
  "data": {
723
  "application/vnd.jupyter.widget-view+json": {
724
+ "model_id": "8bcf474dcbe2414084f78d489e1102e6",
725
  "version_major": 2,
726
  "version_minor": 0
727
  },
728
  "text/plain": [
729
+ " 0%| | 0/50 [00:00<?, ?it/s]"
730
  ]
731
  },
732
  "metadata": {},
 
734
  }
735
  ],
736
  "source": [
737
+ "num_image_list = [50]#[200]#[1600,3200,6400,12800,25600]\n",
738
  "if __name__ == \"__main__\":\n",
739
  " # torch.multiprocessing.set_start_method(\"spawn\")\n",
740
  " # args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)\n",
 
749
  },
750
  {
751
  "cell_type": "code",
752
+ "execution_count": null,
753
  "metadata": {},
754
  "outputs": [],
755
  "source": [
 
758
  },
759
  {
760
  "cell_type": "code",
761
+ "execution_count": 9,
762
  "metadata": {},
763
+ "outputs": [
764
+ {
765
+ "name": "stdout",
766
+ "output_type": "stream",
767
+ "text": [
768
+ "Number of parameters for nn_model: 111048705\n",
769
+ "sampling 6 images with normalized params = tensor([[0.2000, 0.5056]])\n",
770
+ "nn_model resumed from ./outputs/model_state-N20\n"
771
+ ]
772
+ },
773
+ {
774
+ "data": {
775
+ "application/vnd.jupyter.widget-view+json": {
776
+ "model_id": "2f7cc524aa6f4b25be2d44943494e36f",
777
+ "version_major": 2,
778
+ "version_minor": 0
779
+ },
780
+ "text/plain": [
781
+ " 0%| | 0/1000 [00:00<?, ?it/s]"
782
+ ]
783
+ },
784
+ "metadata": {},
785
+ "output_type": "display_data"
786
+ }
787
+ ],
788
  "source": [
789
  "if __name__ == \"__main__\":\n",
790
  " # num_image_list = [1600,3200,6400,12800,25600]\n",
791
+ " num_image_list = [20]\n",
792
  " # num_image_list = [3200,6400,12800,25600]\n",
793
  " # args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)\n",
794
  " repeat = 6\n",
 
810
  },
811
  {
812
  "cell_type": "code",
813
+ "execution_count": 10,
814
  "metadata": {},
815
+ "outputs": [
816
+ {
817
+ "name": "stdout",
818
+ "output_type": "stream",
819
+ "text": [
820
+ "total 7.7G\n",
821
+ "-rw-r--r-- 1 bxia34 pace-jw254 193K Jul 6 20:46 Tvir4.400000095367432-zeta131.34100341796875-N20.npy\n",
822
+ "-rw-r--r-- 1 bxia34 pace-jw254 848M Jul 6 20:45 model_state-N20\n",
823
+ "drwxr-xr-x 44 bxia34 pace-jw254 4.0K Jul 6 20:44 \u001b[0m\u001b[01;34mlogs\u001b[0m/\n",
824
+ "-rw-r--r-- 1 bxia34 pace-jw254 6.1M Jul 5 14:44 Tvir4.400000095367432-zeta131.34100341796875-N200.npy\n",
825
+ "-rw-r--r-- 1 bxia34 pace-jw254 2.3G Jul 5 12:20 model_state-N200\n",
826
+ "-rw-r--r-- 1 bxia34 pace-jw254 13M Jul 3 17:05 Tvir4.800000190734863-zeta131.34100341796875-N25600.npy\n",
827
+ "-rw-r--r-- 1 bxia34 pace-jw254 13M Jul 3 16:46 Tvir5.4770002365112305-zeta200.0-N25600.npy\n",
828
+ "-rw-r--r-- 1 bxia34 pace-jw254 13M Jul 3 16:28 Tvir4.698999881744385-zeta30.0-N25600.npy\n",
829
+ "-rw-r--r-- 1 bxia34 pace-jw254 13M Jul 3 16:09 Tvir5.599999904632568-zeta19.03700065612793-N25600.npy\n"
830
+ ]
831
+ }
832
+ ],
833
  "source": [
834
  "ls -lth outputs | head"
835
  ]
836
  },
837
  {
838
  "cell_type": "code",
839
+ "execution_count": 14,
840
  "metadata": {},
841
+ "outputs": [
842
+ {
843
+ "name": "stdout",
844
+ "output_type": "stream",
845
+ "text": [
846
+ "samples.shape = (6, 1, 1, 64, 128)\n"
847
+ ]
848
+ }
849
+ ],
850
  "source": [
851
  "def plot_grid(samples, c=None, row=2, col=3):\n",
852
  " print(\"samples.shape =\", samples.shape)\n",
 
866
  " plt.close()\n",
867
  " # plt.show()\n",
868
  " \n",
869
+ "data = np.load(\"outputs/Tvir4.400000095367432-zeta131.34100341796875-N20.npy\")\n",
870
  "# print(data.shape)\n",
871
+ "plot_grid(data)\n",
872
+ "# plt.imshow(data)"
873
  ]
874
  },
875
  {