0712-1513
Browse files- diffusion.ipynb +106 -67
- diffusion.py +89 -59
- load_h5.py +7 -6
diffusion.ipynb
CHANGED
|
@@ -33,7 +33,7 @@
|
|
| 33 |
},
|
| 34 |
{
|
| 35 |
"cell_type": "code",
|
| 36 |
-
"execution_count":
|
| 37 |
"metadata": {},
|
| 38 |
"outputs": [],
|
| 39 |
"source": [
|
|
@@ -77,7 +77,7 @@
|
|
| 77 |
},
|
| 78 |
{
|
| 79 |
"cell_type": "code",
|
| 80 |
-
"execution_count":
|
| 81 |
"metadata": {},
|
| 82 |
"outputs": [],
|
| 83 |
"source": [
|
|
@@ -95,11 +95,26 @@
|
|
| 95 |
},
|
| 96 |
{
|
| 97 |
"cell_type": "code",
|
| 98 |
-
"execution_count":
|
| 99 |
"metadata": {},
|
| 100 |
-
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
"source": [
|
| 102 |
-
"
|
| 103 |
]
|
| 104 |
},
|
| 105 |
{
|
|
@@ -817,75 +832,99 @@
|
|
| 817 |
},
|
| 818 |
{
|
| 819 |
"cell_type": "code",
|
| 820 |
-
"execution_count":
|
| 821 |
"metadata": {},
|
| 822 |
-
"outputs": [
|
| 823 |
-
{
|
| 824 |
-
"name": "stdout",
|
| 825 |
-
"output_type": "stream",
|
| 826 |
-
"text": [
|
| 827 |
-
"outputs/model_state-N500-device0\n",
|
| 828 |
-
"outputs/model_state-N500-device1\n",
|
| 829 |
-
"len(model_states) = 2\n",
|
| 830 |
-
"epoch\n",
|
| 831 |
-
"unet_state_dict\n",
|
| 832 |
-
"epoch 9 9\n",
|
| 833 |
-
"odict_keys(['token_embedding.weight', 'token_embedding.bias', 'time_embed.0.weight', 'time_embed.0.bias', 'time_embed.2.weight', 'time_embed.2.bias', 'input_blocks.0.0.weight', 'input_blocks.0.0.bias', 'input_blocks.1.0.in_layers.0.weight', 'input_blocks.1.0.in_layers.0.bias', 'input_blocks.1.0.in_layers.2.weight', 'input_blocks.1.0.in_layers.2.bias', 'input_blocks.1.0.emb_layers.1.weight', 'input_blocks.1.0.emb_layers.1.bias', 'input_blocks.1.0.out_layers.0.weight', 'input_blocks.1.0.out_layers.0.bias', 'input_blocks.1.0.out_layers.3.weight', 'input_blocks.1.0.out_layers.3.bias', 'input_blocks.1.1.norm.weight', 'input_blocks.1.1.norm.bias', 'input_blocks.1.1.qkv.weight', 'input_blocks.1.1.qkv.bias', 'input_blocks.1.1.proj_out.weight', 'input_blocks.1.1.proj_out.bias', 'input_blocks.2.0.in_layers.0.weight', 'input_blocks.2.0.in_layers.0.bias', 'input_blocks.2.0.in_layers.2.weight', 'input_blocks.2.0.in_layers.2.bias', 'input_blocks.2.0.emb_layers.1.weight', 'input_blocks.2.0.emb_layers.1.bias', 'input_blocks.2.0.out_layers.0.weight', 'input_blocks.2.0.out_layers.0.bias', 'input_blocks.2.0.out_layers.3.weight', 'input_blocks.2.0.out_layers.3.bias', 'input_blocks.2.1.norm.weight', 'input_blocks.2.1.norm.bias', 'input_blocks.2.1.qkv.weight', 'input_blocks.2.1.qkv.bias', 'input_blocks.2.1.proj_out.weight', 'input_blocks.2.1.proj_out.bias', 'input_blocks.3.0.op.weight', 'input_blocks.3.0.op.bias', 'input_blocks.4.0.in_layers.0.weight', 'input_blocks.4.0.in_layers.0.bias', 'input_blocks.4.0.in_layers.2.weight', 'input_blocks.4.0.in_layers.2.bias', 'input_blocks.4.0.emb_layers.1.weight', 'input_blocks.4.0.emb_layers.1.bias', 'input_blocks.4.0.out_layers.0.weight', 'input_blocks.4.0.out_layers.0.bias', 'input_blocks.4.0.out_layers.3.weight', 'input_blocks.4.0.out_layers.3.bias', 'input_blocks.4.0.skip_connection.weight', 'input_blocks.4.0.skip_connection.bias', 'input_blocks.5.0.in_layers.0.weight', 'input_blocks.5.0.in_layers.0.bias', 'input_blocks.5.0.in_layers.2.weight', 'input_blocks.5.0.in_layers.2.bias', 'input_blocks.5.0.emb_layers.1.weight', 'input_blocks.5.0.emb_layers.1.bias', 'input_blocks.5.0.out_layers.0.weight', 'input_blocks.5.0.out_layers.0.bias', 'input_blocks.5.0.out_layers.3.weight', 'input_blocks.5.0.out_layers.3.bias', 'input_blocks.6.0.op.weight', 'input_blocks.6.0.op.bias', 'input_blocks.7.0.in_layers.0.weight', 'input_blocks.7.0.in_layers.0.bias', 'input_blocks.7.0.in_layers.2.weight', 'input_blocks.7.0.in_layers.2.bias', 'input_blocks.7.0.emb_layers.1.weight', 'input_blocks.7.0.emb_layers.1.bias', 'input_blocks.7.0.out_layers.0.weight', 'input_blocks.7.0.out_layers.0.bias', 'input_blocks.7.0.out_layers.3.weight', 'input_blocks.7.0.out_layers.3.bias', 'input_blocks.7.0.skip_connection.weight', 'input_blocks.7.0.skip_connection.bias', 'input_blocks.8.0.in_layers.0.weight', 'input_blocks.8.0.in_layers.0.bias', 'input_blocks.8.0.in_layers.2.weight', 'input_blocks.8.0.in_layers.2.bias', 'input_blocks.8.0.emb_layers.1.weight', 'input_blocks.8.0.emb_layers.1.bias', 'input_blocks.8.0.out_layers.0.weight', 'input_blocks.8.0.out_layers.0.bias', 'input_blocks.8.0.out_layers.3.weight', 'input_blocks.8.0.out_layers.3.bias', 'middle_block.0.in_layers.0.weight', 'middle_block.0.in_layers.0.bias', 'middle_block.0.in_layers.2.weight', 'middle_block.0.in_layers.2.bias', 'middle_block.0.emb_layers.1.weight', 'middle_block.0.emb_layers.1.bias', 'middle_block.0.out_layers.0.weight', 'middle_block.0.out_layers.0.bias', 'middle_block.0.out_layers.3.weight', 'middle_block.0.out_layers.3.bias', 'middle_block.1.norm.weight', 'middle_block.1.norm.bias', 'middle_block.1.qkv.weight', 'middle_block.1.qkv.bias', 'middle_block.1.proj_out.weight', 'middle_block.1.proj_out.bias', 'middle_block.2.in_layers.0.weight', 'middle_block.2.in_layers.0.bias', 'middle_block.2.in_layers.2.weight', 'middle_block.2.in_layers.2.bias', 'middle_block.2.emb_layers.1.weight', 'middle_block.2.emb_layers.1.bias', 'middle_block.2.out_layers.0.weight', 'middle_block.2.out_layers.0.bias', 'middle_block.2.out_layers.3.weight', 'middle_block.2.out_layers.3.bias', 'output_blocks.0.0.in_layers.0.weight', 'output_blocks.0.0.in_layers.0.bias', 'output_blocks.0.0.in_layers.2.weight', 'output_blocks.0.0.in_layers.2.bias', 'output_blocks.0.0.emb_layers.1.weight', 'output_blocks.0.0.emb_layers.1.bias', 'output_blocks.0.0.out_layers.0.weight', 'output_blocks.0.0.out_layers.0.bias', 'output_blocks.0.0.out_layers.3.weight', 'output_blocks.0.0.out_layers.3.bias', 'output_blocks.0.0.skip_connection.weight', 'output_blocks.0.0.skip_connection.bias', 'output_blocks.1.0.in_layers.0.weight', 'output_blocks.1.0.in_layers.0.bias', 'output_blocks.1.0.in_layers.2.weight', 'output_blocks.1.0.in_layers.2.bias', 'output_blocks.1.0.emb_layers.1.weight', 'output_blocks.1.0.emb_layers.1.bias', 'output_blocks.1.0.out_layers.0.weight', 'output_blocks.1.0.out_layers.0.bias', 'output_blocks.1.0.out_layers.3.weight', 'output_blocks.1.0.out_layers.3.bias', 'output_blocks.1.0.skip_connection.weight', 'output_blocks.1.0.skip_connection.bias', 'output_blocks.2.0.in_layers.0.weight', 'output_blocks.2.0.in_layers.0.bias', 'output_blocks.2.0.in_layers.2.weight', 'output_blocks.2.0.in_layers.2.bias', 'output_blocks.2.0.emb_layers.1.weight', 'output_blocks.2.0.emb_layers.1.bias', 'output_blocks.2.0.out_layers.0.weight', 'output_blocks.2.0.out_layers.0.bias', 'output_blocks.2.0.out_layers.3.weight', 'output_blocks.2.0.out_layers.3.bias', 'output_blocks.2.0.skip_connection.weight', 'output_blocks.2.0.skip_connection.bias', 'output_blocks.2.1.conv.weight', 'output_blocks.2.1.conv.bias', 'output_blocks.3.0.in_layers.0.weight', 'output_blocks.3.0.in_layers.0.bias', 'output_blocks.3.0.in_layers.2.weight', 'output_blocks.3.0.in_layers.2.bias', 'output_blocks.3.0.emb_layers.1.weight', 'output_blocks.3.0.emb_layers.1.bias', 'output_blocks.3.0.out_layers.0.weight', 'output_blocks.3.0.out_layers.0.bias', 'output_blocks.3.0.out_layers.3.weight', 'output_blocks.3.0.out_layers.3.bias', 'output_blocks.3.0.skip_connection.weight', 'output_blocks.3.0.skip_connection.bias', 'output_blocks.4.0.in_layers.0.weight', 'output_blocks.4.0.in_layers.0.bias', 'output_blocks.4.0.in_layers.2.weight', 'output_blocks.4.0.in_layers.2.bias', 'output_blocks.4.0.emb_layers.1.weight', 'output_blocks.4.0.emb_layers.1.bias', 'output_blocks.4.0.out_layers.0.weight', 'output_blocks.4.0.out_layers.0.bias', 'output_blocks.4.0.out_layers.3.weight', 'output_blocks.4.0.out_layers.3.bias', 'output_blocks.4.0.skip_connection.weight', 'output_blocks.4.0.skip_connection.bias', 'output_blocks.5.0.in_layers.0.weight', 'output_blocks.5.0.in_layers.0.bias', 'output_blocks.5.0.in_layers.2.weight', 'output_blocks.5.0.in_layers.2.bias', 'output_blocks.5.0.emb_layers.1.weight', 'output_blocks.5.0.emb_layers.1.bias', 'output_blocks.5.0.out_layers.0.weight', 'output_blocks.5.0.out_layers.0.bias', 'output_blocks.5.0.out_layers.3.weight', 'output_blocks.5.0.out_layers.3.bias', 'output_blocks.5.0.skip_connection.weight', 'output_blocks.5.0.skip_connection.bias', 'output_blocks.5.1.conv.weight', 'output_blocks.5.1.conv.bias', 'output_blocks.6.0.in_layers.0.weight', 'output_blocks.6.0.in_layers.0.bias', 'output_blocks.6.0.in_layers.2.weight', 'output_blocks.6.0.in_layers.2.bias', 'output_blocks.6.0.emb_layers.1.weight', 'output_blocks.6.0.emb_layers.1.bias', 'output_blocks.6.0.out_layers.0.weight', 'output_blocks.6.0.out_layers.0.bias', 'output_blocks.6.0.out_layers.3.weight', 'output_blocks.6.0.out_layers.3.bias', 'output_blocks.6.0.skip_connection.weight', 'output_blocks.6.0.skip_connection.bias', 'output_blocks.6.1.norm.weight', 'output_blocks.6.1.norm.bias', 'output_blocks.6.1.qkv.weight', 'output_blocks.6.1.qkv.bias', 'output_blocks.6.1.proj_out.weight', 'output_blocks.6.1.proj_out.bias', 'output_blocks.7.0.in_layers.0.weight', 'output_blocks.7.0.in_layers.0.bias', 'output_blocks.7.0.in_layers.2.weight', 'output_blocks.7.0.in_layers.2.bias', 'output_blocks.7.0.emb_layers.1.weight', 'output_blocks.7.0.emb_layers.1.bias', 'output_blocks.7.0.out_layers.0.weight', 'output_blocks.7.0.out_layers.0.bias', 'output_blocks.7.0.out_layers.3.weight', 'output_blocks.7.0.out_layers.3.bias', 'output_blocks.7.0.skip_connection.weight', 'output_blocks.7.0.skip_connection.bias', 'output_blocks.7.1.norm.weight', 'output_blocks.7.1.norm.bias', 'output_blocks.7.1.qkv.weight', 'output_blocks.7.1.qkv.bias', 'output_blocks.7.1.proj_out.weight', 'output_blocks.7.1.proj_out.bias', 'output_blocks.8.0.in_layers.0.weight', 'output_blocks.8.0.in_layers.0.bias', 'output_blocks.8.0.in_layers.2.weight', 'output_blocks.8.0.in_layers.2.bias', 'output_blocks.8.0.emb_layers.1.weight', 'output_blocks.8.0.emb_layers.1.bias', 'output_blocks.8.0.out_layers.0.weight', 'output_blocks.8.0.out_layers.0.bias', 'output_blocks.8.0.out_layers.3.weight', 'output_blocks.8.0.out_layers.3.bias', 'output_blocks.8.0.skip_connection.weight', 'output_blocks.8.0.skip_connection.bias', 'output_blocks.8.1.norm.weight', 'output_blocks.8.1.norm.bias', 'output_blocks.8.1.qkv.weight', 'output_blocks.8.1.qkv.bias', 'output_blocks.8.1.proj_out.weight', 'output_blocks.8.1.proj_out.bias', 'out.0.weight', 'out.0.bias', 'out.2.weight', 'out.2.bias'])\n",
|
| 834 |
-
"exactly same\n"
|
| 835 |
-
]
|
| 836 |
-
}
|
| 837 |
-
],
|
| 838 |
"source": [
|
| 839 |
-
"import torch\n",
|
| 840 |
-
"import os\n",
|
| 841 |
"\n",
|
| 842 |
-
"def compare_models(num_gpus):\n",
|
| 843 |
-
"
|
| 844 |
" \n",
|
| 845 |
-
"
|
| 846 |
-
"
|
| 847 |
-
"
|
| 848 |
-
"
|
| 849 |
-
"
|
| 850 |
-
"
|
| 851 |
-
"
|
| 852 |
-
"
|
| 853 |
-
"
|
| 854 |
" \n",
|
| 855 |
-
"
|
| 856 |
-
"
|
| 857 |
-
"
|
| 858 |
-
"
|
| 859 |
-
"
|
| 860 |
-
"
|
| 861 |
-
"
|
| 862 |
-
"
|
| 863 |
-
"\n",
|
| 864 |
-
"
|
| 865 |
-
"
|
| 866 |
-
"
|
| 867 |
-
"
|
| 868 |
-
"
|
| 869 |
-
"
|
| 870 |
-
"
|
| 871 |
-
"
|
| 872 |
-
"\n",
|
| 873 |
-
"
|
| 874 |
-
"
|
| 875 |
-
"
|
| 876 |
-
"
|
| 877 |
-
"
|
| 878 |
-
"
|
| 879 |
-
"
|
| 880 |
-
"
|
| 881 |
" \n",
|
| 882 |
-
"
|
| 883 |
-
"
|
| 884 |
"\n",
|
| 885 |
-
"if __name__ == \"__main__\":\n",
|
| 886 |
-
"
|
| 887 |
-
"
|
| 888 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 889 |
]
|
| 890 |
}
|
| 891 |
],
|
|
|
|
| 33 |
},
|
| 34 |
{
|
| 35 |
"cell_type": "code",
|
| 36 |
+
"execution_count": 31,
|
| 37 |
"metadata": {},
|
| 38 |
"outputs": [],
|
| 39 |
"source": [
|
|
|
|
| 77 |
},
|
| 78 |
{
|
| 79 |
"cell_type": "code",
|
| 80 |
+
"execution_count": 32,
|
| 81 |
"metadata": {},
|
| 82 |
"outputs": [],
|
| 83 |
"source": [
|
|
|
|
| 95 |
},
|
| 96 |
{
|
| 97 |
"cell_type": "code",
|
| 98 |
+
"execution_count": 34,
|
| 99 |
"metadata": {},
|
| 100 |
+
"outputs": [
|
| 101 |
+
{
|
| 102 |
+
"data": {
|
| 103 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 104 |
+
"model_id": "9bbf7e9db9ce426d9c59d6f6d8e8df29",
|
| 105 |
+
"version_major": 2,
|
| 106 |
+
"version_minor": 0
|
| 107 |
+
},
|
| 108 |
+
"text/plain": [
|
| 109 |
+
"VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…"
|
| 110 |
+
]
|
| 111 |
+
},
|
| 112 |
+
"metadata": {},
|
| 113 |
+
"output_type": "display_data"
|
| 114 |
+
}
|
| 115 |
+
],
|
| 116 |
"source": [
|
| 117 |
+
"notebook_login()"
|
| 118 |
]
|
| 119 |
},
|
| 120 |
{
|
|
|
|
| 832 |
},
|
| 833 |
{
|
| 834 |
"cell_type": "code",
|
| 835 |
+
"execution_count": 9,
|
| 836 |
"metadata": {},
|
| 837 |
+
"outputs": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 838 |
"source": [
|
| 839 |
+
"# import torch\n",
|
| 840 |
+
"# import os\n",
|
| 841 |
"\n",
|
| 842 |
+
"# def compare_models(num_gpus):\n",
|
| 843 |
+
"# model_states = []\n",
|
| 844 |
" \n",
|
| 845 |
+
"# for gpu_id in range(num_gpus):\n",
|
| 846 |
+
"# filename = f\"outputs/model_state-N40-device{gpu_id}\"\n",
|
| 847 |
+
"# if os.path.exists(filename):\n",
|
| 848 |
+
"# state_dict = torch.load(filename, map_location='cpu')\n",
|
| 849 |
+
"# model_states.append(state_dict)\n",
|
| 850 |
+
"# print(filename)\n",
|
| 851 |
+
"# else:\n",
|
| 852 |
+
"# print(f\"File {filename} not found!\")\n",
|
| 853 |
+
"# return False\n",
|
| 854 |
" \n",
|
| 855 |
+
"# # Compare all model state_dicts\n",
|
| 856 |
+
"# print(\"len(model_states) =\", len(model_states))\n",
|
| 857 |
+
"# base_state = model_states[0]\n",
|
| 858 |
+
"# for state in model_states[1:]:\n",
|
| 859 |
+
"# for key in base_state.keys():\n",
|
| 860 |
+
"# # print(key, base_state[key], state[key])\n",
|
| 861 |
+
"# print(key)\n",
|
| 862 |
+
"# print(\"epoch\", base_state['epoch'], state['epoch'])\n",
|
| 863 |
+
"\n",
|
| 864 |
+
"# print(base_state['unet_state_dict'].keys())\n",
|
| 865 |
+
"# for key in base_state['unet_state_dict']:\n",
|
| 866 |
+
"# # print(key)\n",
|
| 867 |
+
"# if not torch.equal(base_state['unet_state_dict'][key], state['unet_state_dict'][key]):\n",
|
| 868 |
+
"# print(\"different\")\n",
|
| 869 |
+
"# return \n",
|
| 870 |
+
"# # else:\n",
|
| 871 |
+
"# print(\"exactly same\")\n",
|
| 872 |
+
"\n",
|
| 873 |
+
"# # if key == 'epoch':\n",
|
| 874 |
+
"# # print(base_state[key], state[key])\n",
|
| 875 |
+
"# # else:\n",
|
| 876 |
+
"# # print(base_state[key], state[key])\n",
|
| 877 |
+
"# # if not torch.equal(base_state[key], state[key]):\n",
|
| 878 |
+
"# # # if not (base_state[key] == state[key]):\n",
|
| 879 |
+
"# # print(f\"Mismatch found in parameter {key}\")\n",
|
| 880 |
+
"# # return False\n",
|
| 881 |
" \n",
|
| 882 |
+
"# # print(\"All models are identical!\")\n",
|
| 883 |
+
"# # return True\n",
|
| 884 |
"\n",
|
| 885 |
+
"# if __name__ == \"__main__\":\n",
|
| 886 |
+
"# # epoch_to_check = 0 # specify the epoch you want to check\n",
|
| 887 |
+
"# num_gpus = torch.cuda.device_count() # specify the number of GPUs used in training\n",
|
| 888 |
+
"# compare_models(num_gpus)"
|
| 889 |
+
]
|
| 890 |
+
},
|
| 891 |
+
{
|
| 892 |
+
"cell_type": "code",
|
| 893 |
+
"execution_count": 6,
|
| 894 |
+
"metadata": {},
|
| 895 |
+
"outputs": [],
|
| 896 |
+
"source": [
|
| 897 |
+
"import numpy as np\n",
|
| 898 |
+
"test = np.random.normal(0,1,(800,1,64,64,512))"
|
| 899 |
+
]
|
| 900 |
+
},
|
| 901 |
+
{
|
| 902 |
+
"cell_type": "code",
|
| 903 |
+
"execution_count": 7,
|
| 904 |
+
"metadata": {},
|
| 905 |
+
"outputs": [
|
| 906 |
+
{
|
| 907 |
+
"data": {
|
| 908 |
+
"text/plain": [
|
| 909 |
+
"12.5"
|
| 910 |
+
]
|
| 911 |
+
},
|
| 912 |
+
"execution_count": 7,
|
| 913 |
+
"metadata": {},
|
| 914 |
+
"output_type": "execute_result"
|
| 915 |
+
}
|
| 916 |
+
],
|
| 917 |
+
"source": [
|
| 918 |
+
"(test.itemsize*test.size) / 1024/1024/1024"
|
| 919 |
+
]
|
| 920 |
+
},
|
| 921 |
+
{
|
| 922 |
+
"cell_type": "code",
|
| 923 |
+
"execution_count": 8,
|
| 924 |
+
"metadata": {},
|
| 925 |
+
"outputs": [],
|
| 926 |
+
"source": [
|
| 927 |
+
"del test"
|
| 928 |
]
|
| 929 |
}
|
| 930 |
],
|
diffusion.py
CHANGED
|
@@ -61,6 +61,7 @@ import torch.multiprocessing as mp
|
|
| 61 |
from torch.utils.data.distributed import DistributedSampler
|
| 62 |
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 63 |
from torch.distributed import init_process_group, destroy_process_group
|
|
|
|
| 64 |
|
| 65 |
# %%
|
| 66 |
def ddp_setup(rank: int, world_size: int):
|
|
@@ -180,14 +181,12 @@ class DDPMScheduler(nn.Module):
|
|
| 180 |
x_i = 1/torch.sqrt(self.alpha_t[i])*(x_i-eps*self.beta_t[i]/torch.sqrt(1-self.bar_alpha_t[i])) + torch.sqrt(self.beta_t[i])*z
|
| 181 |
|
| 182 |
pbar_sample.update(1)
|
| 183 |
-
# pbar_sample.set_postfix(step=i)
|
| 184 |
|
| 185 |
-
# print("x_i.shape =", x_i.shape)
|
| 186 |
# store only part of the intermediate steps
|
| 187 |
-
if i%20==0:# or i==0:# or i<8:
|
| 188 |
-
|
| 189 |
-
x_i = x_i.detach().cpu().numpy()
|
| 190 |
x_i_entire = np.array(x_i_entire)
|
|
|
|
| 191 |
return x_i, x_i_entire
|
| 192 |
|
| 193 |
|
|
@@ -225,7 +224,7 @@ class TrainConfig:
|
|
| 225 |
###########################
|
| 226 |
## hardcoding these here ##
|
| 227 |
###########################
|
| 228 |
-
push_to_hub = True
|
| 229 |
hub_model_id = "Xsmos/ml21cm"
|
| 230 |
hub_private_repo = False
|
| 231 |
dataset_name = "/storage/home/hcoda1/3/bxia34/scratch/LEN128-DIM64-CUB8.h5"
|
|
@@ -265,14 +264,14 @@ class TrainConfig:
|
|
| 265 |
# seed = 0
|
| 266 |
# save_dir = './outputs/'
|
| 267 |
|
| 268 |
-
|
| 269 |
# general parameters for the name and logger
|
| 270 |
# device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 271 |
lrate = 1e-4
|
| 272 |
lr_warmup_steps = 0#5#00
|
| 273 |
output_dir = "./outputs/"
|
| 274 |
save_name = os.path.join(output_dir, 'model_state')
|
| 275 |
-
#
|
| 276 |
# cond = True # if training using the conditional information
|
| 277 |
# lr_decay = False #True# if using the learning rate decay
|
| 278 |
resume = save_name # if resume from the trained checkpoints
|
|
@@ -394,8 +393,8 @@ class DDPM21CM:
|
|
| 394 |
# distributed_type="MULTI_GPU",
|
| 395 |
)
|
| 396 |
# print("!!!!!!!!!!!!!!!!!!!self.accelerator.device:", self.accelerator.device)
|
| 397 |
-
if self.accelerator.is_main_process:
|
| 398 |
-
|
| 399 |
if self.config.output_dir is not None:
|
| 400 |
os.makedirs(self.config.output_dir, exist_ok=True)
|
| 401 |
if self.config.push_to_hub:
|
|
@@ -427,7 +426,7 @@ class DDPM21CM:
|
|
| 427 |
pbar_train = tqdm(total=len(self.dataloader), disable=not self.accelerator.is_local_main_process)
|
| 428 |
pbar_train.set_description(f"device {torch.cuda.current_device()}, Epoch {ep}")
|
| 429 |
for i, (x, c) in enumerate(self.dataloader):
|
| 430 |
-
print(f"device {torch.cuda.current_device()}, x[
|
| 431 |
with self.accelerator.accumulate(self.nn_model):
|
| 432 |
x = x.to(self.config.device)
|
| 433 |
xt, noise, ts = self.ddpm.add_noise(x)
|
|
@@ -460,7 +459,7 @@ class DDPM21CM:
|
|
| 460 |
self.accelerator.log(logs, step=global_step)
|
| 461 |
global_step += 1
|
| 462 |
|
| 463 |
-
# if ep == config.n_epoch-1 or (ep+1)*config.
|
| 464 |
self.save(ep)
|
| 465 |
|
| 466 |
del self.nn_model
|
|
@@ -470,9 +469,9 @@ class DDPM21CM:
|
|
| 470 |
|
| 471 |
def save(self, ep):
|
| 472 |
# save model
|
| 473 |
-
if self.accelerator.is_main_process:
|
| 474 |
-
|
| 475 |
-
if ep == self.config.n_epoch-1 or (ep+1)
|
| 476 |
self.nn_model.eval()
|
| 477 |
with torch.no_grad():
|
| 478 |
if self.config.push_to_hub:
|
|
@@ -488,8 +487,9 @@ class DDPM21CM:
|
|
| 488 |
'unet_state_dict': self.nn_model.module.state_dict(),
|
| 489 |
# 'ema_unet_state_dict': self.ema_model.state_dict(),
|
| 490 |
}
|
| 491 |
-
|
| 492 |
-
|
|
|
|
| 493 |
# print('saved model at ' + config.save_dir + f"model_epoch_{ep}_test_{config.run_name}.pth")
|
| 494 |
|
| 495 |
# def rescale(self, value, type='params', to_ranges=[0,1]):
|
|
@@ -506,7 +506,7 @@ class DDPM21CM:
|
|
| 506 |
value = value * (to[1]-to[0]) + to[0]
|
| 507 |
return value
|
| 508 |
|
| 509 |
-
def sample(self, file, params:torch.tensor=None,
|
| 510 |
# n_sample = params.shape[0]
|
| 511 |
|
| 512 |
if params is None:
|
|
@@ -516,8 +516,8 @@ class DDPM21CM:
|
|
| 516 |
params_backup = params.numpy().copy()
|
| 517 |
params = self.rescale(params, self.ranges_dict['params'], to=[0,1])
|
| 518 |
|
| 519 |
-
print(f"sampling {
|
| 520 |
-
params = params.repeat(
|
| 521 |
assert params.dim() == 2, "params must be a 2D torch.tensor"
|
| 522 |
# print("params =", params)
|
| 523 |
# print("params =", params)
|
|
@@ -526,16 +526,16 @@ class DDPM21CM:
|
|
| 526 |
# del self.ema_model, self.nn
|
| 527 |
# params = torch.tile(params, (n_sample,1)).to(device)
|
| 528 |
|
| 529 |
-
nn_model = ContextUnet(n_param=self.config.n_param, image_size=self.config.HII_DIM, dim=self.config.dim, stride=self.config.stride).to(self.config.device)
|
| 530 |
if ema:
|
| 531 |
-
nn_model.load_state_dict(torch.load(file)['ema_unet_state_dict'])
|
| 532 |
else:
|
| 533 |
-
nn_model.load_state_dict(torch.load(file)['unet_state_dict'])
|
| 534 |
print(f"nn_model resumed from {file}")
|
| 535 |
# nn_model = ContextUnet(n_param=1, image_size=28)
|
| 536 |
# nn_model.train()
|
| 537 |
-
nn_model.to(self.ddpm.device)
|
| 538 |
-
nn_model.eval()
|
| 539 |
|
| 540 |
# self.ema_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM, dim=config.dim, stride=config.stride).to(config.device)
|
| 541 |
# self.ema_model.load_state_dict(torch.load(os.path.join(config.output_dir, f"{config.resume}"))['ema_unet_state_dict'])
|
|
@@ -543,27 +543,27 @@ class DDPM21CM:
|
|
| 543 |
|
| 544 |
with torch.no_grad():
|
| 545 |
x_last, x_entire = self.ddpm.sample(
|
| 546 |
-
nn_model=nn_model,
|
| 547 |
params=params.to(self.config.device),
|
| 548 |
device=self.config.device,
|
| 549 |
guide_w=self.config.guide_w
|
| 550 |
)
|
| 551 |
|
| 552 |
-
|
| 553 |
-
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
|
| 558 |
-
|
| 559 |
# %%
|
| 560 |
-
def
|
| 561 |
config = TrainConfig()
|
| 562 |
config.world_size = world_size
|
| 563 |
|
| 564 |
ddp_setup(rank, world_size)
|
| 565 |
|
| 566 |
-
num_image_list = [
|
| 567 |
for i, num_image in enumerate(num_image_list):
|
| 568 |
config.num_image = num_image
|
| 569 |
# config.world_size = world_size
|
|
@@ -578,17 +578,11 @@ def main(rank, world_size):
|
|
| 578 |
if __name__ == "__main__":
|
| 579 |
# torch.multiprocessing.set_start_method("spawn")
|
| 580 |
# args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)
|
| 581 |
-
world_size =
|
| 582 |
|
| 583 |
-
mp.spawn(
|
| 584 |
# notebook_launcher(ddpm21cm.train, num_processes=1, mixed_precision='fp16')
|
| 585 |
|
| 586 |
-
# %%
|
| 587 |
-
# torch.cuda.set_device(0)
|
| 588 |
-
|
| 589 |
-
# %%
|
| 590 |
-
# print(torch.cuda.__dir__())
|
| 591 |
-
|
| 592 |
# %%
|
| 593 |
# print("torch.cuda.is_initialized() =", torch.cuda.is_initialized())
|
| 594 |
# print("torch.cuda.get_device_name() =", torch.cuda.get_device_name())
|
|
@@ -601,31 +595,67 @@ if __name__ == "__main__":
|
|
| 601 |
# print(torch.cuda.memory())
|
| 602 |
# print('here')
|
| 603 |
# print(torch.cuda.memory_summary())
|
| 604 |
-
|
| 605 |
# %% [markdown]
|
| 606 |
# # Sampling
|
| 607 |
|
| 608 |
# %%
|
| 609 |
-
# if __name__ == "__main__":
|
| 610 |
-
# # num_image_list = [1600,3200,6400,12800,25600]
|
| 611 |
-
# num_image_list = [1000]
|
| 612 |
-
# # num_image_list = [3200,6400,12800,25600]
|
| 613 |
-
# # args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)
|
| 614 |
-
# repeat = 2
|
| 615 |
-
# config = TrainConfig()
|
| 616 |
-
# for i, num_image in enumerate(num_image_list):
|
| 617 |
-
# config.num_image = num_image
|
| 618 |
-
# ddpm21cm = DDPM21CM(config)
|
| 619 |
|
| 620 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 621 |
|
| 622 |
-
|
|
|
|
| 623 |
|
| 624 |
-
|
| 625 |
|
| 626 |
-
|
|
|
|
|
|
|
| 627 |
|
| 628 |
-
# # ddpm21cm.sample(f"./outputs/model_state-N{num_image}", params=torch.tensor((4.8, 131.341)), repeat=repeat)
|
| 629 |
|
| 630 |
# %%
|
| 631 |
# ls -lth outputs | head
|
|
|
|
| 61 |
from torch.utils.data.distributed import DistributedSampler
|
| 62 |
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 63 |
from torch.distributed import init_process_group, destroy_process_group
|
| 64 |
+
import torch.distributed as dist
|
| 65 |
|
| 66 |
# %%
|
| 67 |
def ddp_setup(rank: int, world_size: int):
|
|
|
|
| 181 |
x_i = 1/torch.sqrt(self.alpha_t[i])*(x_i-eps*self.beta_t[i]/torch.sqrt(1-self.bar_alpha_t[i])) + torch.sqrt(self.beta_t[i])*z
|
| 182 |
|
| 183 |
pbar_sample.update(1)
|
|
|
|
| 184 |
|
|
|
|
| 185 |
# store only part of the intermediate steps
|
| 186 |
+
# if i%20==0:# or i==0:# or i<8:
|
| 187 |
+
# x_i_entire.append(x_i.detach().cpu().numpy())
|
|
|
|
| 188 |
x_i_entire = np.array(x_i_entire)
|
| 189 |
+
x_i = x_i.detach().cpu().numpy()
|
| 190 |
return x_i, x_i_entire
|
| 191 |
|
| 192 |
|
|
|
|
| 224 |
###########################
|
| 225 |
## hardcoding these here ##
|
| 226 |
###########################
|
| 227 |
+
push_to_hub = True
|
| 228 |
hub_model_id = "Xsmos/ml21cm"
|
| 229 |
hub_private_repo = False
|
| 230 |
dataset_name = "/storage/home/hcoda1/3/bxia34/scratch/LEN128-DIM64-CUB8.h5"
|
|
|
|
| 264 |
# seed = 0
|
| 265 |
# save_dir = './outputs/'
|
| 266 |
|
| 267 |
+
save_period = np.infty#.1 # the period of sampling
|
| 268 |
# general parameters for the name and logger
|
| 269 |
# device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 270 |
lrate = 1e-4
|
| 271 |
lr_warmup_steps = 0#5#00
|
| 272 |
output_dir = "./outputs/"
|
| 273 |
save_name = os.path.join(output_dir, 'model_state')
|
| 274 |
+
# save_period = 1 #10 # the period of saving model
|
| 275 |
# cond = True # if training using the conditional information
|
| 276 |
# lr_decay = False #True# if using the learning rate decay
|
| 277 |
resume = save_name # if resume from the trained checkpoints
|
|
|
|
| 393 |
# distributed_type="MULTI_GPU",
|
| 394 |
)
|
| 395 |
# print("!!!!!!!!!!!!!!!!!!!self.accelerator.device:", self.accelerator.device)
|
| 396 |
+
# if self.accelerator.is_main_process:
|
| 397 |
+
if torch.cuda.current_device() == 0:
|
| 398 |
if self.config.output_dir is not None:
|
| 399 |
os.makedirs(self.config.output_dir, exist_ok=True)
|
| 400 |
if self.config.push_to_hub:
|
|
|
|
| 426 |
pbar_train = tqdm(total=len(self.dataloader), disable=not self.accelerator.is_local_main_process)
|
| 427 |
pbar_train.set_description(f"device {torch.cuda.current_device()}, Epoch {ep}")
|
| 428 |
for i, (x, c) in enumerate(self.dataloader):
|
| 429 |
+
# print(f"device {torch.cuda.current_device()}, x[:,0,:2,0,0] =", x[:,0,:2,0,0])
|
| 430 |
with self.accelerator.accumulate(self.nn_model):
|
| 431 |
x = x.to(self.config.device)
|
| 432 |
xt, noise, ts = self.ddpm.add_noise(x)
|
|
|
|
| 459 |
self.accelerator.log(logs, step=global_step)
|
| 460 |
global_step += 1
|
| 461 |
|
| 462 |
+
# if ep == config.n_epoch-1 or (ep+1)*config.save_period==1:
|
| 463 |
self.save(ep)
|
| 464 |
|
| 465 |
del self.nn_model
|
|
|
|
| 469 |
|
| 470 |
def save(self, ep):
|
| 471 |
# save model
|
| 472 |
+
# if self.accelerator.is_main_process:
|
| 473 |
+
if torch.cuda.current_device() == 0:
|
| 474 |
+
if ep == self.config.n_epoch-1 or (ep+1) % self.config.save_period == 0:
|
| 475 |
self.nn_model.eval()
|
| 476 |
with torch.no_grad():
|
| 477 |
if self.config.push_to_hub:
|
|
|
|
| 487 |
'unet_state_dict': self.nn_model.module.state_dict(),
|
| 488 |
# 'ema_unet_state_dict': self.ema_model.state_dict(),
|
| 489 |
}
|
| 490 |
+
save_name = self.config.save_name+f"-N{self.config.num_image}-epoch{ep}-device{torch.cuda.current_device()}"
|
| 491 |
+
torch.save(model_state, save_name)
|
| 492 |
+
print(f'device {torch.cuda.current_device()} saved model at ' + save_name)
|
| 493 |
# print('saved model at ' + config.save_dir + f"model_epoch_{ep}_test_{config.run_name}.pth")
|
| 494 |
|
| 495 |
# def rescale(self, value, type='params', to_ranges=[0,1]):
|
|
|
|
| 506 |
value = value * (to[1]-to[0]) + to[0]
|
| 507 |
return value
|
| 508 |
|
| 509 |
+
def sample(self, file, params:torch.tensor=None, num_new_img=192, ema=False, entire=False, save=False):
|
| 510 |
# n_sample = params.shape[0]
|
| 511 |
|
| 512 |
if params is None:
|
|
|
|
| 516 |
params_backup = params.numpy().copy()
|
| 517 |
params = self.rescale(params, self.ranges_dict['params'], to=[0,1])
|
| 518 |
|
| 519 |
+
print(f"sampling {num_new_img} images with normalized params = {params}")
|
| 520 |
+
params = params.repeat(num_new_img,1)
|
| 521 |
assert params.dim() == 2, "params must be a 2D torch.tensor"
|
| 522 |
# print("params =", params)
|
| 523 |
# print("params =", params)
|
|
|
|
| 526 |
# del self.ema_model, self.nn
|
| 527 |
# params = torch.tile(params, (n_sample,1)).to(device)
|
| 528 |
|
| 529 |
+
# nn_model = ContextUnet(n_param=self.config.n_param, image_size=self.config.HII_DIM, dim=self.config.dim, stride=self.config.stride).to(self.config.device)
|
| 530 |
if ema:
|
| 531 |
+
self.nn_model.load_state_dict(torch.load(file)['ema_unet_state_dict'])
|
| 532 |
else:
|
| 533 |
+
self.nn_model.load_state_dict(torch.load(file)['unet_state_dict'])
|
| 534 |
print(f"nn_model resumed from {file}")
|
| 535 |
# nn_model = ContextUnet(n_param=1, image_size=28)
|
| 536 |
# nn_model.train()
|
| 537 |
+
# self.nn_model.to(self.ddpm.device)
|
| 538 |
+
self.nn_model.eval()
|
| 539 |
|
| 540 |
# self.ema_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM, dim=config.dim, stride=config.stride).to(config.device)
|
| 541 |
# self.ema_model.load_state_dict(torch.load(os.path.join(config.output_dir, f"{config.resume}"))['ema_unet_state_dict'])
|
|
|
|
| 543 |
|
| 544 |
with torch.no_grad():
|
| 545 |
x_last, x_entire = self.ddpm.sample(
|
| 546 |
+
nn_model=self.nn_model,
|
| 547 |
params=params.to(self.config.device),
|
| 548 |
device=self.config.device,
|
| 549 |
guide_w=self.config.guide_w
|
| 550 |
)
|
| 551 |
|
| 552 |
+
if save:
|
| 553 |
+
# np.save(os.path.join(self.config.output_dir, f"{self.config.run_name}{'ema' if ema else ''}.npy"), x_last)
|
| 554 |
+
np.save(os.path.join(self.config.output_dir, f"Tvir{params_backup[0]}-zeta{params_backup[1]}-N{self.config.num_image}{'ema' if ema else ''}.npy"), x_last)
|
| 555 |
+
if entire:
|
| 556 |
+
np.save(os.path.join(self.config.output_dir, f"Tvir{params_backup[0]}-zeta{params_backup[1]}-N{self.config.num_image}{'ema' if ema else ''}_entire.npy"), x_last)
|
| 557 |
+
else:
|
| 558 |
+
return x_last
|
| 559 |
# %%
|
| 560 |
+
def train(rank, world_size):
|
| 561 |
config = TrainConfig()
|
| 562 |
config.world_size = world_size
|
| 563 |
|
| 564 |
ddp_setup(rank, world_size)
|
| 565 |
|
| 566 |
+
num_image_list = [100]#[200]#[1600,3200,6400,12800,25600]
|
| 567 |
for i, num_image in enumerate(num_image_list):
|
| 568 |
config.num_image = num_image
|
| 569 |
# config.world_size = world_size
|
|
|
|
| 578 |
if __name__ == "__main__":
|
| 579 |
# torch.multiprocessing.set_start_method("spawn")
|
| 580 |
# args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)
|
| 581 |
+
world_size = torch.cuda.device_count()
|
| 582 |
|
| 583 |
+
mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)
|
| 584 |
# notebook_launcher(ddpm21cm.train, num_processes=1, mixed_precision='fp16')
|
| 585 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 586 |
# %%
|
| 587 |
# print("torch.cuda.is_initialized() =", torch.cuda.is_initialized())
|
| 588 |
# print("torch.cuda.get_device_name() =", torch.cuda.get_device_name())
|
|
|
|
| 595 |
# print(torch.cuda.memory())
|
| 596 |
# print('here')
|
| 597 |
# print(torch.cuda.memory_summary())
|
|
|
|
| 598 |
# %% [markdown]
|
| 599 |
# # Sampling
|
| 600 |
|
| 601 |
# %%
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 602 |
|
| 603 |
+
def generate_samples(model, num_new_img, max_num_img_per_gpu, rank, world_size):
|
| 604 |
+
samples = []
|
| 605 |
+
for _ in ranges(num_new_img // max_num_img_per_gpu):
|
| 606 |
+
sample = model.module.sample(filename, params=torch.tensor([4.4, 131.341]), num_new_img=max_num_img_per_gpu)
|
| 607 |
+
samples.append(sample)
|
| 608 |
+
# model.sample(filename, params=torch.tensor((5.6, 19.037)), num_new_img=max_num_img_per_gpu)
|
| 609 |
+
# model.sample(filename, params=torch.tensor((4.699, 30)), num_new_img=max_num_img_per_gpu)
|
| 610 |
+
# model.sample(filename, params=torch.tensor((5.477, 200)), num_new_img=max_num_img_per_gpu)
|
| 611 |
+
# model.sample(filename, params=torch.tensor((4.8, 131.341)), num_new_img=max_num_img_per_gpu)
|
| 612 |
+
samples = np.concatenate(samples, axis=0)
|
| 613 |
+
|
| 614 |
+
samples_list = [np.empty_like(samples) for _ in range(world_size)]
|
| 615 |
+
dist.all_gather_object(samples_list, samples)
|
| 616 |
+
|
| 617 |
+
if rank == 0:
|
| 618 |
+
all_samples = np.concatenate(samples_list, axis=0)
|
| 619 |
+
return all_samples
|
| 620 |
+
else:
|
| 621 |
+
return None
|
| 622 |
+
|
| 623 |
+
|
| 624 |
+
def sample(rank, world_size, model, num_new_img, max_num_img_per_gpu, return_dict):
|
| 625 |
+
ddp_setup(rank, world_size)
|
| 626 |
+
|
| 627 |
+
samples = generate_samples(model, num_new_img, max_num_img_per_gpu, rank, world_size)
|
| 628 |
+
|
| 629 |
+
if rank == 0:
|
| 630 |
+
return_dict['samples'] = samples
|
| 631 |
+
|
| 632 |
+
dist.destroy_process_group()
|
| 633 |
+
|
| 634 |
+
|
| 635 |
+
if __name__ == "__main__":
|
| 636 |
+
world_size = torch.cuda.device_count()
|
| 637 |
+
# num_image_list = [1600,3200,6400,12800,25600]
|
| 638 |
+
num_image_list = [1000]
|
| 639 |
+
num_new_img = 12
|
| 640 |
+
max_num_img_per_gpu = 2
|
| 641 |
+
|
| 642 |
+
config = TrainConfig()
|
| 643 |
+
config.world_size = world_size
|
| 644 |
+
|
| 645 |
+
for num_image in num_image_list:
|
| 646 |
+
filename = f"./outputs/model_state-N{num_image}-epoch9-device0"
|
| 647 |
+
config.num_image = num_image
|
| 648 |
+
ddpm21cm = DDPM21CM(config)
|
| 649 |
|
| 650 |
+
manager = np.Manager()
|
| 651 |
+
return_dict = manager.dict()
|
| 652 |
|
| 653 |
+
mp.spawn(sample, args=(world_size, ddpm21cm, num_new_img, max_num_img_per_gpu, return_dict), nprocs=world_size, join=True)
|
| 654 |
|
| 655 |
+
if "samples" in return_dict:
|
| 656 |
+
samples = return_dict["samples"]
|
| 657 |
+
print(f"Generated samples shape: {samples.shape}")
|
| 658 |
|
|
|
|
| 659 |
|
| 660 |
# %%
|
| 661 |
# ls -lth outputs | head
|
load_h5.py
CHANGED
|
@@ -60,6 +60,13 @@ class Dataset4h5(Dataset):
|
|
| 60 |
self.images = torch.from_numpy(self.images)
|
| 61 |
print(f"images rescaled to [{self.images.min()}, {self.images.max()}]")
|
| 62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
cond_filter = torch.bernoulli(torch.ones(len(self.params),1)-self.drop_prob).repeat(1,self.params.shape[1]).numpy()
|
| 64 |
self.params = torch.from_numpy(self.params*cond_filter)
|
| 65 |
print(f"params rescaled to [{self.params.min()}, {self.params.max()}]")
|
|
@@ -98,12 +105,6 @@ class Dataset4h5(Dataset):
|
|
| 98 |
self.params = f['params']['values'][self.idx]
|
| 99 |
print("params loaded:", self.params.shape)
|
| 100 |
|
| 101 |
-
# print("before self.images.shape =", self.images.shape)
|
| 102 |
-
self.images = torch.ones_like(torch.from_numpy(self.images)) * torch.arange(len(self.images))[:,None,None,None,None]
|
| 103 |
-
self.images = self.images.numpy()
|
| 104 |
-
# print("after self.images.shape =", self.images.shape)
|
| 105 |
-
print(self.images[:6,0,:2,0,0])
|
| 106 |
-
# self.images = self.images.numpy()
|
| 107 |
|
| 108 |
# plt.imshow(self.images[0,0,0])
|
| 109 |
# plt.show()
|
|
|
|
| 60 |
self.images = torch.from_numpy(self.images)
|
| 61 |
print(f"images rescaled to [{self.images.min()}, {self.images.max()}]")
|
| 62 |
|
| 63 |
+
# print("before self.images.shape =", self.images.shape)
|
| 64 |
+
# self.images = torch.ones_like(self.images) * torch.arange(len(self.images))[:,None,None,None,None]
|
| 65 |
+
# # self.images = self.images.numpy()
|
| 66 |
+
# print("after self.images.shape =", self.images.shape)
|
| 67 |
+
# print(self.images[:6,0,:2,0,0])
|
| 68 |
+
# # self.images = self.images.numpy()
|
| 69 |
+
|
| 70 |
cond_filter = torch.bernoulli(torch.ones(len(self.params),1)-self.drop_prob).repeat(1,self.params.shape[1]).numpy()
|
| 71 |
self.params = torch.from_numpy(self.params*cond_filter)
|
| 72 |
print(f"params rescaled to [{self.params.min()}, {self.params.max()}]")
|
|
|
|
| 105 |
self.params = f['params']['values'][self.idx]
|
| 106 |
print("params loaded:", self.params.shape)
|
| 107 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
|
| 109 |
# plt.imshow(self.images[0,0,0])
|
| 110 |
# plt.show()
|