0522-1710
Browse files- diffusion.ipynb +25 -18
diffusion.ipynb
CHANGED
|
@@ -33,7 +33,7 @@
|
|
| 33 |
{
|
| 34 |
"data": {
|
| 35 |
"application/vnd.jupyter.widget-view+json": {
|
| 36 |
-
"model_id": "
|
| 37 |
"version_major": 2,
|
| 38 |
"version_minor": 0
|
| 39 |
},
|
|
@@ -65,10 +65,10 @@
|
|
| 65 |
"import matplotlib.pyplot as plt\n",
|
| 66 |
"import numpy as np\n",
|
| 67 |
"import random\n",
|
| 68 |
-
"from abc import ABC, abstractmethod\n",
|
| 69 |
"import torch.nn.functional as F\n",
|
| 70 |
"import math\n",
|
| 71 |
-
"from PIL import Image\n",
|
| 72 |
"import os\n",
|
| 73 |
"from torch.utils.tensorboard import SummaryWriter\n",
|
| 74 |
"import copy\n",
|
|
@@ -273,7 +273,7 @@
|
|
| 273 |
"\n",
|
| 274 |
" n_epoch = 10#2#5#25 # 120\n",
|
| 275 |
" num_timesteps = 1000#1000 # 1000, 500; DDPM time steps\n",
|
| 276 |
-
" batch_size = 10#
|
| 277 |
" # n_sample = 24 # 64, the number of samples in sampling process\n",
|
| 278 |
" n_param = 2\n",
|
| 279 |
" guide_w = 0#-1#0#-1#0.1#[0,0.1] #[0,0.5,2] strength of generative guidance\n",
|
|
@@ -313,7 +313,7 @@
|
|
| 313 |
},
|
| 314 |
{
|
| 315 |
"cell_type": "code",
|
| 316 |
-
"execution_count":
|
| 317 |
"metadata": {},
|
| 318 |
"outputs": [
|
| 319 |
{
|
|
@@ -523,7 +523,7 @@
|
|
| 523 |
},
|
| 524 |
{
|
| 525 |
"cell_type": "code",
|
| 526 |
-
"execution_count":
|
| 527 |
"metadata": {},
|
| 528 |
"outputs": [
|
| 529 |
{
|
|
@@ -550,14 +550,14 @@
|
|
| 550 |
"name": "stdout",
|
| 551 |
"output_type": "stream",
|
| 552 |
"text": [
|
| 553 |
-
"images rescaled to [-1.0, 1.
|
| 554 |
-
"params rescaled to [0.0, 0.
|
| 555 |
]
|
| 556 |
},
|
| 557 |
{
|
| 558 |
"data": {
|
| 559 |
"application/vnd.jupyter.widget-view+json": {
|
| 560 |
-
"model_id": "
|
| 561 |
"version_major": 2,
|
| 562 |
"version_minor": 0
|
| 563 |
},
|
|
@@ -571,7 +571,7 @@
|
|
| 571 |
{
|
| 572 |
"data": {
|
| 573 |
"application/vnd.jupyter.widget-view+json": {
|
| 574 |
-
"model_id": "
|
| 575 |
"version_major": 2,
|
| 576 |
"version_minor": 0
|
| 577 |
},
|
|
@@ -585,7 +585,7 @@
|
|
| 585 |
{
|
| 586 |
"data": {
|
| 587 |
"application/vnd.jupyter.widget-view+json": {
|
| 588 |
-
"model_id": "
|
| 589 |
"version_major": 2,
|
| 590 |
"version_minor": 0
|
| 591 |
},
|
|
@@ -599,7 +599,7 @@
|
|
| 599 |
{
|
| 600 |
"data": {
|
| 601 |
"application/vnd.jupyter.widget-view+json": {
|
| 602 |
-
"model_id": "
|
| 603 |
"version_major": 2,
|
| 604 |
"version_minor": 0
|
| 605 |
},
|
|
@@ -613,7 +613,7 @@
|
|
| 613 |
{
|
| 614 |
"data": {
|
| 615 |
"application/vnd.jupyter.widget-view+json": {
|
| 616 |
-
"model_id": "
|
| 617 |
"version_major": 2,
|
| 618 |
"version_minor": 0
|
| 619 |
},
|
|
@@ -627,7 +627,7 @@
|
|
| 627 |
{
|
| 628 |
"data": {
|
| 629 |
"application/vnd.jupyter.widget-view+json": {
|
| 630 |
-
"model_id": "
|
| 631 |
"version_major": 2,
|
| 632 |
"version_minor": 0
|
| 633 |
},
|
|
@@ -641,7 +641,7 @@
|
|
| 641 |
{
|
| 642 |
"data": {
|
| 643 |
"application/vnd.jupyter.widget-view+json": {
|
| 644 |
-
"model_id": "
|
| 645 |
"version_major": 2,
|
| 646 |
"version_minor": 0
|
| 647 |
},
|
|
@@ -655,7 +655,7 @@
|
|
| 655 |
{
|
| 656 |
"data": {
|
| 657 |
"application/vnd.jupyter.widget-view+json": {
|
| 658 |
-
"model_id": "
|
| 659 |
"version_major": 2,
|
| 660 |
"version_minor": 0
|
| 661 |
},
|
|
@@ -669,7 +669,7 @@
|
|
| 669 |
{
|
| 670 |
"data": {
|
| 671 |
"application/vnd.jupyter.widget-view+json": {
|
| 672 |
-
"model_id": "
|
| 673 |
"version_major": 2,
|
| 674 |
"version_minor": 0
|
| 675 |
},
|
|
@@ -683,7 +683,7 @@
|
|
| 683 |
{
|
| 684 |
"data": {
|
| 685 |
"application/vnd.jupyter.widget-view+json": {
|
| 686 |
-
"model_id": "
|
| 687 |
"version_major": 2,
|
| 688 |
"version_minor": 0
|
| 689 |
},
|
|
@@ -735,6 +735,13 @@
|
|
| 735 |
"ddpm21cm.sample(\"./outputs/model_state_09.pth\")"
|
| 736 |
]
|
| 737 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 738 |
{
|
| 739 |
"cell_type": "code",
|
| 740 |
"execution_count": null,
|
|
|
|
| 33 |
{
|
| 34 |
"data": {
|
| 35 |
"application/vnd.jupyter.widget-view+json": {
|
| 36 |
+
"model_id": "b24ea18addd84753a509dd05eeb11cfe",
|
| 37 |
"version_major": 2,
|
| 38 |
"version_minor": 0
|
| 39 |
},
|
|
|
|
| 65 |
"import matplotlib.pyplot as plt\n",
|
| 66 |
"import numpy as np\n",
|
| 67 |
"import random\n",
|
| 68 |
+
"# from abc import ABC, abstractmethod\n",
|
| 69 |
"import torch.nn.functional as F\n",
|
| 70 |
"import math\n",
|
| 71 |
+
"# from PIL import Image\n",
|
| 72 |
"import os\n",
|
| 73 |
"from torch.utils.tensorboard import SummaryWriter\n",
|
| 74 |
"import copy\n",
|
|
|
|
| 273 |
"\n",
|
| 274 |
" n_epoch = 10#2#5#25 # 120\n",
|
| 275 |
" num_timesteps = 1000#1000 # 1000, 500; DDPM time steps\n",
|
| 276 |
+
" batch_size = 10#20#2#100 # 10\n",
|
| 277 |
" # n_sample = 24 # 64, the number of samples in sampling process\n",
|
| 278 |
" n_param = 2\n",
|
| 279 |
" guide_w = 0#-1#0#-1#0.1#[0,0.1] #[0,0.5,2] strength of generative guidance\n",
|
|
|
|
| 313 |
},
|
| 314 |
{
|
| 315 |
"cell_type": "code",
|
| 316 |
+
"execution_count": 6,
|
| 317 |
"metadata": {},
|
| 318 |
"outputs": [
|
| 319 |
{
|
|
|
|
| 523 |
},
|
| 524 |
{
|
| 525 |
"cell_type": "code",
|
| 526 |
+
"execution_count": 7,
|
| 527 |
"metadata": {},
|
| 528 |
"outputs": [
|
| 529 |
{
|
|
|
|
| 550 |
"name": "stdout",
|
| 551 |
"output_type": "stream",
|
| 552 |
"text": [
|
| 553 |
+
"images rescaled to [-1.0, 1.1875841617584229]\n",
|
| 554 |
+
"params rescaled to [0.0, 0.9999290410760016]\n"
|
| 555 |
]
|
| 556 |
},
|
| 557 |
{
|
| 558 |
"data": {
|
| 559 |
"application/vnd.jupyter.widget-view+json": {
|
| 560 |
+
"model_id": "6795b3f3a0ed4c999019c8be5af4e578",
|
| 561 |
"version_major": 2,
|
| 562 |
"version_minor": 0
|
| 563 |
},
|
|
|
|
| 571 |
{
|
| 572 |
"data": {
|
| 573 |
"application/vnd.jupyter.widget-view+json": {
|
| 574 |
+
"model_id": "54cb12b82b434237af28cbb72f7c071b",
|
| 575 |
"version_major": 2,
|
| 576 |
"version_minor": 0
|
| 577 |
},
|
|
|
|
| 585 |
{
|
| 586 |
"data": {
|
| 587 |
"application/vnd.jupyter.widget-view+json": {
|
| 588 |
+
"model_id": "649033a8004c459f8fe892f52a355d7d",
|
| 589 |
"version_major": 2,
|
| 590 |
"version_minor": 0
|
| 591 |
},
|
|
|
|
| 599 |
{
|
| 600 |
"data": {
|
| 601 |
"application/vnd.jupyter.widget-view+json": {
|
| 602 |
+
"model_id": "be5b7b4a88894d769b06dc0cd3daeed5",
|
| 603 |
"version_major": 2,
|
| 604 |
"version_minor": 0
|
| 605 |
},
|
|
|
|
| 613 |
{
|
| 614 |
"data": {
|
| 615 |
"application/vnd.jupyter.widget-view+json": {
|
| 616 |
+
"model_id": "d9ab19a2535d45ea8f795ace2466bbe2",
|
| 617 |
"version_major": 2,
|
| 618 |
"version_minor": 0
|
| 619 |
},
|
|
|
|
| 627 |
{
|
| 628 |
"data": {
|
| 629 |
"application/vnd.jupyter.widget-view+json": {
|
| 630 |
+
"model_id": "f6f6969943344210b1c530cbced3ea20",
|
| 631 |
"version_major": 2,
|
| 632 |
"version_minor": 0
|
| 633 |
},
|
|
|
|
| 641 |
{
|
| 642 |
"data": {
|
| 643 |
"application/vnd.jupyter.widget-view+json": {
|
| 644 |
+
"model_id": "3924f748eccd4bc6aa6bf5c271f9d8a9",
|
| 645 |
"version_major": 2,
|
| 646 |
"version_minor": 0
|
| 647 |
},
|
|
|
|
| 655 |
{
|
| 656 |
"data": {
|
| 657 |
"application/vnd.jupyter.widget-view+json": {
|
| 658 |
+
"model_id": "07ca3e6371d843fa8a05dc490bcc136a",
|
| 659 |
"version_major": 2,
|
| 660 |
"version_minor": 0
|
| 661 |
},
|
|
|
|
| 669 |
{
|
| 670 |
"data": {
|
| 671 |
"application/vnd.jupyter.widget-view+json": {
|
| 672 |
+
"model_id": "af2aa62cfad948589415315abedb32f1",
|
| 673 |
"version_major": 2,
|
| 674 |
"version_minor": 0
|
| 675 |
},
|
|
|
|
| 683 |
{
|
| 684 |
"data": {
|
| 685 |
"application/vnd.jupyter.widget-view+json": {
|
| 686 |
+
"model_id": "e23db426060c41b490fb09b6412ac258",
|
| 687 |
"version_major": 2,
|
| 688 |
"version_minor": 0
|
| 689 |
},
|
|
|
|
| 735 |
"ddpm21cm.sample(\"./outputs/model_state_09.pth\")"
|
| 736 |
]
|
| 737 |
},
|
| 738 |
+
{
|
| 739 |
+
"cell_type": "code",
|
| 740 |
+
"execution_count": null,
|
| 741 |
+
"metadata": {},
|
| 742 |
+
"outputs": [],
|
| 743 |
+
"source": []
|
| 744 |
+
},
|
| 745 |
{
|
| 746 |
"cell_type": "code",
|
| 747 |
"execution_count": null,
|