0706-2100
Browse files- diffusion.ipynb +90 -40
diffusion.ipynb
CHANGED
|
@@ -259,7 +259,7 @@
|
|
| 259 |
" dim = 3\n",
|
| 260 |
" stride = (2,2) if dim == 2 else (2,2,2)\n",
|
| 261 |
" num_image = 2000#32000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560\n",
|
| 262 |
-
" batch_size =
|
| 263 |
" n_epoch = 10#50#20#20#2#5#25 # 120\n",
|
| 264 |
" HII_DIM = 64\n",
|
| 265 |
" num_redshift = 128#64#512#256#256#64#512#128\n",
|
|
@@ -564,15 +564,16 @@
|
|
| 564 |
"name": "stdout",
|
| 565 |
"output_type": "stream",
|
| 566 |
"text": [
|
| 567 |
-
"Number of parameters for nn_model:
|
| 568 |
-
"----------------- num_image =
|
| 569 |
-
"run_name = 0706-
|
| 570 |
"Launching training on one GPU.\n",
|
| 571 |
"dataset content: <KeysViewHDF5 ['brightness_temp', 'density', 'kwargs', 'params', 'redshifts_distances', 'seeds', 'xH_box']>\n",
|
| 572 |
"51200 images can be loaded\n",
|
| 573 |
"field.shape = (64, 64, 514)\n",
|
| 574 |
"params keys = [b'ION_Tvir_MIN', b'HII_EFF_FACTOR']\n",
|
| 575 |
-
"loading
|
|
|
|
| 576 |
]
|
| 577 |
},
|
| 578 |
{
|
|
@@ -586,21 +587,20 @@
|
|
| 586 |
"name": "stdout",
|
| 587 |
"output_type": "stream",
|
| 588 |
"text": [
|
| 589 |
-
"
|
| 590 |
-
"
|
| 591 |
-
"
|
| 592 |
-
"params rescaled to [0.005739769005289105, 0.972333144312969]\n"
|
| 593 |
]
|
| 594 |
},
|
| 595 |
{
|
| 596 |
"data": {
|
| 597 |
"application/vnd.jupyter.widget-view+json": {
|
| 598 |
-
"model_id": "
|
| 599 |
"version_major": 2,
|
| 600 |
"version_minor": 0
|
| 601 |
},
|
| 602 |
"text/plain": [
|
| 603 |
-
" 0%| | 0/
|
| 604 |
]
|
| 605 |
},
|
| 606 |
"metadata": {},
|
|
@@ -609,12 +609,12 @@
|
|
| 609 |
{
|
| 610 |
"data": {
|
| 611 |
"application/vnd.jupyter.widget-view+json": {
|
| 612 |
-
"model_id": "
|
| 613 |
"version_major": 2,
|
| 614 |
"version_minor": 0
|
| 615 |
},
|
| 616 |
"text/plain": [
|
| 617 |
-
" 0%| | 0/
|
| 618 |
]
|
| 619 |
},
|
| 620 |
"metadata": {},
|
|
@@ -623,12 +623,12 @@
|
|
| 623 |
{
|
| 624 |
"data": {
|
| 625 |
"application/vnd.jupyter.widget-view+json": {
|
| 626 |
-
"model_id": "
|
| 627 |
"version_major": 2,
|
| 628 |
"version_minor": 0
|
| 629 |
},
|
| 630 |
"text/plain": [
|
| 631 |
-
" 0%| | 0/
|
| 632 |
]
|
| 633 |
},
|
| 634 |
"metadata": {},
|
|
@@ -637,12 +637,12 @@
|
|
| 637 |
{
|
| 638 |
"data": {
|
| 639 |
"application/vnd.jupyter.widget-view+json": {
|
| 640 |
-
"model_id": "
|
| 641 |
"version_major": 2,
|
| 642 |
"version_minor": 0
|
| 643 |
},
|
| 644 |
"text/plain": [
|
| 645 |
-
" 0%| | 0/
|
| 646 |
]
|
| 647 |
},
|
| 648 |
"metadata": {},
|
|
@@ -651,12 +651,12 @@
|
|
| 651 |
{
|
| 652 |
"data": {
|
| 653 |
"application/vnd.jupyter.widget-view+json": {
|
| 654 |
-
"model_id": "
|
| 655 |
"version_major": 2,
|
| 656 |
"version_minor": 0
|
| 657 |
},
|
| 658 |
"text/plain": [
|
| 659 |
-
" 0%| | 0/
|
| 660 |
]
|
| 661 |
},
|
| 662 |
"metadata": {},
|
|
@@ -665,12 +665,12 @@
|
|
| 665 |
{
|
| 666 |
"data": {
|
| 667 |
"application/vnd.jupyter.widget-view+json": {
|
| 668 |
-
"model_id": "
|
| 669 |
"version_major": 2,
|
| 670 |
"version_minor": 0
|
| 671 |
},
|
| 672 |
"text/plain": [
|
| 673 |
-
" 0%| | 0/
|
| 674 |
]
|
| 675 |
},
|
| 676 |
"metadata": {},
|
|
@@ -679,12 +679,12 @@
|
|
| 679 |
{
|
| 680 |
"data": {
|
| 681 |
"application/vnd.jupyter.widget-view+json": {
|
| 682 |
-
"model_id": "
|
| 683 |
"version_major": 2,
|
| 684 |
"version_minor": 0
|
| 685 |
},
|
| 686 |
"text/plain": [
|
| 687 |
-
" 0%| | 0/
|
| 688 |
]
|
| 689 |
},
|
| 690 |
"metadata": {},
|
|
@@ -693,12 +693,12 @@
|
|
| 693 |
{
|
| 694 |
"data": {
|
| 695 |
"application/vnd.jupyter.widget-view+json": {
|
| 696 |
-
"model_id": "
|
| 697 |
"version_major": 2,
|
| 698 |
"version_minor": 0
|
| 699 |
},
|
| 700 |
"text/plain": [
|
| 701 |
-
" 0%| | 0/
|
| 702 |
]
|
| 703 |
},
|
| 704 |
"metadata": {},
|
|
@@ -707,12 +707,12 @@
|
|
| 707 |
{
|
| 708 |
"data": {
|
| 709 |
"application/vnd.jupyter.widget-view+json": {
|
| 710 |
-
"model_id": "
|
| 711 |
"version_major": 2,
|
| 712 |
"version_minor": 0
|
| 713 |
},
|
| 714 |
"text/plain": [
|
| 715 |
-
" 0%| | 0/
|
| 716 |
]
|
| 717 |
},
|
| 718 |
"metadata": {},
|
|
@@ -721,12 +721,12 @@
|
|
| 721 |
{
|
| 722 |
"data": {
|
| 723 |
"application/vnd.jupyter.widget-view+json": {
|
| 724 |
-
"model_id": "
|
| 725 |
"version_major": 2,
|
| 726 |
"version_minor": 0
|
| 727 |
},
|
| 728 |
"text/plain": [
|
| 729 |
-
" 0%| | 0/
|
| 730 |
]
|
| 731 |
},
|
| 732 |
"metadata": {},
|
|
@@ -734,7 +734,7 @@
|
|
| 734 |
}
|
| 735 |
],
|
| 736 |
"source": [
|
| 737 |
-
"num_image_list = [
|
| 738 |
"if __name__ == \"__main__\":\n",
|
| 739 |
" # torch.multiprocessing.set_start_method(\"spawn\")\n",
|
| 740 |
" # args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)\n",
|
|
@@ -749,7 +749,7 @@
|
|
| 749 |
},
|
| 750 |
{
|
| 751 |
"cell_type": "code",
|
| 752 |
-
"execution_count":
|
| 753 |
"metadata": {},
|
| 754 |
"outputs": [],
|
| 755 |
"source": [
|
|
@@ -758,13 +758,37 @@
|
|
| 758 |
},
|
| 759 |
{
|
| 760 |
"cell_type": "code",
|
| 761 |
-
"execution_count":
|
| 762 |
"metadata": {},
|
| 763 |
-
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 764 |
"source": [
|
| 765 |
"if __name__ == \"__main__\":\n",
|
| 766 |
" # num_image_list = [1600,3200,6400,12800,25600]\n",
|
| 767 |
-
" num_image_list = [
|
| 768 |
" # num_image_list = [3200,6400,12800,25600]\n",
|
| 769 |
" # args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)\n",
|
| 770 |
" repeat = 6\n",
|
|
@@ -786,18 +810,43 @@
|
|
| 786 |
},
|
| 787 |
{
|
| 788 |
"cell_type": "code",
|
| 789 |
-
"execution_count":
|
| 790 |
"metadata": {},
|
| 791 |
-
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 792 |
"source": [
|
| 793 |
"ls -lth outputs | head"
|
| 794 |
]
|
| 795 |
},
|
| 796 |
{
|
| 797 |
"cell_type": "code",
|
| 798 |
-
"execution_count":
|
| 799 |
"metadata": {},
|
| 800 |
-
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 801 |
"source": [
|
| 802 |
"def plot_grid(samples, c=None, row=2, col=3):\n",
|
| 803 |
" print(\"samples.shape =\", samples.shape)\n",
|
|
@@ -817,9 +866,10 @@
|
|
| 817 |
" plt.close()\n",
|
| 818 |
" # plt.show()\n",
|
| 819 |
" \n",
|
| 820 |
-
"data = np.load(\"outputs/Tvir4.400000095367432-zeta131.34100341796875-
|
| 821 |
"# print(data.shape)\n",
|
| 822 |
-
"plot_grid(data)"
|
|
|
|
| 823 |
]
|
| 824 |
},
|
| 825 |
{
|
|
|
|
| 259 |
" dim = 3\n",
|
| 260 |
" stride = (2,2) if dim == 2 else (2,2,2)\n",
|
| 261 |
" num_image = 2000#32000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560\n",
|
| 262 |
+
" batch_size = 1#2#50#20#2#100 # 10\n",
|
| 263 |
" n_epoch = 10#50#20#20#2#5#25 # 120\n",
|
| 264 |
" HII_DIM = 64\n",
|
| 265 |
" num_redshift = 128#64#512#256#256#64#512#128\n",
|
|
|
|
| 564 |
"name": "stdout",
|
| 565 |
"output_type": "stream",
|
| 566 |
"text": [
|
| 567 |
+
"Number of parameters for nn_model: 306285057\n",
|
| 568 |
+
"----------------- num_image = 50 -----------------\n",
|
| 569 |
+
"run_name = 0706-2100\n",
|
| 570 |
"Launching training on one GPU.\n",
|
| 571 |
"dataset content: <KeysViewHDF5 ['brightness_temp', 'density', 'kwargs', 'params', 'redshifts_distances', 'seeds', 'xH_box']>\n",
|
| 572 |
"51200 images can be loaded\n",
|
| 573 |
"field.shape = (64, 64, 514)\n",
|
| 574 |
"params keys = [b'ION_Tvir_MIN', b'HII_EFF_FACTOR']\n",
|
| 575 |
+
"loading 50 images randomly\n",
|
| 576 |
+
"images loaded: (50, 1, 64, 64, 128)\n"
|
| 577 |
]
|
| 578 |
},
|
| 579 |
{
|
|
|
|
| 587 |
"name": "stdout",
|
| 588 |
"output_type": "stream",
|
| 589 |
"text": [
|
| 590 |
+
"params loaded: (50, 2)\n",
|
| 591 |
+
"images rescaled to [-1.0, 1.121453046798706]\n",
|
| 592 |
+
"params rescaled to [0.02178423211262565, 0.9987535256930432]\n"
|
|
|
|
| 593 |
]
|
| 594 |
},
|
| 595 |
{
|
| 596 |
"data": {
|
| 597 |
"application/vnd.jupyter.widget-view+json": {
|
| 598 |
+
"model_id": "aeff9b53f53c454b8d03f7b86f096a66",
|
| 599 |
"version_major": 2,
|
| 600 |
"version_minor": 0
|
| 601 |
},
|
| 602 |
"text/plain": [
|
| 603 |
+
" 0%| | 0/50 [00:00<?, ?it/s]"
|
| 604 |
]
|
| 605 |
},
|
| 606 |
"metadata": {},
|
|
|
|
| 609 |
{
|
| 610 |
"data": {
|
| 611 |
"application/vnd.jupyter.widget-view+json": {
|
| 612 |
+
"model_id": "c741dbbfd1d14e5d92a0e34acce9ab29",
|
| 613 |
"version_major": 2,
|
| 614 |
"version_minor": 0
|
| 615 |
},
|
| 616 |
"text/plain": [
|
| 617 |
+
" 0%| | 0/50 [00:00<?, ?it/s]"
|
| 618 |
]
|
| 619 |
},
|
| 620 |
"metadata": {},
|
|
|
|
| 623 |
{
|
| 624 |
"data": {
|
| 625 |
"application/vnd.jupyter.widget-view+json": {
|
| 626 |
+
"model_id": "7878b2bc18bb4765bdd0278201499c1b",
|
| 627 |
"version_major": 2,
|
| 628 |
"version_minor": 0
|
| 629 |
},
|
| 630 |
"text/plain": [
|
| 631 |
+
" 0%| | 0/50 [00:00<?, ?it/s]"
|
| 632 |
]
|
| 633 |
},
|
| 634 |
"metadata": {},
|
|
|
|
| 637 |
{
|
| 638 |
"data": {
|
| 639 |
"application/vnd.jupyter.widget-view+json": {
|
| 640 |
+
"model_id": "e4a4b9dc4a3a4d0fafabf0f61636e039",
|
| 641 |
"version_major": 2,
|
| 642 |
"version_minor": 0
|
| 643 |
},
|
| 644 |
"text/plain": [
|
| 645 |
+
" 0%| | 0/50 [00:00<?, ?it/s]"
|
| 646 |
]
|
| 647 |
},
|
| 648 |
"metadata": {},
|
|
|
|
| 651 |
{
|
| 652 |
"data": {
|
| 653 |
"application/vnd.jupyter.widget-view+json": {
|
| 654 |
+
"model_id": "b13fb9930f4d475ea76bc4a70d790fb6",
|
| 655 |
"version_major": 2,
|
| 656 |
"version_minor": 0
|
| 657 |
},
|
| 658 |
"text/plain": [
|
| 659 |
+
" 0%| | 0/50 [00:00<?, ?it/s]"
|
| 660 |
]
|
| 661 |
},
|
| 662 |
"metadata": {},
|
|
|
|
| 665 |
{
|
| 666 |
"data": {
|
| 667 |
"application/vnd.jupyter.widget-view+json": {
|
| 668 |
+
"model_id": "19337b95fd3f485fa8f88a3440416cd7",
|
| 669 |
"version_major": 2,
|
| 670 |
"version_minor": 0
|
| 671 |
},
|
| 672 |
"text/plain": [
|
| 673 |
+
" 0%| | 0/50 [00:00<?, ?it/s]"
|
| 674 |
]
|
| 675 |
},
|
| 676 |
"metadata": {},
|
|
|
|
| 679 |
{
|
| 680 |
"data": {
|
| 681 |
"application/vnd.jupyter.widget-view+json": {
|
| 682 |
+
"model_id": "bdf02f0c587d4837b134027ecad6065f",
|
| 683 |
"version_major": 2,
|
| 684 |
"version_minor": 0
|
| 685 |
},
|
| 686 |
"text/plain": [
|
| 687 |
+
" 0%| | 0/50 [00:00<?, ?it/s]"
|
| 688 |
]
|
| 689 |
},
|
| 690 |
"metadata": {},
|
|
|
|
| 693 |
{
|
| 694 |
"data": {
|
| 695 |
"application/vnd.jupyter.widget-view+json": {
|
| 696 |
+
"model_id": "8066317a80394a3e9bce8ff7fe86e582",
|
| 697 |
"version_major": 2,
|
| 698 |
"version_minor": 0
|
| 699 |
},
|
| 700 |
"text/plain": [
|
| 701 |
+
" 0%| | 0/50 [00:00<?, ?it/s]"
|
| 702 |
]
|
| 703 |
},
|
| 704 |
"metadata": {},
|
|
|
|
| 707 |
{
|
| 708 |
"data": {
|
| 709 |
"application/vnd.jupyter.widget-view+json": {
|
| 710 |
+
"model_id": "ac4c9f0b5ad149e0a7e1bf8df834c0f3",
|
| 711 |
"version_major": 2,
|
| 712 |
"version_minor": 0
|
| 713 |
},
|
| 714 |
"text/plain": [
|
| 715 |
+
" 0%| | 0/50 [00:00<?, ?it/s]"
|
| 716 |
]
|
| 717 |
},
|
| 718 |
"metadata": {},
|
|
|
|
| 721 |
{
|
| 722 |
"data": {
|
| 723 |
"application/vnd.jupyter.widget-view+json": {
|
| 724 |
+
"model_id": "8bcf474dcbe2414084f78d489e1102e6",
|
| 725 |
"version_major": 2,
|
| 726 |
"version_minor": 0
|
| 727 |
},
|
| 728 |
"text/plain": [
|
| 729 |
+
" 0%| | 0/50 [00:00<?, ?it/s]"
|
| 730 |
]
|
| 731 |
},
|
| 732 |
"metadata": {},
|
|
|
|
| 734 |
}
|
| 735 |
],
|
| 736 |
"source": [
|
| 737 |
+
"num_image_list = [50]#[200]#[1600,3200,6400,12800,25600]\n",
|
| 738 |
"if __name__ == \"__main__\":\n",
|
| 739 |
" # torch.multiprocessing.set_start_method(\"spawn\")\n",
|
| 740 |
" # args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)\n",
|
|
|
|
| 749 |
},
|
| 750 |
{
|
| 751 |
"cell_type": "code",
|
| 752 |
+
"execution_count": null,
|
| 753 |
"metadata": {},
|
| 754 |
"outputs": [],
|
| 755 |
"source": [
|
|
|
|
| 758 |
},
|
| 759 |
{
|
| 760 |
"cell_type": "code",
|
| 761 |
+
"execution_count": 9,
|
| 762 |
"metadata": {},
|
| 763 |
+
"outputs": [
|
| 764 |
+
{
|
| 765 |
+
"name": "stdout",
|
| 766 |
+
"output_type": "stream",
|
| 767 |
+
"text": [
|
| 768 |
+
"Number of parameters for nn_model: 111048705\n",
|
| 769 |
+
"sampling 6 images with normalized params = tensor([[0.2000, 0.5056]])\n",
|
| 770 |
+
"nn_model resumed from ./outputs/model_state-N20\n"
|
| 771 |
+
]
|
| 772 |
+
},
|
| 773 |
+
{
|
| 774 |
+
"data": {
|
| 775 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 776 |
+
"model_id": "2f7cc524aa6f4b25be2d44943494e36f",
|
| 777 |
+
"version_major": 2,
|
| 778 |
+
"version_minor": 0
|
| 779 |
+
},
|
| 780 |
+
"text/plain": [
|
| 781 |
+
" 0%| | 0/1000 [00:00<?, ?it/s]"
|
| 782 |
+
]
|
| 783 |
+
},
|
| 784 |
+
"metadata": {},
|
| 785 |
+
"output_type": "display_data"
|
| 786 |
+
}
|
| 787 |
+
],
|
| 788 |
"source": [
|
| 789 |
"if __name__ == \"__main__\":\n",
|
| 790 |
" # num_image_list = [1600,3200,6400,12800,25600]\n",
|
| 791 |
+
" num_image_list = [20]\n",
|
| 792 |
" # num_image_list = [3200,6400,12800,25600]\n",
|
| 793 |
" # args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)\n",
|
| 794 |
" repeat = 6\n",
|
|
|
|
| 810 |
},
|
| 811 |
{
|
| 812 |
"cell_type": "code",
|
| 813 |
+
"execution_count": 10,
|
| 814 |
"metadata": {},
|
| 815 |
+
"outputs": [
|
| 816 |
+
{
|
| 817 |
+
"name": "stdout",
|
| 818 |
+
"output_type": "stream",
|
| 819 |
+
"text": [
|
| 820 |
+
"total 7.7G\n",
|
| 821 |
+
"-rw-r--r-- 1 bxia34 pace-jw254 193K Jul 6 20:46 Tvir4.400000095367432-zeta131.34100341796875-N20.npy\n",
|
| 822 |
+
"-rw-r--r-- 1 bxia34 pace-jw254 848M Jul 6 20:45 model_state-N20\n",
|
| 823 |
+
"drwxr-xr-x 44 bxia34 pace-jw254 4.0K Jul 6 20:44 \u001b[0m\u001b[01;34mlogs\u001b[0m/\n",
|
| 824 |
+
"-rw-r--r-- 1 bxia34 pace-jw254 6.1M Jul 5 14:44 Tvir4.400000095367432-zeta131.34100341796875-N200.npy\n",
|
| 825 |
+
"-rw-r--r-- 1 bxia34 pace-jw254 2.3G Jul 5 12:20 model_state-N200\n",
|
| 826 |
+
"-rw-r--r-- 1 bxia34 pace-jw254 13M Jul 3 17:05 Tvir4.800000190734863-zeta131.34100341796875-N25600.npy\n",
|
| 827 |
+
"-rw-r--r-- 1 bxia34 pace-jw254 13M Jul 3 16:46 Tvir5.4770002365112305-zeta200.0-N25600.npy\n",
|
| 828 |
+
"-rw-r--r-- 1 bxia34 pace-jw254 13M Jul 3 16:28 Tvir4.698999881744385-zeta30.0-N25600.npy\n",
|
| 829 |
+
"-rw-r--r-- 1 bxia34 pace-jw254 13M Jul 3 16:09 Tvir5.599999904632568-zeta19.03700065612793-N25600.npy\n"
|
| 830 |
+
]
|
| 831 |
+
}
|
| 832 |
+
],
|
| 833 |
"source": [
|
| 834 |
"ls -lth outputs | head"
|
| 835 |
]
|
| 836 |
},
|
| 837 |
{
|
| 838 |
"cell_type": "code",
|
| 839 |
+
"execution_count": 14,
|
| 840 |
"metadata": {},
|
| 841 |
+
"outputs": [
|
| 842 |
+
{
|
| 843 |
+
"name": "stdout",
|
| 844 |
+
"output_type": "stream",
|
| 845 |
+
"text": [
|
| 846 |
+
"samples.shape = (6, 1, 1, 64, 128)\n"
|
| 847 |
+
]
|
| 848 |
+
}
|
| 849 |
+
],
|
| 850 |
"source": [
|
| 851 |
"def plot_grid(samples, c=None, row=2, col=3):\n",
|
| 852 |
" print(\"samples.shape =\", samples.shape)\n",
|
|
|
|
| 866 |
" plt.close()\n",
|
| 867 |
" # plt.show()\n",
|
| 868 |
" \n",
|
| 869 |
+
"data = np.load(\"outputs/Tvir4.400000095367432-zeta131.34100341796875-N20.npy\")\n",
|
| 870 |
"# print(data.shape)\n",
|
| 871 |
+
"plot_grid(data)\n",
|
| 872 |
+
"# plt.imshow(data)"
|
| 873 |
]
|
| 874 |
},
|
| 875 |
{
|