Xsmos commited on
Commit
316b361
·
verified ·
1 Parent(s): 144694c
Files changed (3) hide show
  1. diffusion.ipynb +106 -67
  2. diffusion.py +89 -59
  3. load_h5.py +7 -6
diffusion.ipynb CHANGED
@@ -33,7 +33,7 @@
33
  },
34
  {
35
  "cell_type": "code",
36
- "execution_count": 1,
37
  "metadata": {},
38
  "outputs": [],
39
  "source": [
@@ -77,7 +77,7 @@
77
  },
78
  {
79
  "cell_type": "code",
80
- "execution_count": 2,
81
  "metadata": {},
82
  "outputs": [],
83
  "source": [
@@ -95,11 +95,26 @@
95
  },
96
  {
97
  "cell_type": "code",
98
- "execution_count": 3,
99
  "metadata": {},
100
- "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  "source": [
102
- "# notebook_login()"
103
  ]
104
  },
105
  {
@@ -817,75 +832,99 @@
817
  },
818
  {
819
  "cell_type": "code",
820
- "execution_count": 29,
821
  "metadata": {},
822
- "outputs": [
823
- {
824
- "name": "stdout",
825
- "output_type": "stream",
826
- "text": [
827
- "outputs/model_state-N500-device0\n",
828
- "outputs/model_state-N500-device1\n",
829
- "len(model_states) = 2\n",
830
- "epoch\n",
831
- "unet_state_dict\n",
832
- "epoch 9 9\n",
833
- "odict_keys(['token_embedding.weight', 'token_embedding.bias', 'time_embed.0.weight', 'time_embed.0.bias', 'time_embed.2.weight', 'time_embed.2.bias', 'input_blocks.0.0.weight', 'input_blocks.0.0.bias', 'input_blocks.1.0.in_layers.0.weight', 'input_blocks.1.0.in_layers.0.bias', 'input_blocks.1.0.in_layers.2.weight', 'input_blocks.1.0.in_layers.2.bias', 'input_blocks.1.0.emb_layers.1.weight', 'input_blocks.1.0.emb_layers.1.bias', 'input_blocks.1.0.out_layers.0.weight', 'input_blocks.1.0.out_layers.0.bias', 'input_blocks.1.0.out_layers.3.weight', 'input_blocks.1.0.out_layers.3.bias', 'input_blocks.1.1.norm.weight', 'input_blocks.1.1.norm.bias', 'input_blocks.1.1.qkv.weight', 'input_blocks.1.1.qkv.bias', 'input_blocks.1.1.proj_out.weight', 'input_blocks.1.1.proj_out.bias', 'input_blocks.2.0.in_layers.0.weight', 'input_blocks.2.0.in_layers.0.bias', 'input_blocks.2.0.in_layers.2.weight', 'input_blocks.2.0.in_layers.2.bias', 'input_blocks.2.0.emb_layers.1.weight', 'input_blocks.2.0.emb_layers.1.bias', 'input_blocks.2.0.out_layers.0.weight', 'input_blocks.2.0.out_layers.0.bias', 'input_blocks.2.0.out_layers.3.weight', 'input_blocks.2.0.out_layers.3.bias', 'input_blocks.2.1.norm.weight', 'input_blocks.2.1.norm.bias', 'input_blocks.2.1.qkv.weight', 'input_blocks.2.1.qkv.bias', 'input_blocks.2.1.proj_out.weight', 'input_blocks.2.1.proj_out.bias', 'input_blocks.3.0.op.weight', 'input_blocks.3.0.op.bias', 'input_blocks.4.0.in_layers.0.weight', 'input_blocks.4.0.in_layers.0.bias', 'input_blocks.4.0.in_layers.2.weight', 'input_blocks.4.0.in_layers.2.bias', 'input_blocks.4.0.emb_layers.1.weight', 'input_blocks.4.0.emb_layers.1.bias', 'input_blocks.4.0.out_layers.0.weight', 'input_blocks.4.0.out_layers.0.bias', 'input_blocks.4.0.out_layers.3.weight', 'input_blocks.4.0.out_layers.3.bias', 'input_blocks.4.0.skip_connection.weight', 'input_blocks.4.0.skip_connection.bias', 'input_blocks.5.0.in_layers.0.weight', 'input_blocks.5.0.in_layers.0.bias', 'input_blocks.5.0.in_layers.2.weight', 'input_blocks.5.0.in_layers.2.bias', 'input_blocks.5.0.emb_layers.1.weight', 'input_blocks.5.0.emb_layers.1.bias', 'input_blocks.5.0.out_layers.0.weight', 'input_blocks.5.0.out_layers.0.bias', 'input_blocks.5.0.out_layers.3.weight', 'input_blocks.5.0.out_layers.3.bias', 'input_blocks.6.0.op.weight', 'input_blocks.6.0.op.bias', 'input_blocks.7.0.in_layers.0.weight', 'input_blocks.7.0.in_layers.0.bias', 'input_blocks.7.0.in_layers.2.weight', 'input_blocks.7.0.in_layers.2.bias', 'input_blocks.7.0.emb_layers.1.weight', 'input_blocks.7.0.emb_layers.1.bias', 'input_blocks.7.0.out_layers.0.weight', 'input_blocks.7.0.out_layers.0.bias', 'input_blocks.7.0.out_layers.3.weight', 'input_blocks.7.0.out_layers.3.bias', 'input_blocks.7.0.skip_connection.weight', 'input_blocks.7.0.skip_connection.bias', 'input_blocks.8.0.in_layers.0.weight', 'input_blocks.8.0.in_layers.0.bias', 'input_blocks.8.0.in_layers.2.weight', 'input_blocks.8.0.in_layers.2.bias', 'input_blocks.8.0.emb_layers.1.weight', 'input_blocks.8.0.emb_layers.1.bias', 'input_blocks.8.0.out_layers.0.weight', 'input_blocks.8.0.out_layers.0.bias', 'input_blocks.8.0.out_layers.3.weight', 'input_blocks.8.0.out_layers.3.bias', 'middle_block.0.in_layers.0.weight', 'middle_block.0.in_layers.0.bias', 'middle_block.0.in_layers.2.weight', 'middle_block.0.in_layers.2.bias', 'middle_block.0.emb_layers.1.weight', 'middle_block.0.emb_layers.1.bias', 'middle_block.0.out_layers.0.weight', 'middle_block.0.out_layers.0.bias', 'middle_block.0.out_layers.3.weight', 'middle_block.0.out_layers.3.bias', 'middle_block.1.norm.weight', 'middle_block.1.norm.bias', 'middle_block.1.qkv.weight', 'middle_block.1.qkv.bias', 'middle_block.1.proj_out.weight', 'middle_block.1.proj_out.bias', 'middle_block.2.in_layers.0.weight', 'middle_block.2.in_layers.0.bias', 'middle_block.2.in_layers.2.weight', 'middle_block.2.in_layers.2.bias', 'middle_block.2.emb_layers.1.weight', 'middle_block.2.emb_layers.1.bias', 'middle_block.2.out_layers.0.weight', 'middle_block.2.out_layers.0.bias', 'middle_block.2.out_layers.3.weight', 'middle_block.2.out_layers.3.bias', 'output_blocks.0.0.in_layers.0.weight', 'output_blocks.0.0.in_layers.0.bias', 'output_blocks.0.0.in_layers.2.weight', 'output_blocks.0.0.in_layers.2.bias', 'output_blocks.0.0.emb_layers.1.weight', 'output_blocks.0.0.emb_layers.1.bias', 'output_blocks.0.0.out_layers.0.weight', 'output_blocks.0.0.out_layers.0.bias', 'output_blocks.0.0.out_layers.3.weight', 'output_blocks.0.0.out_layers.3.bias', 'output_blocks.0.0.skip_connection.weight', 'output_blocks.0.0.skip_connection.bias', 'output_blocks.1.0.in_layers.0.weight', 'output_blocks.1.0.in_layers.0.bias', 'output_blocks.1.0.in_layers.2.weight', 'output_blocks.1.0.in_layers.2.bias', 'output_blocks.1.0.emb_layers.1.weight', 'output_blocks.1.0.emb_layers.1.bias', 'output_blocks.1.0.out_layers.0.weight', 'output_blocks.1.0.out_layers.0.bias', 'output_blocks.1.0.out_layers.3.weight', 'output_blocks.1.0.out_layers.3.bias', 'output_blocks.1.0.skip_connection.weight', 'output_blocks.1.0.skip_connection.bias', 'output_blocks.2.0.in_layers.0.weight', 'output_blocks.2.0.in_layers.0.bias', 'output_blocks.2.0.in_layers.2.weight', 'output_blocks.2.0.in_layers.2.bias', 'output_blocks.2.0.emb_layers.1.weight', 'output_blocks.2.0.emb_layers.1.bias', 'output_blocks.2.0.out_layers.0.weight', 'output_blocks.2.0.out_layers.0.bias', 'output_blocks.2.0.out_layers.3.weight', 'output_blocks.2.0.out_layers.3.bias', 'output_blocks.2.0.skip_connection.weight', 'output_blocks.2.0.skip_connection.bias', 'output_blocks.2.1.conv.weight', 'output_blocks.2.1.conv.bias', 'output_blocks.3.0.in_layers.0.weight', 'output_blocks.3.0.in_layers.0.bias', 'output_blocks.3.0.in_layers.2.weight', 'output_blocks.3.0.in_layers.2.bias', 'output_blocks.3.0.emb_layers.1.weight', 'output_blocks.3.0.emb_layers.1.bias', 'output_blocks.3.0.out_layers.0.weight', 'output_blocks.3.0.out_layers.0.bias', 'output_blocks.3.0.out_layers.3.weight', 'output_blocks.3.0.out_layers.3.bias', 'output_blocks.3.0.skip_connection.weight', 'output_blocks.3.0.skip_connection.bias', 'output_blocks.4.0.in_layers.0.weight', 'output_blocks.4.0.in_layers.0.bias', 'output_blocks.4.0.in_layers.2.weight', 'output_blocks.4.0.in_layers.2.bias', 'output_blocks.4.0.emb_layers.1.weight', 'output_blocks.4.0.emb_layers.1.bias', 'output_blocks.4.0.out_layers.0.weight', 'output_blocks.4.0.out_layers.0.bias', 'output_blocks.4.0.out_layers.3.weight', 'output_blocks.4.0.out_layers.3.bias', 'output_blocks.4.0.skip_connection.weight', 'output_blocks.4.0.skip_connection.bias', 'output_blocks.5.0.in_layers.0.weight', 'output_blocks.5.0.in_layers.0.bias', 'output_blocks.5.0.in_layers.2.weight', 'output_blocks.5.0.in_layers.2.bias', 'output_blocks.5.0.emb_layers.1.weight', 'output_blocks.5.0.emb_layers.1.bias', 'output_blocks.5.0.out_layers.0.weight', 'output_blocks.5.0.out_layers.0.bias', 'output_blocks.5.0.out_layers.3.weight', 'output_blocks.5.0.out_layers.3.bias', 'output_blocks.5.0.skip_connection.weight', 'output_blocks.5.0.skip_connection.bias', 'output_blocks.5.1.conv.weight', 'output_blocks.5.1.conv.bias', 'output_blocks.6.0.in_layers.0.weight', 'output_blocks.6.0.in_layers.0.bias', 'output_blocks.6.0.in_layers.2.weight', 'output_blocks.6.0.in_layers.2.bias', 'output_blocks.6.0.emb_layers.1.weight', 'output_blocks.6.0.emb_layers.1.bias', 'output_blocks.6.0.out_layers.0.weight', 'output_blocks.6.0.out_layers.0.bias', 'output_blocks.6.0.out_layers.3.weight', 'output_blocks.6.0.out_layers.3.bias', 'output_blocks.6.0.skip_connection.weight', 'output_blocks.6.0.skip_connection.bias', 'output_blocks.6.1.norm.weight', 'output_blocks.6.1.norm.bias', 'output_blocks.6.1.qkv.weight', 'output_blocks.6.1.qkv.bias', 'output_blocks.6.1.proj_out.weight', 'output_blocks.6.1.proj_out.bias', 'output_blocks.7.0.in_layers.0.weight', 'output_blocks.7.0.in_layers.0.bias', 'output_blocks.7.0.in_layers.2.weight', 'output_blocks.7.0.in_layers.2.bias', 'output_blocks.7.0.emb_layers.1.weight', 'output_blocks.7.0.emb_layers.1.bias', 'output_blocks.7.0.out_layers.0.weight', 'output_blocks.7.0.out_layers.0.bias', 'output_blocks.7.0.out_layers.3.weight', 'output_blocks.7.0.out_layers.3.bias', 'output_blocks.7.0.skip_connection.weight', 'output_blocks.7.0.skip_connection.bias', 'output_blocks.7.1.norm.weight', 'output_blocks.7.1.norm.bias', 'output_blocks.7.1.qkv.weight', 'output_blocks.7.1.qkv.bias', 'output_blocks.7.1.proj_out.weight', 'output_blocks.7.1.proj_out.bias', 'output_blocks.8.0.in_layers.0.weight', 'output_blocks.8.0.in_layers.0.bias', 'output_blocks.8.0.in_layers.2.weight', 'output_blocks.8.0.in_layers.2.bias', 'output_blocks.8.0.emb_layers.1.weight', 'output_blocks.8.0.emb_layers.1.bias', 'output_blocks.8.0.out_layers.0.weight', 'output_blocks.8.0.out_layers.0.bias', 'output_blocks.8.0.out_layers.3.weight', 'output_blocks.8.0.out_layers.3.bias', 'output_blocks.8.0.skip_connection.weight', 'output_blocks.8.0.skip_connection.bias', 'output_blocks.8.1.norm.weight', 'output_blocks.8.1.norm.bias', 'output_blocks.8.1.qkv.weight', 'output_blocks.8.1.qkv.bias', 'output_blocks.8.1.proj_out.weight', 'output_blocks.8.1.proj_out.bias', 'out.0.weight', 'out.0.bias', 'out.2.weight', 'out.2.bias'])\n",
834
- "exactly same\n"
835
- ]
836
- }
837
- ],
838
  "source": [
839
- "import torch\n",
840
- "import os\n",
841
  "\n",
842
- "def compare_models(num_gpus):\n",
843
- " model_states = []\n",
844
  " \n",
845
- " for gpu_id in range(num_gpus):\n",
846
- " filename = f\"outputs/model_state-N500-device{gpu_id}\"\n",
847
- " if os.path.exists(filename):\n",
848
- " state_dict = torch.load(filename, map_location='cpu')\n",
849
- " model_states.append(state_dict)\n",
850
- " print(filename)\n",
851
- " else:\n",
852
- " print(f\"File {filename} not found!\")\n",
853
- " return False\n",
854
  " \n",
855
- " # Compare all model state_dicts\n",
856
- " print(\"len(model_states) =\", len(model_states))\n",
857
- " base_state = model_states[0]\n",
858
- " for state in model_states[1:]:\n",
859
- " for key in base_state.keys():\n",
860
- " # print(key, base_state[key], state[key])\n",
861
- " print(key)\n",
862
- " print(\"epoch\", base_state['epoch'], state['epoch'])\n",
863
- "\n",
864
- " print(base_state['unet_state_dict'].keys())\n",
865
- " for key in base_state['unet_state_dict']:\n",
866
- " # print(key)\n",
867
- " if not torch.equal(base_state['unet_state_dict'][key], state['unet_state_dict'][key]):\n",
868
- " print(\"different\")\n",
869
- " return \n",
870
- " # else:\n",
871
- " print(\"exactly same\")\n",
872
- "\n",
873
- " # if key == 'epoch':\n",
874
- " # print(base_state[key], state[key])\n",
875
- " # else:\n",
876
- " # print(base_state[key], state[key])\n",
877
- " # if not torch.equal(base_state[key], state[key]):\n",
878
- " # # if not (base_state[key] == state[key]):\n",
879
- " # print(f\"Mismatch found in parameter {key}\")\n",
880
- " # return False\n",
881
  " \n",
882
- " # print(\"All models are identical!\")\n",
883
- " # return True\n",
884
  "\n",
885
- "if __name__ == \"__main__\":\n",
886
- " # epoch_to_check = 0 # specify the epoch you want to check\n",
887
- " num_gpus = 2 # specify the number of GPUs used in training\n",
888
- " compare_models(num_gpus)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
889
  ]
890
  }
891
  ],
 
33
  },
34
  {
35
  "cell_type": "code",
36
+ "execution_count": 31,
37
  "metadata": {},
38
  "outputs": [],
39
  "source": [
 
77
  },
78
  {
79
  "cell_type": "code",
80
+ "execution_count": 32,
81
  "metadata": {},
82
  "outputs": [],
83
  "source": [
 
95
  },
96
  {
97
  "cell_type": "code",
98
+ "execution_count": 34,
99
  "metadata": {},
100
+ "outputs": [
101
+ {
102
+ "data": {
103
+ "application/vnd.jupyter.widget-view+json": {
104
+ "model_id": "9bbf7e9db9ce426d9c59d6f6d8e8df29",
105
+ "version_major": 2,
106
+ "version_minor": 0
107
+ },
108
+ "text/plain": [
109
+ "VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…"
110
+ ]
111
+ },
112
+ "metadata": {},
113
+ "output_type": "display_data"
114
+ }
115
+ ],
116
  "source": [
117
+ "notebook_login()"
118
  ]
119
  },
120
  {
 
832
  },
833
  {
834
  "cell_type": "code",
835
+ "execution_count": 9,
836
  "metadata": {},
837
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
838
  "source": [
839
+ "# import torch\n",
840
+ "# import os\n",
841
  "\n",
842
+ "# def compare_models(num_gpus):\n",
843
+ "# model_states = []\n",
844
  " \n",
845
+ "# for gpu_id in range(num_gpus):\n",
846
+ "# filename = f\"outputs/model_state-N40-device{gpu_id}\"\n",
847
+ "# if os.path.exists(filename):\n",
848
+ "# state_dict = torch.load(filename, map_location='cpu')\n",
849
+ "# model_states.append(state_dict)\n",
850
+ "# print(filename)\n",
851
+ "# else:\n",
852
+ "# print(f\"File {filename} not found!\")\n",
853
+ "# return False\n",
854
  " \n",
855
+ "# # Compare all model state_dicts\n",
856
+ "# print(\"len(model_states) =\", len(model_states))\n",
857
+ "# base_state = model_states[0]\n",
858
+ "# for state in model_states[1:]:\n",
859
+ "# for key in base_state.keys():\n",
860
+ "# # print(key, base_state[key], state[key])\n",
861
+ "# print(key)\n",
862
+ "# print(\"epoch\", base_state['epoch'], state['epoch'])\n",
863
+ "\n",
864
+ "# print(base_state['unet_state_dict'].keys())\n",
865
+ "# for key in base_state['unet_state_dict']:\n",
866
+ "# # print(key)\n",
867
+ "# if not torch.equal(base_state['unet_state_dict'][key], state['unet_state_dict'][key]):\n",
868
+ "# print(\"different\")\n",
869
+ "# return \n",
870
+ "# # else:\n",
871
+ "# print(\"exactly same\")\n",
872
+ "\n",
873
+ "# # if key == 'epoch':\n",
874
+ "# # print(base_state[key], state[key])\n",
875
+ "# # else:\n",
876
+ "# # print(base_state[key], state[key])\n",
877
+ "# # if not torch.equal(base_state[key], state[key]):\n",
878
+ "# # # if not (base_state[key] == state[key]):\n",
879
+ "# # print(f\"Mismatch found in parameter {key}\")\n",
880
+ "# # return False\n",
881
  " \n",
882
+ "# # print(\"All models are identical!\")\n",
883
+ "# # return True\n",
884
  "\n",
885
+ "# if __name__ == \"__main__\":\n",
886
+ "# # epoch_to_check = 0 # specify the epoch you want to check\n",
887
+ "# num_gpus = torch.cuda.device_count() # specify the number of GPUs used in training\n",
888
+ "# compare_models(num_gpus)"
889
+ ]
890
+ },
891
+ {
892
+ "cell_type": "code",
893
+ "execution_count": 6,
894
+ "metadata": {},
895
+ "outputs": [],
896
+ "source": [
897
+ "import numpy as np\n",
898
+ "test = np.random.normal(0,1,(800,1,64,64,512))"
899
+ ]
900
+ },
901
+ {
902
+ "cell_type": "code",
903
+ "execution_count": 7,
904
+ "metadata": {},
905
+ "outputs": [
906
+ {
907
+ "data": {
908
+ "text/plain": [
909
+ "12.5"
910
+ ]
911
+ },
912
+ "execution_count": 7,
913
+ "metadata": {},
914
+ "output_type": "execute_result"
915
+ }
916
+ ],
917
+ "source": [
918
+ "(test.itemsize*test.size) / 1024/1024/1024"
919
+ ]
920
+ },
921
+ {
922
+ "cell_type": "code",
923
+ "execution_count": 8,
924
+ "metadata": {},
925
+ "outputs": [],
926
+ "source": [
927
+ "del test"
928
  ]
929
  }
930
  ],
diffusion.py CHANGED
@@ -61,6 +61,7 @@ import torch.multiprocessing as mp
61
  from torch.utils.data.distributed import DistributedSampler
62
  from torch.nn.parallel import DistributedDataParallel as DDP
63
  from torch.distributed import init_process_group, destroy_process_group
 
64
 
65
  # %%
66
  def ddp_setup(rank: int, world_size: int):
@@ -180,14 +181,12 @@ class DDPMScheduler(nn.Module):
180
  x_i = 1/torch.sqrt(self.alpha_t[i])*(x_i-eps*self.beta_t[i]/torch.sqrt(1-self.bar_alpha_t[i])) + torch.sqrt(self.beta_t[i])*z
181
 
182
  pbar_sample.update(1)
183
- # pbar_sample.set_postfix(step=i)
184
 
185
- # print("x_i.shape =", x_i.shape)
186
  # store only part of the intermediate steps
187
- if i%20==0:# or i==0:# or i<8:
188
- x_i_entire.append(x_i.detach().cpu().numpy())
189
- x_i = x_i.detach().cpu().numpy()
190
  x_i_entire = np.array(x_i_entire)
 
191
  return x_i, x_i_entire
192
 
193
 
@@ -225,7 +224,7 @@ class TrainConfig:
225
  ###########################
226
  ## hardcoding these here ##
227
  ###########################
228
- push_to_hub = True
229
  hub_model_id = "Xsmos/ml21cm"
230
  hub_private_repo = False
231
  dataset_name = "/storage/home/hcoda1/3/bxia34/scratch/LEN128-DIM64-CUB8.h5"
@@ -265,14 +264,14 @@ class TrainConfig:
265
  # seed = 0
266
  # save_dir = './outputs/'
267
 
268
- save_freq = 0#.1 # the period of sampling
269
  # general parameters for the name and logger
270
  # device = "cuda" if torch.cuda.is_available() else "cpu"
271
  lrate = 1e-4
272
  lr_warmup_steps = 0#5#00
273
  output_dir = "./outputs/"
274
  save_name = os.path.join(output_dir, 'model_state')
275
- # save_freq = 1 #10 # the period of saving model
276
  # cond = True # if training using the conditional information
277
  # lr_decay = False #True# if using the learning rate decay
278
  resume = save_name # if resume from the trained checkpoints
@@ -394,8 +393,8 @@ class DDPM21CM:
394
  # distributed_type="MULTI_GPU",
395
  )
396
  # print("!!!!!!!!!!!!!!!!!!!self.accelerator.device:", self.accelerator.device)
397
- if self.accelerator.is_main_process:
398
- # if torch.cuda.current_device() == 0:
399
  if self.config.output_dir is not None:
400
  os.makedirs(self.config.output_dir, exist_ok=True)
401
  if self.config.push_to_hub:
@@ -427,7 +426,7 @@ class DDPM21CM:
427
  pbar_train = tqdm(total=len(self.dataloader), disable=not self.accelerator.is_local_main_process)
428
  pbar_train.set_description(f"device {torch.cuda.current_device()}, Epoch {ep}")
429
  for i, (x, c) in enumerate(self.dataloader):
430
- print(f"device {torch.cuda.current_device()}, x[:10,0,:2,:2,:2] =", x[:10,0,:2,:2,:2])
431
  with self.accelerator.accumulate(self.nn_model):
432
  x = x.to(self.config.device)
433
  xt, noise, ts = self.ddpm.add_noise(x)
@@ -460,7 +459,7 @@ class DDPM21CM:
460
  self.accelerator.log(logs, step=global_step)
461
  global_step += 1
462
 
463
- # if ep == config.n_epoch-1 or (ep+1)*config.save_freq==1:
464
  self.save(ep)
465
 
466
  del self.nn_model
@@ -470,9 +469,9 @@ class DDPM21CM:
470
 
471
  def save(self, ep):
472
  # save model
473
- if self.accelerator.is_main_process:
474
- # if torch.cuda.current_device() == 0:
475
- if ep == self.config.n_epoch-1 or (ep+1)*self.config.save_freq==1:
476
  self.nn_model.eval()
477
  with torch.no_grad():
478
  if self.config.push_to_hub:
@@ -488,8 +487,9 @@ class DDPM21CM:
488
  'unet_state_dict': self.nn_model.module.state_dict(),
489
  # 'ema_unet_state_dict': self.ema_model.state_dict(),
490
  }
491
- torch.save(model_state, self.config.save_name+f"-N{self.config.num_image}-device{torch.cuda.current_device()}")
492
- print(f'device {torch.cuda.current_device()} saved model at ' + self.config.save_name+f"-N{self.config.num_image}-device{torch.cuda.current_device()}")
 
493
  # print('saved model at ' + config.save_dir + f"model_epoch_{ep}_test_{config.run_name}.pth")
494
 
495
  # def rescale(self, value, type='params', to_ranges=[0,1]):
@@ -506,7 +506,7 @@ class DDPM21CM:
506
  value = value * (to[1]-to[0]) + to[0]
507
  return value
508
 
509
- def sample(self, file, params:torch.tensor=None, repeat=192, ema=False, entire=False):
510
  # n_sample = params.shape[0]
511
 
512
  if params is None:
@@ -516,8 +516,8 @@ class DDPM21CM:
516
  params_backup = params.numpy().copy()
517
  params = self.rescale(params, self.ranges_dict['params'], to=[0,1])
518
 
519
- print(f"sampling {repeat} images with normalized params = {params}")
520
- params = params.repeat(repeat,1)
521
  assert params.dim() == 2, "params must be a 2D torch.tensor"
522
  # print("params =", params)
523
  # print("params =", params)
@@ -526,16 +526,16 @@ class DDPM21CM:
526
  # del self.ema_model, self.nn
527
  # params = torch.tile(params, (n_sample,1)).to(device)
528
 
529
- nn_model = ContextUnet(n_param=self.config.n_param, image_size=self.config.HII_DIM, dim=self.config.dim, stride=self.config.stride).to(self.config.device)
530
  if ema:
531
- nn_model.load_state_dict(torch.load(file)['ema_unet_state_dict'])
532
  else:
533
- nn_model.load_state_dict(torch.load(file)['unet_state_dict'])
534
  print(f"nn_model resumed from {file}")
535
  # nn_model = ContextUnet(n_param=1, image_size=28)
536
  # nn_model.train()
537
- nn_model.to(self.ddpm.device)
538
- nn_model.eval()
539
 
540
  # self.ema_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM, dim=config.dim, stride=config.stride).to(config.device)
541
  # self.ema_model.load_state_dict(torch.load(os.path.join(config.output_dir, f"{config.resume}"))['ema_unet_state_dict'])
@@ -543,27 +543,27 @@ class DDPM21CM:
543
 
544
  with torch.no_grad():
545
  x_last, x_entire = self.ddpm.sample(
546
- nn_model=nn_model,
547
  params=params.to(self.config.device),
548
  device=self.config.device,
549
  guide_w=self.config.guide_w
550
  )
551
 
552
- # np.save(os.path.join(self.config.output_dir, f"{self.config.run_name}{'ema' if ema else ''}.npy"), x_last)
553
- np.save(os.path.join(self.config.output_dir, f"Tvir{params_backup[0]}-zeta{params_backup[1]}-N{self.config.num_image}{'ema' if ema else ''}.npy"), x_last)
554
-
555
- if entire:
556
- np.save(os.path.join(self.config.output_dir, f"Tvir{params_backup[0]}-zeta{params_backup[1]}-N{self.config.num_image}{'ema' if ema else ''}_entire.npy"), x_last)
557
- # print("device =", config.device)
558
-
559
  # %%
560
- def main(rank, world_size):
561
  config = TrainConfig()
562
  config.world_size = world_size
563
 
564
  ddp_setup(rank, world_size)
565
 
566
- num_image_list = [500]#[200]#[1600,3200,6400,12800,25600]
567
  for i, num_image in enumerate(num_image_list):
568
  config.num_image = num_image
569
  # config.world_size = world_size
@@ -578,17 +578,11 @@ def main(rank, world_size):
578
  if __name__ == "__main__":
579
  # torch.multiprocessing.set_start_method("spawn")
580
  # args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)
581
- world_size = 2#torch.cuda.device_count()
582
 
583
- mp.spawn(main, args=(world_size,), nprocs=world_size, join=True)
584
  # notebook_launcher(ddpm21cm.train, num_processes=1, mixed_precision='fp16')
585
 
586
- # %%
587
- # torch.cuda.set_device(0)
588
-
589
- # %%
590
- # print(torch.cuda.__dir__())
591
-
592
  # %%
593
  # print("torch.cuda.is_initialized() =", torch.cuda.is_initialized())
594
  # print("torch.cuda.get_device_name() =", torch.cuda.get_device_name())
@@ -601,31 +595,67 @@ if __name__ == "__main__":
601
  # print(torch.cuda.memory())
602
  # print('here')
603
  # print(torch.cuda.memory_summary())
604
-
605
  # %% [markdown]
606
  # # Sampling
607
 
608
  # %%
609
- # if __name__ == "__main__":
610
- # # num_image_list = [1600,3200,6400,12800,25600]
611
- # num_image_list = [1000]
612
- # # num_image_list = [3200,6400,12800,25600]
613
- # # args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)
614
- # repeat = 2
615
- # config = TrainConfig()
616
- # for i, num_image in enumerate(num_image_list):
617
- # config.num_image = num_image
618
- # ddpm21cm = DDPM21CM(config)
619
 
620
- # ddpm21cm.sample(f"./outputs/model_state-N{num_image}", params=torch.tensor([4.4, 131.341]), repeat=repeat)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
621
 
622
- # # ddpm21cm.sample(f"./outputs/model_state-N{num_image}", params=torch.tensor((5.6, 19.037)), repeat=repeat)
 
623
 
624
- # # ddpm21cm.sample(f"./outputs/model_state-N{num_image}", params=torch.tensor((4.699, 30)), repeat=repeat)
625
 
626
- # # ddpm21cm.sample(f"./outputs/model_state-N{num_image}", params=torch.tensor((5.477, 200)), repeat=repeat)
 
 
627
 
628
- # # ddpm21cm.sample(f"./outputs/model_state-N{num_image}", params=torch.tensor((4.8, 131.341)), repeat=repeat)
629
 
630
  # %%
631
  # ls -lth outputs | head
 
61
  from torch.utils.data.distributed import DistributedSampler
62
  from torch.nn.parallel import DistributedDataParallel as DDP
63
  from torch.distributed import init_process_group, destroy_process_group
64
+ import torch.distributed as dist
65
 
66
  # %%
67
  def ddp_setup(rank: int, world_size: int):
 
181
  x_i = 1/torch.sqrt(self.alpha_t[i])*(x_i-eps*self.beta_t[i]/torch.sqrt(1-self.bar_alpha_t[i])) + torch.sqrt(self.beta_t[i])*z
182
 
183
  pbar_sample.update(1)
 
184
 
 
185
  # store only part of the intermediate steps
186
+ # if i%20==0:# or i==0:# or i<8:
187
+ # x_i_entire.append(x_i.detach().cpu().numpy())
 
188
  x_i_entire = np.array(x_i_entire)
189
+ x_i = x_i.detach().cpu().numpy()
190
  return x_i, x_i_entire
191
 
192
 
 
224
  ###########################
225
  ## hardcoding these here ##
226
  ###########################
227
+ push_to_hub = True
228
  hub_model_id = "Xsmos/ml21cm"
229
  hub_private_repo = False
230
  dataset_name = "/storage/home/hcoda1/3/bxia34/scratch/LEN128-DIM64-CUB8.h5"
 
264
  # seed = 0
265
  # save_dir = './outputs/'
266
 
267
+ save_period = np.infty#.1 # the period of sampling
268
  # general parameters for the name and logger
269
  # device = "cuda" if torch.cuda.is_available() else "cpu"
270
  lrate = 1e-4
271
  lr_warmup_steps = 0#5#00
272
  output_dir = "./outputs/"
273
  save_name = os.path.join(output_dir, 'model_state')
274
+ # save_period = 1 #10 # the period of saving model
275
  # cond = True # if training using the conditional information
276
  # lr_decay = False #True# if using the learning rate decay
277
  resume = save_name # if resume from the trained checkpoints
 
393
  # distributed_type="MULTI_GPU",
394
  )
395
  # print("!!!!!!!!!!!!!!!!!!!self.accelerator.device:", self.accelerator.device)
396
+ # if self.accelerator.is_main_process:
397
+ if torch.cuda.current_device() == 0:
398
  if self.config.output_dir is not None:
399
  os.makedirs(self.config.output_dir, exist_ok=True)
400
  if self.config.push_to_hub:
 
426
  pbar_train = tqdm(total=len(self.dataloader), disable=not self.accelerator.is_local_main_process)
427
  pbar_train.set_description(f"device {torch.cuda.current_device()}, Epoch {ep}")
428
  for i, (x, c) in enumerate(self.dataloader):
429
+ # print(f"device {torch.cuda.current_device()}, x[:,0,:2,0,0] =", x[:,0,:2,0,0])
430
  with self.accelerator.accumulate(self.nn_model):
431
  x = x.to(self.config.device)
432
  xt, noise, ts = self.ddpm.add_noise(x)
 
459
  self.accelerator.log(logs, step=global_step)
460
  global_step += 1
461
 
462
+ # if ep == config.n_epoch-1 or (ep+1)*config.save_period==1:
463
  self.save(ep)
464
 
465
  del self.nn_model
 
469
 
470
  def save(self, ep):
471
  # save model
472
+ # if self.accelerator.is_main_process:
473
+ if torch.cuda.current_device() == 0:
474
+ if ep == self.config.n_epoch-1 or (ep+1) % self.config.save_period == 0:
475
  self.nn_model.eval()
476
  with torch.no_grad():
477
  if self.config.push_to_hub:
 
487
  'unet_state_dict': self.nn_model.module.state_dict(),
488
  # 'ema_unet_state_dict': self.ema_model.state_dict(),
489
  }
490
+ save_name = self.config.save_name+f"-N{self.config.num_image}-epoch{ep}-device{torch.cuda.current_device()}"
491
+ torch.save(model_state, save_name)
492
+ print(f'device {torch.cuda.current_device()} saved model at ' + save_name)
493
  # print('saved model at ' + config.save_dir + f"model_epoch_{ep}_test_{config.run_name}.pth")
494
 
495
  # def rescale(self, value, type='params', to_ranges=[0,1]):
 
506
  value = value * (to[1]-to[0]) + to[0]
507
  return value
508
 
509
+ def sample(self, file, params:torch.tensor=None, num_new_img=192, ema=False, entire=False, save=False):
510
  # n_sample = params.shape[0]
511
 
512
  if params is None:
 
516
  params_backup = params.numpy().copy()
517
  params = self.rescale(params, self.ranges_dict['params'], to=[0,1])
518
 
519
+ print(f"sampling {num_new_img} images with normalized params = {params}")
520
+ params = params.repeat(num_new_img,1)
521
  assert params.dim() == 2, "params must be a 2D torch.tensor"
522
  # print("params =", params)
523
  # print("params =", params)
 
526
  # del self.ema_model, self.nn
527
  # params = torch.tile(params, (n_sample,1)).to(device)
528
 
529
+ # nn_model = ContextUnet(n_param=self.config.n_param, image_size=self.config.HII_DIM, dim=self.config.dim, stride=self.config.stride).to(self.config.device)
530
  if ema:
531
+ self.nn_model.load_state_dict(torch.load(file)['ema_unet_state_dict'])
532
  else:
533
+ self.nn_model.load_state_dict(torch.load(file)['unet_state_dict'])
534
  print(f"nn_model resumed from {file}")
535
  # nn_model = ContextUnet(n_param=1, image_size=28)
536
  # nn_model.train()
537
+ # self.nn_model.to(self.ddpm.device)
538
+ self.nn_model.eval()
539
 
540
  # self.ema_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM, dim=config.dim, stride=config.stride).to(config.device)
541
  # self.ema_model.load_state_dict(torch.load(os.path.join(config.output_dir, f"{config.resume}"))['ema_unet_state_dict'])
 
543
 
544
  with torch.no_grad():
545
  x_last, x_entire = self.ddpm.sample(
546
+ nn_model=self.nn_model,
547
  params=params.to(self.config.device),
548
  device=self.config.device,
549
  guide_w=self.config.guide_w
550
  )
551
 
552
+ if save:
553
+ # np.save(os.path.join(self.config.output_dir, f"{self.config.run_name}{'ema' if ema else ''}.npy"), x_last)
554
+ np.save(os.path.join(self.config.output_dir, f"Tvir{params_backup[0]}-zeta{params_backup[1]}-N{self.config.num_image}{'ema' if ema else ''}.npy"), x_last)
555
+ if entire:
556
+ np.save(os.path.join(self.config.output_dir, f"Tvir{params_backup[0]}-zeta{params_backup[1]}-N{self.config.num_image}{'ema' if ema else ''}_entire.npy"), x_last)
557
+ else:
558
+ return x_last
559
  # %%
560
+ def train(rank, world_size):
561
  config = TrainConfig()
562
  config.world_size = world_size
563
 
564
  ddp_setup(rank, world_size)
565
 
566
+ num_image_list = [100]#[200]#[1600,3200,6400,12800,25600]
567
  for i, num_image in enumerate(num_image_list):
568
  config.num_image = num_image
569
  # config.world_size = world_size
 
578
  if __name__ == "__main__":
579
  # torch.multiprocessing.set_start_method("spawn")
580
  # args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)
581
+ world_size = torch.cuda.device_count()
582
 
583
+ mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)
584
  # notebook_launcher(ddpm21cm.train, num_processes=1, mixed_precision='fp16')
585
 
 
 
 
 
 
 
586
  # %%
587
  # print("torch.cuda.is_initialized() =", torch.cuda.is_initialized())
588
  # print("torch.cuda.get_device_name() =", torch.cuda.get_device_name())
 
595
  # print(torch.cuda.memory())
596
  # print('here')
597
  # print(torch.cuda.memory_summary())
 
598
  # %% [markdown]
599
  # # Sampling
600
 
601
  # %%
 
 
 
 
 
 
 
 
 
 
602
 
603
+ def generate_samples(model, num_new_img, max_num_img_per_gpu, rank, world_size):
604
+ samples = []
605
+ for _ in ranges(num_new_img // max_num_img_per_gpu):
606
+ sample = model.module.sample(filename, params=torch.tensor([4.4, 131.341]), num_new_img=max_num_img_per_gpu)
607
+ samples.append(sample)
608
+ # model.sample(filename, params=torch.tensor((5.6, 19.037)), num_new_img=max_num_img_per_gpu)
609
+ # model.sample(filename, params=torch.tensor((4.699, 30)), num_new_img=max_num_img_per_gpu)
610
+ # model.sample(filename, params=torch.tensor((5.477, 200)), num_new_img=max_num_img_per_gpu)
611
+ # model.sample(filename, params=torch.tensor((4.8, 131.341)), num_new_img=max_num_img_per_gpu)
612
+ samples = np.concatenate(samples, axis=0)
613
+
614
+ samples_list = [np.empty_like(samples) for _ in range(world_size)]
615
+ dist.all_gather_object(samples_list, samples)
616
+
617
+ if rank == 0:
618
+ all_samples = np.concatenate(samples_list, axis=0)
619
+ return all_samples
620
+ else:
621
+ return None
622
+
623
+
624
+ def sample(rank, world_size, model, num_new_img, max_num_img_per_gpu, return_dict):
625
+ ddp_setup(rank, world_size)
626
+
627
+ samples = generate_samples(model, num_new_img, max_num_img_per_gpu, rank, world_size)
628
+
629
+ if rank == 0:
630
+ return_dict['samples'] = samples
631
+
632
+ dist.destroy_process_group()
633
+
634
+
635
+ if __name__ == "__main__":
636
+ world_size = torch.cuda.device_count()
637
+ # num_image_list = [1600,3200,6400,12800,25600]
638
+ num_image_list = [1000]
639
+ num_new_img = 12
640
+ max_num_img_per_gpu = 2
641
+
642
+ config = TrainConfig()
643
+ config.world_size = world_size
644
+
645
+ for num_image in num_image_list:
646
+ filename = f"./outputs/model_state-N{num_image}-epoch9-device0"
647
+ config.num_image = num_image
648
+ ddpm21cm = DDPM21CM(config)
649
 
650
+ manager = np.Manager()
651
+ return_dict = manager.dict()
652
 
653
+ mp.spawn(sample, args=(world_size, ddpm21cm, num_new_img, max_num_img_per_gpu, return_dict), nprocs=world_size, join=True)
654
 
655
+ if "samples" in return_dict:
656
+ samples = return_dict["samples"]
657
+ print(f"Generated samples shape: {samples.shape}")
658
 
 
659
 
660
  # %%
661
  # ls -lth outputs | head
load_h5.py CHANGED
@@ -60,6 +60,13 @@ class Dataset4h5(Dataset):
60
  self.images = torch.from_numpy(self.images)
61
  print(f"images rescaled to [{self.images.min()}, {self.images.max()}]")
62
 
 
 
 
 
 
 
 
63
  cond_filter = torch.bernoulli(torch.ones(len(self.params),1)-self.drop_prob).repeat(1,self.params.shape[1]).numpy()
64
  self.params = torch.from_numpy(self.params*cond_filter)
65
  print(f"params rescaled to [{self.params.min()}, {self.params.max()}]")
@@ -98,12 +105,6 @@ class Dataset4h5(Dataset):
98
  self.params = f['params']['values'][self.idx]
99
  print("params loaded:", self.params.shape)
100
 
101
- # print("before self.images.shape =", self.images.shape)
102
- self.images = torch.ones_like(torch.from_numpy(self.images)) * torch.arange(len(self.images))[:,None,None,None,None]
103
- self.images = self.images.numpy()
104
- # print("after self.images.shape =", self.images.shape)
105
- print(self.images[:6,0,:2,0,0])
106
- # self.images = self.images.numpy()
107
 
108
  # plt.imshow(self.images[0,0,0])
109
  # plt.show()
 
60
  self.images = torch.from_numpy(self.images)
61
  print(f"images rescaled to [{self.images.min()}, {self.images.max()}]")
62
 
63
+ # print("before self.images.shape =", self.images.shape)
64
+ # self.images = torch.ones_like(self.images) * torch.arange(len(self.images))[:,None,None,None,None]
65
+ # # self.images = self.images.numpy()
66
+ # print("after self.images.shape =", self.images.shape)
67
+ # print(self.images[:6,0,:2,0,0])
68
+ # # self.images = self.images.numpy()
69
+
70
  cond_filter = torch.bernoulli(torch.ones(len(self.params),1)-self.drop_prob).repeat(1,self.params.shape[1]).numpy()
71
  self.params = torch.from_numpy(self.params*cond_filter)
72
  print(f"params rescaled to [{self.params.min()}, {self.params.max()}]")
 
105
  self.params = f['params']['values'][self.idx]
106
  print("params loaded:", self.params.shape)
107
 
 
 
 
 
 
 
108
 
109
  # plt.imshow(self.images[0,0,0])
110
  # plt.show()