0708-1342
Browse files- context_unet.py +3 -1
- diffusion.ipynb +100 -57
context_unet.py
CHANGED
|
@@ -327,8 +327,10 @@ class ContextUnet(nn.Module):
|
|
| 327 |
channel_mult = (1, 1, 2, 3, 4)
|
| 328 |
elif image_size == 64:
|
| 329 |
channel_mult = (1, 1, 2, 2, 4, 4)#(1, 2, 3, 4)
|
|
|
|
|
|
|
| 330 |
elif image_size == 28:
|
| 331 |
-
channel_mult = (1, 2)#(1, 2, 3, 4)
|
| 332 |
else:
|
| 333 |
raise ValueError(f"unsupported image size: {image_size}")
|
| 334 |
# else:
|
|
|
|
| 327 |
channel_mult = (1, 1, 2, 3, 4)
|
| 328 |
elif image_size == 64:
|
| 329 |
channel_mult = (1, 1, 2, 2, 4, 4)#(1, 2, 3, 4)
|
| 330 |
+
elif image_size == 32:
|
| 331 |
+
channel_mult = (1, 2, 2, 4)
|
| 332 |
elif image_size == 28:
|
| 333 |
+
channel_mult = (1, 2, 4)#(1, 2, 3, 4)
|
| 334 |
else:
|
| 335 |
raise ValueError(f"unsupported image size: {image_size}")
|
| 336 |
# else:
|
diffusion.ipynb
CHANGED
|
@@ -74,24 +74,9 @@
|
|
| 74 |
"cell_type": "code",
|
| 75 |
"execution_count": 2,
|
| 76 |
"metadata": {},
|
| 77 |
-
"outputs": [
|
| 78 |
-
{
|
| 79 |
-
"data": {
|
| 80 |
-
"application/vnd.jupyter.widget-view+json": {
|
| 81 |
-
"model_id": "9d207dc0a7a24bccb2f419954b64ddc0",
|
| 82 |
-
"version_major": 2,
|
| 83 |
-
"version_minor": 0
|
| 84 |
-
},
|
| 85 |
-
"text/plain": [
|
| 86 |
-
"VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…"
|
| 87 |
-
]
|
| 88 |
-
},
|
| 89 |
-
"metadata": {},
|
| 90 |
-
"output_type": "display_data"
|
| 91 |
-
}
|
| 92 |
-
],
|
| 93 |
"source": [
|
| 94 |
-
"notebook_login()"
|
| 95 |
]
|
| 96 |
},
|
| 97 |
{
|
|
@@ -272,12 +257,12 @@
|
|
| 272 |
"\n",
|
| 273 |
" # dim = 2\n",
|
| 274 |
" dim = 3\n",
|
| 275 |
-
" stride = (2,2) if dim == 2 else (2,2,
|
| 276 |
" num_image = 2000#32000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560\n",
|
| 277 |
-
" batch_size = 2#
|
| 278 |
" n_epoch = 10#50#20#20#2#5#25 # 120\n",
|
| 279 |
-
" HII_DIM = 64\n",
|
| 280 |
-
" num_redshift = 128#64#512#256#256#64#512#128\n",
|
| 281 |
" channel = 1\n",
|
| 282 |
" img_shape = (channel, HII_DIM, num_redshift) if dim == 2 else (channel, HII_DIM, HII_DIM, num_redshift)\n",
|
| 283 |
"\n",
|
|
@@ -579,17 +564,16 @@
|
|
| 579 |
"name": "stdout",
|
| 580 |
"output_type": "stream",
|
| 581 |
"text": [
|
| 582 |
-
"Number of parameters for nn_model:
|
| 583 |
-
"---------------- num_image =
|
| 584 |
-
"run_name =
|
| 585 |
"Launching training on one GPU.\n",
|
| 586 |
"dataset content: <KeysViewHDF5 ['brightness_temp', 'density', 'kwargs', 'params', 'redshifts_distances', 'seeds', 'xH_box']>\n",
|
| 587 |
"51200 images can be loaded\n",
|
| 588 |
"field.shape = (64, 64, 514)\n",
|
| 589 |
"params keys = [b'ION_Tvir_MIN', b'HII_EFF_FACTOR']\n",
|
| 590 |
-
"loading
|
| 591 |
-
"images loaded: (
|
| 592 |
-
"params loaded: (1000, 2)\n"
|
| 593 |
]
|
| 594 |
},
|
| 595 |
{
|
|
@@ -603,19 +587,20 @@
|
|
| 603 |
"name": "stdout",
|
| 604 |
"output_type": "stream",
|
| 605 |
"text": [
|
| 606 |
-
"
|
| 607 |
-
"
|
|
|
|
| 608 |
]
|
| 609 |
},
|
| 610 |
{
|
| 611 |
"data": {
|
| 612 |
"application/vnd.jupyter.widget-view+json": {
|
| 613 |
-
"model_id": "
|
| 614 |
"version_major": 2,
|
| 615 |
"version_minor": 0
|
| 616 |
},
|
| 617 |
"text/plain": [
|
| 618 |
-
" 0%| | 0/
|
| 619 |
]
|
| 620 |
},
|
| 621 |
"metadata": {},
|
|
@@ -624,12 +609,12 @@
|
|
| 624 |
{
|
| 625 |
"data": {
|
| 626 |
"application/vnd.jupyter.widget-view+json": {
|
| 627 |
-
"model_id": "
|
| 628 |
"version_major": 2,
|
| 629 |
"version_minor": 0
|
| 630 |
},
|
| 631 |
"text/plain": [
|
| 632 |
-
" 0%| | 0/
|
| 633 |
]
|
| 634 |
},
|
| 635 |
"metadata": {},
|
|
@@ -638,12 +623,12 @@
|
|
| 638 |
{
|
| 639 |
"data": {
|
| 640 |
"application/vnd.jupyter.widget-view+json": {
|
| 641 |
-
"model_id": "
|
| 642 |
"version_major": 2,
|
| 643 |
"version_minor": 0
|
| 644 |
},
|
| 645 |
"text/plain": [
|
| 646 |
-
" 0%| | 0/
|
| 647 |
]
|
| 648 |
},
|
| 649 |
"metadata": {},
|
|
@@ -652,12 +637,12 @@
|
|
| 652 |
{
|
| 653 |
"data": {
|
| 654 |
"application/vnd.jupyter.widget-view+json": {
|
| 655 |
-
"model_id": "
|
| 656 |
"version_major": 2,
|
| 657 |
"version_minor": 0
|
| 658 |
},
|
| 659 |
"text/plain": [
|
| 660 |
-
" 0%| | 0/
|
| 661 |
]
|
| 662 |
},
|
| 663 |
"metadata": {},
|
|
@@ -666,12 +651,12 @@
|
|
| 666 |
{
|
| 667 |
"data": {
|
| 668 |
"application/vnd.jupyter.widget-view+json": {
|
| 669 |
-
"model_id": "
|
| 670 |
"version_major": 2,
|
| 671 |
"version_minor": 0
|
| 672 |
},
|
| 673 |
"text/plain": [
|
| 674 |
-
" 0%| | 0/
|
| 675 |
]
|
| 676 |
},
|
| 677 |
"metadata": {},
|
|
@@ -680,12 +665,12 @@
|
|
| 680 |
{
|
| 681 |
"data": {
|
| 682 |
"application/vnd.jupyter.widget-view+json": {
|
| 683 |
-
"model_id": "
|
| 684 |
"version_major": 2,
|
| 685 |
"version_minor": 0
|
| 686 |
},
|
| 687 |
"text/plain": [
|
| 688 |
-
" 0%| | 0/
|
| 689 |
]
|
| 690 |
},
|
| 691 |
"metadata": {},
|
|
@@ -694,12 +679,12 @@
|
|
| 694 |
{
|
| 695 |
"data": {
|
| 696 |
"application/vnd.jupyter.widget-view+json": {
|
| 697 |
-
"model_id": "
|
| 698 |
"version_major": 2,
|
| 699 |
"version_minor": 0
|
| 700 |
},
|
| 701 |
"text/plain": [
|
| 702 |
-
" 0%| | 0/
|
| 703 |
]
|
| 704 |
},
|
| 705 |
"metadata": {},
|
|
@@ -708,12 +693,12 @@
|
|
| 708 |
{
|
| 709 |
"data": {
|
| 710 |
"application/vnd.jupyter.widget-view+json": {
|
| 711 |
-
"model_id": "
|
| 712 |
"version_major": 2,
|
| 713 |
"version_minor": 0
|
| 714 |
},
|
| 715 |
"text/plain": [
|
| 716 |
-
" 0%| | 0/
|
| 717 |
]
|
| 718 |
},
|
| 719 |
"metadata": {},
|
|
@@ -722,12 +707,12 @@
|
|
| 722 |
{
|
| 723 |
"data": {
|
| 724 |
"application/vnd.jupyter.widget-view+json": {
|
| 725 |
-
"model_id": "
|
| 726 |
"version_major": 2,
|
| 727 |
"version_minor": 0
|
| 728 |
},
|
| 729 |
"text/plain": [
|
| 730 |
-
" 0%| | 0/
|
| 731 |
]
|
| 732 |
},
|
| 733 |
"metadata": {},
|
|
@@ -736,12 +721,12 @@
|
|
| 736 |
{
|
| 737 |
"data": {
|
| 738 |
"application/vnd.jupyter.widget-view+json": {
|
| 739 |
-
"model_id": "
|
| 740 |
"version_major": 2,
|
| 741 |
"version_minor": 0
|
| 742 |
},
|
| 743 |
"text/plain": [
|
| 744 |
-
" 0%| | 0/
|
| 745 |
]
|
| 746 |
},
|
| 747 |
"metadata": {},
|
|
@@ -749,7 +734,7 @@
|
|
| 749 |
}
|
| 750 |
],
|
| 751 |
"source": [
|
| 752 |
-
"num_image_list = [1000]#[200]#[1600,3200,6400,12800,25600]\n",
|
| 753 |
"if __name__ == \"__main__\":\n",
|
| 754 |
" # torch.multiprocessing.set_start_method(\"spawn\")\n",
|
| 755 |
" # args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)\n",
|
|
@@ -759,7 +744,9 @@
|
|
| 759 |
" ddpm21cm = DDPM21CM(config)\n",
|
| 760 |
" print(f\" num_image = {ddpm21cm.config.num_image} \".center(50, '-'))\n",
|
| 761 |
" print(f\"run_name = {ddpm21cm.config.run_name}\")\n",
|
| 762 |
-
" notebook_launcher(
|
|
|
|
|
|
|
| 763 |
]
|
| 764 |
},
|
| 765 |
{
|
|
@@ -767,6 +754,37 @@
|
|
| 767 |
"execution_count": null,
|
| 768 |
"metadata": {},
|
| 769 |
"outputs": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 770 |
"source": [
|
| 771 |
"if __name__ == \"__main__\":\n",
|
| 772 |
" # num_image_list = [1600,3200,6400,12800,25600]\n",
|
|
@@ -792,22 +810,47 @@
|
|
| 792 |
},
|
| 793 |
{
|
| 794 |
"cell_type": "code",
|
| 795 |
-
"execution_count":
|
| 796 |
"metadata": {},
|
| 797 |
-
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 798 |
"source": [
|
| 799 |
"ls -lth outputs | head"
|
| 800 |
]
|
| 801 |
},
|
| 802 |
{
|
| 803 |
"cell_type": "code",
|
| 804 |
-
"execution_count":
|
| 805 |
"metadata": {},
|
| 806 |
-
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 807 |
"source": [
|
| 808 |
"def plot_grid(samples, c=None, row=1, col=2):\n",
|
| 809 |
" print(\"samples.shape =\", samples.shape)\n",
|
| 810 |
-
" for j in range(samples.shape[
|
| 811 |
" plt.figure(figsize = (12,6), dpi=400)\n",
|
| 812 |
" for i in range(len(samples)):\n",
|
| 813 |
" plt.subplot(row,col,i+1)\n",
|
|
@@ -819,7 +862,7 @@
|
|
| 819 |
" # plt.suptitle('simulations')\n",
|
| 820 |
" plt.tight_layout()\n",
|
| 821 |
" plt.subplots_adjust(wspace=0, hspace=0)\n",
|
| 822 |
-
" plt.savefig(f\"test3D-{j:
|
| 823 |
" plt.close()\n",
|
| 824 |
" # plt.show()\n",
|
| 825 |
" \n",
|
|
|
|
| 74 |
"cell_type": "code",
|
| 75 |
"execution_count": 2,
|
| 76 |
"metadata": {},
|
| 77 |
+
"outputs": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
"source": [
|
| 79 |
+
"# notebook_login()"
|
| 80 |
]
|
| 81 |
},
|
| 82 |
{
|
|
|
|
| 257 |
"\n",
|
| 258 |
" # dim = 2\n",
|
| 259 |
" dim = 3\n",
|
| 260 |
+
" stride = (2,2) if dim == 2 else (2,2,1)\n",
|
| 261 |
" num_image = 2000#32000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560\n",
|
| 262 |
+
" batch_size = 2#50#20#2#100 # 10\n",
|
| 263 |
" n_epoch = 10#50#20#20#2#5#25 # 120\n",
|
| 264 |
+
" HII_DIM = 32#64\n",
|
| 265 |
+
" num_redshift = 4#128#64#512#256#256#64#512#128\n",
|
| 266 |
" channel = 1\n",
|
| 267 |
" img_shape = (channel, HII_DIM, num_redshift) if dim == 2 else (channel, HII_DIM, HII_DIM, num_redshift)\n",
|
| 268 |
"\n",
|
|
|
|
| 564 |
"name": "stdout",
|
| 565 |
"output_type": "stream",
|
| 566 |
"text": [
|
| 567 |
+
"Number of parameters for nn_model: 190142209\n",
|
| 568 |
+
"---------------- num_image = 100 -----------------\n",
|
| 569 |
+
"run_name = 0708-1342\n",
|
| 570 |
"Launching training on one GPU.\n",
|
| 571 |
"dataset content: <KeysViewHDF5 ['brightness_temp', 'density', 'kwargs', 'params', 'redshifts_distances', 'seeds', 'xH_box']>\n",
|
| 572 |
"51200 images can be loaded\n",
|
| 573 |
"field.shape = (64, 64, 514)\n",
|
| 574 |
"params keys = [b'ION_Tvir_MIN', b'HII_EFF_FACTOR']\n",
|
| 575 |
+
"loading 100 images randomly\n",
|
| 576 |
+
"images loaded: (100, 1, 32, 32, 4)\n"
|
|
|
|
| 577 |
]
|
| 578 |
},
|
| 579 |
{
|
|
|
|
| 587 |
"name": "stdout",
|
| 588 |
"output_type": "stream",
|
| 589 |
"text": [
|
| 590 |
+
"params loaded: (100, 2)\n",
|
| 591 |
+
"images rescaled to [-1.0, 1.2789411544799805]\n",
|
| 592 |
+
"params rescaled to [0.004197723271926046, 0.9944779188934443]\n"
|
| 593 |
]
|
| 594 |
},
|
| 595 |
{
|
| 596 |
"data": {
|
| 597 |
"application/vnd.jupyter.widget-view+json": {
|
| 598 |
+
"model_id": "a9bfadac7d3841c9a5a8c3440649c4f0",
|
| 599 |
"version_major": 2,
|
| 600 |
"version_minor": 0
|
| 601 |
},
|
| 602 |
"text/plain": [
|
| 603 |
+
" 0%| | 0/50 [00:00<?, ?it/s]"
|
| 604 |
]
|
| 605 |
},
|
| 606 |
"metadata": {},
|
|
|
|
| 609 |
{
|
| 610 |
"data": {
|
| 611 |
"application/vnd.jupyter.widget-view+json": {
|
| 612 |
+
"model_id": "9df4310f213742d9a7aae110fca32403",
|
| 613 |
"version_major": 2,
|
| 614 |
"version_minor": 0
|
| 615 |
},
|
| 616 |
"text/plain": [
|
| 617 |
+
" 0%| | 0/50 [00:00<?, ?it/s]"
|
| 618 |
]
|
| 619 |
},
|
| 620 |
"metadata": {},
|
|
|
|
| 623 |
{
|
| 624 |
"data": {
|
| 625 |
"application/vnd.jupyter.widget-view+json": {
|
| 626 |
+
"model_id": "399101df4a5f4de8a4a3f155b3ade75b",
|
| 627 |
"version_major": 2,
|
| 628 |
"version_minor": 0
|
| 629 |
},
|
| 630 |
"text/plain": [
|
| 631 |
+
" 0%| | 0/50 [00:00<?, ?it/s]"
|
| 632 |
]
|
| 633 |
},
|
| 634 |
"metadata": {},
|
|
|
|
| 637 |
{
|
| 638 |
"data": {
|
| 639 |
"application/vnd.jupyter.widget-view+json": {
|
| 640 |
+
"model_id": "ea4834c350594a9c9cbd87727a88a6b8",
|
| 641 |
"version_major": 2,
|
| 642 |
"version_minor": 0
|
| 643 |
},
|
| 644 |
"text/plain": [
|
| 645 |
+
" 0%| | 0/50 [00:00<?, ?it/s]"
|
| 646 |
]
|
| 647 |
},
|
| 648 |
"metadata": {},
|
|
|
|
| 651 |
{
|
| 652 |
"data": {
|
| 653 |
"application/vnd.jupyter.widget-view+json": {
|
| 654 |
+
"model_id": "b7b0d9a8c2ad456387dc1b053550c702",
|
| 655 |
"version_major": 2,
|
| 656 |
"version_minor": 0
|
| 657 |
},
|
| 658 |
"text/plain": [
|
| 659 |
+
" 0%| | 0/50 [00:00<?, ?it/s]"
|
| 660 |
]
|
| 661 |
},
|
| 662 |
"metadata": {},
|
|
|
|
| 665 |
{
|
| 666 |
"data": {
|
| 667 |
"application/vnd.jupyter.widget-view+json": {
|
| 668 |
+
"model_id": "17e73c3722d64ae895f337a7379b5225",
|
| 669 |
"version_major": 2,
|
| 670 |
"version_minor": 0
|
| 671 |
},
|
| 672 |
"text/plain": [
|
| 673 |
+
" 0%| | 0/50 [00:00<?, ?it/s]"
|
| 674 |
]
|
| 675 |
},
|
| 676 |
"metadata": {},
|
|
|
|
| 679 |
{
|
| 680 |
"data": {
|
| 681 |
"application/vnd.jupyter.widget-view+json": {
|
| 682 |
+
"model_id": "0a743e1a2db2445d93533c9ec5ed921f",
|
| 683 |
"version_major": 2,
|
| 684 |
"version_minor": 0
|
| 685 |
},
|
| 686 |
"text/plain": [
|
| 687 |
+
" 0%| | 0/50 [00:00<?, ?it/s]"
|
| 688 |
]
|
| 689 |
},
|
| 690 |
"metadata": {},
|
|
|
|
| 693 |
{
|
| 694 |
"data": {
|
| 695 |
"application/vnd.jupyter.widget-view+json": {
|
| 696 |
+
"model_id": "cfda06e79a0e4f8b8172e9314263fb5b",
|
| 697 |
"version_major": 2,
|
| 698 |
"version_minor": 0
|
| 699 |
},
|
| 700 |
"text/plain": [
|
| 701 |
+
" 0%| | 0/50 [00:00<?, ?it/s]"
|
| 702 |
]
|
| 703 |
},
|
| 704 |
"metadata": {},
|
|
|
|
| 707 |
{
|
| 708 |
"data": {
|
| 709 |
"application/vnd.jupyter.widget-view+json": {
|
| 710 |
+
"model_id": "8cb50f206a844a9da91197c2a9ed715b",
|
| 711 |
"version_major": 2,
|
| 712 |
"version_minor": 0
|
| 713 |
},
|
| 714 |
"text/plain": [
|
| 715 |
+
" 0%| | 0/50 [00:00<?, ?it/s]"
|
| 716 |
]
|
| 717 |
},
|
| 718 |
"metadata": {},
|
|
|
|
| 721 |
{
|
| 722 |
"data": {
|
| 723 |
"application/vnd.jupyter.widget-view+json": {
|
| 724 |
+
"model_id": "dbaed562dee44cddb3b0d17f439464b6",
|
| 725 |
"version_major": 2,
|
| 726 |
"version_minor": 0
|
| 727 |
},
|
| 728 |
"text/plain": [
|
| 729 |
+
" 0%| | 0/50 [00:00<?, ?it/s]"
|
| 730 |
]
|
| 731 |
},
|
| 732 |
"metadata": {},
|
|
|
|
| 734 |
}
|
| 735 |
],
|
| 736 |
"source": [
|
| 737 |
+
"num_image_list = [100]#[1000]#[200]#[1600,3200,6400,12800,25600]\n",
|
| 738 |
"if __name__ == \"__main__\":\n",
|
| 739 |
" # torch.multiprocessing.set_start_method(\"spawn\")\n",
|
| 740 |
" # args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)\n",
|
|
|
|
| 744 |
" ddpm21cm = DDPM21CM(config)\n",
|
| 745 |
" print(f\" num_image = {ddpm21cm.config.num_image} \".center(50, '-'))\n",
|
| 746 |
" print(f\"run_name = {ddpm21cm.config.run_name}\")\n",
|
| 747 |
+
" notebook_launcher(\n",
|
| 748 |
+
" ddpm21cm.train, num_processes=1#, mixed_precision='fp16'\n",
|
| 749 |
+
" )"
|
| 750 |
]
|
| 751 |
},
|
| 752 |
{
|
|
|
|
| 754 |
"execution_count": null,
|
| 755 |
"metadata": {},
|
| 756 |
"outputs": [],
|
| 757 |
+
"source": []
|
| 758 |
+
},
|
| 759 |
+
{
|
| 760 |
+
"cell_type": "code",
|
| 761 |
+
"execution_count": 15,
|
| 762 |
+
"metadata": {},
|
| 763 |
+
"outputs": [
|
| 764 |
+
{
|
| 765 |
+
"name": "stdout",
|
| 766 |
+
"output_type": "stream",
|
| 767 |
+
"text": [
|
| 768 |
+
"Number of parameters for nn_model: 306285057\n",
|
| 769 |
+
"sampling 2 images with normalized params = tensor([[0.2000, 0.5056]])\n",
|
| 770 |
+
"nn_model resumed from ./outputs/model_state-N1000\n"
|
| 771 |
+
]
|
| 772 |
+
},
|
| 773 |
+
{
|
| 774 |
+
"data": {
|
| 775 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 776 |
+
"model_id": "69eb2e5d3375414cab966a9c8db91901",
|
| 777 |
+
"version_major": 2,
|
| 778 |
+
"version_minor": 0
|
| 779 |
+
},
|
| 780 |
+
"text/plain": [
|
| 781 |
+
" 0%| | 0/1000 [00:00<?, ?it/s]"
|
| 782 |
+
]
|
| 783 |
+
},
|
| 784 |
+
"metadata": {},
|
| 785 |
+
"output_type": "display_data"
|
| 786 |
+
}
|
| 787 |
+
],
|
| 788 |
"source": [
|
| 789 |
"if __name__ == \"__main__\":\n",
|
| 790 |
" # num_image_list = [1600,3200,6400,12800,25600]\n",
|
|
|
|
| 810 |
},
|
| 811 |
{
|
| 812 |
"cell_type": "code",
|
| 813 |
+
"execution_count": 19,
|
| 814 |
"metadata": {},
|
| 815 |
+
"outputs": [
|
| 816 |
+
{
|
| 817 |
+
"name": "stdout",
|
| 818 |
+
"output_type": "stream",
|
| 819 |
+
"text": [
|
| 820 |
+
"total 13G\n",
|
| 821 |
+
"-rw-r--r-- 1 bxia34 pace-jw254 4.1M Jul 6 23:59 Tvir4.400000095367432-zeta131.34100341796875-N1000.npy\n",
|
| 822 |
+
"-rw-r--r-- 1 bxia34 pace-jw254 2.3G Jul 6 23:06 model_state-N1000\n",
|
| 823 |
+
"drwxr-xr-x 56 bxia34 pace-jw254 4.0K Jul 6 22:09 \u001b[0m\u001b[01;34mlogs\u001b[0m/\n",
|
| 824 |
+
"-rw-r--r-- 1 bxia34 pace-jw254 4.1M Jul 6 21:39 Tvir4.400000095367432-zeta131.34100341796875-N50.npy\n",
|
| 825 |
+
"-rw-r--r-- 1 bxia34 pace-jw254 2.3G Jul 6 21:25 model_state-N50\n",
|
| 826 |
+
"-rw-r--r-- 1 bxia34 pace-jw254 193K Jul 6 20:46 Tvir4.400000095367432-zeta131.34100341796875-N20.npy\n",
|
| 827 |
+
"-rw-r--r-- 1 bxia34 pace-jw254 848M Jul 6 20:45 model_state-N20\n",
|
| 828 |
+
"-rw-r--r-- 1 bxia34 pace-jw254 6.1M Jul 5 14:44 Tvir4.400000095367432-zeta131.34100341796875-N200.npy\n",
|
| 829 |
+
"-rw-r--r-- 1 bxia34 pace-jw254 2.3G Jul 5 12:20 model_state-N200\n"
|
| 830 |
+
]
|
| 831 |
+
}
|
| 832 |
+
],
|
| 833 |
"source": [
|
| 834 |
"ls -lth outputs | head"
|
| 835 |
]
|
| 836 |
},
|
| 837 |
{
|
| 838 |
"cell_type": "code",
|
| 839 |
+
"execution_count": 21,
|
| 840 |
"metadata": {},
|
| 841 |
+
"outputs": [
|
| 842 |
+
{
|
| 843 |
+
"name": "stdout",
|
| 844 |
+
"output_type": "stream",
|
| 845 |
+
"text": [
|
| 846 |
+
"samples.shape = (2, 1, 64, 64, 128)\n"
|
| 847 |
+
]
|
| 848 |
+
}
|
| 849 |
+
],
|
| 850 |
"source": [
|
| 851 |
"def plot_grid(samples, c=None, row=1, col=2):\n",
|
| 852 |
" print(\"samples.shape =\", samples.shape)\n",
|
| 853 |
+
" for j in range(samples.shape[4]):\n",
|
| 854 |
" plt.figure(figsize = (12,6), dpi=400)\n",
|
| 855 |
" for i in range(len(samples)):\n",
|
| 856 |
" plt.subplot(row,col,i+1)\n",
|
|
|
|
| 862 |
" # plt.suptitle('simulations')\n",
|
| 863 |
" plt.tight_layout()\n",
|
| 864 |
" plt.subplots_adjust(wspace=0, hspace=0)\n",
|
| 865 |
+
" plt.savefig(f\"test3D-{j:03d}.png\")\n",
|
| 866 |
" plt.close()\n",
|
| 867 |
" # plt.show()\n",
|
| 868 |
" \n",
|