Xsmos commited on
Commit
d556329
·
verified ·
1 Parent(s): 3f1087b
Files changed (2) hide show
  1. diffusion.ipynb +67 -2
  2. load_h5.py +1 -1
diffusion.ipynb CHANGED
@@ -809,18 +809,83 @@
809
  "cell_type": "code",
810
  "execution_count": 1,
811
  "metadata": {},
 
 
 
 
 
 
 
 
 
 
812
  "outputs": [
813
  {
814
  "name": "stdout",
815
  "output_type": "stream",
816
  "text": [
817
- "1.12.0\n"
 
 
 
 
 
 
 
818
  ]
819
  }
820
  ],
821
  "source": [
822
  "import torch\n",
823
- "print(torch.__version__)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
824
  ]
825
  }
826
  ],
 
809
  "cell_type": "code",
810
  "execution_count": 1,
811
  "metadata": {},
812
+ "outputs": [],
813
+ "source": [
814
+ "# import torch\n",
815
+ "# print(torch.__version__)"
816
+ ]
817
+ },
818
+ {
819
+ "cell_type": "code",
820
+ "execution_count": 29,
821
+ "metadata": {},
822
  "outputs": [
823
  {
824
  "name": "stdout",
825
  "output_type": "stream",
826
  "text": [
827
+ "outputs/model_state-N500-device0\n",
828
+ "outputs/model_state-N500-device1\n",
829
+ "len(model_states) = 2\n",
830
+ "epoch\n",
831
+ "unet_state_dict\n",
832
+ "epoch 9 9\n",
833
+ "odict_keys(['token_embedding.weight', 'token_embedding.bias', 'time_embed.0.weight', 'time_embed.0.bias', 'time_embed.2.weight', 'time_embed.2.bias', 'input_blocks.0.0.weight', 'input_blocks.0.0.bias', 'input_blocks.1.0.in_layers.0.weight', 'input_blocks.1.0.in_layers.0.bias', 'input_blocks.1.0.in_layers.2.weight', 'input_blocks.1.0.in_layers.2.bias', 'input_blocks.1.0.emb_layers.1.weight', 'input_blocks.1.0.emb_layers.1.bias', 'input_blocks.1.0.out_layers.0.weight', 'input_blocks.1.0.out_layers.0.bias', 'input_blocks.1.0.out_layers.3.weight', 'input_blocks.1.0.out_layers.3.bias', 'input_blocks.1.1.norm.weight', 'input_blocks.1.1.norm.bias', 'input_blocks.1.1.qkv.weight', 'input_blocks.1.1.qkv.bias', 'input_blocks.1.1.proj_out.weight', 'input_blocks.1.1.proj_out.bias', 'input_blocks.2.0.in_layers.0.weight', 'input_blocks.2.0.in_layers.0.bias', 'input_blocks.2.0.in_layers.2.weight', 'input_blocks.2.0.in_layers.2.bias', 'input_blocks.2.0.emb_layers.1.weight', 'input_blocks.2.0.emb_layers.1.bias', 'input_blocks.2.0.out_layers.0.weight', 'input_blocks.2.0.out_layers.0.bias', 'input_blocks.2.0.out_layers.3.weight', 'input_blocks.2.0.out_layers.3.bias', 'input_blocks.2.1.norm.weight', 'input_blocks.2.1.norm.bias', 'input_blocks.2.1.qkv.weight', 'input_blocks.2.1.qkv.bias', 'input_blocks.2.1.proj_out.weight', 'input_blocks.2.1.proj_out.bias', 'input_blocks.3.0.op.weight', 'input_blocks.3.0.op.bias', 'input_blocks.4.0.in_layers.0.weight', 'input_blocks.4.0.in_layers.0.bias', 'input_blocks.4.0.in_layers.2.weight', 'input_blocks.4.0.in_layers.2.bias', 'input_blocks.4.0.emb_layers.1.weight', 'input_blocks.4.0.emb_layers.1.bias', 'input_blocks.4.0.out_layers.0.weight', 'input_blocks.4.0.out_layers.0.bias', 'input_blocks.4.0.out_layers.3.weight', 'input_blocks.4.0.out_layers.3.bias', 'input_blocks.4.0.skip_connection.weight', 'input_blocks.4.0.skip_connection.bias', 'input_blocks.5.0.in_layers.0.weight', 'input_blocks.5.0.in_layers.0.bias', 'input_blocks.5.0.in_layers.2.weight', 'input_blocks.5.0.in_layers.2.bias', 'input_blocks.5.0.emb_layers.1.weight', 'input_blocks.5.0.emb_layers.1.bias', 'input_blocks.5.0.out_layers.0.weight', 'input_blocks.5.0.out_layers.0.bias', 'input_blocks.5.0.out_layers.3.weight', 'input_blocks.5.0.out_layers.3.bias', 'input_blocks.6.0.op.weight', 'input_blocks.6.0.op.bias', 'input_blocks.7.0.in_layers.0.weight', 'input_blocks.7.0.in_layers.0.bias', 'input_blocks.7.0.in_layers.2.weight', 'input_blocks.7.0.in_layers.2.bias', 'input_blocks.7.0.emb_layers.1.weight', 'input_blocks.7.0.emb_layers.1.bias', 'input_blocks.7.0.out_layers.0.weight', 'input_blocks.7.0.out_layers.0.bias', 'input_blocks.7.0.out_layers.3.weight', 'input_blocks.7.0.out_layers.3.bias', 'input_blocks.7.0.skip_connection.weight', 'input_blocks.7.0.skip_connection.bias', 'input_blocks.8.0.in_layers.0.weight', 'input_blocks.8.0.in_layers.0.bias', 'input_blocks.8.0.in_layers.2.weight', 'input_blocks.8.0.in_layers.2.bias', 'input_blocks.8.0.emb_layers.1.weight', 'input_blocks.8.0.emb_layers.1.bias', 'input_blocks.8.0.out_layers.0.weight', 'input_blocks.8.0.out_layers.0.bias', 'input_blocks.8.0.out_layers.3.weight', 'input_blocks.8.0.out_layers.3.bias', 'middle_block.0.in_layers.0.weight', 'middle_block.0.in_layers.0.bias', 'middle_block.0.in_layers.2.weight', 'middle_block.0.in_layers.2.bias', 'middle_block.0.emb_layers.1.weight', 'middle_block.0.emb_layers.1.bias', 'middle_block.0.out_layers.0.weight', 'middle_block.0.out_layers.0.bias', 'middle_block.0.out_layers.3.weight', 'middle_block.0.out_layers.3.bias', 'middle_block.1.norm.weight', 'middle_block.1.norm.bias', 'middle_block.1.qkv.weight', 'middle_block.1.qkv.bias', 'middle_block.1.proj_out.weight', 'middle_block.1.proj_out.bias', 'middle_block.2.in_layers.0.weight', 'middle_block.2.in_layers.0.bias', 'middle_block.2.in_layers.2.weight', 'middle_block.2.in_layers.2.bias', 'middle_block.2.emb_layers.1.weight', 'middle_block.2.emb_layers.1.bias', 'middle_block.2.out_layers.0.weight', 'middle_block.2.out_layers.0.bias', 'middle_block.2.out_layers.3.weight', 'middle_block.2.out_layers.3.bias', 'output_blocks.0.0.in_layers.0.weight', 'output_blocks.0.0.in_layers.0.bias', 'output_blocks.0.0.in_layers.2.weight', 'output_blocks.0.0.in_layers.2.bias', 'output_blocks.0.0.emb_layers.1.weight', 'output_blocks.0.0.emb_layers.1.bias', 'output_blocks.0.0.out_layers.0.weight', 'output_blocks.0.0.out_layers.0.bias', 'output_blocks.0.0.out_layers.3.weight', 'output_blocks.0.0.out_layers.3.bias', 'output_blocks.0.0.skip_connection.weight', 'output_blocks.0.0.skip_connection.bias', 'output_blocks.1.0.in_layers.0.weight', 'output_blocks.1.0.in_layers.0.bias', 'output_blocks.1.0.in_layers.2.weight', 'output_blocks.1.0.in_layers.2.bias', 'output_blocks.1.0.emb_layers.1.weight', 'output_blocks.1.0.emb_layers.1.bias', 'output_blocks.1.0.out_layers.0.weight', 'output_blocks.1.0.out_layers.0.bias', 'output_blocks.1.0.out_layers.3.weight', 'output_blocks.1.0.out_layers.3.bias', 'output_blocks.1.0.skip_connection.weight', 'output_blocks.1.0.skip_connection.bias', 'output_blocks.2.0.in_layers.0.weight', 'output_blocks.2.0.in_layers.0.bias', 'output_blocks.2.0.in_layers.2.weight', 'output_blocks.2.0.in_layers.2.bias', 'output_blocks.2.0.emb_layers.1.weight', 'output_blocks.2.0.emb_layers.1.bias', 'output_blocks.2.0.out_layers.0.weight', 'output_blocks.2.0.out_layers.0.bias', 'output_blocks.2.0.out_layers.3.weight', 'output_blocks.2.0.out_layers.3.bias', 'output_blocks.2.0.skip_connection.weight', 'output_blocks.2.0.skip_connection.bias', 'output_blocks.2.1.conv.weight', 'output_blocks.2.1.conv.bias', 'output_blocks.3.0.in_layers.0.weight', 'output_blocks.3.0.in_layers.0.bias', 'output_blocks.3.0.in_layers.2.weight', 'output_blocks.3.0.in_layers.2.bias', 'output_blocks.3.0.emb_layers.1.weight', 'output_blocks.3.0.emb_layers.1.bias', 'output_blocks.3.0.out_layers.0.weight', 'output_blocks.3.0.out_layers.0.bias', 'output_blocks.3.0.out_layers.3.weight', 'output_blocks.3.0.out_layers.3.bias', 'output_blocks.3.0.skip_connection.weight', 'output_blocks.3.0.skip_connection.bias', 'output_blocks.4.0.in_layers.0.weight', 'output_blocks.4.0.in_layers.0.bias', 'output_blocks.4.0.in_layers.2.weight', 'output_blocks.4.0.in_layers.2.bias', 'output_blocks.4.0.emb_layers.1.weight', 'output_blocks.4.0.emb_layers.1.bias', 'output_blocks.4.0.out_layers.0.weight', 'output_blocks.4.0.out_layers.0.bias', 'output_blocks.4.0.out_layers.3.weight', 'output_blocks.4.0.out_layers.3.bias', 'output_blocks.4.0.skip_connection.weight', 'output_blocks.4.0.skip_connection.bias', 'output_blocks.5.0.in_layers.0.weight', 'output_blocks.5.0.in_layers.0.bias', 'output_blocks.5.0.in_layers.2.weight', 'output_blocks.5.0.in_layers.2.bias', 'output_blocks.5.0.emb_layers.1.weight', 'output_blocks.5.0.emb_layers.1.bias', 'output_blocks.5.0.out_layers.0.weight', 'output_blocks.5.0.out_layers.0.bias', 'output_blocks.5.0.out_layers.3.weight', 'output_blocks.5.0.out_layers.3.bias', 'output_blocks.5.0.skip_connection.weight', 'output_blocks.5.0.skip_connection.bias', 'output_blocks.5.1.conv.weight', 'output_blocks.5.1.conv.bias', 'output_blocks.6.0.in_layers.0.weight', 'output_blocks.6.0.in_layers.0.bias', 'output_blocks.6.0.in_layers.2.weight', 'output_blocks.6.0.in_layers.2.bias', 'output_blocks.6.0.emb_layers.1.weight', 'output_blocks.6.0.emb_layers.1.bias', 'output_blocks.6.0.out_layers.0.weight', 'output_blocks.6.0.out_layers.0.bias', 'output_blocks.6.0.out_layers.3.weight', 'output_blocks.6.0.out_layers.3.bias', 'output_blocks.6.0.skip_connection.weight', 'output_blocks.6.0.skip_connection.bias', 'output_blocks.6.1.norm.weight', 'output_blocks.6.1.norm.bias', 'output_blocks.6.1.qkv.weight', 'output_blocks.6.1.qkv.bias', 'output_blocks.6.1.proj_out.weight', 'output_blocks.6.1.proj_out.bias', 'output_blocks.7.0.in_layers.0.weight', 'output_blocks.7.0.in_layers.0.bias', 'output_blocks.7.0.in_layers.2.weight', 'output_blocks.7.0.in_layers.2.bias', 'output_blocks.7.0.emb_layers.1.weight', 'output_blocks.7.0.emb_layers.1.bias', 'output_blocks.7.0.out_layers.0.weight', 'output_blocks.7.0.out_layers.0.bias', 'output_blocks.7.0.out_layers.3.weight', 'output_blocks.7.0.out_layers.3.bias', 'output_blocks.7.0.skip_connection.weight', 'output_blocks.7.0.skip_connection.bias', 'output_blocks.7.1.norm.weight', 'output_blocks.7.1.norm.bias', 'output_blocks.7.1.qkv.weight', 'output_blocks.7.1.qkv.bias', 'output_blocks.7.1.proj_out.weight', 'output_blocks.7.1.proj_out.bias', 'output_blocks.8.0.in_layers.0.weight', 'output_blocks.8.0.in_layers.0.bias', 'output_blocks.8.0.in_layers.2.weight', 'output_blocks.8.0.in_layers.2.bias', 'output_blocks.8.0.emb_layers.1.weight', 'output_blocks.8.0.emb_layers.1.bias', 'output_blocks.8.0.out_layers.0.weight', 'output_blocks.8.0.out_layers.0.bias', 'output_blocks.8.0.out_layers.3.weight', 'output_blocks.8.0.out_layers.3.bias', 'output_blocks.8.0.skip_connection.weight', 'output_blocks.8.0.skip_connection.bias', 'output_blocks.8.1.norm.weight', 'output_blocks.8.1.norm.bias', 'output_blocks.8.1.qkv.weight', 'output_blocks.8.1.qkv.bias', 'output_blocks.8.1.proj_out.weight', 'output_blocks.8.1.proj_out.bias', 'out.0.weight', 'out.0.bias', 'out.2.weight', 'out.2.bias'])\n",
834
+ "exactly same\n"
835
  ]
836
  }
837
  ],
838
  "source": [
839
  "import torch\n",
840
+ "import os\n",
841
+ "\n",
842
+ "def compare_models(num_gpus):\n",
843
+ " model_states = []\n",
844
+ " \n",
845
+ " for gpu_id in range(num_gpus):\n",
846
+ " filename = f\"outputs/model_state-N500-device{gpu_id}\"\n",
847
+ " if os.path.exists(filename):\n",
848
+ " state_dict = torch.load(filename, map_location='cpu')\n",
849
+ " model_states.append(state_dict)\n",
850
+ " print(filename)\n",
851
+ " else:\n",
852
+ " print(f\"File {filename} not found!\")\n",
853
+ " return False\n",
854
+ " \n",
855
+ " # Compare all model state_dicts\n",
856
+ " print(\"len(model_states) =\", len(model_states))\n",
857
+ " base_state = model_states[0]\n",
858
+ " for state in model_states[1:]:\n",
859
+ " for key in base_state.keys():\n",
860
+ " # print(key, base_state[key], state[key])\n",
861
+ " print(key)\n",
862
+ " print(\"epoch\", base_state['epoch'], state['epoch'])\n",
863
+ "\n",
864
+ " print(base_state['unet_state_dict'].keys())\n",
865
+ " for key in base_state['unet_state_dict']:\n",
866
+ " # print(key)\n",
867
+ " if not torch.equal(base_state['unet_state_dict'][key], state['unet_state_dict'][key]):\n",
868
+ " print(\"different\")\n",
869
+ " return \n",
870
+ " # else:\n",
871
+ " print(\"exactly same\")\n",
872
+ "\n",
873
+ " # if key == 'epoch':\n",
874
+ " # print(base_state[key], state[key])\n",
875
+ " # else:\n",
876
+ " # print(base_state[key], state[key])\n",
877
+ " # if not torch.equal(base_state[key], state[key]):\n",
878
+ " # # if not (base_state[key] == state[key]):\n",
879
+ " # print(f\"Mismatch found in parameter {key}\")\n",
880
+ " # return False\n",
881
+ " \n",
882
+ " # print(\"All models are identical!\")\n",
883
+ " # return True\n",
884
+ "\n",
885
+ "if __name__ == \"__main__\":\n",
886
+ " # epoch_to_check = 0 # specify the epoch you want to check\n",
887
+ " num_gpus = 2 # specify the number of GPUs used in training\n",
888
+ " compare_models(num_gpus)"
889
  ]
890
  }
891
  ],
load_h5.py CHANGED
@@ -26,7 +26,7 @@ import datetime
26
  # from huggingface_hub import create_repo, upload_folder
27
 
28
  class Dataset4h5(Dataset):
29
- def __init__(self, dir_name, num_image=10, field='brightness_temp', shuffle=True, idx=None, num_redshift=512, HII_DIM=64, rescale=True, drop_prob = 0, dim=2, transform=True, ranges_dict=None):
30
  super().__init__()
31
 
32
  self.dir_name = dir_name
 
26
  # from huggingface_hub import create_repo, upload_folder
27
 
28
  class Dataset4h5(Dataset):
29
+ def __init__(self, dir_name, num_image=10, field='brightness_temp', shuffle=False, idx=None, num_redshift=512, HII_DIM=64, rescale=True, drop_prob = 0, dim=2, transform=True, ranges_dict=None):
30
  super().__init__()
31
 
32
  self.dir_name = dir_name