Xsmos commited on
Commit
35298ab
·
verified ·
1 Parent(s): 5cb0c22

0523-1621

Browse files
Files changed (1) hide show
  1. diffusion.ipynb +23 -202
diffusion.ipynb CHANGED
@@ -283,8 +283,8 @@
283
  " mixed_precision = \"fp16\"\n",
284
  " gradient_accumulation_steps = 1\n",
285
  "\n",
286
- " date = datetime.datetime.now().strftime(\"%m%d-%H%M\")\n",
287
- " run_name = f'{date}' # the unique name of each experiment\n",
288
  "\n",
289
  "# config = TrainConfig()\n",
290
  "# print(\"device =\", config.device)"
@@ -294,22 +294,14 @@
294
  "cell_type": "code",
295
  "execution_count": 5,
296
  "metadata": {},
297
- "outputs": [
298
- {
299
- "name": "stdout",
300
- "output_type": "stream",
301
- "text": [
302
- "resumed nn_model from model_state.pth\n",
303
- "Number of parameters for nn_model: 111048705\n",
304
- "resumed ema_model from model_state.pth\n"
305
- ]
306
- }
307
- ],
308
  "source": [
309
  "# @dataclass\n",
310
  "class DDPM21CM:\n",
311
  " def __init__(self):\n",
312
  " config = TrainConfig()\n",
 
 
313
  " self.config = config\n",
314
  " # dataset = Dataset4h5(config.dataset_name, num_image=config.num_image, HII_DIM=config.HII_DIM, num_redshift=config.num_redshift, drop_prob=config.drop_prob, dim=config.dim)\n",
315
  " # # self.shape_loaded = dataset.images.shape\n",
@@ -380,7 +372,7 @@
380
  " self.repo_id = create_repo(\n",
381
  " repo_id=self.config.hub_model_id or Path(self.config.output_dir).name, exist_ok=True\n",
382
  " ).repo_id\n",
383
- " self.accelerator.init_trackers(f\"{self.config.date}\")\n",
384
  "\n",
385
  " self.nn_model, self.optimizer, self.dataloader, self.lr_scheduler = \\\n",
386
  " self.accelerator.prepare(\n",
@@ -444,7 +436,7 @@
444
  " upload_folder(\n",
445
  " repo_id = self.repo_id,\n",
446
  " folder_path = \".\",#config.output_dir,\n",
447
- " commit_message = f\"{self.config.date}\",\n",
448
  " ignore_patterns = [\"step_*\", \"epoch_*\", \"*.npy\", \"__pycache__\"],\n",
449
  " )\n",
450
  " if self.config.save_model:\n",
@@ -500,13 +492,13 @@
500
  },
501
  {
502
  "cell_type": "code",
503
- "execution_count": 12,
504
  "metadata": {},
505
  "outputs": [
506
  {
507
  "data": {
508
  "application/vnd.jupyter.widget-view+json": {
509
- "model_id": "a5f2462b91824a309b66da5c9e46905c",
510
  "version_major": 2,
511
  "version_minor": 0
512
  },
@@ -524,6 +516,7 @@
524
  "resumed nn_model from model_state.pth\n",
525
  "Number of parameters for nn_model: 111048705\n",
526
  "resumed ema_model from model_state.pth\n",
 
527
  "Launching training on one GPU.\n",
528
  "dataset content: <KeysViewHDF5 ['brightness_temp', 'density', 'kwargs', 'params', 'redshifts_distances', 'seeds', 'xH_box']>\n",
529
  "51200 images can be loaded\n",
@@ -545,187 +538,14 @@
545
  "output_type": "stream",
546
  "text": [
547
  "params loaded: (240, 2)\n",
548
- "images rescaled to [-1.0, 1.1086735725402832]\n",
549
- "params rescaled to [0.0, 0.9959690281993576]\n"
550
- ]
551
- },
552
- {
553
- "data": {
554
- "application/vnd.jupyter.widget-view+json": {
555
- "model_id": "ca18a0ce55d643dbaf33a148f579c7a7",
556
- "version_major": 2,
557
- "version_minor": 0
558
- },
559
- "text/plain": [
560
- " 0%| | 0/24 [00:00<?, ?it/s]"
561
- ]
562
- },
563
- "metadata": {},
564
- "output_type": "display_data"
565
- },
566
- {
567
- "data": {
568
- "application/vnd.jupyter.widget-view+json": {
569
- "model_id": "1034bb856742454398d95d08087ad46b",
570
- "version_major": 2,
571
- "version_minor": 0
572
- },
573
- "text/plain": [
574
- " 0%| | 0/24 [00:00<?, ?it/s]"
575
- ]
576
- },
577
- "metadata": {},
578
- "output_type": "display_data"
579
- },
580
- {
581
- "data": {
582
- "application/vnd.jupyter.widget-view+json": {
583
- "model_id": "feb8f77c25fe4a478a4b04885109b934",
584
- "version_major": 2,
585
- "version_minor": 0
586
- },
587
- "text/plain": [
588
- " 0%| | 0/24 [00:00<?, ?it/s]"
589
- ]
590
- },
591
- "metadata": {},
592
- "output_type": "display_data"
593
- },
594
- {
595
- "data": {
596
- "application/vnd.jupyter.widget-view+json": {
597
- "model_id": "0b517c548fe24cf493beac54b559027d",
598
- "version_major": 2,
599
- "version_minor": 0
600
- },
601
- "text/plain": [
602
- " 0%| | 0/24 [00:00<?, ?it/s]"
603
- ]
604
- },
605
- "metadata": {},
606
- "output_type": "display_data"
607
- },
608
- {
609
- "data": {
610
- "application/vnd.jupyter.widget-view+json": {
611
- "model_id": "fe21226d2ea14120b926ac1ea2b2f48a",
612
- "version_major": 2,
613
- "version_minor": 0
614
- },
615
- "text/plain": [
616
- " 0%| | 0/24 [00:00<?, ?it/s]"
617
- ]
618
- },
619
- "metadata": {},
620
- "output_type": "display_data"
621
- },
622
- {
623
- "data": {
624
- "application/vnd.jupyter.widget-view+json": {
625
- "model_id": "2526c949e3484788b4818cb4f3590750",
626
- "version_major": 2,
627
- "version_minor": 0
628
- },
629
- "text/plain": [
630
- " 0%| | 0/24 [00:00<?, ?it/s]"
631
- ]
632
- },
633
- "metadata": {},
634
- "output_type": "display_data"
635
- },
636
- {
637
- "data": {
638
- "application/vnd.jupyter.widget-view+json": {
639
- "model_id": "fbc6f3ddea3343d498d19ca32b877f84",
640
- "version_major": 2,
641
- "version_minor": 0
642
- },
643
- "text/plain": [
644
- " 0%| | 0/24 [00:00<?, ?it/s]"
645
- ]
646
- },
647
- "metadata": {},
648
- "output_type": "display_data"
649
- },
650
- {
651
- "data": {
652
- "application/vnd.jupyter.widget-view+json": {
653
- "model_id": "ec74582b77764d8387bfaec48a30ff6b",
654
- "version_major": 2,
655
- "version_minor": 0
656
- },
657
- "text/plain": [
658
- " 0%| | 0/24 [00:00<?, ?it/s]"
659
- ]
660
- },
661
- "metadata": {},
662
- "output_type": "display_data"
663
- },
664
- {
665
- "data": {
666
- "application/vnd.jupyter.widget-view+json": {
667
- "model_id": "52099c31f7cb4190b773f1bfeefa4d75",
668
- "version_major": 2,
669
- "version_minor": 0
670
- },
671
- "text/plain": [
672
- " 0%| | 0/24 [00:00<?, ?it/s]"
673
- ]
674
- },
675
- "metadata": {},
676
- "output_type": "display_data"
677
- },
678
- {
679
- "data": {
680
- "application/vnd.jupyter.widget-view+json": {
681
- "model_id": "c2ebf165487d4a06aa5c49afb7f27572",
682
- "version_major": 2,
683
- "version_minor": 0
684
- },
685
- "text/plain": [
686
- " 0%| | 0/24 [00:00<?, ?it/s]"
687
- ]
688
- },
689
- "metadata": {},
690
- "output_type": "display_data"
691
- },
692
- {
693
- "name": "stdout",
694
- "output_type": "stream",
695
- "text": [
696
- "saved model at ./outputs/model_state_09.pth\n",
697
- "resumed nn_model from model_state.pth\n",
698
- "Number of parameters for nn_model: 111048705\n",
699
- "resumed ema_model from model_state.pth\n",
700
- "Launching training on one GPU.\n",
701
- "dataset content: <KeysViewHDF5 ['brightness_temp', 'density', 'kwargs', 'params', 'redshifts_distances', 'seeds', 'xH_box']>\n",
702
- "51200 images can be loaded\n",
703
- "field.shape = (64, 64, 514)\n",
704
- "params keys = [b'ION_Tvir_MIN', b'HII_EFF_FACTOR']\n",
705
- "loading 240 images randomly\n",
706
- "images loaded: (240, 1, 64, 512)\n"
707
- ]
708
- },
709
- {
710
- "name": "stderr",
711
- "output_type": "stream",
712
- "text": [
713
- "Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n"
714
- ]
715
- },
716
- {
717
- "name": "stdout",
718
- "output_type": "stream",
719
- "text": [
720
- "params loaded: (240, 2)\n",
721
- "images rescaled to [-1.0, 1.1578054428100586]\n",
722
- "params rescaled to [0.0, 0.9981726090056542]\n"
723
  ]
724
  },
725
  {
726
  "data": {
727
  "application/vnd.jupyter.widget-view+json": {
728
- "model_id": "e9240b35009f40f29dd25cf6f45e90f1",
729
  "version_major": 2,
730
  "version_minor": 0
731
  },
@@ -739,7 +559,7 @@
739
  {
740
  "data": {
741
  "application/vnd.jupyter.widget-view+json": {
742
- "model_id": "adf6eb441a3b40bd84aa7c274be78f3a",
743
  "version_major": 2,
744
  "version_minor": 0
745
  },
@@ -753,7 +573,7 @@
753
  {
754
  "data": {
755
  "application/vnd.jupyter.widget-view+json": {
756
- "model_id": "8153ead286dc4b07801b22897b38ff84",
757
  "version_major": 2,
758
  "version_minor": 0
759
  },
@@ -767,7 +587,7 @@
767
  {
768
  "data": {
769
  "application/vnd.jupyter.widget-view+json": {
770
- "model_id": "21f4e4b41dfb44928c23aa89d2c14b36",
771
  "version_major": 2,
772
  "version_minor": 0
773
  },
@@ -781,7 +601,7 @@
781
  {
782
  "data": {
783
  "application/vnd.jupyter.widget-view+json": {
784
- "model_id": "54bb7a5c70f84fc484336be7155c8215",
785
  "version_major": 2,
786
  "version_minor": 0
787
  },
@@ -795,7 +615,7 @@
795
  {
796
  "data": {
797
  "application/vnd.jupyter.widget-view+json": {
798
- "model_id": "e0332693a3fc4611907cd10d4f0ba467",
799
  "version_major": 2,
800
  "version_minor": 0
801
  },
@@ -809,7 +629,7 @@
809
  {
810
  "data": {
811
  "application/vnd.jupyter.widget-view+json": {
812
- "model_id": "7e22d84514b6499397fd5216542ea317",
813
  "version_major": 2,
814
  "version_minor": 0
815
  },
@@ -823,7 +643,7 @@
823
  {
824
  "data": {
825
  "application/vnd.jupyter.widget-view+json": {
826
- "model_id": "9193803414974584bd5ce5819022eddc",
827
  "version_major": 2,
828
  "version_minor": 0
829
  },
@@ -837,7 +657,7 @@
837
  {
838
  "data": {
839
  "application/vnd.jupyter.widget-view+json": {
840
- "model_id": "39aa6a86ec44471a95ffece6c40a2fb0",
841
  "version_major": 2,
842
  "version_minor": 0
843
  },
@@ -851,7 +671,7 @@
851
  {
852
  "data": {
853
  "application/vnd.jupyter.widget-view+json": {
854
- "model_id": "6a815218832c4c4c855d06737e6131b6",
855
  "version_major": 2,
856
  "version_minor": 0
857
  },
@@ -870,6 +690,7 @@
870
  " repeat = 2\n",
871
  " for i in range(repeat):\n",
872
  " ddpm21cm = DDPM21CM()\n",
 
873
  " notebook_launcher(ddpm21cm.train, num_processes=1)"
874
  ]
875
  },
 
283
  " mixed_precision = \"fp16\"\n",
284
  " gradient_accumulation_steps = 1\n",
285
  "\n",
286
+ " # date = datetime.datetime.now().strftime(\"%m%d-%H%M\")\n",
287
+ " # run_name = f'{date}' # the unique name of each experiment\n",
288
  "\n",
289
  "# config = TrainConfig()\n",
290
  "# print(\"device =\", config.device)"
 
294
  "cell_type": "code",
295
  "execution_count": 5,
296
  "metadata": {},
297
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
298
  "source": [
299
  "# @dataclass\n",
300
  "class DDPM21CM:\n",
301
  " def __init__(self):\n",
302
  " config = TrainConfig()\n",
303
+ " # date = datetime.datetime.now().strftime(\"%m%d-%H%M\")\n",
304
+ " config.run_name = datetime.datetime.now().strftime(\"%m%d-%H%M\") # the unique name of each experiment\n",
305
  " self.config = config\n",
306
  " # dataset = Dataset4h5(config.dataset_name, num_image=config.num_image, HII_DIM=config.HII_DIM, num_redshift=config.num_redshift, drop_prob=config.drop_prob, dim=config.dim)\n",
307
  " # # self.shape_loaded = dataset.images.shape\n",
 
372
  " self.repo_id = create_repo(\n",
373
  " repo_id=self.config.hub_model_id or Path(self.config.output_dir).name, exist_ok=True\n",
374
  " ).repo_id\n",
375
+ " self.accelerator.init_trackers(f\"{self.config.run_name}\")\n",
376
  "\n",
377
  " self.nn_model, self.optimizer, self.dataloader, self.lr_scheduler = \\\n",
378
  " self.accelerator.prepare(\n",
 
436
  " upload_folder(\n",
437
  " repo_id = self.repo_id,\n",
438
  " folder_path = \".\",#config.output_dir,\n",
439
+ " commit_message = f\"{self.config.run_name}\",\n",
440
  " ignore_patterns = [\"step_*\", \"epoch_*\", \"*.npy\", \"__pycache__\"],\n",
441
  " )\n",
442
  " if self.config.save_model:\n",
 
492
  },
493
  {
494
  "cell_type": "code",
495
+ "execution_count": 6,
496
  "metadata": {},
497
  "outputs": [
498
  {
499
  "data": {
500
  "application/vnd.jupyter.widget-view+json": {
501
+ "model_id": "6dca1df1da3148f28c71fed756c7abc9",
502
  "version_major": 2,
503
  "version_minor": 0
504
  },
 
516
  "resumed nn_model from model_state.pth\n",
517
  "Number of parameters for nn_model: 111048705\n",
518
  "resumed ema_model from model_state.pth\n",
519
+ "run_name = 0523-1621\n",
520
  "Launching training on one GPU.\n",
521
  "dataset content: <KeysViewHDF5 ['brightness_temp', 'density', 'kwargs', 'params', 'redshifts_distances', 'seeds', 'xH_box']>\n",
522
  "51200 images can be loaded\n",
 
538
  "output_type": "stream",
539
  "text": [
540
  "params loaded: (240, 2)\n",
541
+ "images rescaled to [-1.0, 1.1240839958190918]\n",
542
+ "params rescaled to [0.0, 0.9972546078293054]\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
543
  ]
544
  },
545
  {
546
  "data": {
547
  "application/vnd.jupyter.widget-view+json": {
548
+ "model_id": "15d75d83ca9f4f49be17a89f6ddd58e1",
549
  "version_major": 2,
550
  "version_minor": 0
551
  },
 
559
  {
560
  "data": {
561
  "application/vnd.jupyter.widget-view+json": {
562
+ "model_id": "66959c994f6b40649ab527212de8d3c2",
563
  "version_major": 2,
564
  "version_minor": 0
565
  },
 
573
  {
574
  "data": {
575
  "application/vnd.jupyter.widget-view+json": {
576
+ "model_id": "564f6d85e359481f973a49f75b180440",
577
  "version_major": 2,
578
  "version_minor": 0
579
  },
 
587
  {
588
  "data": {
589
  "application/vnd.jupyter.widget-view+json": {
590
+ "model_id": "079a2325ab83494282c83b76ffb8e52e",
591
  "version_major": 2,
592
  "version_minor": 0
593
  },
 
601
  {
602
  "data": {
603
  "application/vnd.jupyter.widget-view+json": {
604
+ "model_id": "fefa0f8dbfeb474d90e0aaf55f8ca5e8",
605
  "version_major": 2,
606
  "version_minor": 0
607
  },
 
615
  {
616
  "data": {
617
  "application/vnd.jupyter.widget-view+json": {
618
+ "model_id": "b216c0bb3bd4457f9230b32b8d2ede1f",
619
  "version_major": 2,
620
  "version_minor": 0
621
  },
 
629
  {
630
  "data": {
631
  "application/vnd.jupyter.widget-view+json": {
632
+ "model_id": "78d4bdad3dc34ba18f3074802c67bf61",
633
  "version_major": 2,
634
  "version_minor": 0
635
  },
 
643
  {
644
  "data": {
645
  "application/vnd.jupyter.widget-view+json": {
646
+ "model_id": "e78d2d3247b442b78f06b38b65944887",
647
  "version_major": 2,
648
  "version_minor": 0
649
  },
 
657
  {
658
  "data": {
659
  "application/vnd.jupyter.widget-view+json": {
660
+ "model_id": "5e1d909d5f3f4c26a11bd40978c57f4e",
661
  "version_major": 2,
662
  "version_minor": 0
663
  },
 
671
  {
672
  "data": {
673
  "application/vnd.jupyter.widget-view+json": {
674
+ "model_id": "d1f56418378049b59ba1f9de7c5676f1",
675
  "version_major": 2,
676
  "version_minor": 0
677
  },
 
690
  " repeat = 2\n",
691
  " for i in range(repeat):\n",
692
  " ddpm21cm = DDPM21CM()\n",
693
+ " print(f\"run_name = {ddpm21cm.config.run_name}\")\n",
694
  " notebook_launcher(ddpm21cm.train, num_processes=1)"
695
  ]
696
  },