0706-2157
Browse files- diffusion.ipynb +51 -45
diffusion.ipynb
CHANGED
|
@@ -74,9 +74,24 @@
|
|
| 74 |
"cell_type": "code",
|
| 75 |
"execution_count": 2,
|
| 76 |
"metadata": {},
|
| 77 |
-
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
"source": [
|
| 79 |
-
"
|
| 80 |
]
|
| 81 |
},
|
| 82 |
{
|
|
@@ -259,7 +274,7 @@
|
|
| 259 |
" dim = 3\n",
|
| 260 |
" stride = (2,2) if dim == 2 else (2,2,2)\n",
|
| 261 |
" num_image = 2000#32000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560\n",
|
| 262 |
-
" batch_size =
|
| 263 |
" n_epoch = 10#50#20#20#2#5#25 # 120\n",
|
| 264 |
" HII_DIM = 64\n",
|
| 265 |
" num_redshift = 128#64#512#256#256#64#512#128\n",
|
|
@@ -565,15 +580,16 @@
|
|
| 565 |
"output_type": "stream",
|
| 566 |
"text": [
|
| 567 |
"Number of parameters for nn_model: 306285057\n",
|
| 568 |
-
"
|
| 569 |
-
"run_name = 0706-
|
| 570 |
"Launching training on one GPU.\n",
|
| 571 |
"dataset content: <KeysViewHDF5 ['brightness_temp', 'density', 'kwargs', 'params', 'redshifts_distances', 'seeds', 'xH_box']>\n",
|
| 572 |
"51200 images can be loaded\n",
|
| 573 |
"field.shape = (64, 64, 514)\n",
|
| 574 |
"params keys = [b'ION_Tvir_MIN', b'HII_EFF_FACTOR']\n",
|
| 575 |
-
"loading
|
| 576 |
-
"images loaded: (
|
|
|
|
| 577 |
]
|
| 578 |
},
|
| 579 |
{
|
|
@@ -587,20 +603,19 @@
|
|
| 587 |
"name": "stdout",
|
| 588 |
"output_type": "stream",
|
| 589 |
"text": [
|
| 590 |
-
"
|
| 591 |
-
"
|
| 592 |
-
"params rescaled to [0.0031794774029485495, 0.9969930182712254]\n"
|
| 593 |
]
|
| 594 |
},
|
| 595 |
{
|
| 596 |
"data": {
|
| 597 |
"application/vnd.jupyter.widget-view+json": {
|
| 598 |
-
"model_id": "
|
| 599 |
"version_major": 2,
|
| 600 |
"version_minor": 0
|
| 601 |
},
|
| 602 |
"text/plain": [
|
| 603 |
-
" 0%| | 0/
|
| 604 |
]
|
| 605 |
},
|
| 606 |
"metadata": {},
|
|
@@ -609,12 +624,12 @@
|
|
| 609 |
{
|
| 610 |
"data": {
|
| 611 |
"application/vnd.jupyter.widget-view+json": {
|
| 612 |
-
"model_id": "
|
| 613 |
"version_major": 2,
|
| 614 |
"version_minor": 0
|
| 615 |
},
|
| 616 |
"text/plain": [
|
| 617 |
-
" 0%| | 0/
|
| 618 |
]
|
| 619 |
},
|
| 620 |
"metadata": {},
|
|
@@ -623,12 +638,12 @@
|
|
| 623 |
{
|
| 624 |
"data": {
|
| 625 |
"application/vnd.jupyter.widget-view+json": {
|
| 626 |
-
"model_id": "
|
| 627 |
"version_major": 2,
|
| 628 |
"version_minor": 0
|
| 629 |
},
|
| 630 |
"text/plain": [
|
| 631 |
-
" 0%| | 0/
|
| 632 |
]
|
| 633 |
},
|
| 634 |
"metadata": {},
|
|
@@ -637,12 +652,12 @@
|
|
| 637 |
{
|
| 638 |
"data": {
|
| 639 |
"application/vnd.jupyter.widget-view+json": {
|
| 640 |
-
"model_id": "
|
| 641 |
"version_major": 2,
|
| 642 |
"version_minor": 0
|
| 643 |
},
|
| 644 |
"text/plain": [
|
| 645 |
-
" 0%| | 0/
|
| 646 |
]
|
| 647 |
},
|
| 648 |
"metadata": {},
|
|
@@ -651,12 +666,12 @@
|
|
| 651 |
{
|
| 652 |
"data": {
|
| 653 |
"application/vnd.jupyter.widget-view+json": {
|
| 654 |
-
"model_id": "
|
| 655 |
"version_major": 2,
|
| 656 |
"version_minor": 0
|
| 657 |
},
|
| 658 |
"text/plain": [
|
| 659 |
-
" 0%| | 0/
|
| 660 |
]
|
| 661 |
},
|
| 662 |
"metadata": {},
|
|
@@ -665,12 +680,12 @@
|
|
| 665 |
{
|
| 666 |
"data": {
|
| 667 |
"application/vnd.jupyter.widget-view+json": {
|
| 668 |
-
"model_id": "
|
| 669 |
"version_major": 2,
|
| 670 |
"version_minor": 0
|
| 671 |
},
|
| 672 |
"text/plain": [
|
| 673 |
-
" 0%| | 0/
|
| 674 |
]
|
| 675 |
},
|
| 676 |
"metadata": {},
|
|
@@ -679,12 +694,12 @@
|
|
| 679 |
{
|
| 680 |
"data": {
|
| 681 |
"application/vnd.jupyter.widget-view+json": {
|
| 682 |
-
"model_id": "
|
| 683 |
"version_major": 2,
|
| 684 |
"version_minor": 0
|
| 685 |
},
|
| 686 |
"text/plain": [
|
| 687 |
-
" 0%| | 0/
|
| 688 |
]
|
| 689 |
},
|
| 690 |
"metadata": {},
|
|
@@ -693,12 +708,12 @@
|
|
| 693 |
{
|
| 694 |
"data": {
|
| 695 |
"application/vnd.jupyter.widget-view+json": {
|
| 696 |
-
"model_id": "
|
| 697 |
"version_major": 2,
|
| 698 |
"version_minor": 0
|
| 699 |
},
|
| 700 |
"text/plain": [
|
| 701 |
-
" 0%| | 0/
|
| 702 |
]
|
| 703 |
},
|
| 704 |
"metadata": {},
|
|
@@ -707,12 +722,12 @@
|
|
| 707 |
{
|
| 708 |
"data": {
|
| 709 |
"application/vnd.jupyter.widget-view+json": {
|
| 710 |
-
"model_id": "
|
| 711 |
"version_major": 2,
|
| 712 |
"version_minor": 0
|
| 713 |
},
|
| 714 |
"text/plain": [
|
| 715 |
-
" 0%| | 0/
|
| 716 |
]
|
| 717 |
},
|
| 718 |
"metadata": {},
|
|
@@ -721,12 +736,12 @@
|
|
| 721 |
{
|
| 722 |
"data": {
|
| 723 |
"application/vnd.jupyter.widget-view+json": {
|
| 724 |
-
"model_id": "
|
| 725 |
"version_major": 2,
|
| 726 |
"version_minor": 0
|
| 727 |
},
|
| 728 |
"text/plain": [
|
| 729 |
-
" 0%| | 0/
|
| 730 |
]
|
| 731 |
},
|
| 732 |
"metadata": {},
|
|
@@ -734,7 +749,7 @@
|
|
| 734 |
}
|
| 735 |
],
|
| 736 |
"source": [
|
| 737 |
-
"num_image_list = [
|
| 738 |
"if __name__ == \"__main__\":\n",
|
| 739 |
" # torch.multiprocessing.set_start_method(\"spawn\")\n",
|
| 740 |
" # args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)\n",
|
|
@@ -747,15 +762,6 @@
|
|
| 747 |
" notebook_launcher(ddpm21cm.train, num_processes=1, mixed_precision='fp16')"
|
| 748 |
]
|
| 749 |
},
|
| 750 |
-
{
|
| 751 |
-
"cell_type": "code",
|
| 752 |
-
"execution_count": null,
|
| 753 |
-
"metadata": {},
|
| 754 |
-
"outputs": [],
|
| 755 |
-
"source": [
|
| 756 |
-
"# ll -lth outputs"
|
| 757 |
-
]
|
| 758 |
-
},
|
| 759 |
{
|
| 760 |
"cell_type": "code",
|
| 761 |
"execution_count": null,
|
|
@@ -764,10 +770,10 @@
|
|
| 764 |
"source": [
|
| 765 |
"if __name__ == \"__main__\":\n",
|
| 766 |
" # num_image_list = [1600,3200,6400,12800,25600]\n",
|
| 767 |
-
" num_image_list = [
|
| 768 |
" # num_image_list = [3200,6400,12800,25600]\n",
|
| 769 |
" # args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)\n",
|
| 770 |
-
" repeat =
|
| 771 |
" config = TrainConfig()\n",
|
| 772 |
" for i, num_image in enumerate(num_image_list):\n",
|
| 773 |
" config.num_image = num_image\n",
|
|
@@ -799,10 +805,10 @@
|
|
| 799 |
"metadata": {},
|
| 800 |
"outputs": [],
|
| 801 |
"source": [
|
| 802 |
-
"def plot_grid(samples, c=None, row=
|
| 803 |
" print(\"samples.shape =\", samples.shape)\n",
|
| 804 |
" for j in range(samples.shape[2]):\n",
|
| 805 |
-
" plt.figure(figsize = (
|
| 806 |
" for i in range(len(samples)):\n",
|
| 807 |
" plt.subplot(row,col,i+1)\n",
|
| 808 |
" plt.imshow(samples[i,0,:,:,j], cmap='gray')#, vmin=-1, vmax=1)\n",
|
|
@@ -817,7 +823,7 @@
|
|
| 817 |
" plt.close()\n",
|
| 818 |
" # plt.show()\n",
|
| 819 |
" \n",
|
| 820 |
-
"data = np.load(\"outputs/Tvir4.400000095367432-zeta131.34100341796875-
|
| 821 |
"# print(data.shape)\n",
|
| 822 |
"plot_grid(data)\n",
|
| 823 |
"# plt.imshow(data)"
|
|
|
|
| 74 |
"cell_type": "code",
|
| 75 |
"execution_count": 2,
|
| 76 |
"metadata": {},
|
| 77 |
+
"outputs": [
|
| 78 |
+
{
|
| 79 |
+
"data": {
|
| 80 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 81 |
+
"model_id": "9d207dc0a7a24bccb2f419954b64ddc0",
|
| 82 |
+
"version_major": 2,
|
| 83 |
+
"version_minor": 0
|
| 84 |
+
},
|
| 85 |
+
"text/plain": [
|
| 86 |
+
"VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…"
|
| 87 |
+
]
|
| 88 |
+
},
|
| 89 |
+
"metadata": {},
|
| 90 |
+
"output_type": "display_data"
|
| 91 |
+
}
|
| 92 |
+
],
|
| 93 |
"source": [
|
| 94 |
+
"notebook_login()"
|
| 95 |
]
|
| 96 |
},
|
| 97 |
{
|
|
|
|
| 274 |
" dim = 3\n",
|
| 275 |
" stride = (2,2) if dim == 2 else (2,2,2)\n",
|
| 276 |
" num_image = 2000#32000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560\n",
|
| 277 |
+
" batch_size = 2#2#50#20#2#100 # 10\n",
|
| 278 |
" n_epoch = 10#50#20#20#2#5#25 # 120\n",
|
| 279 |
" HII_DIM = 64\n",
|
| 280 |
" num_redshift = 128#64#512#256#256#64#512#128\n",
|
|
|
|
| 580 |
"output_type": "stream",
|
| 581 |
"text": [
|
| 582 |
"Number of parameters for nn_model: 306285057\n",
|
| 583 |
+
"---------------- num_image = 1000 ----------------\n",
|
| 584 |
+
"run_name = 0706-2157\n",
|
| 585 |
"Launching training on one GPU.\n",
|
| 586 |
"dataset content: <KeysViewHDF5 ['brightness_temp', 'density', 'kwargs', 'params', 'redshifts_distances', 'seeds', 'xH_box']>\n",
|
| 587 |
"51200 images can be loaded\n",
|
| 588 |
"field.shape = (64, 64, 514)\n",
|
| 589 |
"params keys = [b'ION_Tvir_MIN', b'HII_EFF_FACTOR']\n",
|
| 590 |
+
"loading 1000 images randomly\n",
|
| 591 |
+
"images loaded: (1000, 1, 64, 64, 128)\n",
|
| 592 |
+
"params loaded: (1000, 2)\n"
|
| 593 |
]
|
| 594 |
},
|
| 595 |
{
|
|
|
|
| 603 |
"name": "stdout",
|
| 604 |
"output_type": "stream",
|
| 605 |
"text": [
|
| 606 |
+
"images rescaled to [-1.0, 1.2339041233062744]\n",
|
| 607 |
+
"params rescaled to [0.0006788479393025145, 0.9997530171043563]\n"
|
|
|
|
| 608 |
]
|
| 609 |
},
|
| 610 |
{
|
| 611 |
"data": {
|
| 612 |
"application/vnd.jupyter.widget-view+json": {
|
| 613 |
+
"model_id": "1d069efc366d480ba2515135ccb18f6c",
|
| 614 |
"version_major": 2,
|
| 615 |
"version_minor": 0
|
| 616 |
},
|
| 617 |
"text/plain": [
|
| 618 |
+
" 0%| | 0/500 [00:00<?, ?it/s]"
|
| 619 |
]
|
| 620 |
},
|
| 621 |
"metadata": {},
|
|
|
|
| 624 |
{
|
| 625 |
"data": {
|
| 626 |
"application/vnd.jupyter.widget-view+json": {
|
| 627 |
+
"model_id": "93abe7efc5044d3796f67a0d243058b2",
|
| 628 |
"version_major": 2,
|
| 629 |
"version_minor": 0
|
| 630 |
},
|
| 631 |
"text/plain": [
|
| 632 |
+
" 0%| | 0/500 [00:00<?, ?it/s]"
|
| 633 |
]
|
| 634 |
},
|
| 635 |
"metadata": {},
|
|
|
|
| 638 |
{
|
| 639 |
"data": {
|
| 640 |
"application/vnd.jupyter.widget-view+json": {
|
| 641 |
+
"model_id": "75e22a8998244becb916714cf4b8053e",
|
| 642 |
"version_major": 2,
|
| 643 |
"version_minor": 0
|
| 644 |
},
|
| 645 |
"text/plain": [
|
| 646 |
+
" 0%| | 0/500 [00:00<?, ?it/s]"
|
| 647 |
]
|
| 648 |
},
|
| 649 |
"metadata": {},
|
|
|
|
| 652 |
{
|
| 653 |
"data": {
|
| 654 |
"application/vnd.jupyter.widget-view+json": {
|
| 655 |
+
"model_id": "7364f0924d2246e49fa02a596e0afde8",
|
| 656 |
"version_major": 2,
|
| 657 |
"version_minor": 0
|
| 658 |
},
|
| 659 |
"text/plain": [
|
| 660 |
+
" 0%| | 0/500 [00:00<?, ?it/s]"
|
| 661 |
]
|
| 662 |
},
|
| 663 |
"metadata": {},
|
|
|
|
| 666 |
{
|
| 667 |
"data": {
|
| 668 |
"application/vnd.jupyter.widget-view+json": {
|
| 669 |
+
"model_id": "a5c57b0059154c2695e7fc0d426bb0d9",
|
| 670 |
"version_major": 2,
|
| 671 |
"version_minor": 0
|
| 672 |
},
|
| 673 |
"text/plain": [
|
| 674 |
+
" 0%| | 0/500 [00:00<?, ?it/s]"
|
| 675 |
]
|
| 676 |
},
|
| 677 |
"metadata": {},
|
|
|
|
| 680 |
{
|
| 681 |
"data": {
|
| 682 |
"application/vnd.jupyter.widget-view+json": {
|
| 683 |
+
"model_id": "b0c9ee9c876144eab582e196e71ed83c",
|
| 684 |
"version_major": 2,
|
| 685 |
"version_minor": 0
|
| 686 |
},
|
| 687 |
"text/plain": [
|
| 688 |
+
" 0%| | 0/500 [00:00<?, ?it/s]"
|
| 689 |
]
|
| 690 |
},
|
| 691 |
"metadata": {},
|
|
|
|
| 694 |
{
|
| 695 |
"data": {
|
| 696 |
"application/vnd.jupyter.widget-view+json": {
|
| 697 |
+
"model_id": "d2956adb8f4c45938f2d8f93baca2998",
|
| 698 |
"version_major": 2,
|
| 699 |
"version_minor": 0
|
| 700 |
},
|
| 701 |
"text/plain": [
|
| 702 |
+
" 0%| | 0/500 [00:00<?, ?it/s]"
|
| 703 |
]
|
| 704 |
},
|
| 705 |
"metadata": {},
|
|
|
|
| 708 |
{
|
| 709 |
"data": {
|
| 710 |
"application/vnd.jupyter.widget-view+json": {
|
| 711 |
+
"model_id": "1f3031b4cd1040829ce3b457ee30e464",
|
| 712 |
"version_major": 2,
|
| 713 |
"version_minor": 0
|
| 714 |
},
|
| 715 |
"text/plain": [
|
| 716 |
+
" 0%| | 0/500 [00:00<?, ?it/s]"
|
| 717 |
]
|
| 718 |
},
|
| 719 |
"metadata": {},
|
|
|
|
| 722 |
{
|
| 723 |
"data": {
|
| 724 |
"application/vnd.jupyter.widget-view+json": {
|
| 725 |
+
"model_id": "b4dcf375da4c4c4492d308e4bdc19358",
|
| 726 |
"version_major": 2,
|
| 727 |
"version_minor": 0
|
| 728 |
},
|
| 729 |
"text/plain": [
|
| 730 |
+
" 0%| | 0/500 [00:00<?, ?it/s]"
|
| 731 |
]
|
| 732 |
},
|
| 733 |
"metadata": {},
|
|
|
|
| 736 |
{
|
| 737 |
"data": {
|
| 738 |
"application/vnd.jupyter.widget-view+json": {
|
| 739 |
+
"model_id": "50c6936f333a451394d6761fa9a085be",
|
| 740 |
"version_major": 2,
|
| 741 |
"version_minor": 0
|
| 742 |
},
|
| 743 |
"text/plain": [
|
| 744 |
+
" 0%| | 0/500 [00:00<?, ?it/s]"
|
| 745 |
]
|
| 746 |
},
|
| 747 |
"metadata": {},
|
|
|
|
| 749 |
}
|
| 750 |
],
|
| 751 |
"source": [
|
| 752 |
+
"num_image_list = [1000]#[200]#[1600,3200,6400,12800,25600]\n",
|
| 753 |
"if __name__ == \"__main__\":\n",
|
| 754 |
" # torch.multiprocessing.set_start_method(\"spawn\")\n",
|
| 755 |
" # args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)\n",
|
|
|
|
| 762 |
" notebook_launcher(ddpm21cm.train, num_processes=1, mixed_precision='fp16')"
|
| 763 |
]
|
| 764 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 765 |
{
|
| 766 |
"cell_type": "code",
|
| 767 |
"execution_count": null,
|
|
|
|
| 770 |
"source": [
|
| 771 |
"if __name__ == \"__main__\":\n",
|
| 772 |
" # num_image_list = [1600,3200,6400,12800,25600]\n",
|
| 773 |
+
" num_image_list = [1000]\n",
|
| 774 |
" # num_image_list = [3200,6400,12800,25600]\n",
|
| 775 |
" # args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)\n",
|
| 776 |
+
" repeat = 2\n",
|
| 777 |
" config = TrainConfig()\n",
|
| 778 |
" for i, num_image in enumerate(num_image_list):\n",
|
| 779 |
" config.num_image = num_image\n",
|
|
|
|
| 805 |
"metadata": {},
|
| 806 |
"outputs": [],
|
| 807 |
"source": [
|
| 808 |
+
"def plot_grid(samples, c=None, row=1, col=2):\n",
|
| 809 |
" print(\"samples.shape =\", samples.shape)\n",
|
| 810 |
" for j in range(samples.shape[2]):\n",
|
| 811 |
+
" plt.figure(figsize = (12,6), dpi=400)\n",
|
| 812 |
" for i in range(len(samples)):\n",
|
| 813 |
" plt.subplot(row,col,i+1)\n",
|
| 814 |
" plt.imshow(samples[i,0,:,:,j], cmap='gray')#, vmin=-1, vmax=1)\n",
|
|
|
|
| 823 |
" plt.close()\n",
|
| 824 |
" # plt.show()\n",
|
| 825 |
" \n",
|
| 826 |
+
"data = np.load(\"outputs/Tvir4.400000095367432-zeta131.34100341796875-N1000.npy\")\n",
|
| 827 |
"# print(data.shape)\n",
|
| 828 |
"plot_grid(data)\n",
|
| 829 |
"# plt.imshow(data)"
|