0523-1704
Browse files- diffusion.ipynb +43 -279
diffusion.ipynb
CHANGED
|
@@ -244,13 +244,13 @@
|
|
| 244 |
" # dim = 2\n",
|
| 245 |
" dim = 2\n",
|
| 246 |
" stride = (2,2) if dim == 2 else (2,2,4)\n",
|
| 247 |
-
" num_image =
|
| 248 |
" HII_DIM = 64\n",
|
| 249 |
" num_redshift = 512#256#256#64#512#128\n",
|
| 250 |
" channel = 1\n",
|
| 251 |
" img_shape = (channel, HII_DIM, num_redshift) if dim == 2 else (channel, HII_DIM, HII_DIM, num_redshift)\n",
|
| 252 |
"\n",
|
| 253 |
-
" n_epoch =
|
| 254 |
" num_timesteps = 1000#1000 # 1000, 500; DDPM time steps\n",
|
| 255 |
" batch_size = 10#20#2#100 # 10\n",
|
| 256 |
" # n_sample = 24 # 64, the number of samples in sampling process\n",
|
|
@@ -268,17 +268,17 @@
|
|
| 268 |
" # device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
|
| 269 |
" lrate = 1e-4\n",
|
| 270 |
" lr_warmup_steps = 0#5#00\n",
|
| 271 |
-
"
|
|
|
|
| 272 |
" # save_freq = 1 #10 # the period of saving model\n",
|
| 273 |
" # cond = True # if training using the conditional information\n",
|
| 274 |
" # lr_decay = False #True# if using the learning rate decay\n",
|
| 275 |
-
" resume =
|
| 276 |
" # params_single = torch.tensor([0.2,0.80000023])\n",
|
| 277 |
" # params = torch.tile(params_single,(n_sample,1)).to(device)\n",
|
| 278 |
" # params = params\n",
|
| 279 |
" # data_dir = './data' # data directory\n",
|
| 280 |
"\n",
|
| 281 |
-
" output_dir = \"./outputs/\"\n",
|
| 282 |
"\n",
|
| 283 |
" mixed_precision = \"fp16\"\n",
|
| 284 |
" gradient_accumulation_steps = 1\n",
|
|
@@ -313,8 +313,9 @@
|
|
| 313 |
" # initialize the unet\n",
|
| 314 |
" self.nn_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM, dim=config.dim, stride=config.stride)\n",
|
| 315 |
"\n",
|
| 316 |
-
" if config.resume:\n",
|
| 317 |
-
"
|
|
|
|
| 318 |
" print(f\"resumed nn_model from {config.resume}\")\n",
|
| 319 |
" # nn_model = ContextUnet(n_param=1, image_size=28)\n",
|
| 320 |
" self.nn_model.train()\n",
|
|
@@ -327,12 +328,12 @@
|
|
| 327 |
" # whether to use ema\n",
|
| 328 |
" if config.ema:\n",
|
| 329 |
" self.ema = EMA(config.ema_rate)\n",
|
| 330 |
-
" if config.resume:\n",
|
| 331 |
" self.ema_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM, dim=config.dim, stride=config.stride).to(config.device)\n",
|
| 332 |
-
" self.ema_model.load_state_dict(torch.load(
|
| 333 |
" print(f\"resumed ema_model from {config.resume}\")\n",
|
| 334 |
" else:\n",
|
| 335 |
-
" self.ema_model = copy.deepcopy(nn_model).eval().requires_grad_(False)\n",
|
| 336 |
"\n",
|
| 337 |
" self.optimizer = torch.optim.AdamW(self.nn_model.parameters(), lr=config.lrate)\n",
|
| 338 |
" self.lr_scheduler = get_cosine_schedule_with_warmup(\n",
|
|
@@ -439,14 +440,14 @@
|
|
| 439 |
" commit_message = f\"{self.config.run_name}\",\n",
|
| 440 |
" ignore_patterns = [\"step_*\", \"epoch_*\", \"*.npy\", \"__pycache__\"],\n",
|
| 441 |
" )\n",
|
| 442 |
-
" if self.config.
|
| 443 |
" model_state = {\n",
|
| 444 |
" 'epoch': ep,\n",
|
| 445 |
" 'unet_state_dict': self.nn_model.state_dict(),\n",
|
| 446 |
" 'ema_unet_state_dict': self.ema_model.state_dict(),\n",
|
| 447 |
" }\n",
|
| 448 |
-
" torch.save(model_state, self.config.
|
| 449 |
-
" print('saved model at ' + self.config.
|
| 450 |
" # print('saved model at ' + config.save_dir + f\"model_epoch_{ep}_test_{config.run_name}.pth\")\n",
|
| 451 |
"\n",
|
| 452 |
" def sample(self, file, params:torch.tensor=None, ema=False, entire=False):\n",
|
|
@@ -498,7 +499,7 @@
|
|
| 498 |
{
|
| 499 |
"data": {
|
| 500 |
"application/vnd.jupyter.widget-view+json": {
|
| 501 |
-
"model_id": "
|
| 502 |
"version_major": 2,
|
| 503 |
"version_minor": 0
|
| 504 |
},
|
|
@@ -508,196 +509,31 @@
|
|
| 508 |
},
|
| 509 |
"metadata": {},
|
| 510 |
"output_type": "display_data"
|
| 511 |
-
}
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
|
| 520 |
-
|
| 521 |
-
|
| 522 |
-
"51200 images can be loaded\n",
|
| 523 |
-
"field.shape = (64, 64, 514)\n",
|
| 524 |
-
"params keys = [b'ION_Tvir_MIN', b'HII_EFF_FACTOR']\n",
|
| 525 |
-
"loading 240 images randomly\n",
|
| 526 |
-
"images loaded: (240, 1, 64, 512)\n"
|
| 527 |
-
]
|
| 528 |
-
},
|
| 529 |
-
{
|
| 530 |
-
"name": "stderr",
|
| 531 |
-
"output_type": "stream",
|
| 532 |
-
"text": [
|
| 533 |
-
"Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n"
|
| 534 |
-
]
|
| 535 |
-
},
|
| 536 |
-
{
|
| 537 |
-
"name": "stdout",
|
| 538 |
-
"output_type": "stream",
|
| 539 |
-
"text": [
|
| 540 |
-
"params loaded: (240, 2)\n",
|
| 541 |
-
"images rescaled to [-1.0, 1.1240839958190918]\n",
|
| 542 |
-
"params rescaled to [0.0, 0.9972546078293054]\n"
|
| 543 |
-
]
|
| 544 |
-
},
|
| 545 |
-
{
|
| 546 |
-
"data": {
|
| 547 |
-
"application/vnd.jupyter.widget-view+json": {
|
| 548 |
-
"model_id": "15d75d83ca9f4f49be17a89f6ddd58e1",
|
| 549 |
-
"version_major": 2,
|
| 550 |
-
"version_minor": 0
|
| 551 |
-
},
|
| 552 |
-
"text/plain": [
|
| 553 |
-
" 0%| | 0/24 [00:00<?, ?it/s]"
|
| 554 |
-
]
|
| 555 |
-
},
|
| 556 |
-
"metadata": {},
|
| 557 |
-
"output_type": "display_data"
|
| 558 |
-
},
|
| 559 |
-
{
|
| 560 |
-
"data": {
|
| 561 |
-
"application/vnd.jupyter.widget-view+json": {
|
| 562 |
-
"model_id": "66959c994f6b40649ab527212de8d3c2",
|
| 563 |
-
"version_major": 2,
|
| 564 |
-
"version_minor": 0
|
| 565 |
-
},
|
| 566 |
-
"text/plain": [
|
| 567 |
-
" 0%| | 0/24 [00:00<?, ?it/s]"
|
| 568 |
-
]
|
| 569 |
-
},
|
| 570 |
-
"metadata": {},
|
| 571 |
-
"output_type": "display_data"
|
| 572 |
-
},
|
| 573 |
-
{
|
| 574 |
-
"data": {
|
| 575 |
-
"application/vnd.jupyter.widget-view+json": {
|
| 576 |
-
"model_id": "564f6d85e359481f973a49f75b180440",
|
| 577 |
-
"version_major": 2,
|
| 578 |
-
"version_minor": 0
|
| 579 |
-
},
|
| 580 |
-
"text/plain": [
|
| 581 |
-
" 0%| | 0/24 [00:00<?, ?it/s]"
|
| 582 |
-
]
|
| 583 |
-
},
|
| 584 |
-
"metadata": {},
|
| 585 |
-
"output_type": "display_data"
|
| 586 |
-
},
|
| 587 |
-
{
|
| 588 |
-
"data": {
|
| 589 |
-
"application/vnd.jupyter.widget-view+json": {
|
| 590 |
-
"model_id": "079a2325ab83494282c83b76ffb8e52e",
|
| 591 |
-
"version_major": 2,
|
| 592 |
-
"version_minor": 0
|
| 593 |
-
},
|
| 594 |
-
"text/plain": [
|
| 595 |
-
" 0%| | 0/24 [00:00<?, ?it/s]"
|
| 596 |
-
]
|
| 597 |
-
},
|
| 598 |
-
"metadata": {},
|
| 599 |
-
"output_type": "display_data"
|
| 600 |
-
},
|
| 601 |
-
{
|
| 602 |
-
"data": {
|
| 603 |
-
"application/vnd.jupyter.widget-view+json": {
|
| 604 |
-
"model_id": "fefa0f8dbfeb474d90e0aaf55f8ca5e8",
|
| 605 |
-
"version_major": 2,
|
| 606 |
-
"version_minor": 0
|
| 607 |
-
},
|
| 608 |
-
"text/plain": [
|
| 609 |
-
" 0%| | 0/24 [00:00<?, ?it/s]"
|
| 610 |
-
]
|
| 611 |
-
},
|
| 612 |
-
"metadata": {},
|
| 613 |
-
"output_type": "display_data"
|
| 614 |
-
},
|
| 615 |
-
{
|
| 616 |
-
"data": {
|
| 617 |
-
"application/vnd.jupyter.widget-view+json": {
|
| 618 |
-
"model_id": "b216c0bb3bd4457f9230b32b8d2ede1f",
|
| 619 |
-
"version_major": 2,
|
| 620 |
-
"version_minor": 0
|
| 621 |
-
},
|
| 622 |
-
"text/plain": [
|
| 623 |
-
" 0%| | 0/24 [00:00<?, ?it/s]"
|
| 624 |
-
]
|
| 625 |
-
},
|
| 626 |
-
"metadata": {},
|
| 627 |
-
"output_type": "display_data"
|
| 628 |
-
},
|
| 629 |
-
{
|
| 630 |
-
"data": {
|
| 631 |
-
"application/vnd.jupyter.widget-view+json": {
|
| 632 |
-
"model_id": "78d4bdad3dc34ba18f3074802c67bf61",
|
| 633 |
-
"version_major": 2,
|
| 634 |
-
"version_minor": 0
|
| 635 |
-
},
|
| 636 |
-
"text/plain": [
|
| 637 |
-
" 0%| | 0/24 [00:00<?, ?it/s]"
|
| 638 |
-
]
|
| 639 |
-
},
|
| 640 |
-
"metadata": {},
|
| 641 |
-
"output_type": "display_data"
|
| 642 |
-
},
|
| 643 |
-
{
|
| 644 |
-
"data": {
|
| 645 |
-
"application/vnd.jupyter.widget-view+json": {
|
| 646 |
-
"model_id": "e78d2d3247b442b78f06b38b65944887",
|
| 647 |
-
"version_major": 2,
|
| 648 |
-
"version_minor": 0
|
| 649 |
-
},
|
| 650 |
-
"text/plain": [
|
| 651 |
-
" 0%| | 0/24 [00:00<?, ?it/s]"
|
| 652 |
-
]
|
| 653 |
-
},
|
| 654 |
-
"metadata": {},
|
| 655 |
-
"output_type": "display_data"
|
| 656 |
-
},
|
| 657 |
-
{
|
| 658 |
-
"data": {
|
| 659 |
-
"application/vnd.jupyter.widget-view+json": {
|
| 660 |
-
"model_id": "5e1d909d5f3f4c26a11bd40978c57f4e",
|
| 661 |
-
"version_major": 2,
|
| 662 |
-
"version_minor": 0
|
| 663 |
-
},
|
| 664 |
-
"text/plain": [
|
| 665 |
-
" 0%| | 0/24 [00:00<?, ?it/s]"
|
| 666 |
-
]
|
| 667 |
-
},
|
| 668 |
-
"metadata": {},
|
| 669 |
-
"output_type": "display_data"
|
| 670 |
-
},
|
| 671 |
-
{
|
| 672 |
-
"data": {
|
| 673 |
-
"application/vnd.jupyter.widget-view+json": {
|
| 674 |
-
"model_id": "d1f56418378049b59ba1f9de7c5676f1",
|
| 675 |
-
"version_major": 2,
|
| 676 |
-
"version_minor": 0
|
| 677 |
-
},
|
| 678 |
-
"text/plain": [
|
| 679 |
-
" 0%| | 0/24 [00:00<?, ?it/s]"
|
| 680 |
-
]
|
| 681 |
-
},
|
| 682 |
-
"metadata": {},
|
| 683 |
-
"output_type": "display_data"
|
| 684 |
-
},
|
| 685 |
{
|
| 686 |
"name": "stdout",
|
| 687 |
"output_type": "stream",
|
| 688 |
"text": [
|
| 689 |
-
"saved model at ./outputs/model_state_09.pth\n",
|
| 690 |
-
"resumed nn_model from model_state.pth\n",
|
| 691 |
"Number of parameters for nn_model: 111048705\n",
|
| 692 |
-
"
|
| 693 |
-
"run_name = 0523-1624\n",
|
| 694 |
"Launching training on one GPU.\n",
|
| 695 |
"dataset content: <KeysViewHDF5 ['brightness_temp', 'density', 'kwargs', 'params', 'redshifts_distances', 'seeds', 'xH_box']>\n",
|
| 696 |
"51200 images can be loaded\n",
|
| 697 |
"field.shape = (64, 64, 514)\n",
|
| 698 |
"params keys = [b'ION_Tvir_MIN', b'HII_EFF_FACTOR']\n",
|
| 699 |
-
"loading
|
| 700 |
-
"images loaded: (
|
|
|
|
| 701 |
]
|
| 702 |
},
|
| 703 |
{
|
|
@@ -711,90 +547,19 @@
|
|
| 711 |
"name": "stdout",
|
| 712 |
"output_type": "stream",
|
| 713 |
"text": [
|
| 714 |
-
"
|
| 715 |
-
"
|
| 716 |
-
"params rescaled to [0.0, 0.9999922179553216]\n"
|
| 717 |
]
|
| 718 |
},
|
| 719 |
{
|
| 720 |
"data": {
|
| 721 |
"application/vnd.jupyter.widget-view+json": {
|
| 722 |
-
"model_id": "
|
| 723 |
-
"version_major": 2,
|
| 724 |
-
"version_minor": 0
|
| 725 |
-
},
|
| 726 |
-
"text/plain": [
|
| 727 |
-
" 0%| | 0/24 [00:00<?, ?it/s]"
|
| 728 |
-
]
|
| 729 |
-
},
|
| 730 |
-
"metadata": {},
|
| 731 |
-
"output_type": "display_data"
|
| 732 |
-
},
|
| 733 |
-
{
|
| 734 |
-
"data": {
|
| 735 |
-
"application/vnd.jupyter.widget-view+json": {
|
| 736 |
-
"model_id": "c7e04c734ba3482eb44344f7a4e37916",
|
| 737 |
-
"version_major": 2,
|
| 738 |
-
"version_minor": 0
|
| 739 |
-
},
|
| 740 |
-
"text/plain": [
|
| 741 |
-
" 0%| | 0/24 [00:00<?, ?it/s]"
|
| 742 |
-
]
|
| 743 |
-
},
|
| 744 |
-
"metadata": {},
|
| 745 |
-
"output_type": "display_data"
|
| 746 |
-
},
|
| 747 |
-
{
|
| 748 |
-
"data": {
|
| 749 |
-
"application/vnd.jupyter.widget-view+json": {
|
| 750 |
-
"model_id": "9a8c702e844f4fbaa295fb8f6d21503b",
|
| 751 |
-
"version_major": 2,
|
| 752 |
-
"version_minor": 0
|
| 753 |
-
},
|
| 754 |
-
"text/plain": [
|
| 755 |
-
" 0%| | 0/24 [00:00<?, ?it/s]"
|
| 756 |
-
]
|
| 757 |
-
},
|
| 758 |
-
"metadata": {},
|
| 759 |
-
"output_type": "display_data"
|
| 760 |
-
},
|
| 761 |
-
{
|
| 762 |
-
"data": {
|
| 763 |
-
"application/vnd.jupyter.widget-view+json": {
|
| 764 |
-
"model_id": "062be80bffee4540b396159acc223e6e",
|
| 765 |
-
"version_major": 2,
|
| 766 |
-
"version_minor": 0
|
| 767 |
-
},
|
| 768 |
-
"text/plain": [
|
| 769 |
-
" 0%| | 0/24 [00:00<?, ?it/s]"
|
| 770 |
-
]
|
| 771 |
-
},
|
| 772 |
-
"metadata": {},
|
| 773 |
-
"output_type": "display_data"
|
| 774 |
-
},
|
| 775 |
-
{
|
| 776 |
-
"data": {
|
| 777 |
-
"application/vnd.jupyter.widget-view+json": {
|
| 778 |
-
"model_id": "3849ae228e284a1a8c01235ffe2691aa",
|
| 779 |
-
"version_major": 2,
|
| 780 |
-
"version_minor": 0
|
| 781 |
-
},
|
| 782 |
-
"text/plain": [
|
| 783 |
-
" 0%| | 0/24 [00:00<?, ?it/s]"
|
| 784 |
-
]
|
| 785 |
-
},
|
| 786 |
-
"metadata": {},
|
| 787 |
-
"output_type": "display_data"
|
| 788 |
-
},
|
| 789 |
-
{
|
| 790 |
-
"data": {
|
| 791 |
-
"application/vnd.jupyter.widget-view+json": {
|
| 792 |
-
"model_id": "b359bb4eeb8b4ec58be692424d352164",
|
| 793 |
"version_major": 2,
|
| 794 |
"version_minor": 0
|
| 795 |
},
|
| 796 |
"text/plain": [
|
| 797 |
-
" 0%| | 0/
|
| 798 |
]
|
| 799 |
},
|
| 800 |
"metadata": {},
|
|
@@ -803,12 +568,12 @@
|
|
| 803 |
{
|
| 804 |
"data": {
|
| 805 |
"application/vnd.jupyter.widget-view+json": {
|
| 806 |
-
"model_id": "
|
| 807 |
"version_major": 2,
|
| 808 |
"version_minor": 0
|
| 809 |
},
|
| 810 |
"text/plain": [
|
| 811 |
-
" 0%| | 0/
|
| 812 |
]
|
| 813 |
},
|
| 814 |
"metadata": {},
|
|
@@ -817,12 +582,12 @@
|
|
| 817 |
{
|
| 818 |
"data": {
|
| 819 |
"application/vnd.jupyter.widget-view+json": {
|
| 820 |
-
"model_id": "
|
| 821 |
"version_major": 2,
|
| 822 |
"version_minor": 0
|
| 823 |
},
|
| 824 |
"text/plain": [
|
| 825 |
-
" 0%| | 0/
|
| 826 |
]
|
| 827 |
},
|
| 828 |
"metadata": {},
|
|
@@ -831,12 +596,12 @@
|
|
| 831 |
{
|
| 832 |
"data": {
|
| 833 |
"application/vnd.jupyter.widget-view+json": {
|
| 834 |
-
"model_id": "
|
| 835 |
"version_major": 2,
|
| 836 |
"version_minor": 0
|
| 837 |
},
|
| 838 |
"text/plain": [
|
| 839 |
-
" 0%| | 0/
|
| 840 |
]
|
| 841 |
},
|
| 842 |
"metadata": {},
|
|
@@ -845,12 +610,12 @@
|
|
| 845 |
{
|
| 846 |
"data": {
|
| 847 |
"application/vnd.jupyter.widget-view+json": {
|
| 848 |
-
"model_id": "
|
| 849 |
"version_major": 2,
|
| 850 |
"version_minor": 0
|
| 851 |
},
|
| 852 |
"text/plain": [
|
| 853 |
-
" 0%| | 0/
|
| 854 |
]
|
| 855 |
},
|
| 856 |
"metadata": {},
|
|
@@ -860,8 +625,7 @@
|
|
| 860 |
"source": [
|
| 861 |
"if __name__ == \"__main__\":\n",
|
| 862 |
" # args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)\n",
|
| 863 |
-
"
|
| 864 |
-
" repeat = 2\n",
|
| 865 |
" for i in range(repeat):\n",
|
| 866 |
" ddpm21cm = DDPM21CM()\n",
|
| 867 |
" print(f\"run_name = {ddpm21cm.config.run_name}\")\n",
|
|
|
|
| 244 |
" # dim = 2\n",
|
| 245 |
" dim = 2\n",
|
| 246 |
" stride = (2,2) if dim == 2 else (2,2,4)\n",
|
| 247 |
+
" num_image = 2560\n",
|
| 248 |
" HII_DIM = 64\n",
|
| 249 |
" num_redshift = 512#256#256#64#512#128\n",
|
| 250 |
" channel = 1\n",
|
| 251 |
" img_shape = (channel, HII_DIM, num_redshift) if dim == 2 else (channel, HII_DIM, HII_DIM, num_redshift)\n",
|
| 252 |
"\n",
|
| 253 |
+
" n_epoch = 5#2#5#25 # 120\n",
|
| 254 |
" num_timesteps = 1000#1000 # 1000, 500; DDPM time steps\n",
|
| 255 |
" batch_size = 10#20#2#100 # 10\n",
|
| 256 |
" # n_sample = 24 # 64, the number of samples in sampling process\n",
|
|
|
|
| 268 |
" # device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
|
| 269 |
" lrate = 1e-4\n",
|
| 270 |
" lr_warmup_steps = 0#5#00\n",
|
| 271 |
+
" output_dir = \"./outputs/\"\n",
|
| 272 |
+
" save_name = os.path.join(output_dir, 'model_state.pth')\n",
|
| 273 |
" # save_freq = 1 #10 # the period of saving model\n",
|
| 274 |
" # cond = True # if training using the conditional information\n",
|
| 275 |
" # lr_decay = False #True# if using the learning rate decay\n",
|
| 276 |
+
" resume = save_name # if resume from the trained checkpoints\n",
|
| 277 |
" # params_single = torch.tensor([0.2,0.80000023])\n",
|
| 278 |
" # params = torch.tile(params_single,(n_sample,1)).to(device)\n",
|
| 279 |
" # params = params\n",
|
| 280 |
" # data_dir = './data' # data directory\n",
|
| 281 |
"\n",
|
|
|
|
| 282 |
"\n",
|
| 283 |
" mixed_precision = \"fp16\"\n",
|
| 284 |
" gradient_accumulation_steps = 1\n",
|
|
|
|
| 313 |
" # initialize the unet\n",
|
| 314 |
" self.nn_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM, dim=config.dim, stride=config.stride)\n",
|
| 315 |
"\n",
|
| 316 |
+
" if config.resume and os.path.exists(config.resume):\n",
|
| 317 |
+
" # resume_file = os.path.join(config.output_dir, f\"{config.resume}\")\n",
|
| 318 |
+
" self.nn_model.load_state_dict(torch.load(config.resume)['unet_state_dict'])\n",
|
| 319 |
" print(f\"resumed nn_model from {config.resume}\")\n",
|
| 320 |
" # nn_model = ContextUnet(n_param=1, image_size=28)\n",
|
| 321 |
" self.nn_model.train()\n",
|
|
|
|
| 328 |
" # whether to use ema\n",
|
| 329 |
" if config.ema:\n",
|
| 330 |
" self.ema = EMA(config.ema_rate)\n",
|
| 331 |
+
" if config.resume and os.path.exists(config.resume):\n",
|
| 332 |
" self.ema_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM, dim=config.dim, stride=config.stride).to(config.device)\n",
|
| 333 |
+
" self.ema_model.load_state_dict(torch.load(config.resume)['ema_unet_state_dict'])\n",
|
| 334 |
" print(f\"resumed ema_model from {config.resume}\")\n",
|
| 335 |
" else:\n",
|
| 336 |
+
" self.ema_model = copy.deepcopy(self.nn_model).eval().requires_grad_(False)\n",
|
| 337 |
"\n",
|
| 338 |
" self.optimizer = torch.optim.AdamW(self.nn_model.parameters(), lr=config.lrate)\n",
|
| 339 |
" self.lr_scheduler = get_cosine_schedule_with_warmup(\n",
|
|
|
|
| 440 |
" commit_message = f\"{self.config.run_name}\",\n",
|
| 441 |
" ignore_patterns = [\"step_*\", \"epoch_*\", \"*.npy\", \"__pycache__\"],\n",
|
| 442 |
" )\n",
|
| 443 |
+
" if self.config.save_name:\n",
|
| 444 |
" model_state = {\n",
|
| 445 |
" 'epoch': ep,\n",
|
| 446 |
" 'unet_state_dict': self.nn_model.state_dict(),\n",
|
| 447 |
" 'ema_unet_state_dict': self.ema_model.state_dict(),\n",
|
| 448 |
" }\n",
|
| 449 |
+
" torch.save(model_state, self.config.save_name)\n",
|
| 450 |
+
" print('saved model at ' + self.config.save_name)\n",
|
| 451 |
" # print('saved model at ' + config.save_dir + f\"model_epoch_{ep}_test_{config.run_name}.pth\")\n",
|
| 452 |
"\n",
|
| 453 |
" def sample(self, file, params:torch.tensor=None, ema=False, entire=False):\n",
|
|
|
|
| 499 |
{
|
| 500 |
"data": {
|
| 501 |
"application/vnd.jupyter.widget-view+json": {
|
| 502 |
+
"model_id": "e0f355a0bc8b4592952af6c1ccd5d2fb",
|
| 503 |
"version_major": 2,
|
| 504 |
"version_minor": 0
|
| 505 |
},
|
|
|
|
| 509 |
},
|
| 510 |
"metadata": {},
|
| 511 |
"output_type": "display_data"
|
| 512 |
+
}
|
| 513 |
+
],
|
| 514 |
+
"source": [
|
| 515 |
+
"notebook_login()"
|
| 516 |
+
]
|
| 517 |
+
},
|
| 518 |
+
{
|
| 519 |
+
"cell_type": "code",
|
| 520 |
+
"execution_count": 7,
|
| 521 |
+
"metadata": {},
|
| 522 |
+
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 523 |
{
|
| 524 |
"name": "stdout",
|
| 525 |
"output_type": "stream",
|
| 526 |
"text": [
|
|
|
|
|
|
|
| 527 |
"Number of parameters for nn_model: 111048705\n",
|
| 528 |
+
"run_name = 0523-1704\n",
|
|
|
|
| 529 |
"Launching training on one GPU.\n",
|
| 530 |
"dataset content: <KeysViewHDF5 ['brightness_temp', 'density', 'kwargs', 'params', 'redshifts_distances', 'seeds', 'xH_box']>\n",
|
| 531 |
"51200 images can be loaded\n",
|
| 532 |
"field.shape = (64, 64, 514)\n",
|
| 533 |
"params keys = [b'ION_Tvir_MIN', b'HII_EFF_FACTOR']\n",
|
| 534 |
+
"loading 2560 images randomly\n",
|
| 535 |
+
"images loaded: (2560, 1, 64, 512)\n",
|
| 536 |
+
"params loaded: (2560, 2)\n"
|
| 537 |
]
|
| 538 |
},
|
| 539 |
{
|
|
|
|
| 547 |
"name": "stdout",
|
| 548 |
"output_type": "stream",
|
| 549 |
"text": [
|
| 550 |
+
"images rescaled to [-1.0, 1.1378462314605713]\n",
|
| 551 |
+
"params rescaled to [0.0, 0.9995994165819857]\n"
|
|
|
|
| 552 |
]
|
| 553 |
},
|
| 554 |
{
|
| 555 |
"data": {
|
| 556 |
"application/vnd.jupyter.widget-view+json": {
|
| 557 |
+
"model_id": "4d787d2fbdcf4575b7b17a6e5161f5ec",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 558 |
"version_major": 2,
|
| 559 |
"version_minor": 0
|
| 560 |
},
|
| 561 |
"text/plain": [
|
| 562 |
+
" 0%| | 0/256 [00:00<?, ?it/s]"
|
| 563 |
]
|
| 564 |
},
|
| 565 |
"metadata": {},
|
|
|
|
| 568 |
{
|
| 569 |
"data": {
|
| 570 |
"application/vnd.jupyter.widget-view+json": {
|
| 571 |
+
"model_id": "e67439e56e594ecfb3967edbfb3f0d60",
|
| 572 |
"version_major": 2,
|
| 573 |
"version_minor": 0
|
| 574 |
},
|
| 575 |
"text/plain": [
|
| 576 |
+
" 0%| | 0/256 [00:00<?, ?it/s]"
|
| 577 |
]
|
| 578 |
},
|
| 579 |
"metadata": {},
|
|
|
|
| 582 |
{
|
| 583 |
"data": {
|
| 584 |
"application/vnd.jupyter.widget-view+json": {
|
| 585 |
+
"model_id": "9ca7cb14960348fa8d83c90d773057ac",
|
| 586 |
"version_major": 2,
|
| 587 |
"version_minor": 0
|
| 588 |
},
|
| 589 |
"text/plain": [
|
| 590 |
+
" 0%| | 0/256 [00:00<?, ?it/s]"
|
| 591 |
]
|
| 592 |
},
|
| 593 |
"metadata": {},
|
|
|
|
| 596 |
{
|
| 597 |
"data": {
|
| 598 |
"application/vnd.jupyter.widget-view+json": {
|
| 599 |
+
"model_id": "a6368ae7b9fb4505b6b62d51c5d675ed",
|
| 600 |
"version_major": 2,
|
| 601 |
"version_minor": 0
|
| 602 |
},
|
| 603 |
"text/plain": [
|
| 604 |
+
" 0%| | 0/256 [00:00<?, ?it/s]"
|
| 605 |
]
|
| 606 |
},
|
| 607 |
"metadata": {},
|
|
|
|
| 610 |
{
|
| 611 |
"data": {
|
| 612 |
"application/vnd.jupyter.widget-view+json": {
|
| 613 |
+
"model_id": "d5a391c5bbfb4f6481c1f2ad6e754e24",
|
| 614 |
"version_major": 2,
|
| 615 |
"version_minor": 0
|
| 616 |
},
|
| 617 |
"text/plain": [
|
| 618 |
+
" 0%| | 0/256 [00:00<?, ?it/s]"
|
| 619 |
]
|
| 620 |
},
|
| 621 |
"metadata": {},
|
|
|
|
| 625 |
"source": [
|
| 626 |
"if __name__ == \"__main__\":\n",
|
| 627 |
" # args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)\n",
|
| 628 |
+
" repeat = 30\n",
|
|
|
|
| 629 |
" for i in range(repeat):\n",
|
| 630 |
" ddpm21cm = DDPM21CM()\n",
|
| 631 |
" print(f\"run_name = {ddpm21cm.config.run_name}\")\n",
|