0711-1203
Browse files- diffusion.ipynb +67 -2
- load_h5.py +1 -1
diffusion.ipynb
CHANGED
|
@@ -809,18 +809,83 @@
|
|
| 809 |
"cell_type": "code",
|
| 810 |
"execution_count": 1,
|
| 811 |
"metadata": {},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 812 |
"outputs": [
|
| 813 |
{
|
| 814 |
"name": "stdout",
|
| 815 |
"output_type": "stream",
|
| 816 |
"text": [
|
| 817 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 818 |
]
|
| 819 |
}
|
| 820 |
],
|
| 821 |
"source": [
|
| 822 |
"import torch\n",
|
| 823 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 824 |
]
|
| 825 |
}
|
| 826 |
],
|
|
|
|
| 809 |
"cell_type": "code",
|
| 810 |
"execution_count": 1,
|
| 811 |
"metadata": {},
|
| 812 |
+
"outputs": [],
|
| 813 |
+
"source": [
|
| 814 |
+
"# import torch\n",
|
| 815 |
+
"# print(torch.__version__)"
|
| 816 |
+
]
|
| 817 |
+
},
|
| 818 |
+
{
|
| 819 |
+
"cell_type": "code",
|
| 820 |
+
"execution_count": 29,
|
| 821 |
+
"metadata": {},
|
| 822 |
"outputs": [
|
| 823 |
{
|
| 824 |
"name": "stdout",
|
| 825 |
"output_type": "stream",
|
| 826 |
"text": [
|
| 827 |
+
"outputs/model_state-N500-device0\n",
|
| 828 |
+
"outputs/model_state-N500-device1\n",
|
| 829 |
+
"len(model_states) = 2\n",
|
| 830 |
+
"epoch\n",
|
| 831 |
+
"unet_state_dict\n",
|
| 832 |
+
"epoch 9 9\n",
|
| 833 |
+
"odict_keys(['token_embedding.weight', 'token_embedding.bias', 'time_embed.0.weight', 'time_embed.0.bias', 'time_embed.2.weight', 'time_embed.2.bias', 'input_blocks.0.0.weight', 'input_blocks.0.0.bias', 'input_blocks.1.0.in_layers.0.weight', 'input_blocks.1.0.in_layers.0.bias', 'input_blocks.1.0.in_layers.2.weight', 'input_blocks.1.0.in_layers.2.bias', 'input_blocks.1.0.emb_layers.1.weight', 'input_blocks.1.0.emb_layers.1.bias', 'input_blocks.1.0.out_layers.0.weight', 'input_blocks.1.0.out_layers.0.bias', 'input_blocks.1.0.out_layers.3.weight', 'input_blocks.1.0.out_layers.3.bias', 'input_blocks.1.1.norm.weight', 'input_blocks.1.1.norm.bias', 'input_blocks.1.1.qkv.weight', 'input_blocks.1.1.qkv.bias', 'input_blocks.1.1.proj_out.weight', 'input_blocks.1.1.proj_out.bias', 'input_blocks.2.0.in_layers.0.weight', 'input_blocks.2.0.in_layers.0.bias', 'input_blocks.2.0.in_layers.2.weight', 'input_blocks.2.0.in_layers.2.bias', 'input_blocks.2.0.emb_layers.1.weight', 'input_blocks.2.0.emb_layers.1.bias', 'input_blocks.2.0.out_layers.0.weight', 'input_blocks.2.0.out_layers.0.bias', 'input_blocks.2.0.out_layers.3.weight', 'input_blocks.2.0.out_layers.3.bias', 'input_blocks.2.1.norm.weight', 'input_blocks.2.1.norm.bias', 'input_blocks.2.1.qkv.weight', 'input_blocks.2.1.qkv.bias', 'input_blocks.2.1.proj_out.weight', 'input_blocks.2.1.proj_out.bias', 'input_blocks.3.0.op.weight', 'input_blocks.3.0.op.bias', 'input_blocks.4.0.in_layers.0.weight', 'input_blocks.4.0.in_layers.0.bias', 'input_blocks.4.0.in_layers.2.weight', 'input_blocks.4.0.in_layers.2.bias', 'input_blocks.4.0.emb_layers.1.weight', 'input_blocks.4.0.emb_layers.1.bias', 'input_blocks.4.0.out_layers.0.weight', 'input_blocks.4.0.out_layers.0.bias', 'input_blocks.4.0.out_layers.3.weight', 'input_blocks.4.0.out_layers.3.bias', 'input_blocks.4.0.skip_connection.weight', 'input_blocks.4.0.skip_connection.bias', 'input_blocks.5.0.in_layers.0.weight', 'input_blocks.5.0.in_layers.0.bias', 'input_blocks.5.0.in_layers.2.weight', 'input_blocks.5.0.in_layers.2.bias', 'input_blocks.5.0.emb_layers.1.weight', 'input_blocks.5.0.emb_layers.1.bias', 'input_blocks.5.0.out_layers.0.weight', 'input_blocks.5.0.out_layers.0.bias', 'input_blocks.5.0.out_layers.3.weight', 'input_blocks.5.0.out_layers.3.bias', 'input_blocks.6.0.op.weight', 'input_blocks.6.0.op.bias', 'input_blocks.7.0.in_layers.0.weight', 'input_blocks.7.0.in_layers.0.bias', 'input_blocks.7.0.in_layers.2.weight', 'input_blocks.7.0.in_layers.2.bias', 'input_blocks.7.0.emb_layers.1.weight', 'input_blocks.7.0.emb_layers.1.bias', 'input_blocks.7.0.out_layers.0.weight', 'input_blocks.7.0.out_layers.0.bias', 'input_blocks.7.0.out_layers.3.weight', 'input_blocks.7.0.out_layers.3.bias', 'input_blocks.7.0.skip_connection.weight', 'input_blocks.7.0.skip_connection.bias', 'input_blocks.8.0.in_layers.0.weight', 'input_blocks.8.0.in_layers.0.bias', 'input_blocks.8.0.in_layers.2.weight', 'input_blocks.8.0.in_layers.2.bias', 'input_blocks.8.0.emb_layers.1.weight', 'input_blocks.8.0.emb_layers.1.bias', 'input_blocks.8.0.out_layers.0.weight', 'input_blocks.8.0.out_layers.0.bias', 'input_blocks.8.0.out_layers.3.weight', 'input_blocks.8.0.out_layers.3.bias', 'middle_block.0.in_layers.0.weight', 'middle_block.0.in_layers.0.bias', 'middle_block.0.in_layers.2.weight', 'middle_block.0.in_layers.2.bias', 'middle_block.0.emb_layers.1.weight', 'middle_block.0.emb_layers.1.bias', 'middle_block.0.out_layers.0.weight', 'middle_block.0.out_layers.0.bias', 'middle_block.0.out_layers.3.weight', 'middle_block.0.out_layers.3.bias', 'middle_block.1.norm.weight', 'middle_block.1.norm.bias', 'middle_block.1.qkv.weight', 'middle_block.1.qkv.bias', 'middle_block.1.proj_out.weight', 'middle_block.1.proj_out.bias', 'middle_block.2.in_layers.0.weight', 'middle_block.2.in_layers.0.bias', 'middle_block.2.in_layers.2.weight', 'middle_block.2.in_layers.2.bias', 'middle_block.2.emb_layers.1.weight', 'middle_block.2.emb_layers.1.bias', 'middle_block.2.out_layers.0.weight', 'middle_block.2.out_layers.0.bias', 'middle_block.2.out_layers.3.weight', 'middle_block.2.out_layers.3.bias', 'output_blocks.0.0.in_layers.0.weight', 'output_blocks.0.0.in_layers.0.bias', 'output_blocks.0.0.in_layers.2.weight', 'output_blocks.0.0.in_layers.2.bias', 'output_blocks.0.0.emb_layers.1.weight', 'output_blocks.0.0.emb_layers.1.bias', 'output_blocks.0.0.out_layers.0.weight', 'output_blocks.0.0.out_layers.0.bias', 'output_blocks.0.0.out_layers.3.weight', 'output_blocks.0.0.out_layers.3.bias', 'output_blocks.0.0.skip_connection.weight', 'output_blocks.0.0.skip_connection.bias', 'output_blocks.1.0.in_layers.0.weight', 'output_blocks.1.0.in_layers.0.bias', 'output_blocks.1.0.in_layers.2.weight', 'output_blocks.1.0.in_layers.2.bias', 'output_blocks.1.0.emb_layers.1.weight', 'output_blocks.1.0.emb_layers.1.bias', 'output_blocks.1.0.out_layers.0.weight', 'output_blocks.1.0.out_layers.0.bias', 'output_blocks.1.0.out_layers.3.weight', 'output_blocks.1.0.out_layers.3.bias', 'output_blocks.1.0.skip_connection.weight', 'output_blocks.1.0.skip_connection.bias', 'output_blocks.2.0.in_layers.0.weight', 'output_blocks.2.0.in_layers.0.bias', 'output_blocks.2.0.in_layers.2.weight', 'output_blocks.2.0.in_layers.2.bias', 'output_blocks.2.0.emb_layers.1.weight', 'output_blocks.2.0.emb_layers.1.bias', 'output_blocks.2.0.out_layers.0.weight', 'output_blocks.2.0.out_layers.0.bias', 'output_blocks.2.0.out_layers.3.weight', 'output_blocks.2.0.out_layers.3.bias', 'output_blocks.2.0.skip_connection.weight', 'output_blocks.2.0.skip_connection.bias', 'output_blocks.2.1.conv.weight', 'output_blocks.2.1.conv.bias', 'output_blocks.3.0.in_layers.0.weight', 'output_blocks.3.0.in_layers.0.bias', 'output_blocks.3.0.in_layers.2.weight', 'output_blocks.3.0.in_layers.2.bias', 'output_blocks.3.0.emb_layers.1.weight', 'output_blocks.3.0.emb_layers.1.bias', 'output_blocks.3.0.out_layers.0.weight', 'output_blocks.3.0.out_layers.0.bias', 'output_blocks.3.0.out_layers.3.weight', 'output_blocks.3.0.out_layers.3.bias', 'output_blocks.3.0.skip_connection.weight', 'output_blocks.3.0.skip_connection.bias', 'output_blocks.4.0.in_layers.0.weight', 'output_blocks.4.0.in_layers.0.bias', 'output_blocks.4.0.in_layers.2.weight', 'output_blocks.4.0.in_layers.2.bias', 'output_blocks.4.0.emb_layers.1.weight', 'output_blocks.4.0.emb_layers.1.bias', 'output_blocks.4.0.out_layers.0.weight', 'output_blocks.4.0.out_layers.0.bias', 'output_blocks.4.0.out_layers.3.weight', 'output_blocks.4.0.out_layers.3.bias', 'output_blocks.4.0.skip_connection.weight', 'output_blocks.4.0.skip_connection.bias', 'output_blocks.5.0.in_layers.0.weight', 'output_blocks.5.0.in_layers.0.bias', 'output_blocks.5.0.in_layers.2.weight', 'output_blocks.5.0.in_layers.2.bias', 'output_blocks.5.0.emb_layers.1.weight', 'output_blocks.5.0.emb_layers.1.bias', 'output_blocks.5.0.out_layers.0.weight', 'output_blocks.5.0.out_layers.0.bias', 'output_blocks.5.0.out_layers.3.weight', 'output_blocks.5.0.out_layers.3.bias', 'output_blocks.5.0.skip_connection.weight', 'output_blocks.5.0.skip_connection.bias', 'output_blocks.5.1.conv.weight', 'output_blocks.5.1.conv.bias', 'output_blocks.6.0.in_layers.0.weight', 'output_blocks.6.0.in_layers.0.bias', 'output_blocks.6.0.in_layers.2.weight', 'output_blocks.6.0.in_layers.2.bias', 'output_blocks.6.0.emb_layers.1.weight', 'output_blocks.6.0.emb_layers.1.bias', 'output_blocks.6.0.out_layers.0.weight', 'output_blocks.6.0.out_layers.0.bias', 'output_blocks.6.0.out_layers.3.weight', 'output_blocks.6.0.out_layers.3.bias', 'output_blocks.6.0.skip_connection.weight', 'output_blocks.6.0.skip_connection.bias', 'output_blocks.6.1.norm.weight', 'output_blocks.6.1.norm.bias', 'output_blocks.6.1.qkv.weight', 'output_blocks.6.1.qkv.bias', 'output_blocks.6.1.proj_out.weight', 'output_blocks.6.1.proj_out.bias', 'output_blocks.7.0.in_layers.0.weight', 'output_blocks.7.0.in_layers.0.bias', 'output_blocks.7.0.in_layers.2.weight', 'output_blocks.7.0.in_layers.2.bias', 'output_blocks.7.0.emb_layers.1.weight', 'output_blocks.7.0.emb_layers.1.bias', 'output_blocks.7.0.out_layers.0.weight', 'output_blocks.7.0.out_layers.0.bias', 'output_blocks.7.0.out_layers.3.weight', 'output_blocks.7.0.out_layers.3.bias', 'output_blocks.7.0.skip_connection.weight', 'output_blocks.7.0.skip_connection.bias', 'output_blocks.7.1.norm.weight', 'output_blocks.7.1.norm.bias', 'output_blocks.7.1.qkv.weight', 'output_blocks.7.1.qkv.bias', 'output_blocks.7.1.proj_out.weight', 'output_blocks.7.1.proj_out.bias', 'output_blocks.8.0.in_layers.0.weight', 'output_blocks.8.0.in_layers.0.bias', 'output_blocks.8.0.in_layers.2.weight', 'output_blocks.8.0.in_layers.2.bias', 'output_blocks.8.0.emb_layers.1.weight', 'output_blocks.8.0.emb_layers.1.bias', 'output_blocks.8.0.out_layers.0.weight', 'output_blocks.8.0.out_layers.0.bias', 'output_blocks.8.0.out_layers.3.weight', 'output_blocks.8.0.out_layers.3.bias', 'output_blocks.8.0.skip_connection.weight', 'output_blocks.8.0.skip_connection.bias', 'output_blocks.8.1.norm.weight', 'output_blocks.8.1.norm.bias', 'output_blocks.8.1.qkv.weight', 'output_blocks.8.1.qkv.bias', 'output_blocks.8.1.proj_out.weight', 'output_blocks.8.1.proj_out.bias', 'out.0.weight', 'out.0.bias', 'out.2.weight', 'out.2.bias'])\n",
|
| 834 |
+
"exactly same\n"
|
| 835 |
]
|
| 836 |
}
|
| 837 |
],
|
| 838 |
"source": [
|
| 839 |
"import torch\n",
|
| 840 |
+
"import os\n",
|
| 841 |
+
"\n",
|
| 842 |
+
"def compare_models(num_gpus):\n",
|
| 843 |
+
" model_states = []\n",
|
| 844 |
+
" \n",
|
| 845 |
+
" for gpu_id in range(num_gpus):\n",
|
| 846 |
+
" filename = f\"outputs/model_state-N500-device{gpu_id}\"\n",
|
| 847 |
+
" if os.path.exists(filename):\n",
|
| 848 |
+
" state_dict = torch.load(filename, map_location='cpu')\n",
|
| 849 |
+
" model_states.append(state_dict)\n",
|
| 850 |
+
" print(filename)\n",
|
| 851 |
+
" else:\n",
|
| 852 |
+
" print(f\"File {filename} not found!\")\n",
|
| 853 |
+
" return False\n",
|
| 854 |
+
" \n",
|
| 855 |
+
" # Compare all model state_dicts\n",
|
| 856 |
+
" print(\"len(model_states) =\", len(model_states))\n",
|
| 857 |
+
" base_state = model_states[0]\n",
|
| 858 |
+
" for state in model_states[1:]:\n",
|
| 859 |
+
" for key in base_state.keys():\n",
|
| 860 |
+
" # print(key, base_state[key], state[key])\n",
|
| 861 |
+
" print(key)\n",
|
| 862 |
+
" print(\"epoch\", base_state['epoch'], state['epoch'])\n",
|
| 863 |
+
"\n",
|
| 864 |
+
" print(base_state['unet_state_dict'].keys())\n",
|
| 865 |
+
" for key in base_state['unet_state_dict']:\n",
|
| 866 |
+
" # print(key)\n",
|
| 867 |
+
" if not torch.equal(base_state['unet_state_dict'][key], state['unet_state_dict'][key]):\n",
|
| 868 |
+
" print(\"different\")\n",
|
| 869 |
+
" return \n",
|
| 870 |
+
" # else:\n",
|
| 871 |
+
" print(\"exactly same\")\n",
|
| 872 |
+
"\n",
|
| 873 |
+
" # if key == 'epoch':\n",
|
| 874 |
+
" # print(base_state[key], state[key])\n",
|
| 875 |
+
" # else:\n",
|
| 876 |
+
" # print(base_state[key], state[key])\n",
|
| 877 |
+
" # if not torch.equal(base_state[key], state[key]):\n",
|
| 878 |
+
" # # if not (base_state[key] == state[key]):\n",
|
| 879 |
+
" # print(f\"Mismatch found in parameter {key}\")\n",
|
| 880 |
+
" # return False\n",
|
| 881 |
+
" \n",
|
| 882 |
+
" # print(\"All models are identical!\")\n",
|
| 883 |
+
" # return True\n",
|
| 884 |
+
"\n",
|
| 885 |
+
"if __name__ == \"__main__\":\n",
|
| 886 |
+
" # epoch_to_check = 0 # specify the epoch you want to check\n",
|
| 887 |
+
" num_gpus = 2 # specify the number of GPUs used in training\n",
|
| 888 |
+
" compare_models(num_gpus)"
|
| 889 |
]
|
| 890 |
}
|
| 891 |
],
|
load_h5.py
CHANGED
|
@@ -26,7 +26,7 @@ import datetime
|
|
| 26 |
# from huggingface_hub import create_repo, upload_folder
|
| 27 |
|
| 28 |
class Dataset4h5(Dataset):
|
| 29 |
-
def __init__(self, dir_name, num_image=10, field='brightness_temp', shuffle=
|
| 30 |
super().__init__()
|
| 31 |
|
| 32 |
self.dir_name = dir_name
|
|
|
|
| 26 |
# from huggingface_hub import create_repo, upload_folder
|
| 27 |
|
| 28 |
class Dataset4h5(Dataset):
|
| 29 |
+
def __init__(self, dir_name, num_image=10, field='brightness_temp', shuffle=False, idx=None, num_redshift=512, HII_DIM=64, rescale=True, drop_prob = 0, dim=2, transform=True, ranges_dict=None):
|
| 30 |
super().__init__()
|
| 31 |
|
| 32 |
self.dir_name = dir_name
|