0709-1331
Browse files- diffusion.ipynb +31 -193
diffusion.ipynb
CHANGED
|
@@ -259,9 +259,9 @@
|
|
| 259 |
" dim = 3\n",
|
| 260 |
" stride = (2,2) if dim == 2 else (2,2,1)\n",
|
| 261 |
" num_image = 2000#32000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560\n",
|
| 262 |
-
" batch_size = 2#50#20#2#100 # 10\n",
|
| 263 |
" n_epoch = 10#50#20#20#2#5#25 # 120\n",
|
| 264 |
-
" HII_DIM =
|
| 265 |
" num_redshift = 4#128#64#512#256#256#64#512#128\n",
|
| 266 |
" channel = 1\n",
|
| 267 |
" img_shape = (channel, HII_DIM, num_redshift) if dim == 2 else (channel, HII_DIM, HII_DIM, num_redshift)\n",
|
|
@@ -564,16 +564,15 @@
|
|
| 564 |
"name": "stdout",
|
| 565 |
"output_type": "stream",
|
| 566 |
"text": [
|
| 567 |
-
"Number of parameters for nn_model:
|
| 568 |
"---------------- num_image = 100 -----------------\n",
|
| 569 |
-
"run_name =
|
| 570 |
-
"Launching training on one GPU.\n",
|
| 571 |
"dataset content: <KeysViewHDF5 ['brightness_temp', 'density', 'kwargs', 'params', 'redshifts_distances', 'seeds', 'xH_box']>\n",
|
| 572 |
"51200 images can be loaded\n",
|
| 573 |
"field.shape = (64, 64, 514)\n",
|
| 574 |
"params keys = [b'ION_Tvir_MIN', b'HII_EFF_FACTOR']\n",
|
| 575 |
"loading 100 images randomly\n",
|
| 576 |
-
"images loaded: (100, 1,
|
| 577 |
]
|
| 578 |
},
|
| 579 |
{
|
|
@@ -588,14 +587,14 @@
|
|
| 588 |
"output_type": "stream",
|
| 589 |
"text": [
|
| 590 |
"params loaded: (100, 2)\n",
|
| 591 |
-
"images rescaled to [-1.0, 1.
|
| 592 |
-
"params rescaled to [0.
|
| 593 |
]
|
| 594 |
},
|
| 595 |
{
|
| 596 |
"data": {
|
| 597 |
"application/vnd.jupyter.widget-view+json": {
|
| 598 |
-
"model_id": "
|
| 599 |
"version_major": 2,
|
| 600 |
"version_minor": 0
|
| 601 |
},
|
|
@@ -609,7 +608,7 @@
|
|
| 609 |
{
|
| 610 |
"data": {
|
| 611 |
"application/vnd.jupyter.widget-view+json": {
|
| 612 |
-
"model_id": "
|
| 613 |
"version_major": 2,
|
| 614 |
"version_minor": 0
|
| 615 |
},
|
|
@@ -623,7 +622,7 @@
|
|
| 623 |
{
|
| 624 |
"data": {
|
| 625 |
"application/vnd.jupyter.widget-view+json": {
|
| 626 |
-
"model_id": "
|
| 627 |
"version_major": 2,
|
| 628 |
"version_minor": 0
|
| 629 |
},
|
|
@@ -637,7 +636,7 @@
|
|
| 637 |
{
|
| 638 |
"data": {
|
| 639 |
"application/vnd.jupyter.widget-view+json": {
|
| 640 |
-
"model_id": "
|
| 641 |
"version_major": 2,
|
| 642 |
"version_minor": 0
|
| 643 |
},
|
|
@@ -651,7 +650,7 @@
|
|
| 651 |
{
|
| 652 |
"data": {
|
| 653 |
"application/vnd.jupyter.widget-view+json": {
|
| 654 |
-
"model_id": "
|
| 655 |
"version_major": 2,
|
| 656 |
"version_minor": 0
|
| 657 |
},
|
|
@@ -665,7 +664,7 @@
|
|
| 665 |
{
|
| 666 |
"data": {
|
| 667 |
"application/vnd.jupyter.widget-view+json": {
|
| 668 |
-
"model_id": "
|
| 669 |
"version_major": 2,
|
| 670 |
"version_minor": 0
|
| 671 |
},
|
|
@@ -679,7 +678,7 @@
|
|
| 679 |
{
|
| 680 |
"data": {
|
| 681 |
"application/vnd.jupyter.widget-view+json": {
|
| 682 |
-
"model_id": "
|
| 683 |
"version_major": 2,
|
| 684 |
"version_minor": 0
|
| 685 |
},
|
|
@@ -693,7 +692,7 @@
|
|
| 693 |
{
|
| 694 |
"data": {
|
| 695 |
"application/vnd.jupyter.widget-view+json": {
|
| 696 |
-
"model_id": "
|
| 697 |
"version_major": 2,
|
| 698 |
"version_minor": 0
|
| 699 |
},
|
|
@@ -707,7 +706,7 @@
|
|
| 707 |
{
|
| 708 |
"data": {
|
| 709 |
"application/vnd.jupyter.widget-view+json": {
|
| 710 |
-
"model_id": "
|
| 711 |
"version_major": 2,
|
| 712 |
"version_minor": 0
|
| 713 |
},
|
|
@@ -721,7 +720,7 @@
|
|
| 721 |
{
|
| 722 |
"data": {
|
| 723 |
"application/vnd.jupyter.widget-view+json": {
|
| 724 |
-
"model_id": "
|
| 725 |
"version_major": 2,
|
| 726 |
"version_minor": 0
|
| 727 |
},
|
|
@@ -734,7 +733,7 @@
|
|
| 734 |
}
|
| 735 |
],
|
| 736 |
"source": [
|
| 737 |
-
"num_image_list = [100]#[
|
| 738 |
"if __name__ == \"__main__\":\n",
|
| 739 |
" # torch.multiprocessing.set_start_method(\"spawn\")\n",
|
| 740 |
" # args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)\n",
|
|
@@ -744,47 +743,23 @@
|
|
| 744 |
" ddpm21cm = DDPM21CM(config)\n",
|
| 745 |
" print(f\" num_image = {ddpm21cm.config.num_image} \".center(50, '-'))\n",
|
| 746 |
" print(f\"run_name = {ddpm21cm.config.run_name}\")\n",
|
| 747 |
-
"
|
| 748 |
-
"
|
| 749 |
-
" )"
|
| 750 |
]
|
| 751 |
},
|
| 752 |
{
|
| 753 |
-
"
|
| 754 |
-
"
|
| 755 |
"metadata": {},
|
| 756 |
-
"
|
| 757 |
-
|
|
|
|
| 758 |
},
|
| 759 |
{
|
| 760 |
"cell_type": "code",
|
| 761 |
-
"execution_count":
|
| 762 |
"metadata": {},
|
| 763 |
-
"outputs": [
|
| 764 |
-
{
|
| 765 |
-
"name": "stdout",
|
| 766 |
-
"output_type": "stream",
|
| 767 |
-
"text": [
|
| 768 |
-
"Number of parameters for nn_model: 306285057\n",
|
| 769 |
-
"sampling 2 images with normalized params = tensor([[0.2000, 0.5056]])\n",
|
| 770 |
-
"nn_model resumed from ./outputs/model_state-N1000\n"
|
| 771 |
-
]
|
| 772 |
-
},
|
| 773 |
-
{
|
| 774 |
-
"data": {
|
| 775 |
-
"application/vnd.jupyter.widget-view+json": {
|
| 776 |
-
"model_id": "69eb2e5d3375414cab966a9c8db91901",
|
| 777 |
-
"version_major": 2,
|
| 778 |
-
"version_minor": 0
|
| 779 |
-
},
|
| 780 |
-
"text/plain": [
|
| 781 |
-
" 0%| | 0/1000 [00:00<?, ?it/s]"
|
| 782 |
-
]
|
| 783 |
-
},
|
| 784 |
-
"metadata": {},
|
| 785 |
-
"output_type": "display_data"
|
| 786 |
-
}
|
| 787 |
-
],
|
| 788 |
"source": [
|
| 789 |
"if __name__ == \"__main__\":\n",
|
| 790 |
" # num_image_list = [1600,3200,6400,12800,25600]\n",
|
|
@@ -810,43 +785,18 @@
|
|
| 810 |
},
|
| 811 |
{
|
| 812 |
"cell_type": "code",
|
| 813 |
-
"execution_count":
|
| 814 |
"metadata": {},
|
| 815 |
-
"outputs": [
|
| 816 |
-
{
|
| 817 |
-
"name": "stdout",
|
| 818 |
-
"output_type": "stream",
|
| 819 |
-
"text": [
|
| 820 |
-
"total 13G\n",
|
| 821 |
-
"-rw-r--r-- 1 bxia34 pace-jw254 4.1M Jul 6 23:59 Tvir4.400000095367432-zeta131.34100341796875-N1000.npy\n",
|
| 822 |
-
"-rw-r--r-- 1 bxia34 pace-jw254 2.3G Jul 6 23:06 model_state-N1000\n",
|
| 823 |
-
"drwxr-xr-x 56 bxia34 pace-jw254 4.0K Jul 6 22:09 \u001b[0m\u001b[01;34mlogs\u001b[0m/\n",
|
| 824 |
-
"-rw-r--r-- 1 bxia34 pace-jw254 4.1M Jul 6 21:39 Tvir4.400000095367432-zeta131.34100341796875-N50.npy\n",
|
| 825 |
-
"-rw-r--r-- 1 bxia34 pace-jw254 2.3G Jul 6 21:25 model_state-N50\n",
|
| 826 |
-
"-rw-r--r-- 1 bxia34 pace-jw254 193K Jul 6 20:46 Tvir4.400000095367432-zeta131.34100341796875-N20.npy\n",
|
| 827 |
-
"-rw-r--r-- 1 bxia34 pace-jw254 848M Jul 6 20:45 model_state-N20\n",
|
| 828 |
-
"-rw-r--r-- 1 bxia34 pace-jw254 6.1M Jul 5 14:44 Tvir4.400000095367432-zeta131.34100341796875-N200.npy\n",
|
| 829 |
-
"-rw-r--r-- 1 bxia34 pace-jw254 2.3G Jul 5 12:20 model_state-N200\n"
|
| 830 |
-
]
|
| 831 |
-
}
|
| 832 |
-
],
|
| 833 |
"source": [
|
| 834 |
"ls -lth outputs | head"
|
| 835 |
]
|
| 836 |
},
|
| 837 |
{
|
| 838 |
"cell_type": "code",
|
| 839 |
-
"execution_count":
|
| 840 |
"metadata": {},
|
| 841 |
-
"outputs": [
|
| 842 |
-
{
|
| 843 |
-
"name": "stdout",
|
| 844 |
-
"output_type": "stream",
|
| 845 |
-
"text": [
|
| 846 |
-
"samples.shape = (2, 1, 64, 64, 128)\n"
|
| 847 |
-
]
|
| 848 |
-
}
|
| 849 |
-
],
|
| 850 |
"source": [
|
| 851 |
"def plot_grid(samples, c=None, row=1, col=2):\n",
|
| 852 |
" print(\"samples.shape =\", samples.shape)\n",
|
|
@@ -899,118 +849,6 @@
|
|
| 899 |
"# # plt.imshow(images[0,0])\n",
|
| 900 |
"# # plt.show()"
|
| 901 |
]
|
| 902 |
-
},
|
| 903 |
-
{
|
| 904 |
-
"cell_type": "code",
|
| 905 |
-
"execution_count": null,
|
| 906 |
-
"metadata": {},
|
| 907 |
-
"outputs": [],
|
| 908 |
-
"source": [
|
| 909 |
-
"# plot(\"outputs/0528-1433.npy\")\n",
|
| 910 |
-
"# plot(\"outputs/0520-2323.npy\")\n",
|
| 911 |
-
"# plot(\"outputs/0604-2353.npy\")"
|
| 912 |
-
]
|
| 913 |
-
},
|
| 914 |
-
{
|
| 915 |
-
"cell_type": "code",
|
| 916 |
-
"execution_count": null,
|
| 917 |
-
"metadata": {},
|
| 918 |
-
"outputs": [],
|
| 919 |
-
"source": [
|
| 920 |
-
"# x = np.load(\"outputs/0528-1433.npy\")\n",
|
| 921 |
-
"# print(x.shape)"
|
| 922 |
-
]
|
| 923 |
-
},
|
| 924 |
-
{
|
| 925 |
-
"cell_type": "code",
|
| 926 |
-
"execution_count": null,
|
| 927 |
-
"metadata": {},
|
| 928 |
-
"outputs": [],
|
| 929 |
-
"source": [
|
| 930 |
-
"import torch\n",
|
| 931 |
-
"import torch.nn as nn\n",
|
| 932 |
-
"import time\n",
|
| 933 |
-
"\n",
|
| 934 |
-
"class MyModel(nn.Module):\n",
|
| 935 |
-
" def __init__(self):\n",
|
| 936 |
-
" super().__init__()\n",
|
| 937 |
-
" self.fc = nn.Linear(100,50)\n",
|
| 938 |
-
"\n",
|
| 939 |
-
" def forward(self, x):\n",
|
| 940 |
-
" return self.fc(x)\n",
|
| 941 |
-
"\n",
|
| 942 |
-
"model = MyModel()\n",
|
| 943 |
-
"\n",
|
| 944 |
-
"device_count = torch.cuda.device_count()\n",
|
| 945 |
-
"print(\"device_count =\", device_count)\n",
|
| 946 |
-
"\n",
|
| 947 |
-
"if device_count > 1:\n",
|
| 948 |
-
" print(f\"using {device_count} GPUs!\")\n",
|
| 949 |
-
" model = nn.DataParallel(model)\n",
|
| 950 |
-
"\n",
|
| 951 |
-
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
|
| 952 |
-
"model.to(device)\n",
|
| 953 |
-
"\n",
|
| 954 |
-
"start_time = time.time()\n",
|
| 955 |
-
"for i in range(10):\n",
|
| 956 |
-
" myinput = torch.randn(10,10,32000,100).to(device)\n",
|
| 957 |
-
" output = model(myinput)\n",
|
| 958 |
-
" print(output.shape)\n",
|
| 959 |
-
"# plt.imshow(myinput.cpu()[0])\n",
|
| 960 |
-
"# plt.show()\n",
|
| 961 |
-
"# plt.imshow(output.detach().cpu().numpy()[0])\n",
|
| 962 |
-
"# plt.show()"
|
| 963 |
-
]
|
| 964 |
-
},
|
| 965 |
-
{
|
| 966 |
-
"cell_type": "code",
|
| 967 |
-
"execution_count": null,
|
| 968 |
-
"metadata": {},
|
| 969 |
-
"outputs": [],
|
| 970 |
-
"source": [
|
| 971 |
-
"# import torch.distributed as dist\n",
|
| 972 |
-
"# dist.init_process_group(backend='nccl')"
|
| 973 |
-
]
|
| 974 |
-
},
|
| 975 |
-
{
|
| 976 |
-
"cell_type": "code",
|
| 977 |
-
"execution_count": null,
|
| 978 |
-
"metadata": {},
|
| 979 |
-
"outputs": [],
|
| 980 |
-
"source": [
|
| 981 |
-
"import numpy as np\n",
|
| 982 |
-
"import torch\n",
|
| 983 |
-
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
|
| 984 |
-
"\n",
|
| 985 |
-
"data = torch.randn((64,64,64))\n",
|
| 986 |
-
"\n",
|
| 987 |
-
"num_elements = data.numpy().size\n",
|
| 988 |
-
"element_size = data.numpy().itemsize\n",
|
| 989 |
-
"\n",
|
| 990 |
-
"print(data.dtype)\n",
|
| 991 |
-
"print(num_elements, element_size)\n",
|
| 992 |
-
"print(f\"total size = {num_elements*element_size/1024/1024} MB\")\n",
|
| 993 |
-
"\n",
|
| 994 |
-
"print(\"---\"*30)\n",
|
| 995 |
-
"data = data.to(torch.float64)\n",
|
| 996 |
-
"\n",
|
| 997 |
-
"num_elements = data.numpy().size\n",
|
| 998 |
-
"element_size = data.numpy().itemsize\n",
|
| 999 |
-
"\n",
|
| 1000 |
-
"print(data.dtype)\n",
|
| 1001 |
-
"print(num_elements, element_size)\n",
|
| 1002 |
-
"print(f\"total size = {num_elements*element_size/1024/1024} MB\")\n",
|
| 1003 |
-
"\n",
|
| 1004 |
-
"print(\"---\"*30)\n",
|
| 1005 |
-
"data = data.to(torch.float16)\n",
|
| 1006 |
-
"\n",
|
| 1007 |
-
"num_elements = data.numpy().size\n",
|
| 1008 |
-
"element_size = data.numpy().itemsize\n",
|
| 1009 |
-
"\n",
|
| 1010 |
-
"print(data.dtype)\n",
|
| 1011 |
-
"print(num_elements, element_size)\n",
|
| 1012 |
-
"print(f\"total size = {num_elements*element_size/1024/1024} MB\")"
|
| 1013 |
-
]
|
| 1014 |
}
|
| 1015 |
],
|
| 1016 |
"metadata": {
|
|
|
|
| 259 |
" dim = 3\n",
|
| 260 |
" stride = (2,2) if dim == 2 else (2,2,1)\n",
|
| 261 |
" num_image = 2000#32000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560\n",
|
| 262 |
+
" batch_size = 2#2#50#20#2#100 # 10\n",
|
| 263 |
" n_epoch = 10#50#20#20#2#5#25 # 120\n",
|
| 264 |
+
" HII_DIM = 28#64\n",
|
| 265 |
" num_redshift = 4#128#64#512#256#256#64#512#128\n",
|
| 266 |
" channel = 1\n",
|
| 267 |
" img_shape = (channel, HII_DIM, num_redshift) if dim == 2 else (channel, HII_DIM, HII_DIM, num_redshift)\n",
|
|
|
|
| 564 |
"name": "stdout",
|
| 565 |
"output_type": "stream",
|
| 566 |
"text": [
|
| 567 |
+
"Number of parameters for nn_model: 160234497\n",
|
| 568 |
"---------------- num_image = 100 -----------------\n",
|
| 569 |
+
"run_name = 0709-1331\n",
|
|
|
|
| 570 |
"dataset content: <KeysViewHDF5 ['brightness_temp', 'density', 'kwargs', 'params', 'redshifts_distances', 'seeds', 'xH_box']>\n",
|
| 571 |
"51200 images can be loaded\n",
|
| 572 |
"field.shape = (64, 64, 514)\n",
|
| 573 |
"params keys = [b'ION_Tvir_MIN', b'HII_EFF_FACTOR']\n",
|
| 574 |
"loading 100 images randomly\n",
|
| 575 |
+
"images loaded: (100, 1, 28, 28, 4)\n"
|
| 576 |
]
|
| 577 |
},
|
| 578 |
{
|
|
|
|
| 587 |
"output_type": "stream",
|
| 588 |
"text": [
|
| 589 |
"params loaded: (100, 2)\n",
|
| 590 |
+
"images rescaled to [-1.0, 1.1254141330718994]\n",
|
| 591 |
+
"params rescaled to [0.0022036265313531977, 0.9978807793709957]\n"
|
| 592 |
]
|
| 593 |
},
|
| 594 |
{
|
| 595 |
"data": {
|
| 596 |
"application/vnd.jupyter.widget-view+json": {
|
| 597 |
+
"model_id": "ae9f12def1154f6cb1eb0fc8d1e1871c",
|
| 598 |
"version_major": 2,
|
| 599 |
"version_minor": 0
|
| 600 |
},
|
|
|
|
| 608 |
{
|
| 609 |
"data": {
|
| 610 |
"application/vnd.jupyter.widget-view+json": {
|
| 611 |
+
"model_id": "cae19ac5ef7a4c34b6b57a6478dc159d",
|
| 612 |
"version_major": 2,
|
| 613 |
"version_minor": 0
|
| 614 |
},
|
|
|
|
| 622 |
{
|
| 623 |
"data": {
|
| 624 |
"application/vnd.jupyter.widget-view+json": {
|
| 625 |
+
"model_id": "b448b32948894b3c8e8780f1b6e6bf58",
|
| 626 |
"version_major": 2,
|
| 627 |
"version_minor": 0
|
| 628 |
},
|
|
|
|
| 636 |
{
|
| 637 |
"data": {
|
| 638 |
"application/vnd.jupyter.widget-view+json": {
|
| 639 |
+
"model_id": "0271ab7c081a43ebb830dcfd3db145c1",
|
| 640 |
"version_major": 2,
|
| 641 |
"version_minor": 0
|
| 642 |
},
|
|
|
|
| 650 |
{
|
| 651 |
"data": {
|
| 652 |
"application/vnd.jupyter.widget-view+json": {
|
| 653 |
+
"model_id": "1bdf4c28272840f496288545bfdbdb96",
|
| 654 |
"version_major": 2,
|
| 655 |
"version_minor": 0
|
| 656 |
},
|
|
|
|
| 664 |
{
|
| 665 |
"data": {
|
| 666 |
"application/vnd.jupyter.widget-view+json": {
|
| 667 |
+
"model_id": "01b50340758b4b05891e51a616660eb8",
|
| 668 |
"version_major": 2,
|
| 669 |
"version_minor": 0
|
| 670 |
},
|
|
|
|
| 678 |
{
|
| 679 |
"data": {
|
| 680 |
"application/vnd.jupyter.widget-view+json": {
|
| 681 |
+
"model_id": "20ba64496e5e4467b97428cf6dcdbeb5",
|
| 682 |
"version_major": 2,
|
| 683 |
"version_minor": 0
|
| 684 |
},
|
|
|
|
| 692 |
{
|
| 693 |
"data": {
|
| 694 |
"application/vnd.jupyter.widget-view+json": {
|
| 695 |
+
"model_id": "8023e0bde0c3438fb218126c58fee954",
|
| 696 |
"version_major": 2,
|
| 697 |
"version_minor": 0
|
| 698 |
},
|
|
|
|
| 706 |
{
|
| 707 |
"data": {
|
| 708 |
"application/vnd.jupyter.widget-view+json": {
|
| 709 |
+
"model_id": "22881fbe8aff4ac1a3bde8735ab4fd24",
|
| 710 |
"version_major": 2,
|
| 711 |
"version_minor": 0
|
| 712 |
},
|
|
|
|
| 720 |
{
|
| 721 |
"data": {
|
| 722 |
"application/vnd.jupyter.widget-view+json": {
|
| 723 |
+
"model_id": "3bf15c9ba5144fa69fdddae271c7dbee",
|
| 724 |
"version_major": 2,
|
| 725 |
"version_minor": 0
|
| 726 |
},
|
|
|
|
| 733 |
}
|
| 734 |
],
|
| 735 |
"source": [
|
| 736 |
+
"num_image_list = [100]#[200]#[1600,3200,6400,12800,25600]\n",
|
| 737 |
"if __name__ == \"__main__\":\n",
|
| 738 |
" # torch.multiprocessing.set_start_method(\"spawn\")\n",
|
| 739 |
" # args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)\n",
|
|
|
|
| 743 |
" ddpm21cm = DDPM21CM(config)\n",
|
| 744 |
" print(f\" num_image = {ddpm21cm.config.num_image} \".center(50, '-'))\n",
|
| 745 |
" print(f\"run_name = {ddpm21cm.config.run_name}\")\n",
|
| 746 |
+
" ddpm21cm.train()\n",
|
| 747 |
+
" # notebook_launcher(ddpm21cm.train, num_processes=1, mixed_precision='fp16')"
|
|
|
|
| 748 |
]
|
| 749 |
},
|
| 750 |
{
|
| 751 |
+
"attachments": {},
|
| 752 |
+
"cell_type": "markdown",
|
| 753 |
"metadata": {},
|
| 754 |
+
"source": [
|
| 755 |
+
"# Sampling"
|
| 756 |
+
]
|
| 757 |
},
|
| 758 |
{
|
| 759 |
"cell_type": "code",
|
| 760 |
+
"execution_count": null,
|
| 761 |
"metadata": {},
|
| 762 |
+
"outputs": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 763 |
"source": [
|
| 764 |
"if __name__ == \"__main__\":\n",
|
| 765 |
" # num_image_list = [1600,3200,6400,12800,25600]\n",
|
|
|
|
| 785 |
},
|
| 786 |
{
|
| 787 |
"cell_type": "code",
|
| 788 |
+
"execution_count": null,
|
| 789 |
"metadata": {},
|
| 790 |
+
"outputs": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 791 |
"source": [
|
| 792 |
"ls -lth outputs | head"
|
| 793 |
]
|
| 794 |
},
|
| 795 |
{
|
| 796 |
"cell_type": "code",
|
| 797 |
+
"execution_count": null,
|
| 798 |
"metadata": {},
|
| 799 |
+
"outputs": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 800 |
"source": [
|
| 801 |
"def plot_grid(samples, c=None, row=1, col=2):\n",
|
| 802 |
" print(\"samples.shape =\", samples.shape)\n",
|
|
|
|
| 849 |
"# # plt.imshow(images[0,0])\n",
|
| 850 |
"# # plt.show()"
|
| 851 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 852 |
}
|
| 853 |
],
|
| 854 |
"metadata": {
|