Xsmos commited on
Commit
2327293
·
verified ·
1 Parent(s): 8807e69
Files changed (1) hide show
  1. diffusion.ipynb +51 -45
diffusion.ipynb CHANGED
@@ -74,9 +74,24 @@
74
  "cell_type": "code",
75
  "execution_count": 2,
76
  "metadata": {},
77
- "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  "source": [
79
- "# notebook_login()"
80
  ]
81
  },
82
  {
@@ -259,7 +274,7 @@
259
  " dim = 3\n",
260
  " stride = (2,2) if dim == 2 else (2,2,2)\n",
261
  " num_image = 2000#32000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560\n",
262
- " batch_size = 1#2#50#20#2#100 # 10\n",
263
  " n_epoch = 10#50#20#20#2#5#25 # 120\n",
264
  " HII_DIM = 64\n",
265
  " num_redshift = 128#64#512#256#256#64#512#128\n",
@@ -565,15 +580,16 @@
565
  "output_type": "stream",
566
  "text": [
567
  "Number of parameters for nn_model: 306285057\n",
568
- "----------------- num_image = 50 -----------------\n",
569
- "run_name = 0706-2119\n",
570
  "Launching training on one GPU.\n",
571
  "dataset content: <KeysViewHDF5 ['brightness_temp', 'density', 'kwargs', 'params', 'redshifts_distances', 'seeds', 'xH_box']>\n",
572
  "51200 images can be loaded\n",
573
  "field.shape = (64, 64, 514)\n",
574
  "params keys = [b'ION_Tvir_MIN', b'HII_EFF_FACTOR']\n",
575
- "loading 50 images randomly\n",
576
- "images loaded: (50, 1, 64, 64, 128)\n"
 
577
  ]
578
  },
579
  {
@@ -587,20 +603,19 @@
587
  "name": "stdout",
588
  "output_type": "stream",
589
  "text": [
590
- "params loaded: (50, 2)\n",
591
- "images rescaled to [-1.0, 1.1198105812072754]\n",
592
- "params rescaled to [0.0031794774029485495, 0.9969930182712254]\n"
593
  ]
594
  },
595
  {
596
  "data": {
597
  "application/vnd.jupyter.widget-view+json": {
598
- "model_id": "3d6033db849844caa518578cd47e47f1",
599
  "version_major": 2,
600
  "version_minor": 0
601
  },
602
  "text/plain": [
603
- " 0%| | 0/50 [00:00<?, ?it/s]"
604
  ]
605
  },
606
  "metadata": {},
@@ -609,12 +624,12 @@
609
  {
610
  "data": {
611
  "application/vnd.jupyter.widget-view+json": {
612
- "model_id": "d5df662ea2a242f48099eff4ee120160",
613
  "version_major": 2,
614
  "version_minor": 0
615
  },
616
  "text/plain": [
617
- " 0%| | 0/50 [00:00<?, ?it/s]"
618
  ]
619
  },
620
  "metadata": {},
@@ -623,12 +638,12 @@
623
  {
624
  "data": {
625
  "application/vnd.jupyter.widget-view+json": {
626
- "model_id": "ed794eb171ce45be8bc333b21f71adad",
627
  "version_major": 2,
628
  "version_minor": 0
629
  },
630
  "text/plain": [
631
- " 0%| | 0/50 [00:00<?, ?it/s]"
632
  ]
633
  },
634
  "metadata": {},
@@ -637,12 +652,12 @@
637
  {
638
  "data": {
639
  "application/vnd.jupyter.widget-view+json": {
640
- "model_id": "f917cec5bd9e46358337a66fd4f7b364",
641
  "version_major": 2,
642
  "version_minor": 0
643
  },
644
  "text/plain": [
645
- " 0%| | 0/50 [00:00<?, ?it/s]"
646
  ]
647
  },
648
  "metadata": {},
@@ -651,12 +666,12 @@
651
  {
652
  "data": {
653
  "application/vnd.jupyter.widget-view+json": {
654
- "model_id": "7790690aaee44f4cb601b954fc5cf714",
655
  "version_major": 2,
656
  "version_minor": 0
657
  },
658
  "text/plain": [
659
- " 0%| | 0/50 [00:00<?, ?it/s]"
660
  ]
661
  },
662
  "metadata": {},
@@ -665,12 +680,12 @@
665
  {
666
  "data": {
667
  "application/vnd.jupyter.widget-view+json": {
668
- "model_id": "57d0579828ae47ed92afbf52716acee3",
669
  "version_major": 2,
670
  "version_minor": 0
671
  },
672
  "text/plain": [
673
- " 0%| | 0/50 [00:00<?, ?it/s]"
674
  ]
675
  },
676
  "metadata": {},
@@ -679,12 +694,12 @@
679
  {
680
  "data": {
681
  "application/vnd.jupyter.widget-view+json": {
682
- "model_id": "b9d490d81736432ba7a3ec38137cbb6e",
683
  "version_major": 2,
684
  "version_minor": 0
685
  },
686
  "text/plain": [
687
- " 0%| | 0/50 [00:00<?, ?it/s]"
688
  ]
689
  },
690
  "metadata": {},
@@ -693,12 +708,12 @@
693
  {
694
  "data": {
695
  "application/vnd.jupyter.widget-view+json": {
696
- "model_id": "7fac169a311e466095ded759464bd864",
697
  "version_major": 2,
698
  "version_minor": 0
699
  },
700
  "text/plain": [
701
- " 0%| | 0/50 [00:00<?, ?it/s]"
702
  ]
703
  },
704
  "metadata": {},
@@ -707,12 +722,12 @@
707
  {
708
  "data": {
709
  "application/vnd.jupyter.widget-view+json": {
710
- "model_id": "36bcd1e04d3d4d95aa5d9a078233e2c9",
711
  "version_major": 2,
712
  "version_minor": 0
713
  },
714
  "text/plain": [
715
- " 0%| | 0/50 [00:00<?, ?it/s]"
716
  ]
717
  },
718
  "metadata": {},
@@ -721,12 +736,12 @@
721
  {
722
  "data": {
723
  "application/vnd.jupyter.widget-view+json": {
724
- "model_id": "e017f171453b4f8eb91d09de38e3fb05",
725
  "version_major": 2,
726
  "version_minor": 0
727
  },
728
  "text/plain": [
729
- " 0%| | 0/50 [00:00<?, ?it/s]"
730
  ]
731
  },
732
  "metadata": {},
@@ -734,7 +749,7 @@
734
  }
735
  ],
736
  "source": [
737
- "num_image_list = [50]#[200]#[1600,3200,6400,12800,25600]\n",
738
  "if __name__ == \"__main__\":\n",
739
  " # torch.multiprocessing.set_start_method(\"spawn\")\n",
740
  " # args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)\n",
@@ -747,15 +762,6 @@
747
  " notebook_launcher(ddpm21cm.train, num_processes=1, mixed_precision='fp16')"
748
  ]
749
  },
750
- {
751
- "cell_type": "code",
752
- "execution_count": null,
753
- "metadata": {},
754
- "outputs": [],
755
- "source": [
756
- "# ll -lth outputs"
757
- ]
758
- },
759
  {
760
  "cell_type": "code",
761
  "execution_count": null,
@@ -764,10 +770,10 @@
764
  "source": [
765
  "if __name__ == \"__main__\":\n",
766
  " # num_image_list = [1600,3200,6400,12800,25600]\n",
767
- " num_image_list = [20]\n",
768
  " # num_image_list = [3200,6400,12800,25600]\n",
769
  " # args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)\n",
770
- " repeat = 6\n",
771
  " config = TrainConfig()\n",
772
  " for i, num_image in enumerate(num_image_list):\n",
773
  " config.num_image = num_image\n",
@@ -799,10 +805,10 @@
799
  "metadata": {},
800
  "outputs": [],
801
  "source": [
802
- "def plot_grid(samples, c=None, row=2, col=3):\n",
803
  " print(\"samples.shape =\", samples.shape)\n",
804
  " for j in range(samples.shape[2]):\n",
805
- " plt.figure(figsize = (9,6), dpi=400)\n",
806
  " for i in range(len(samples)):\n",
807
  " plt.subplot(row,col,i+1)\n",
808
  " plt.imshow(samples[i,0,:,:,j], cmap='gray')#, vmin=-1, vmax=1)\n",
@@ -817,7 +823,7 @@
817
  " plt.close()\n",
818
  " # plt.show()\n",
819
  " \n",
820
- "data = np.load(\"outputs/Tvir4.400000095367432-zeta131.34100341796875-N20.npy\")\n",
821
  "# print(data.shape)\n",
822
  "plot_grid(data)\n",
823
  "# plt.imshow(data)"
 
74
  "cell_type": "code",
75
  "execution_count": 2,
76
  "metadata": {},
77
+ "outputs": [
78
+ {
79
+ "data": {
80
+ "application/vnd.jupyter.widget-view+json": {
81
+ "model_id": "9d207dc0a7a24bccb2f419954b64ddc0",
82
+ "version_major": 2,
83
+ "version_minor": 0
84
+ },
85
+ "text/plain": [
86
+ "VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…"
87
+ ]
88
+ },
89
+ "metadata": {},
90
+ "output_type": "display_data"
91
+ }
92
+ ],
93
  "source": [
94
+ "notebook_login()"
95
  ]
96
  },
97
  {
 
274
  " dim = 3\n",
275
  " stride = (2,2) if dim == 2 else (2,2,2)\n",
276
  " num_image = 2000#32000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560\n",
277
+ " batch_size = 2#2#50#20#2#100 # 10\n",
278
  " n_epoch = 10#50#20#20#2#5#25 # 120\n",
279
  " HII_DIM = 64\n",
280
  " num_redshift = 128#64#512#256#256#64#512#128\n",
 
580
  "output_type": "stream",
581
  "text": [
582
  "Number of parameters for nn_model: 306285057\n",
583
+ "---------------- num_image = 1000 ----------------\n",
584
+ "run_name = 0706-2157\n",
585
  "Launching training on one GPU.\n",
586
  "dataset content: <KeysViewHDF5 ['brightness_temp', 'density', 'kwargs', 'params', 'redshifts_distances', 'seeds', 'xH_box']>\n",
587
  "51200 images can be loaded\n",
588
  "field.shape = (64, 64, 514)\n",
589
  "params keys = [b'ION_Tvir_MIN', b'HII_EFF_FACTOR']\n",
590
+ "loading 1000 images randomly\n",
591
+ "images loaded: (1000, 1, 64, 64, 128)\n",
592
+ "params loaded: (1000, 2)\n"
593
  ]
594
  },
595
  {
 
603
  "name": "stdout",
604
  "output_type": "stream",
605
  "text": [
606
+ "images rescaled to [-1.0, 1.2339041233062744]\n",
607
+ "params rescaled to [0.0006788479393025145, 0.9997530171043563]\n"
 
608
  ]
609
  },
610
  {
611
  "data": {
612
  "application/vnd.jupyter.widget-view+json": {
613
+ "model_id": "1d069efc366d480ba2515135ccb18f6c",
614
  "version_major": 2,
615
  "version_minor": 0
616
  },
617
  "text/plain": [
618
+ " 0%| | 0/500 [00:00<?, ?it/s]"
619
  ]
620
  },
621
  "metadata": {},
 
624
  {
625
  "data": {
626
  "application/vnd.jupyter.widget-view+json": {
627
+ "model_id": "93abe7efc5044d3796f67a0d243058b2",
628
  "version_major": 2,
629
  "version_minor": 0
630
  },
631
  "text/plain": [
632
+ " 0%| | 0/500 [00:00<?, ?it/s]"
633
  ]
634
  },
635
  "metadata": {},
 
638
  {
639
  "data": {
640
  "application/vnd.jupyter.widget-view+json": {
641
+ "model_id": "75e22a8998244becb916714cf4b8053e",
642
  "version_major": 2,
643
  "version_minor": 0
644
  },
645
  "text/plain": [
646
+ " 0%| | 0/500 [00:00<?, ?it/s]"
647
  ]
648
  },
649
  "metadata": {},
 
652
  {
653
  "data": {
654
  "application/vnd.jupyter.widget-view+json": {
655
+ "model_id": "7364f0924d2246e49fa02a596e0afde8",
656
  "version_major": 2,
657
  "version_minor": 0
658
  },
659
  "text/plain": [
660
+ " 0%| | 0/500 [00:00<?, ?it/s]"
661
  ]
662
  },
663
  "metadata": {},
 
666
  {
667
  "data": {
668
  "application/vnd.jupyter.widget-view+json": {
669
+ "model_id": "a5c57b0059154c2695e7fc0d426bb0d9",
670
  "version_major": 2,
671
  "version_minor": 0
672
  },
673
  "text/plain": [
674
+ " 0%| | 0/500 [00:00<?, ?it/s]"
675
  ]
676
  },
677
  "metadata": {},
 
680
  {
681
  "data": {
682
  "application/vnd.jupyter.widget-view+json": {
683
+ "model_id": "b0c9ee9c876144eab582e196e71ed83c",
684
  "version_major": 2,
685
  "version_minor": 0
686
  },
687
  "text/plain": [
688
+ " 0%| | 0/500 [00:00<?, ?it/s]"
689
  ]
690
  },
691
  "metadata": {},
 
694
  {
695
  "data": {
696
  "application/vnd.jupyter.widget-view+json": {
697
+ "model_id": "d2956adb8f4c45938f2d8f93baca2998",
698
  "version_major": 2,
699
  "version_minor": 0
700
  },
701
  "text/plain": [
702
+ " 0%| | 0/500 [00:00<?, ?it/s]"
703
  ]
704
  },
705
  "metadata": {},
 
708
  {
709
  "data": {
710
  "application/vnd.jupyter.widget-view+json": {
711
+ "model_id": "1f3031b4cd1040829ce3b457ee30e464",
712
  "version_major": 2,
713
  "version_minor": 0
714
  },
715
  "text/plain": [
716
+ " 0%| | 0/500 [00:00<?, ?it/s]"
717
  ]
718
  },
719
  "metadata": {},
 
722
  {
723
  "data": {
724
  "application/vnd.jupyter.widget-view+json": {
725
+ "model_id": "b4dcf375da4c4c4492d308e4bdc19358",
726
  "version_major": 2,
727
  "version_minor": 0
728
  },
729
  "text/plain": [
730
+ " 0%| | 0/500 [00:00<?, ?it/s]"
731
  ]
732
  },
733
  "metadata": {},
 
736
  {
737
  "data": {
738
  "application/vnd.jupyter.widget-view+json": {
739
+ "model_id": "50c6936f333a451394d6761fa9a085be",
740
  "version_major": 2,
741
  "version_minor": 0
742
  },
743
  "text/plain": [
744
+ " 0%| | 0/500 [00:00<?, ?it/s]"
745
  ]
746
  },
747
  "metadata": {},
 
749
  }
750
  ],
751
  "source": [
752
+ "num_image_list = [1000]#[200]#[1600,3200,6400,12800,25600]\n",
753
  "if __name__ == \"__main__\":\n",
754
  " # torch.multiprocessing.set_start_method(\"spawn\")\n",
755
  " # args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)\n",
 
762
  " notebook_launcher(ddpm21cm.train, num_processes=1, mixed_precision='fp16')"
763
  ]
764
  },
 
 
 
 
 
 
 
 
 
765
  {
766
  "cell_type": "code",
767
  "execution_count": null,
 
770
  "source": [
771
  "if __name__ == \"__main__\":\n",
772
  " # num_image_list = [1600,3200,6400,12800,25600]\n",
773
+ " num_image_list = [1000]\n",
774
  " # num_image_list = [3200,6400,12800,25600]\n",
775
  " # args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)\n",
776
+ " repeat = 2\n",
777
  " config = TrainConfig()\n",
778
  " for i, num_image in enumerate(num_image_list):\n",
779
  " config.num_image = num_image\n",
 
805
  "metadata": {},
806
  "outputs": [],
807
  "source": [
808
+ "def plot_grid(samples, c=None, row=1, col=2):\n",
809
  " print(\"samples.shape =\", samples.shape)\n",
810
  " for j in range(samples.shape[2]):\n",
811
+ " plt.figure(figsize = (12,6), dpi=400)\n",
812
  " for i in range(len(samples)):\n",
813
  " plt.subplot(row,col,i+1)\n",
814
  " plt.imshow(samples[i,0,:,:,j], cmap='gray')#, vmin=-1, vmax=1)\n",
 
823
  " plt.close()\n",
824
  " # plt.show()\n",
825
  " \n",
826
+ "data = np.load(\"outputs/Tvir4.400000095367432-zeta131.34100341796875-N1000.npy\")\n",
827
  "# print(data.shape)\n",
828
  "plot_grid(data)\n",
829
  "# plt.imshow(data)"