Xsmos commited on
Commit
245c27a
·
verified ·
1 Parent(s): d69a5d7
Files changed (1) hide show
  1. diffusion.ipynb +43 -279
diffusion.ipynb CHANGED
@@ -244,13 +244,13 @@
244
  " # dim = 2\n",
245
  " dim = 2\n",
246
  " stride = (2,2) if dim == 2 else (2,2,4)\n",
247
- " num_image = 240#0\n",
248
  " HII_DIM = 64\n",
249
  " num_redshift = 512#256#256#64#512#128\n",
250
  " channel = 1\n",
251
  " img_shape = (channel, HII_DIM, num_redshift) if dim == 2 else (channel, HII_DIM, HII_DIM, num_redshift)\n",
252
  "\n",
253
- " n_epoch = 10#2#5#25 # 120\n",
254
  " num_timesteps = 1000#1000 # 1000, 500; DDPM time steps\n",
255
  " batch_size = 10#20#2#100 # 10\n",
256
  " # n_sample = 24 # 64, the number of samples in sampling process\n",
@@ -268,17 +268,17 @@
268
  " # device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
269
  " lrate = 1e-4\n",
270
  " lr_warmup_steps = 0#5#00\n",
271
- " save_model = True\n",
 
272
  " # save_freq = 1 #10 # the period of saving model\n",
273
  " # cond = True # if training using the conditional information\n",
274
  " # lr_decay = False #True# if using the learning rate decay\n",
275
- " resume = 'model_state.pth' # if resume from the trained checkpoints\n",
276
  " # params_single = torch.tensor([0.2,0.80000023])\n",
277
  " # params = torch.tile(params_single,(n_sample,1)).to(device)\n",
278
  " # params = params\n",
279
  " # data_dir = './data' # data directory\n",
280
  "\n",
281
- " output_dir = \"./outputs/\"\n",
282
  "\n",
283
  " mixed_precision = \"fp16\"\n",
284
  " gradient_accumulation_steps = 1\n",
@@ -313,8 +313,9 @@
313
  " # initialize the unet\n",
314
  " self.nn_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM, dim=config.dim, stride=config.stride)\n",
315
  "\n",
316
- " if config.resume:\n",
317
- " self.nn_model.load_state_dict(torch.load(os.path.join(config.output_dir, f\"{config.resume}\"))['unet_state_dict'])\n",
 
318
  " print(f\"resumed nn_model from {config.resume}\")\n",
319
  " # nn_model = ContextUnet(n_param=1, image_size=28)\n",
320
  " self.nn_model.train()\n",
@@ -327,12 +328,12 @@
327
  " # whether to use ema\n",
328
  " if config.ema:\n",
329
  " self.ema = EMA(config.ema_rate)\n",
330
- " if config.resume:\n",
331
  " self.ema_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM, dim=config.dim, stride=config.stride).to(config.device)\n",
332
- " self.ema_model.load_state_dict(torch.load(os.path.join(config.output_dir, f\"{config.resume}\"))['ema_unet_state_dict'])\n",
333
  " print(f\"resumed ema_model from {config.resume}\")\n",
334
  " else:\n",
335
- " self.ema_model = copy.deepcopy(nn_model).eval().requires_grad_(False)\n",
336
  "\n",
337
  " self.optimizer = torch.optim.AdamW(self.nn_model.parameters(), lr=config.lrate)\n",
338
  " self.lr_scheduler = get_cosine_schedule_with_warmup(\n",
@@ -439,14 +440,14 @@
439
  " commit_message = f\"{self.config.run_name}\",\n",
440
  " ignore_patterns = [\"step_*\", \"epoch_*\", \"*.npy\", \"__pycache__\"],\n",
441
  " )\n",
442
- " if self.config.save_model:\n",
443
  " model_state = {\n",
444
  " 'epoch': ep,\n",
445
  " 'unet_state_dict': self.nn_model.state_dict(),\n",
446
  " 'ema_unet_state_dict': self.ema_model.state_dict(),\n",
447
  " }\n",
448
- " torch.save(model_state, self.config.output_dir + f\"model_state_{ep:02d}.pth\")\n",
449
- " print('saved model at ' + self.config.output_dir + f\"model_state_{ep:02d}.pth\")\n",
450
  " # print('saved model at ' + config.save_dir + f\"model_epoch_{ep}_test_{config.run_name}.pth\")\n",
451
  "\n",
452
  " def sample(self, file, params:torch.tensor=None, ema=False, entire=False):\n",
@@ -498,7 +499,7 @@
498
  {
499
  "data": {
500
  "application/vnd.jupyter.widget-view+json": {
501
- "model_id": "6dca1df1da3148f28c71fed756c7abc9",
502
  "version_major": 2,
503
  "version_minor": 0
504
  },
@@ -508,196 +509,31 @@
508
  },
509
  "metadata": {},
510
  "output_type": "display_data"
511
- },
512
- {
513
- "name": "stdout",
514
- "output_type": "stream",
515
- "text": [
516
- "resumed nn_model from model_state.pth\n",
517
- "Number of parameters for nn_model: 111048705\n",
518
- "resumed ema_model from model_state.pth\n",
519
- "run_name = 0523-1621\n",
520
- "Launching training on one GPU.\n",
521
- "dataset content: <KeysViewHDF5 ['brightness_temp', 'density', 'kwargs', 'params', 'redshifts_distances', 'seeds', 'xH_box']>\n",
522
- "51200 images can be loaded\n",
523
- "field.shape = (64, 64, 514)\n",
524
- "params keys = [b'ION_Tvir_MIN', b'HII_EFF_FACTOR']\n",
525
- "loading 240 images randomly\n",
526
- "images loaded: (240, 1, 64, 512)\n"
527
- ]
528
- },
529
- {
530
- "name": "stderr",
531
- "output_type": "stream",
532
- "text": [
533
- "Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n"
534
- ]
535
- },
536
- {
537
- "name": "stdout",
538
- "output_type": "stream",
539
- "text": [
540
- "params loaded: (240, 2)\n",
541
- "images rescaled to [-1.0, 1.1240839958190918]\n",
542
- "params rescaled to [0.0, 0.9972546078293054]\n"
543
- ]
544
- },
545
- {
546
- "data": {
547
- "application/vnd.jupyter.widget-view+json": {
548
- "model_id": "15d75d83ca9f4f49be17a89f6ddd58e1",
549
- "version_major": 2,
550
- "version_minor": 0
551
- },
552
- "text/plain": [
553
- " 0%| | 0/24 [00:00<?, ?it/s]"
554
- ]
555
- },
556
- "metadata": {},
557
- "output_type": "display_data"
558
- },
559
- {
560
- "data": {
561
- "application/vnd.jupyter.widget-view+json": {
562
- "model_id": "66959c994f6b40649ab527212de8d3c2",
563
- "version_major": 2,
564
- "version_minor": 0
565
- },
566
- "text/plain": [
567
- " 0%| | 0/24 [00:00<?, ?it/s]"
568
- ]
569
- },
570
- "metadata": {},
571
- "output_type": "display_data"
572
- },
573
- {
574
- "data": {
575
- "application/vnd.jupyter.widget-view+json": {
576
- "model_id": "564f6d85e359481f973a49f75b180440",
577
- "version_major": 2,
578
- "version_minor": 0
579
- },
580
- "text/plain": [
581
- " 0%| | 0/24 [00:00<?, ?it/s]"
582
- ]
583
- },
584
- "metadata": {},
585
- "output_type": "display_data"
586
- },
587
- {
588
- "data": {
589
- "application/vnd.jupyter.widget-view+json": {
590
- "model_id": "079a2325ab83494282c83b76ffb8e52e",
591
- "version_major": 2,
592
- "version_minor": 0
593
- },
594
- "text/plain": [
595
- " 0%| | 0/24 [00:00<?, ?it/s]"
596
- ]
597
- },
598
- "metadata": {},
599
- "output_type": "display_data"
600
- },
601
- {
602
- "data": {
603
- "application/vnd.jupyter.widget-view+json": {
604
- "model_id": "fefa0f8dbfeb474d90e0aaf55f8ca5e8",
605
- "version_major": 2,
606
- "version_minor": 0
607
- },
608
- "text/plain": [
609
- " 0%| | 0/24 [00:00<?, ?it/s]"
610
- ]
611
- },
612
- "metadata": {},
613
- "output_type": "display_data"
614
- },
615
- {
616
- "data": {
617
- "application/vnd.jupyter.widget-view+json": {
618
- "model_id": "b216c0bb3bd4457f9230b32b8d2ede1f",
619
- "version_major": 2,
620
- "version_minor": 0
621
- },
622
- "text/plain": [
623
- " 0%| | 0/24 [00:00<?, ?it/s]"
624
- ]
625
- },
626
- "metadata": {},
627
- "output_type": "display_data"
628
- },
629
- {
630
- "data": {
631
- "application/vnd.jupyter.widget-view+json": {
632
- "model_id": "78d4bdad3dc34ba18f3074802c67bf61",
633
- "version_major": 2,
634
- "version_minor": 0
635
- },
636
- "text/plain": [
637
- " 0%| | 0/24 [00:00<?, ?it/s]"
638
- ]
639
- },
640
- "metadata": {},
641
- "output_type": "display_data"
642
- },
643
- {
644
- "data": {
645
- "application/vnd.jupyter.widget-view+json": {
646
- "model_id": "e78d2d3247b442b78f06b38b65944887",
647
- "version_major": 2,
648
- "version_minor": 0
649
- },
650
- "text/plain": [
651
- " 0%| | 0/24 [00:00<?, ?it/s]"
652
- ]
653
- },
654
- "metadata": {},
655
- "output_type": "display_data"
656
- },
657
- {
658
- "data": {
659
- "application/vnd.jupyter.widget-view+json": {
660
- "model_id": "5e1d909d5f3f4c26a11bd40978c57f4e",
661
- "version_major": 2,
662
- "version_minor": 0
663
- },
664
- "text/plain": [
665
- " 0%| | 0/24 [00:00<?, ?it/s]"
666
- ]
667
- },
668
- "metadata": {},
669
- "output_type": "display_data"
670
- },
671
- {
672
- "data": {
673
- "application/vnd.jupyter.widget-view+json": {
674
- "model_id": "d1f56418378049b59ba1f9de7c5676f1",
675
- "version_major": 2,
676
- "version_minor": 0
677
- },
678
- "text/plain": [
679
- " 0%| | 0/24 [00:00<?, ?it/s]"
680
- ]
681
- },
682
- "metadata": {},
683
- "output_type": "display_data"
684
- },
685
  {
686
  "name": "stdout",
687
  "output_type": "stream",
688
  "text": [
689
- "saved model at ./outputs/model_state_09.pth\n",
690
- "resumed nn_model from model_state.pth\n",
691
  "Number of parameters for nn_model: 111048705\n",
692
- "resumed ema_model from model_state.pth\n",
693
- "run_name = 0523-1624\n",
694
  "Launching training on one GPU.\n",
695
  "dataset content: <KeysViewHDF5 ['brightness_temp', 'density', 'kwargs', 'params', 'redshifts_distances', 'seeds', 'xH_box']>\n",
696
  "51200 images can be loaded\n",
697
  "field.shape = (64, 64, 514)\n",
698
  "params keys = [b'ION_Tvir_MIN', b'HII_EFF_FACTOR']\n",
699
- "loading 240 images randomly\n",
700
- "images loaded: (240, 1, 64, 512)\n"
 
701
  ]
702
  },
703
  {
@@ -711,90 +547,19 @@
711
  "name": "stdout",
712
  "output_type": "stream",
713
  "text": [
714
- "params loaded: (240, 2)\n",
715
- "images rescaled to [-1.0, 1.186443567276001]\n",
716
- "params rescaled to [0.0, 0.9999922179553216]\n"
717
  ]
718
  },
719
  {
720
  "data": {
721
  "application/vnd.jupyter.widget-view+json": {
722
- "model_id": "75b0459ffb784e1d9e4070b7a424a506",
723
- "version_major": 2,
724
- "version_minor": 0
725
- },
726
- "text/plain": [
727
- " 0%| | 0/24 [00:00<?, ?it/s]"
728
- ]
729
- },
730
- "metadata": {},
731
- "output_type": "display_data"
732
- },
733
- {
734
- "data": {
735
- "application/vnd.jupyter.widget-view+json": {
736
- "model_id": "c7e04c734ba3482eb44344f7a4e37916",
737
- "version_major": 2,
738
- "version_minor": 0
739
- },
740
- "text/plain": [
741
- " 0%| | 0/24 [00:00<?, ?it/s]"
742
- ]
743
- },
744
- "metadata": {},
745
- "output_type": "display_data"
746
- },
747
- {
748
- "data": {
749
- "application/vnd.jupyter.widget-view+json": {
750
- "model_id": "9a8c702e844f4fbaa295fb8f6d21503b",
751
- "version_major": 2,
752
- "version_minor": 0
753
- },
754
- "text/plain": [
755
- " 0%| | 0/24 [00:00<?, ?it/s]"
756
- ]
757
- },
758
- "metadata": {},
759
- "output_type": "display_data"
760
- },
761
- {
762
- "data": {
763
- "application/vnd.jupyter.widget-view+json": {
764
- "model_id": "062be80bffee4540b396159acc223e6e",
765
- "version_major": 2,
766
- "version_minor": 0
767
- },
768
- "text/plain": [
769
- " 0%| | 0/24 [00:00<?, ?it/s]"
770
- ]
771
- },
772
- "metadata": {},
773
- "output_type": "display_data"
774
- },
775
- {
776
- "data": {
777
- "application/vnd.jupyter.widget-view+json": {
778
- "model_id": "3849ae228e284a1a8c01235ffe2691aa",
779
- "version_major": 2,
780
- "version_minor": 0
781
- },
782
- "text/plain": [
783
- " 0%| | 0/24 [00:00<?, ?it/s]"
784
- ]
785
- },
786
- "metadata": {},
787
- "output_type": "display_data"
788
- },
789
- {
790
- "data": {
791
- "application/vnd.jupyter.widget-view+json": {
792
- "model_id": "b359bb4eeb8b4ec58be692424d352164",
793
  "version_major": 2,
794
  "version_minor": 0
795
  },
796
  "text/plain": [
797
- " 0%| | 0/24 [00:00<?, ?it/s]"
798
  ]
799
  },
800
  "metadata": {},
@@ -803,12 +568,12 @@
803
  {
804
  "data": {
805
  "application/vnd.jupyter.widget-view+json": {
806
- "model_id": "06c56d6f2e1443fd87bfa949f092b8f0",
807
  "version_major": 2,
808
  "version_minor": 0
809
  },
810
  "text/plain": [
811
- " 0%| | 0/24 [00:00<?, ?it/s]"
812
  ]
813
  },
814
  "metadata": {},
@@ -817,12 +582,12 @@
817
  {
818
  "data": {
819
  "application/vnd.jupyter.widget-view+json": {
820
- "model_id": "b5b613300a7046c8a1a62b5237ab5b4e",
821
  "version_major": 2,
822
  "version_minor": 0
823
  },
824
  "text/plain": [
825
- " 0%| | 0/24 [00:00<?, ?it/s]"
826
  ]
827
  },
828
  "metadata": {},
@@ -831,12 +596,12 @@
831
  {
832
  "data": {
833
  "application/vnd.jupyter.widget-view+json": {
834
- "model_id": "6e76a18a0bce4ee9acd9e8344a81fd65",
835
  "version_major": 2,
836
  "version_minor": 0
837
  },
838
  "text/plain": [
839
- " 0%| | 0/24 [00:00<?, ?it/s]"
840
  ]
841
  },
842
  "metadata": {},
@@ -845,12 +610,12 @@
845
  {
846
  "data": {
847
  "application/vnd.jupyter.widget-view+json": {
848
- "model_id": "3ecf321272464c988e4a291b19d164e0",
849
  "version_major": 2,
850
  "version_minor": 0
851
  },
852
  "text/plain": [
853
- " 0%| | 0/24 [00:00<?, ?it/s]"
854
  ]
855
  },
856
  "metadata": {},
@@ -860,8 +625,7 @@
860
  "source": [
861
  "if __name__ == \"__main__\":\n",
862
  " # args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)\n",
863
- " notebook_login()\n",
864
- " repeat = 2\n",
865
  " for i in range(repeat):\n",
866
  " ddpm21cm = DDPM21CM()\n",
867
  " print(f\"run_name = {ddpm21cm.config.run_name}\")\n",
 
244
  " # dim = 2\n",
245
  " dim = 2\n",
246
  " stride = (2,2) if dim == 2 else (2,2,4)\n",
247
+ " num_image = 2560\n",
248
  " HII_DIM = 64\n",
249
  " num_redshift = 512#256#256#64#512#128\n",
250
  " channel = 1\n",
251
  " img_shape = (channel, HII_DIM, num_redshift) if dim == 2 else (channel, HII_DIM, HII_DIM, num_redshift)\n",
252
  "\n",
253
+ " n_epoch = 5#2#5#25 # 120\n",
254
  " num_timesteps = 1000#1000 # 1000, 500; DDPM time steps\n",
255
  " batch_size = 10#20#2#100 # 10\n",
256
  " # n_sample = 24 # 64, the number of samples in sampling process\n",
 
268
  " # device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
269
  " lrate = 1e-4\n",
270
  " lr_warmup_steps = 0#5#00\n",
271
+ " output_dir = \"./outputs/\"\n",
272
+ " save_name = os.path.join(output_dir, 'model_state.pth')\n",
273
  " # save_freq = 1 #10 # the period of saving model\n",
274
  " # cond = True # if training using the conditional information\n",
275
  " # lr_decay = False #True# if using the learning rate decay\n",
276
+ " resume = save_name # if resume from the trained checkpoints\n",
277
  " # params_single = torch.tensor([0.2,0.80000023])\n",
278
  " # params = torch.tile(params_single,(n_sample,1)).to(device)\n",
279
  " # params = params\n",
280
  " # data_dir = './data' # data directory\n",
281
  "\n",
 
282
  "\n",
283
  " mixed_precision = \"fp16\"\n",
284
  " gradient_accumulation_steps = 1\n",
 
313
  " # initialize the unet\n",
314
  " self.nn_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM, dim=config.dim, stride=config.stride)\n",
315
  "\n",
316
+ " if config.resume and os.path.exists(config.resume):\n",
317
+ " # resume_file = os.path.join(config.output_dir, f\"{config.resume}\")\n",
318
+ " self.nn_model.load_state_dict(torch.load(config.resume)['unet_state_dict'])\n",
319
  " print(f\"resumed nn_model from {config.resume}\")\n",
320
  " # nn_model = ContextUnet(n_param=1, image_size=28)\n",
321
  " self.nn_model.train()\n",
 
328
  " # whether to use ema\n",
329
  " if config.ema:\n",
330
  " self.ema = EMA(config.ema_rate)\n",
331
+ " if config.resume and os.path.exists(config.resume):\n",
332
  " self.ema_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM, dim=config.dim, stride=config.stride).to(config.device)\n",
333
+ " self.ema_model.load_state_dict(torch.load(config.resume)['ema_unet_state_dict'])\n",
334
  " print(f\"resumed ema_model from {config.resume}\")\n",
335
  " else:\n",
336
+ " self.ema_model = copy.deepcopy(self.nn_model).eval().requires_grad_(False)\n",
337
  "\n",
338
  " self.optimizer = torch.optim.AdamW(self.nn_model.parameters(), lr=config.lrate)\n",
339
  " self.lr_scheduler = get_cosine_schedule_with_warmup(\n",
 
440
  " commit_message = f\"{self.config.run_name}\",\n",
441
  " ignore_patterns = [\"step_*\", \"epoch_*\", \"*.npy\", \"__pycache__\"],\n",
442
  " )\n",
443
+ " if self.config.save_name:\n",
444
  " model_state = {\n",
445
  " 'epoch': ep,\n",
446
  " 'unet_state_dict': self.nn_model.state_dict(),\n",
447
  " 'ema_unet_state_dict': self.ema_model.state_dict(),\n",
448
  " }\n",
449
+ " torch.save(model_state, self.config.save_name)\n",
450
+ " print('saved model at ' + self.config.save_name)\n",
451
  " # print('saved model at ' + config.save_dir + f\"model_epoch_{ep}_test_{config.run_name}.pth\")\n",
452
  "\n",
453
  " def sample(self, file, params:torch.tensor=None, ema=False, entire=False):\n",
 
499
  {
500
  "data": {
501
  "application/vnd.jupyter.widget-view+json": {
502
+ "model_id": "e0f355a0bc8b4592952af6c1ccd5d2fb",
503
  "version_major": 2,
504
  "version_minor": 0
505
  },
 
509
  },
510
  "metadata": {},
511
  "output_type": "display_data"
512
+ }
513
+ ],
514
+ "source": [
515
+ "notebook_login()"
516
+ ]
517
+ },
518
+ {
519
+ "cell_type": "code",
520
+ "execution_count": 7,
521
+ "metadata": {},
522
+ "outputs": [
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
523
  {
524
  "name": "stdout",
525
  "output_type": "stream",
526
  "text": [
 
 
527
  "Number of parameters for nn_model: 111048705\n",
528
+ "run_name = 0523-1704\n",
 
529
  "Launching training on one GPU.\n",
530
  "dataset content: <KeysViewHDF5 ['brightness_temp', 'density', 'kwargs', 'params', 'redshifts_distances', 'seeds', 'xH_box']>\n",
531
  "51200 images can be loaded\n",
532
  "field.shape = (64, 64, 514)\n",
533
  "params keys = [b'ION_Tvir_MIN', b'HII_EFF_FACTOR']\n",
534
+ "loading 2560 images randomly\n",
535
+ "images loaded: (2560, 1, 64, 512)\n",
536
+ "params loaded: (2560, 2)\n"
537
  ]
538
  },
539
  {
 
547
  "name": "stdout",
548
  "output_type": "stream",
549
  "text": [
550
+ "images rescaled to [-1.0, 1.1378462314605713]\n",
551
+ "params rescaled to [0.0, 0.9995994165819857]\n"
 
552
  ]
553
  },
554
  {
555
  "data": {
556
  "application/vnd.jupyter.widget-view+json": {
557
+ "model_id": "4d787d2fbdcf4575b7b17a6e5161f5ec",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
558
  "version_major": 2,
559
  "version_minor": 0
560
  },
561
  "text/plain": [
562
+ " 0%| | 0/256 [00:00<?, ?it/s]"
563
  ]
564
  },
565
  "metadata": {},
 
568
  {
569
  "data": {
570
  "application/vnd.jupyter.widget-view+json": {
571
+ "model_id": "e67439e56e594ecfb3967edbfb3f0d60",
572
  "version_major": 2,
573
  "version_minor": 0
574
  },
575
  "text/plain": [
576
+ " 0%| | 0/256 [00:00<?, ?it/s]"
577
  ]
578
  },
579
  "metadata": {},
 
582
  {
583
  "data": {
584
  "application/vnd.jupyter.widget-view+json": {
585
+ "model_id": "9ca7cb14960348fa8d83c90d773057ac",
586
  "version_major": 2,
587
  "version_minor": 0
588
  },
589
  "text/plain": [
590
+ " 0%| | 0/256 [00:00<?, ?it/s]"
591
  ]
592
  },
593
  "metadata": {},
 
596
  {
597
  "data": {
598
  "application/vnd.jupyter.widget-view+json": {
599
+ "model_id": "a6368ae7b9fb4505b6b62d51c5d675ed",
600
  "version_major": 2,
601
  "version_minor": 0
602
  },
603
  "text/plain": [
604
+ " 0%| | 0/256 [00:00<?, ?it/s]"
605
  ]
606
  },
607
  "metadata": {},
 
610
  {
611
  "data": {
612
  "application/vnd.jupyter.widget-view+json": {
613
+ "model_id": "d5a391c5bbfb4f6481c1f2ad6e754e24",
614
  "version_major": 2,
615
  "version_minor": 0
616
  },
617
  "text/plain": [
618
+ " 0%| | 0/256 [00:00<?, ?it/s]"
619
  ]
620
  },
621
  "metadata": {},
 
625
  "source": [
626
  "if __name__ == \"__main__\":\n",
627
  " # args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)\n",
628
+ " repeat = 30\n",
 
629
  " for i in range(repeat):\n",
630
  " ddpm21cm = DDPM21CM()\n",
631
  " print(f\"run_name = {ddpm21cm.config.run_name}\")\n",