Xsmos commited on
Commit
d130d8c
·
verified ·
1 Parent(s): 7447107
Files changed (3) hide show
  1. context_unet.py +1 -0
  2. diffusion.ipynb +130 -178
  3. diffusion.py +645 -0
context_unet.py CHANGED
@@ -516,6 +516,7 @@ class ContextUnet(nn.Module):
516
 
517
  def forward(self, x, timesteps, y=None):
518
  hs = []
 
519
  emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
520
  if y != None:
521
  text_outputs = self.token_embedding(y.float())
 
516
 
517
  def forward(self, x, timesteps, y=None):
518
  hs = []
519
+ print("device of timesteps, self.model_channels:", timesteps.device, self.model_channels)
520
  emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
521
  if y != None:
522
  text_outputs = self.token_embedding(y.float())
diffusion.ipynb CHANGED
@@ -67,7 +67,12 @@
67
  "from load_h5 import Dataset4h5\n",
68
  "from context_unet import ContextUnet\n",
69
  "\n",
70
- "from huggingface_hub import notebook_login"
 
 
 
 
 
71
  ]
72
  },
73
  {
@@ -75,6 +80,24 @@
75
  "execution_count": 2,
76
  "metadata": {},
77
  "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  "source": [
79
  "# notebook_login()"
80
  ]
@@ -101,7 +124,7 @@
101
  },
102
  {
103
  "cell_type": "code",
104
- "execution_count": 3,
105
  "metadata": {},
106
  "outputs": [],
107
  "source": [
@@ -208,7 +231,7 @@
208
  },
209
  {
210
  "cell_type": "code",
211
- "execution_count": 4,
212
  "metadata": {},
213
  "outputs": [],
214
  "source": [
@@ -239,7 +262,7 @@
239
  },
240
  {
241
  "cell_type": "code",
242
- "execution_count": 5,
243
  "metadata": {},
244
  "outputs": [],
245
  "source": [
@@ -253,6 +276,7 @@
253
  " hub_private_repo = False\n",
254
  " dataset_name = \"/storage/home/hcoda1/3/bxia34/scratch/LEN128-DIM64-CUB8.h5\"\n",
255
  " device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
 
256
  " # repeat = 2\n",
257
  "\n",
258
  " # dim = 2\n",
@@ -316,7 +340,7 @@
316
  },
317
  {
318
  "cell_type": "code",
319
- "execution_count": 6,
320
  "metadata": {},
321
  "outputs": [],
322
  "source": [
@@ -330,7 +354,7 @@
330
  },
331
  {
332
  "cell_type": "code",
333
- "execution_count": 7,
334
  "metadata": {},
335
  "outputs": [],
336
  "source": [
@@ -399,7 +423,7 @@
399
  " ## training loop ##\n",
400
  " ###################\n",
401
  " # plot_unet = True\n",
402
- " \n",
403
  " self.load()\n",
404
  " self.accelerator = Accelerator(\n",
405
  " mixed_precision=self.config.mixed_precision,\n",
@@ -559,195 +583,123 @@
559
  },
560
  {
561
  "cell_type": "code",
562
- "execution_count": 8,
563
  "metadata": {},
564
  "outputs": [
565
- {
566
- "name": "stdout",
567
- "output_type": "stream",
568
- "text": [
569
- "Number of parameters for nn_model: 160234497\n",
570
- "---------------- num_image = 100 -----------------\n",
571
- "run_name = 0709-1355\n",
572
- "dataset content: <KeysViewHDF5 ['brightness_temp', 'density', 'kwargs', 'params', 'redshifts_distances', 'seeds', 'xH_box']>\n",
573
- "51200 images can be loaded\n",
574
- "field.shape = (64, 64, 514)\n",
575
- "params keys = [b'ION_Tvir_MIN', b'HII_EFF_FACTOR']\n",
576
- "loading 100 images randomly\n",
577
- "images loaded: (100, 1, 28, 28, 4)\n"
578
- ]
579
- },
580
  {
581
  "name": "stderr",
582
  "output_type": "stream",
583
  "text": [
584
- "Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n"
 
 
 
 
 
 
585
  ]
586
  },
587
  {
588
- "name": "stdout",
589
- "output_type": "stream",
590
- "text": [
591
- "params loaded: (100, 2)\n",
592
- "images rescaled to [-1.0, 1.1527893543243408]\n",
593
- "params rescaled to [0.01349925332487425, 0.9922290052754472]\n",
594
- "self.accelerator.is_main_process: True\n"
 
 
 
 
595
  ]
596
- },
597
- {
598
- "data": {
599
- "application/vnd.jupyter.widget-view+json": {
600
- "model_id": "00624ec4da1045659dcdab30773ac988",
601
- "version_major": 2,
602
- "version_minor": 0
603
- },
604
- "text/plain": [
605
- " 0%| | 0/50 [00:00<?, ?it/s]"
606
- ]
607
- },
608
- "metadata": {},
609
- "output_type": "display_data"
610
- },
611
- {
612
- "data": {
613
- "application/vnd.jupyter.widget-view+json": {
614
- "model_id": "817bb9f1a9534a6d9aeb353178cbada7",
615
- "version_major": 2,
616
- "version_minor": 0
617
- },
618
- "text/plain": [
619
- " 0%| | 0/50 [00:00<?, ?it/s]"
620
- ]
621
- },
622
- "metadata": {},
623
- "output_type": "display_data"
624
- },
625
- {
626
- "data": {
627
- "application/vnd.jupyter.widget-view+json": {
628
- "model_id": "076537a3f8ae45a888d6c57f71cad35c",
629
- "version_major": 2,
630
- "version_minor": 0
631
- },
632
- "text/plain": [
633
- " 0%| | 0/50 [00:00<?, ?it/s]"
634
- ]
635
- },
636
- "metadata": {},
637
- "output_type": "display_data"
638
- },
639
- {
640
- "data": {
641
- "application/vnd.jupyter.widget-view+json": {
642
- "model_id": "4fdf9b37d0f049538be616bea1c1ee15",
643
- "version_major": 2,
644
- "version_minor": 0
645
- },
646
- "text/plain": [
647
- " 0%| | 0/50 [00:00<?, ?it/s]"
648
- ]
649
- },
650
- "metadata": {},
651
- "output_type": "display_data"
652
- },
653
- {
654
- "data": {
655
- "application/vnd.jupyter.widget-view+json": {
656
- "model_id": "92922681843d42b2a0d56f67eaa32506",
657
- "version_major": 2,
658
- "version_minor": 0
659
- },
660
- "text/plain": [
661
- " 0%| | 0/50 [00:00<?, ?it/s]"
662
- ]
663
- },
664
- "metadata": {},
665
- "output_type": "display_data"
666
- },
667
- {
668
- "data": {
669
- "application/vnd.jupyter.widget-view+json": {
670
- "model_id": "8bf7c44087e244da85573e2146f18c74",
671
- "version_major": 2,
672
- "version_minor": 0
673
- },
674
- "text/plain": [
675
- " 0%| | 0/50 [00:00<?, ?it/s]"
676
- ]
677
- },
678
- "metadata": {},
679
- "output_type": "display_data"
680
- },
681
- {
682
- "data": {
683
- "application/vnd.jupyter.widget-view+json": {
684
- "model_id": "bbeb4c2aa6f742e8bda3176f53f1b0d8",
685
- "version_major": 2,
686
- "version_minor": 0
687
- },
688
- "text/plain": [
689
- " 0%| | 0/50 [00:00<?, ?it/s]"
690
- ]
691
- },
692
- "metadata": {},
693
- "output_type": "display_data"
694
- },
695
- {
696
- "data": {
697
- "application/vnd.jupyter.widget-view+json": {
698
- "model_id": "6d5b5b24015247c48544e995dfa917ec",
699
- "version_major": 2,
700
- "version_minor": 0
701
- },
702
- "text/plain": [
703
- " 0%| | 0/50 [00:00<?, ?it/s]"
704
- ]
705
- },
706
- "metadata": {},
707
- "output_type": "display_data"
708
- },
709
- {
710
- "data": {
711
- "application/vnd.jupyter.widget-view+json": {
712
- "model_id": "d0e3399ca4a54a9bae5e63c11bbd361f",
713
- "version_major": 2,
714
- "version_minor": 0
715
- },
716
- "text/plain": [
717
- " 0%| | 0/50 [00:00<?, ?it/s]"
718
- ]
719
- },
720
- "metadata": {},
721
- "output_type": "display_data"
722
- },
723
- {
724
- "data": {
725
- "application/vnd.jupyter.widget-view+json": {
726
- "model_id": "a9939b9288c84fd99091702efe3a2e8a",
727
- "version_major": 2,
728
- "version_minor": 0
729
- },
730
- "text/plain": [
731
- " 0%| | 0/50 [00:00<?, ?it/s]"
732
- ]
733
- },
734
- "metadata": {},
735
- "output_type": "display_data"
736
  }
737
  ],
738
  "source": [
739
- "num_image_list = [100]#[200]#[1600,3200,6400,12800,25600]\n",
740
- "if __name__ == \"__main__\":\n",
741
- " # torch.multiprocessing.set_start_method(\"spawn\")\n",
742
- " # args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)\n",
743
  " config = TrainConfig()\n",
 
 
 
744
  " for i, num_image in enumerate(num_image_list):\n",
745
  " config.num_image = num_image\n",
 
 
746
  " ddpm21cm = DDPM21CM(config)\n",
747
  " print(f\" num_image = {ddpm21cm.config.num_image} \".center(50, '-'))\n",
748
  " print(f\"run_name = {ddpm21cm.config.run_name}\")\n",
749
  " ddpm21cm.train()\n",
750
- " # notebook_launcher(ddpm21cm.train, num_processes=1, mixed_precision='fp16')"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
751
  ]
752
  },
753
  {
@@ -788,11 +740,11 @@
788
  },
789
  {
790
  "cell_type": "code",
791
- "execution_count": null,
792
  "metadata": {},
793
  "outputs": [],
794
  "source": [
795
- "ls -lth outputs | head"
796
  ]
797
  },
798
  {
 
67
  "from load_h5 import Dataset4h5\n",
68
  "from context_unet import ContextUnet\n",
69
  "\n",
70
+ "from huggingface_hub import notebook_login\n",
71
+ "\n",
72
+ "import torch.multiprocessing as mp\n",
73
+ "from torch.utils.data.distributed import DistributedSampler\n",
74
+ "from torch.nn.parallel import DistributedDataParallel as DDP\n",
75
+ "from torch.distributed import init_process_group, destroy_process_group"
76
  ]
77
  },
78
  {
 
80
  "execution_count": 2,
81
  "metadata": {},
82
  "outputs": [],
83
+ "source": [
84
+ "def ddp_setup(rank: int, world_size: int):\n",
85
+ " \"\"\"\n",
86
+ " Args:\n",
87
+ " rank: Unique identifier of each process\n",
88
+ " world_size: Total number of processes\n",
89
+ " \"\"\"\n",
90
+ " os.environ[\"MASTER_ADDR\"] = \"localhost\"\n",
91
+ " os.environ[\"MASTER_PORT\"] = \"12355\"\n",
92
+ " torch.cuda.set_device(rank)\n",
93
+ " init_process_group(backend=\"nccl\", rank=rank, world_size=world_size)"
94
+ ]
95
+ },
96
+ {
97
+ "cell_type": "code",
98
+ "execution_count": 3,
99
+ "metadata": {},
100
+ "outputs": [],
101
  "source": [
102
  "# notebook_login()"
103
  ]
 
124
  },
125
  {
126
  "cell_type": "code",
127
+ "execution_count": 4,
128
  "metadata": {},
129
  "outputs": [],
130
  "source": [
 
231
  },
232
  {
233
  "cell_type": "code",
234
+ "execution_count": 5,
235
  "metadata": {},
236
  "outputs": [],
237
  "source": [
 
262
  },
263
  {
264
  "cell_type": "code",
265
+ "execution_count": 6,
266
  "metadata": {},
267
  "outputs": [],
268
  "source": [
 
276
  " hub_private_repo = False\n",
277
  " dataset_name = \"/storage/home/hcoda1/3/bxia34/scratch/LEN128-DIM64-CUB8.h5\"\n",
278
  " device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
279
+ " # world_size = torch.cuda.device_count()\n",
280
  " # repeat = 2\n",
281
  "\n",
282
  " # dim = 2\n",
 
340
  },
341
  {
342
  "cell_type": "code",
343
+ "execution_count": 7,
344
  "metadata": {},
345
  "outputs": [],
346
  "source": [
 
354
  },
355
  {
356
  "cell_type": "code",
357
+ "execution_count": 8,
358
  "metadata": {},
359
  "outputs": [],
360
  "source": [
 
423
  " ## training loop ##\n",
424
  " ###################\n",
425
  " # plot_unet = True\n",
426
+ "\n",
427
  " self.load()\n",
428
  " self.accelerator = Accelerator(\n",
429
  " mixed_precision=self.config.mixed_precision,\n",
 
583
  },
584
  {
585
  "cell_type": "code",
586
+ "execution_count": 12,
587
  "metadata": {},
588
  "outputs": [
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
589
  {
590
  "name": "stderr",
591
  "output_type": "stream",
592
  "text": [
593
+ "Traceback (most recent call last):\n",
594
+ " File \"<string>\", line 1, in <module>\n",
595
+ " File \"/storage/home/hcoda1/3/bxia34/.conda/envs/diffusers/lib/python3.9/multiprocessing/spawn.py\", line 116, in spawn_main\n",
596
+ " exitcode = _main(fd, parent_sentinel)\n",
597
+ " File \"/storage/home/hcoda1/3/bxia34/.conda/envs/diffusers/lib/python3.9/multiprocessing/spawn.py\", line 126, in _main\n",
598
+ " self = reduction.pickle.load(from_parent)\n",
599
+ "AttributeError: Can't get attribute 'single_main' on <module '__main__' (built-in)>\n"
600
  ]
601
  },
602
  {
603
+ "ename": "ProcessExitedException",
604
+ "evalue": "process 0 terminated with exit code 1",
605
+ "output_type": "error",
606
+ "traceback": [
607
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
608
+ "\u001b[0;31mProcessExitedException\u001b[0m Traceback (most recent call last)",
609
+ "Cell \u001b[0;32mIn[12], line 21\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39m__name__\u001b[39m \u001b[39m==\u001b[39m \u001b[39m\"\u001b[39m\u001b[39m__main__\u001b[39m\u001b[39m\"\u001b[39m:\n\u001b[1;32m 17\u001b[0m \u001b[39m# torch.multiprocessing.set_start_method(\"spawn\")\u001b[39;00m\n\u001b[1;32m 18\u001b[0m \u001b[39m# args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)\u001b[39;00m\n\u001b[1;32m 19\u001b[0m world_size \u001b[39m=\u001b[39m \u001b[39m1\u001b[39m\u001b[39m#torch.cuda.device_count()\u001b[39;00m\n\u001b[0;32m---> 21\u001b[0m mp\u001b[39m.\u001b[39;49mspawn(single_main, args\u001b[39m=\u001b[39;49m(world_size,), nprocs\u001b[39m=\u001b[39;49mworld_size)\n\u001b[1;32m 22\u001b[0m \u001b[39m# notebook_launcher(ddpm21cm.train, num_processes=1, mixed_precision='fp16')\u001b[39;00m\n",
610
+ "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/multiprocessing/spawn.py:240\u001b[0m, in \u001b[0;36mspawn\u001b[0;34m(fn, args, nprocs, join, daemon, start_method)\u001b[0m\n\u001b[1;32m 236\u001b[0m msg \u001b[39m=\u001b[39m (\u001b[39m'\u001b[39m\u001b[39mThis method only supports start_method=spawn (got: \u001b[39m\u001b[39m%s\u001b[39;00m\u001b[39m).\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39m'\u001b[39m\n\u001b[1;32m 237\u001b[0m \u001b[39m'\u001b[39m\u001b[39mTo use a different start_method use:\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39m\\t\u001b[39;00m\u001b[39m\\t\u001b[39;00m\u001b[39m'\u001b[39m\n\u001b[1;32m 238\u001b[0m \u001b[39m'\u001b[39m\u001b[39m torch.multiprocessing.start_processes(...)\u001b[39m\u001b[39m'\u001b[39m \u001b[39m%\u001b[39m start_method)\n\u001b[1;32m 239\u001b[0m warnings\u001b[39m.\u001b[39mwarn(msg)\n\u001b[0;32m--> 240\u001b[0m \u001b[39mreturn\u001b[39;00m start_processes(fn, args, nprocs, join, daemon, start_method\u001b[39m=\u001b[39;49m\u001b[39m'\u001b[39;49m\u001b[39mspawn\u001b[39;49m\u001b[39m'\u001b[39;49m)\n",
611
+ "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/multiprocessing/spawn.py:198\u001b[0m, in \u001b[0;36mstart_processes\u001b[0;34m(fn, args, nprocs, join, daemon, start_method)\u001b[0m\n\u001b[1;32m 195\u001b[0m \u001b[39mreturn\u001b[39;00m context\n\u001b[1;32m 197\u001b[0m \u001b[39m# Loop on join until it returns True or raises an exception.\u001b[39;00m\n\u001b[0;32m--> 198\u001b[0m \u001b[39mwhile\u001b[39;00m \u001b[39mnot\u001b[39;00m context\u001b[39m.\u001b[39;49mjoin():\n\u001b[1;32m 199\u001b[0m \u001b[39mpass\u001b[39;00m\n",
612
+ "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/multiprocessing/spawn.py:149\u001b[0m, in \u001b[0;36mProcessContext.join\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 140\u001b[0m \u001b[39mraise\u001b[39;00m ProcessExitedException(\n\u001b[1;32m 141\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mprocess \u001b[39m\u001b[39m%d\u001b[39;00m\u001b[39m terminated with signal \u001b[39m\u001b[39m%s\u001b[39;00m\u001b[39m\"\u001b[39m \u001b[39m%\u001b[39m\n\u001b[1;32m 142\u001b[0m (error_index, name),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 146\u001b[0m signal_name\u001b[39m=\u001b[39mname\n\u001b[1;32m 147\u001b[0m )\n\u001b[1;32m 148\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m--> 149\u001b[0m \u001b[39mraise\u001b[39;00m ProcessExitedException(\n\u001b[1;32m 150\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mprocess \u001b[39m\u001b[39m%d\u001b[39;00m\u001b[39m terminated with exit code \u001b[39m\u001b[39m%d\u001b[39;00m\u001b[39m\"\u001b[39m \u001b[39m%\u001b[39m\n\u001b[1;32m 151\u001b[0m (error_index, exitcode),\n\u001b[1;32m 152\u001b[0m error_index\u001b[39m=\u001b[39merror_index,\n\u001b[1;32m 153\u001b[0m error_pid\u001b[39m=\u001b[39mfailed_process\u001b[39m.\u001b[39mpid,\n\u001b[1;32m 154\u001b[0m exit_code\u001b[39m=\u001b[39mexitcode\n\u001b[1;32m 155\u001b[0m )\n\u001b[1;32m 157\u001b[0m original_trace \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39merror_queues[error_index]\u001b[39m.\u001b[39mget()\n\u001b[1;32m 158\u001b[0m msg \u001b[39m=\u001b[39m \u001b[39m\"\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39m\\n\u001b[39;00m\u001b[39m-- Process \u001b[39m\u001b[39m%d\u001b[39;00m\u001b[39m terminated with the following error:\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39m\"\u001b[39m \u001b[39m%\u001b[39m error_index\n",
613
+ "\u001b[0;31mProcessExitedException\u001b[0m: process 0 terminated with exit code 1"
614
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
615
  }
616
  ],
617
  "source": [
618
+ "def single_main(rank, world_size):\n",
 
 
 
619
  " config = TrainConfig()\n",
620
+ " ddp_setup(rank, world_size)\n",
621
+ " \n",
622
+ " num_image_list = [100]#[200]#[1600,3200,6400,12800,25600]\n",
623
  " for i, num_image in enumerate(num_image_list):\n",
624
  " config.num_image = num_image\n",
625
+ " # config.world_size = world_size\n",
626
+ " \n",
627
  " ddpm21cm = DDPM21CM(config)\n",
628
  " print(f\" num_image = {ddpm21cm.config.num_image} \".center(50, '-'))\n",
629
  " print(f\"run_name = {ddpm21cm.config.run_name}\")\n",
630
  " ddpm21cm.train()\n",
631
+ "\n",
632
+ " \n",
633
+ "if __name__ == \"__main__\":\n",
634
+ " # torch.multiprocessing.set_start_method(\"spawn\")\n",
635
+ " # args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)\n",
636
+ " world_size = 1#torch.cuda.device_count()\n",
637
+ "\n",
638
+ " mp.spawn(single_main, args=(world_size,), nprocs=world_size)\n",
639
+ " # notebook_launcher(ddpm21cm.train, num_processes=1, mixed_precision='fp16')"
640
+ ]
641
+ },
642
+ {
643
+ "cell_type": "code",
644
+ "execution_count": null,
645
+ "metadata": {},
646
+ "outputs": [],
647
+ "source": [
648
+ "# torch.cuda.set_device(0)"
649
+ ]
650
+ },
651
+ {
652
+ "cell_type": "code",
653
+ "execution_count": null,
654
+ "metadata": {},
655
+ "outputs": [
656
+ {
657
+ "name": "stdout",
658
+ "output_type": "stream",
659
+ "text": [
660
+ "True\n",
661
+ "2\n",
662
+ "['__name__', '__doc__', '__package__', '__loader__', '__spec__', '__path__', '__file__', '__cached__', '__builtins__', '__annotations__', 'contextlib', 'os', 'torch', 'Device', 'traceback', 'warnings', 'threading', 'List', 'Optional', 'Tuple', 'Union', 'Any', '_utils', '_get_device_index', '_dummy_type', 'classproperty', 'graphs', 'CUDAGraph', 'graph_pool_handle', 'graph', 'make_graphed_callables', 'is_current_stream_capturing', 'streams', 'ExternalStream', 'Stream', 'Event', '_device', '_cudart', '_initialized', '_tls', '_initialization_lock', '_queued_calls', '_is_in_bad_fork', '_device_t', '_LazySeedTracker', '_lazy_seed_tracker', '_CudaDeviceProperties', 'has_magma', 'has_half', 'default_generators', 'is_available', 'is_bf16_supported', '_sleep', '_check_capability', '_check_cubins', 'is_initialized', '_lazy_call', 'DeferredCudaCallError', 'init', '_lazy_init', 'cudart', 'cudaStatus', 'CudaError', 'check_error', 'device', 'device_of', 'set_device', 'get_device_name', 'get_device_capability', 'get_device_properties', 'can_device_access_peer', 'StreamContext', 'stream', 'set_stream', 'device_count', 'get_arch_list', 'get_gencode_flags', 'current_device', 'synchronize', 'ipc_collect', 'current_stream', 'default_stream', 'current_blas_handle', 'set_sync_debug_mode', 'get_sync_debug_mode', 'memory_usage', 'utilization', 'memory', 'caching_allocator_alloc', 'caching_allocator_delete', 'set_per_process_memory_fraction', 'empty_cache', 'memory_stats', 'memory_stats_as_nested_dict', 'reset_accumulated_memory_stats', 'reset_peak_memory_stats', 'reset_max_memory_allocated', 'reset_max_memory_cached', 'memory_allocated', 'max_memory_allocated', 'memory_reserved', 'max_memory_reserved', 'memory_cached', 'max_memory_cached', 'memory_snapshot', 'memory_summary', 'list_gpu_processes', 'mem_get_info', 'random', 'get_rng_state', 'get_rng_state_all', 'set_rng_state', 'set_rng_state_all', 'manual_seed', 'manual_seed_all', 'seed', 'seed_all', 'initial_seed', '_lazy_new', '_CudaBase', 'ByteStorage', 'DoubleStorage', 'FloatStorage', 'HalfStorage', 'LongStorage', 'IntStorage', 'ShortStorage', 'CharStorage', 'BoolStorage', 'BFloat16Storage', 'ComplexDoubleStorage', 'ComplexFloatStorage', 'sparse', 'profiler', 'nvtx', 'amp', 'jiterator', 'ByteTensor', 'CharTensor', 'DoubleTensor', 'FloatTensor', 'IntTensor', 'LongTensor', 'ShortTensor', 'HalfTensor', 'BoolTensor', 'BFloat16Tensor', 'nccl', '_get_device_properties']\n"
663
+ ]
664
+ }
665
+ ],
666
+ "source": [
667
+ "print(torch.cuda.is_available())\n",
668
+ "print(torch.cuda.device_count())\n",
669
+ "print(torch.cuda.__dir__())"
670
+ ]
671
+ },
672
+ {
673
+ "cell_type": "code",
674
+ "execution_count": 17,
675
+ "metadata": {},
676
+ "outputs": [
677
+ {
678
+ "name": "stdout",
679
+ "output_type": "stream",
680
+ "text": [
681
+ "True\n",
682
+ "<class 'torch.cuda.device'>\n",
683
+ "Quadro RTX 6000\n",
684
+ "0\n",
685
+ "(7, 5)\n",
686
+ "_CudaDeviceProperties(name='Quadro RTX 6000', major=7, minor=5, total_memory=24212MB, multi_processor_count=72)\n"
687
+ ]
688
+ }
689
+ ],
690
+ "source": [
691
+ "print(torch.cuda.is_initialized())\n",
692
+ "print(torch.cuda.device)\n",
693
+ "print(torch.cuda.get_device_name())\n",
694
+ "print(torch.cuda.current_device())\n",
695
+ "print(torch.cuda.get_device_capability())\n",
696
+ "print(torch.cuda.get_device_properties(torch.cuda.device))\n",
697
+ "# print('here')\n",
698
+ "# print(torch.cuda.memory_usage())\n",
699
+ "# print(torch.cuda.utilization())\n",
700
+ "# print(torch.cuda.memory())\n",
701
+ "# print('here')\n",
702
+ "# print(torch.cuda.memory_summary())"
703
  ]
704
  },
705
  {
 
740
  },
741
  {
742
  "cell_type": "code",
743
+ "execution_count": 13,
744
  "metadata": {},
745
  "outputs": [],
746
  "source": [
747
+ "# ls -lth outputs | head"
748
  ]
749
  },
750
  {
diffusion.py ADDED
@@ -0,0 +1,645 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %% [markdown]
2
+ # ## 改編ContextUnet及相關代碼,使其首先對二維的情況適用。並於diffusers.Unet2DModel作比較並加以優化。最後再改寫爲3維的情形。
3
+ # - 經試用diffusers的Unet2DModel,發現loss從0.3降到0.2但仍然很高,説明存在非Unet2DModel的問題可以優化
4
+ # - 改用diffusers的DDMPScheduler和DDPMPipeline后,loss降低至0.1以下,有時甚至可以低至0.004,可見我的代碼問題主要出在DDPM部分。DDPMScheduler部分比較簡短,似乎沒有問題,所以問題應該在DDPMPipeline裏某一部分代碼是我代碼欠缺的。
5
+ # - 我在DDPMScheduler部分有一個typo,導致beta_t一直很小,修正后loss從0.2能降低至0.02, 維持在0.1以下
6
+ # - 用diffusers的DDPMScheduler似乎效果要好一些,loss總是比我的DDPMScheduler要小一點。儅epoch為19時,前者的loss約0.02,後者loss約0.07。而且前者還支持3維圖像的加噪,不如直接用別人的輪子。但我想知道爲什麽我的loss會高一些。
7
+ # - 我意識到別人的DDPMScheduler在sample函數中沒有兼容輸入參數,所以歸根結底還是需要我的DDPMscheduler。不過我可以先用別人的來debug我的ContextUnet.
8
+ # - 我需要將我的ContextUnet擴展兼容不同維度的照片,畢竟我本身也需要和原文獻對比完了再拓展到三維的情形
9
+ # - 我已將我的ContextUnet轉成了2維的模式,與diffusers.Unet2DModel的loss=0.037相比,我的Unet的loss=0.07。同時我的Unet生成的圖像看上去很奇怪,説明我的Unet也有問題。我需要將代碼退回原Unet,並檢查問題所在。
10
+ # - 我將紅移方向的像素的數量限制在了64.以此比較兩個Unet的差別。經比較:\
11
+ # Unet2DModel loss:0.03, 0.0655, 0.05, 0.02, 0.05\
12
+ # ContextUnet loss: 0.1, 0.16, 0.1, 0.2186, 0.06
13
+ # - 我把ContextUnet退回到了原作者的版本,結果loss=0.05,輸出的照片也不錯。我主要的改動是改回了他原用的normalization函數,其中還有個參數swish。有時間我可以研究一下具體是哪裏影響了訓練的結果。另外我發現了要想tensorboard的圖綫獨立美觀,需要把他們放在不同的文件夾下
14
+ # - 經過驗證,GroupNorm比batchNorm效果要好
15
+ # - 已擴展爲接受不同維度的情形
16
+ # - 融合cond, guide_w, drop_out這些參數
17
+ # - 生成的21cm圖像該暗的地方不夠暗,似乎換成MNIST的數字圖像就沒問題
18
+ # - 我用diffusion模型生成MNIST的數字時發現,儘管生成的數據的範圍也存在負數數值,如-0.1,但畫出來的圖像卻是理想的黑色。數據的分佈與21cm的結果的分佈沒多大差別,我現在打算把代碼退回到21cm的情形
19
+ # - 我統一了ddpm21cm這個module,能統一實現訓練和生成樣本,但目前有個bug, sample時總是會cuda out of memory,然而單獨resume model並sample就不會。
20
+ # - 解決了,問題出在我忘了寫with torch.no_grad():
21
+ # - 接下來就是生成800個lightcones,與此同時研究如何計算global signal以及power spectrum
22
+ # - 儅訓練圖片的數量達到5000時,生成的圖片與檢測數據的相似程度很高
23
+ # - it takes 62 mins to generated 8 images with shape of (64,64,64), which is even slower than simulation, which takes ~5 mins for each image. Besides, the batch_size during training and num of images to be generated are limited to be 2 and 8, respectively.
24
+ # - the slowerness can be solved by using multi-GPUs, and the limited-num-of-images can be solved by multi-accuracy, multi-GPUs.
25
+ # - In addtion, the performance of DDPM can looks better compared to computation-intensive simulations.
26
+
27
+ # %%
28
+ from dataclasses import dataclass
29
+ import h5py
30
+ import torch
31
+ import torch.nn as nn
32
+ from torch.utils.data import DataLoader, Dataset
33
+ # from datasets import Dataset
34
+ import matplotlib.pyplot as plt
35
+ import numpy as np
36
+ import random
37
+ # from abc import ABC, abstractmethod
38
+ import torch.nn.functional as F
39
+ import math
40
+ # from PIL import Image
41
+ import os
42
+ from torch.utils.tensorboard import SummaryWriter
43
+ import copy
44
+ from tqdm.auto import tqdm
45
+ # from torchvision import transforms
46
+ # from diffusers import UNet2DModel#, UNet3DConditionModel
47
+ # from diffusers import DDPMScheduler
48
+ from diffusers.utils import make_image_grid
49
+ import datetime
50
+ from pathlib import Path
51
+ from diffusers.optimization import get_cosine_schedule_with_warmup
52
+ from accelerate import notebook_launcher, Accelerator
53
+ from huggingface_hub import create_repo, upload_folder
54
+
55
+ from load_h5 import Dataset4h5
56
+ from context_unet import ContextUnet
57
+
58
+ from huggingface_hub import notebook_login
59
+
60
+ import torch.multiprocessing as mp
61
+ from torch.utils.data.distributed import DistributedSampler
62
+ from torch.nn.parallel import DistributedDataParallel as DDP
63
+ from torch.distributed import init_process_group, destroy_process_group
64
+
65
+ # %%
66
+ def ddp_setup(rank: int, world_size: int):
67
+ """
68
+ Args:
69
+ rank: Unique identifier of each process
70
+ world_size: Total number of processes
71
+ """
72
+ os.environ["MASTER_ADDR"] = "localhost"
73
+ os.environ["MASTER_PORT"] = "12355"
74
+ torch.cuda.set_device(rank)
75
+ init_process_group(backend="nccl", rank=rank, world_size=world_size)
76
+
77
+ # %%
78
+ # notebook_login()
79
+
80
+ # %% [markdown]
81
+ # # Add noise:
82
+ #
83
+ # \begin{align*}
84
+ # x_t &\sim \mathcal N\left(\sqrt{1-\beta_t}\ x_{t-1},\ \beta_t \right) \\
85
+ # x_t &\equiv \sqrt{1-\beta_t}\ x_{t-1} + \sqrt{\beta_t}\ \epsilon\\
86
+ # \epsilon &\sim \mathcal N(0,1)\\
87
+ # \alpha_t & \equiv 1 - \beta_t\\
88
+ # & ...\\
89
+ # x_t &= \sqrt{\bar {\alpha_t}} x_0 + \epsilon\ \sqrt{1 - \bar{\alpha_t}}\\
90
+ # \bar {\alpha_t} &\equiv \prod_{i=1}^t \alpha_i\\
91
+ # &= \exp\left({\ln{\prod_{i=1}^t \alpha_i}}\right)\\
92
+ # &= \exp\left({\sum_{i=1}^t\ln{ \alpha_i}}\right)
93
+ # \end{align*}
94
+
95
+ # %%
96
+ class DDPMScheduler(nn.Module):
97
+ def __init__(self, betas: tuple, num_timesteps: int, img_shape: list, device='cpu'):
98
+ super().__init__()
99
+
100
+ beta_1, beta_T = betas
101
+ assert 0 < beta_1 <= beta_T <= 1, "ensure 0 < beta_1 <= beta_T <= 1"
102
+ self.device = device
103
+ self.num_timesteps = num_timesteps
104
+ self.img_shape = img_shape
105
+ self.beta_t = torch.linspace(beta_1, beta_T, self.num_timesteps) #* (beta_T-beta_1) + beta_1
106
+ self.beta_t = self.beta_t.to(self.device)
107
+
108
+ # self.drop_prob = drop_prob
109
+ # self.cond = cond
110
+ self.alpha_t = 1 - self.beta_t
111
+ # self.bar_alpha_t = torch.exp(torch.cumsum(torch.log(self.alpha_t), dim=0))
112
+ self.bar_alpha_t = torch.cumprod(self.alpha_t, dim=0)
113
+
114
+ def add_noise(self, clean_images):
115
+ shape = clean_images.shape
116
+ expand = torch.ones(len(shape)-1, dtype=int)
117
+ # ts_expand = ts.view(ts.shape[0], *expand.tolist())
118
+ # expand = [1 for i in range(len(shape)-1)]
119
+
120
+ noise = torch.randn_like(clean_images).to(self.device)
121
+ ts = torch.randint(0, self.num_timesteps, (shape[0],)).to(self.device)
122
+
123
+ # test_expand = test.view(test.shape[0],*expand)
124
+ # extend_dim = [None for i in range(shape.dim()-1)]
125
+ noisy_images = (
126
+ clean_images * torch.sqrt(self.bar_alpha_t[ts]).view(shape[0], *expand.tolist())
127
+ + noise * torch.sqrt(1-self.bar_alpha_t[ts]).view(shape[0], *expand.tolist())
128
+ )
129
+ # print(x_t.shape)
130
+
131
+ return noisy_images, noise, ts
132
+
133
+ def sample(self, nn_model, params, device, guide_w = 0):
134
+ n_sample = len(params) #params.shape[0]
135
+ # print("params.shape[0], len(params)", params.shape[0], len(params))
136
+ x_i = torch.randn(n_sample, *self.img_shape).to(device)
137
+ # print("x_i.shape =", x_i.shape)
138
+ # print("x_i.shape =", x_i.shape)
139
+ if guide_w != -1:
140
+ c_i = params
141
+ uncond_tokens = torch.zeros(int(n_sample), params.shape[1]).to(device)
142
+ # uncond_tokens = torch.tensor(np.float32(np.array([0,0]))).to(device)
143
+ # uncond_tokens = uncond_tokens.repeat(int(n_sample),1)
144
+ c_i = torch.cat((c_i, uncond_tokens), 0)
145
+
146
+ x_i_entire = [] # keep track of generated steps in case want to plot something
147
+ # print("self.num_timesteps =", self.num_timesteps)
148
+ # for i in range(self.num_timesteps, 0, -1):
149
+ # print(f'sampling!!!')
150
+ pbar_sample = tqdm(total=self.num_timesteps)
151
+ pbar_sample.set_description("Sampling")
152
+ for i in reversed(range(0, self.num_timesteps)):
153
+ # print(f'sampling timestep {i:4d}',end='\r')
154
+ t_is = torch.tensor([i]).to(device)
155
+ t_is = t_is.repeat(n_sample)
156
+
157
+ z = torch.randn(n_sample, *self.img_shape).to(device) if i > 0 else 0
158
+
159
+ if guide_w == -1:
160
+ # eps = nn_model(x_i, t_is, return_dict=False)[0]
161
+ eps = nn_model(x_i, t_is)
162
+ # x_i = 1/torch.sqrt(self.alpha_t[i])*(x_i-eps*self.beta_t[i]/torch.sqrt(1-self.bar_alpha_t[i])) + torch.sqrt(self.beta_t[i])*z
163
+ else:
164
+ # double batch
165
+ x_i = x_i.repeat(2, *torch.ones(len(self.img_shape), dtype=int).tolist())
166
+ t_is = t_is.repeat(2)
167
+
168
+ # split predictions and compute weighting
169
+ # print("nn_model input shape", x_i.shape, t_is.shape, c_i.shape)
170
+ eps = nn_model(x_i, t_is, c_i)
171
+ eps1 = eps[:n_sample]
172
+ eps2 = eps[n_sample:]
173
+ eps = eps1 + guide_w*(eps1 - eps2)
174
+ # eps = (1+guide_w)*eps1 - guide_w*eps2
175
+ x_i = x_i[:n_sample]
176
+ # x_i = 1/torch.sqrt(self.alpha_t[i])*(x_i-eps*self.beta_t[i]/torch.sqrt(1-self.bar_alpha_t[i])) + torch.sqrt(self.beta_t[i])*z
177
+
178
+ # print("x_i.shape =", x_i.shape)
179
+ x_i = 1/torch.sqrt(self.alpha_t[i])*(x_i-eps*self.beta_t[i]/torch.sqrt(1-self.bar_alpha_t[i])) + torch.sqrt(self.beta_t[i])*z
180
+
181
+ pbar_sample.update(1)
182
+ # pbar_sample.set_postfix(step=i)
183
+
184
+ # print("x_i.shape =", x_i.shape)
185
+ # store only part of the intermediate steps
186
+ if i%20==0:# or i==0:# or i<8:
187
+ x_i_entire.append(x_i.detach().cpu().numpy())
188
+ x_i = x_i.detach().cpu().numpy()
189
+ x_i_entire = np.array(x_i_entire)
190
+ return x_i, x_i_entire
191
+
192
+
193
+ # ddpm_scheduler = DDPMScheduler((1e-4,0.02),10)
194
+ # noisy_images, noise, ts = ddpm_scheduler.add_noise(images)
195
+
196
+ # %%
197
+ class EMA:
198
+ def __init__(self, beta):
199
+ super().__init__()
200
+ self.beta = beta
201
+ self.step = 0
202
+
203
+ def update_model_average(self, ma_model, current_model):
204
+ for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
205
+ old_weight, up_weight = ma_params.data, current_params.data
206
+ ma_params.data = self.update_average(old_weight, up_weight)
207
+
208
+ def update_average(self, old, new):
209
+ if old is None:
210
+ return new
211
+ return old * self.beta + (1 - self.beta) * new
212
+
213
+ def step_ema(self, ema_model, model):
214
+ self.update_model_average(ema_model, model)
215
+ self.step += 1
216
+
217
+ def reset_parameters(self, ema_model, model):
218
+ ema_model.load_state_dict(model.state_dict())
219
+
220
+
221
+ # %%
222
+ @dataclass
223
+ class TrainConfig:
224
+ ###########################
225
+ ## hardcoding these here ##
226
+ ###########################
227
+ push_to_hub = True
228
+ hub_model_id = "Xsmos/ml21cm"
229
+ hub_private_repo = False
230
+ dataset_name = "/storage/home/hcoda1/3/bxia34/scratch/LEN128-DIM64-CUB8.h5"
231
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
232
+ # world_size = torch.cuda.device_count()
233
+ # repeat = 2
234
+
235
+ # dim = 2
236
+ dim = 3
237
+ stride = (2,2) if dim == 2 else (2,2,1)
238
+ num_image = 2000#32000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560
239
+ batch_size = 2#2#50#20#2#100 # 10
240
+ n_epoch = 10#50#20#20#2#5#25 # 120
241
+ HII_DIM = 28#64
242
+ num_redshift = 4#128#64#512#256#256#64#512#128
243
+ channel = 1
244
+ img_shape = (channel, HII_DIM, num_redshift) if dim == 2 else (channel, HII_DIM, HII_DIM, num_redshift)
245
+
246
+ ranges_dict = dict(
247
+ params = {
248
+ 0: [4, 6], # ION_Tvir_MIN
249
+ 1: [10, 250], # HII_EFF_FACTOR
250
+ },
251
+ images = {
252
+ 0: [0, 80], # brightness_temp
253
+ }
254
+ )
255
+
256
+ num_timesteps = 1000#1000 # 1000, 500; DDPM time steps
257
+ # n_sample = 24 # 64, the number of samples in sampling process
258
+ n_param = 2
259
+ guide_w = 0#-1#0#-1#0#-1#0.1#[0,0.1] #[0,0.5,2] strength of generative guidance
260
+ drop_prob = 0#0.28 # only takes effect when guide_w != -1
261
+ ema=True # whether to use ema
262
+ ema_rate=0.995
263
+
264
+ # seed = 0
265
+ # save_dir = './outputs/'
266
+
267
+ save_freq = 0#.1 # the period of sampling
268
+ # general parameters for the name and logger
269
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
270
+ lrate = 1e-4
271
+ lr_warmup_steps = 0#5#00
272
+ output_dir = "./outputs/"
273
+ save_name = os.path.join(output_dir, 'model_state')
274
+ # save_freq = 1 #10 # the period of saving model
275
+ # cond = True # if training using the conditional information
276
+ # lr_decay = False #True# if using the learning rate decay
277
+ resume = save_name # if resume from the trained checkpoints
278
+ # params_single = torch.tensor([0.2,0.80000023])
279
+ # params = torch.tile(params_single,(n_sample,1)).to(device)
280
+ # params = params
281
+ # data_dir = './data' # data directory
282
+
283
+
284
+ mixed_precision = "fp16"
285
+ gradient_accumulation_steps = 1
286
+
287
+ # date = datetime.datetime.now().strftime("%m%d-%H%M")
288
+ # run_name = f'{date}' # the unique name of each experiment
289
+
290
+ # config = TrainConfig()
291
+ # print("device =", config.device)
292
+
293
+ # %%
294
+ # import os
295
+ # print(os.cpu_count())
296
+ # print(len(os.sched_getaffinity(0)))
297
+ # import torch
298
+ # data = torch.randn((64,64))
299
+ # print(data.dtype)
300
+
301
+ # %%
302
+ # @dataclass
303
+ class DDPM21CM:
304
+ def __init__(self, config):
305
+ # config = TrainConfig()
306
+ # date = datetime.datetime.now().strftime("%m%d-%H%M")
307
+ config.run_name = datetime.datetime.now().strftime("%m%d-%H%M") # the unique name of each experiment
308
+ self.config = config
309
+ # dataset = Dataset4h5(config.dataset_name, num_image=config.num_image, HII_DIM=config.HII_DIM, num_redshift=config.num_redshift, drop_prob=config.drop_prob, dim=config.dim)
310
+ # # self.shape_loaded = dataset.images.shape
311
+ # # print("shape_loaded =", self.shape_loaded)
312
+ # self.dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True)
313
+ # del dataset
314
+ self.ddpm = DDPMScheduler(betas=(1e-4, 0.02), num_timesteps=config.num_timesteps, img_shape=config.img_shape, device=config.device)
315
+
316
+ # initialize the unet
317
+ self.nn_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM, dim=config.dim, stride=config.stride)
318
+
319
+ if config.resume and os.path.exists(config.resume):
320
+ # resume_file = os.path.join(config.output_dir, f"{config.resume}")
321
+ self.nn_model.load_state_dict(torch.load(config.resume)['unet_state_dict'])
322
+ print(f"resumed nn_model from {config.resume}")
323
+ # nn_model = ContextUnet(n_param=1, image_size=28)
324
+ self.nn_model.train()
325
+ self.nn_model.to(self.ddpm.device)
326
+ # print("nn_model.device =", ddpm.device)
327
+ # number of parameters to be trained
328
+ self.number_of_params = sum(x.numel() for x in self.nn_model.parameters())
329
+ print(f"Number of parameters for nn_model: {self.number_of_params}")
330
+
331
+ # whether to use ema
332
+ if config.ema:
333
+ self.ema = EMA(config.ema_rate)
334
+ if config.resume and os.path.exists(config.resume):
335
+ self.ema_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM, dim=config.dim, stride=config.stride).to(config.device)
336
+ self.ema_model.load_state_dict(torch.load(config.resume)['ema_unet_state_dict'])
337
+ print(f"resumed ema_model from {config.resume}")
338
+ else:
339
+ self.ema_model = copy.deepcopy(self.nn_model).eval().requires_grad_(False)
340
+
341
+ self.optimizer = torch.optim.AdamW(self.nn_model.parameters(), lr=config.lrate)
342
+ self.lr_scheduler = get_cosine_schedule_with_warmup(
343
+ optimizer=self.optimizer,
344
+ num_warmup_steps=config.lr_warmup_steps,
345
+ num_training_steps=(int(config.num_image/config.batch_size) * config.n_epoch),
346
+ # num_training_steps=(len(self.dataloader) * config.n_epoch),
347
+ )
348
+
349
+ self.ranges_dict = config.ranges_dict
350
+
351
+ def load(self):
352
+ dataset = Dataset4h5(self.config.dataset_name, num_image=self.config.num_image, HII_DIM=self.config.HII_DIM, num_redshift=self.config.num_redshift, drop_prob=self.config.drop_prob, dim=self.config.dim, ranges_dict=self.ranges_dict)
353
+ # self.shape_loaded = dataset.images.shape
354
+ # print("shape_loaded =", self.shape_loaded)
355
+ self.dataloader = DataLoader(dataset, batch_size=self.config.batch_size, shuffle=True, num_workers=len(os.sched_getaffinity(0)), pin_memory=True)
356
+ # del dataset
357
+ # self.accelerate(self.config)
358
+ del dataset
359
+
360
+ # def accelerate(self):
361
+
362
+ def train(self):
363
+ ###################
364
+ ## training loop ##
365
+ ###################
366
+ # plot_unet = True
367
+
368
+ self.load()
369
+ self.accelerator = Accelerator(
370
+ mixed_precision=self.config.mixed_precision,
371
+ gradient_accumulation_steps=self.config.gradient_accumulation_steps,
372
+ log_with="tensorboard",
373
+ project_dir=os.path.join(self.config.output_dir, "logs"),
374
+ )
375
+ print("self.accelerator.is_main_process:", self.accelerator.is_main_process)
376
+ if self.accelerator.is_main_process:
377
+ if self.config.output_dir is not None:
378
+ os.makedirs(self.config.output_dir, exist_ok=True)
379
+ if self.config.push_to_hub:
380
+ self.repo_id = create_repo(
381
+ repo_id=self.config.hub_model_id or Path(self.config.output_dir).name, exist_ok=True
382
+ ).repo_id
383
+ self.accelerator.init_trackers(f"{self.config.run_name}")
384
+
385
+ self.nn_model, self.optimizer, self.dataloader, self.lr_scheduler = \
386
+ self.accelerator.prepare(
387
+ self.nn_model, self.optimizer, self.dataloader, self.lr_scheduler
388
+ )
389
+
390
+ global_step = 0
391
+ for ep in range(self.config.n_epoch):
392
+ self.ddpm.train()
393
+
394
+ pbar_train = tqdm(total=len(self.dataloader), disable=not self.accelerator.is_local_main_process)
395
+ pbar_train.set_description(f"Epoch {ep}")
396
+ for i, (x, c) in enumerate(self.dataloader):
397
+ with self.accelerator.accumulate(self.nn_model):
398
+ x = x.to(self.config.device)
399
+ xt, noise, ts = self.ddpm.add_noise(x)
400
+
401
+ if self.config.guide_w == -1:
402
+ noise_pred = self.nn_model(xt, ts)
403
+ else:
404
+ c = c.to(self.config.device)
405
+ noise_pred = self.nn_model(xt, ts, c)
406
+
407
+ loss = F.mse_loss(noise, noise_pred)
408
+ self.accelerator.backward(loss)
409
+ self.accelerator.clip_grad_norm_(self.nn_model.parameters(), 1)
410
+ self.optimizer.step()
411
+ self.lr_scheduler.step()
412
+ self.optimizer.zero_grad()
413
+
414
+ # ema update
415
+ if self.config.ema:
416
+ self.ema.step_ema(self.ema_model, self.nn_model)
417
+
418
+ pbar_train.update(1)
419
+ logs = dict(
420
+ loss=loss.detach().item(),
421
+ lr=self.optimizer.param_groups[0]['lr'],
422
+ step=global_step
423
+ )
424
+ pbar_train.set_postfix(**logs)
425
+
426
+ self.accelerator.log(logs, step=global_step)
427
+ global_step += 1
428
+
429
+ # if ep == config.n_epoch-1 or (ep+1)*config.save_freq==1:
430
+ self.save(ep)
431
+
432
+ del self.nn_model
433
+ if self.config.ema:
434
+ del self.ema_model
435
+ torch.cuda.empty_cache()
436
+
437
+ def save(self, ep):
438
+ # save model
439
+ if self.accelerator.is_main_process:
440
+ if ep == self.config.n_epoch-1 or (ep+1)*self.config.save_freq==1:
441
+ self.nn_model.eval()
442
+ with torch.no_grad():
443
+ if self.config.push_to_hub:
444
+ upload_folder(
445
+ repo_id = self.repo_id,
446
+ folder_path = ".",#config.output_dir,
447
+ commit_message = f"{self.config.run_name}",
448
+ ignore_patterns = ["step_*", "epoch_*", "*.npy", "__pycache__"],
449
+ )
450
+ if self.config.save_name:
451
+ model_state = {
452
+ 'epoch': ep,
453
+ 'unet_state_dict': self.nn_model.state_dict(),
454
+ 'ema_unet_state_dict': self.ema_model.state_dict(),
455
+ }
456
+ torch.save(model_state, self.config.save_name+f"-N{self.config.num_image}")
457
+ print('saved model at ' + self.config.save_name+f"-N{self.config.num_image}")
458
+ # print('saved model at ' + config.save_dir + f"model_epoch_{ep}_test_{config.run_name}.pth")
459
+
460
+ # def rescale(self, value, type='params', to_ranges=[0,1]):
461
+ # for i, from_ranges in self.ranges_dict[type].items():
462
+ # value[i] = (value[i] - from_ranges[0])/(from_ranges[1]-from_ranges[0]) # normalize
463
+ # value[i] =
464
+ def rescale(self, value, ranges, to: list):
465
+ if value.ndim == 1:
466
+ value = value.view(-1,len(value))
467
+
468
+ for i in range(np.shape(value)[1]):
469
+ value[:,i] = (value[:,i] - ranges[i][0]) / (ranges[i][1]-ranges[i][0])
470
+ # print(f"i = {i}, value.min = {value[:,i].min()}, value.max = {value[:,i].max()}")
471
+ value = value * (to[1]-to[0]) + to[0]
472
+ return value
473
+
474
+ def sample(self, file, params:torch.tensor=None, repeat=192, ema=False, entire=False):
475
+ # n_sample = params.shape[0]
476
+
477
+ if params is None:
478
+ params = torch.tensor([0.20000000000000018, 0.5055875000000001])
479
+ params_backup = params.numpy().copy()
480
+ else:
481
+ params_backup = params.numpy().copy()
482
+ params = self.rescale(params, self.ranges_dict['params'], to=[0,1])
483
+
484
+ print(f"sampling {repeat} images with normalized params = {params}")
485
+ params = params.repeat(repeat,1)
486
+ assert params.dim() == 2, "params must be a 2D torch.tensor"
487
+ # print("params =", params)
488
+ # print("params =", params)
489
+ # print("len(params) =", len(params))
490
+ # model = self.ema_model if ema else self.nn_model
491
+ # del self.ema_model, self.nn
492
+ # params = torch.tile(params, (n_sample,1)).to(device)
493
+
494
+ nn_model = ContextUnet(n_param=self.config.n_param, image_size=self.config.HII_DIM, dim=self.config.dim, stride=self.config.stride).to(self.config.device)
495
+ if ema:
496
+ nn_model.load_state_dict(torch.load(file)['ema_unet_state_dict'])
497
+ else:
498
+ nn_model.load_state_dict(torch.load(file)['unet_state_dict'])
499
+ print(f"nn_model resumed from {file}")
500
+ # nn_model = ContextUnet(n_param=1, image_size=28)
501
+ # nn_model.train()
502
+ nn_model.to(self.ddpm.device)
503
+ nn_model.eval()
504
+
505
+ # self.ema_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM, dim=config.dim, stride=config.stride).to(config.device)
506
+ # self.ema_model.load_state_dict(torch.load(os.path.join(config.output_dir, f"{config.resume}"))['ema_unet_state_dict'])
507
+ # print(f"resumed ema_model from {config.resume}")
508
+
509
+ with torch.no_grad():
510
+ x_last, x_entire = self.ddpm.sample(
511
+ nn_model=nn_model,
512
+ params=params.to(self.config.device),
513
+ device=self.config.device,
514
+ guide_w=self.config.guide_w
515
+ )
516
+
517
+ # np.save(os.path.join(self.config.output_dir, f"{self.config.run_name}{'ema' if ema else ''}.npy"), x_last)
518
+ np.save(os.path.join(self.config.output_dir, f"Tvir{params_backup[0]}-zeta{params_backup[1]}-N{self.config.num_image}{'ema' if ema else ''}.npy"), x_last)
519
+
520
+ if entire:
521
+ np.save(os.path.join(self.config.output_dir, f"Tvir{params_backup[0]}-zeta{params_backup[1]}-N{self.config.num_image}{'ema' if ema else ''}_entire.npy"), x_last)
522
+ # print("device =", config.device)
523
+
524
+ # %%
525
+ def single_main(rank, world_size):
526
+ config = TrainConfig()
527
+ ddp_setup(rank, world_size)
528
+
529
+ num_image_list = [100]#[200]#[1600,3200,6400,12800,25600]
530
+ for i, num_image in enumerate(num_image_list):
531
+ config.num_image = num_image
532
+ # config.world_size = world_size
533
+
534
+ ddpm21cm = DDPM21CM(config)
535
+ print(f" num_image = {ddpm21cm.config.num_image} ".center(50, '-'))
536
+ print(f"run_name = {ddpm21cm.config.run_name}")
537
+ ddpm21cm.train()
538
+
539
+
540
+ if __name__ == "__main__":
541
+ # torch.multiprocessing.set_start_method("spawn")
542
+ # args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)
543
+ world_size = 1#torch.cuda.device_count()
544
+
545
+ mp.spawn(single_main, args=(world_size,), nprocs=world_size)
546
+ # notebook_launcher(ddpm21cm.train, num_processes=1, mixed_precision='fp16')
547
+
548
+ # %%
549
+ # torch.cuda.set_device(0)
550
+
551
+ # %%
552
+ print("torch.cuda.is_available() =", torch.cuda.is_available())
553
+ print("torch.cuda.device_count() =", torch.cuda.device_count())
554
+ # print(torch.cuda.__dir__())
555
+
556
+ # %%
557
+ print("torch.cuda.is_initialized() =", torch.cuda.is_initialized())
558
+ # print("torch.cuda.get_device_name() =", torch.cuda.get_device_name())
559
+ print("torch.cuda.current_device() =", torch.cuda.current_device())
560
+ # print("torch.cuda.get_device_capability() =", torch.cuda.get_device_capability())
561
+ # print("torch.cuda.get_device_properties(torch.cuda.device) =", torch.cuda.get_device_properties(torch.cuda.device))
562
+ # print('here')
563
+ # print(torch.cuda.memory_usage())
564
+ # print(torch.cuda.utilization())
565
+ # print(torch.cuda.memory())
566
+ # print('here')
567
+ # print(torch.cuda.memory_summary())
568
+
569
+ # %% [markdown]
570
+ # # Sampling
571
+
572
+ # %%
573
+ if __name__ == "__main__":
574
+ # num_image_list = [1600,3200,6400,12800,25600]
575
+ num_image_list = [1000]
576
+ # num_image_list = [3200,6400,12800,25600]
577
+ # args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)
578
+ repeat = 2
579
+ config = TrainConfig()
580
+ for i, num_image in enumerate(num_image_list):
581
+ config.num_image = num_image
582
+ ddpm21cm = DDPM21CM(config)
583
+
584
+ ddpm21cm.sample(f"./outputs/model_state-N{num_image}", params=torch.tensor([4.4, 131.341]), repeat=repeat)
585
+
586
+ # ddpm21cm.sample(f"./outputs/model_state-N{num_image}", params=torch.tensor((5.6, 19.037)), repeat=repeat)
587
+
588
+ # ddpm21cm.sample(f"./outputs/model_state-N{num_image}", params=torch.tensor((4.699, 30)), repeat=repeat)
589
+
590
+ # ddpm21cm.sample(f"./outputs/model_state-N{num_image}", params=torch.tensor((5.477, 200)), repeat=repeat)
591
+
592
+ # ddpm21cm.sample(f"./outputs/model_state-N{num_image}", params=torch.tensor((4.8, 131.341)), repeat=repeat)
593
+
594
+ # %%
595
+ # ls -lth outputs | head
596
+
597
+ # # %%
598
+ # def plot_grid(samples, c=None, row=1, col=2):
599
+ # print("samples.shape =", samples.shape)
600
+ # for j in range(samples.shape[4]):
601
+ # plt.figure(figsize = (12,6), dpi=400)
602
+ # for i in range(len(samples)):
603
+ # plt.subplot(row,col,i+1)
604
+ # plt.imshow(samples[i,0,:,:,j], cmap='gray')#, vmin=-1, vmax=1)
605
+ # plt.xticks([])
606
+ # plt.yticks([])
607
+ # # plt.suptitle(f"ION_Tvir_MIN = {c[0][0]}, HII_EFF_FACTOR = {c[0][1]}")
608
+ # # plt.show()
609
+ # # plt.suptitle('simulations')
610
+ # plt.tight_layout()
611
+ # plt.subplots_adjust(wspace=0, hspace=0)
612
+ # plt.savefig(f"test3D-{j:03d}.png")
613
+ # plt.close()
614
+ # # plt.show()
615
+
616
+ # data = np.load("outputs/Tvir4.400000095367432-zeta131.34100341796875-N1000.npy")
617
+ # # print(data.shape)
618
+ # plot_grid(data)
619
+ # plt.imshow(data)
620
+
621
+ # %%
622
+ # config = TrainConfig()
623
+ # def plot(filename, row=4, col=6):
624
+ # samples = np.load(filename)
625
+ # params = filename.split('guide_w')[-1][:-4]
626
+ # print("plotting", samples.shape, params)
627
+ # plt.figure(figsize = (8,8))
628
+ # for i in range(24):
629
+ # plt.subplot(row,col,i+1)
630
+ # plt.imshow(samples[i,0,:,:], cmap='gray')#, vmin=-1, vmax=1)
631
+ # plt.xticks([])
632
+ # plt.yticks([])
633
+ # # plt.show()
634
+ # plt.suptitle(params)
635
+ # plt.tight_layout()
636
+ # plt.subplots_adjust(wspace=0, hspace=0)
637
+ # plt.show()
638
+ # # plt.savefig('outputs/'+params+'.png')
639
+ # # plt.close()
640
+ # # plt.imshow(images[0,0])
641
+ # # plt.show()
642
+
643
+ # %%
644
+
645
+