0522-1549
Browse files- diffusion.ipynb +149 -75
- load_h5.py +27 -0
diffusion.ipynb
CHANGED
|
@@ -21,7 +21,8 @@
|
|
| 21 |
"- 已擴展爲接受不同維度的情形\n",
|
| 22 |
"- 融合cond, guide_w, drop_out這些參數\n",
|
| 23 |
"- 生成的21cm圖像該暗的地方不夠暗,似乎換成MNIST的數字圖像就沒問題\n",
|
| 24 |
-
"- 我用diffusion模型生成MNIST的數字時發現,儘管生成的數據的範圍也存在負數數值,如-0.1,但畫出來的圖像卻是理想的黑色。數據的分佈與21cm的結果的分佈沒多大差別,我現在打算把代碼退回到21cm
|
|
|
|
| 25 |
]
|
| 26 |
},
|
| 27 |
{
|
|
@@ -32,7 +33,7 @@
|
|
| 32 |
{
|
| 33 |
"data": {
|
| 34 |
"application/vnd.jupyter.widget-view+json": {
|
| 35 |
-
"model_id": "
|
| 36 |
"version_major": 2,
|
| 37 |
"version_minor": 0
|
| 38 |
},
|
|
@@ -85,7 +86,7 @@
|
|
| 85 |
},
|
| 86 |
{
|
| 87 |
"cell_type": "code",
|
| 88 |
-
"execution_count":
|
| 89 |
"metadata": {},
|
| 90 |
"outputs": [],
|
| 91 |
"source": [
|
|
@@ -94,7 +95,7 @@
|
|
| 94 |
},
|
| 95 |
{
|
| 96 |
"cell_type": "code",
|
| 97 |
-
"execution_count":
|
| 98 |
"metadata": {},
|
| 99 |
"outputs": [],
|
| 100 |
"source": [
|
|
@@ -191,7 +192,7 @@
|
|
| 191 |
},
|
| 192 |
{
|
| 193 |
"cell_type": "code",
|
| 194 |
-
"execution_count":
|
| 195 |
"metadata": {},
|
| 196 |
"outputs": [],
|
| 197 |
"source": [
|
|
@@ -201,7 +202,7 @@
|
|
| 201 |
},
|
| 202 |
{
|
| 203 |
"cell_type": "code",
|
| 204 |
-
"execution_count":
|
| 205 |
"metadata": {},
|
| 206 |
"outputs": [],
|
| 207 |
"source": [
|
|
@@ -234,7 +235,7 @@
|
|
| 234 |
},
|
| 235 |
{
|
| 236 |
"cell_type": "code",
|
| 237 |
-
"execution_count":
|
| 238 |
"metadata": {},
|
| 239 |
"outputs": [],
|
| 240 |
"source": [
|
|
@@ -341,7 +342,7 @@
|
|
| 341 |
},
|
| 342 |
{
|
| 343 |
"cell_type": "code",
|
| 344 |
-
"execution_count":
|
| 345 |
"metadata": {},
|
| 346 |
"outputs": [],
|
| 347 |
"source": [
|
|
@@ -383,7 +384,7 @@
|
|
| 383 |
},
|
| 384 |
{
|
| 385 |
"cell_type": "code",
|
| 386 |
-
"execution_count":
|
| 387 |
"metadata": {},
|
| 388 |
"outputs": [],
|
| 389 |
"source": [
|
|
@@ -408,7 +409,7 @@
|
|
| 408 |
},
|
| 409 |
{
|
| 410 |
"cell_type": "code",
|
| 411 |
-
"execution_count":
|
| 412 |
"metadata": {},
|
| 413 |
"outputs": [],
|
| 414 |
"source": [
|
|
@@ -438,7 +439,7 @@
|
|
| 438 |
},
|
| 439 |
{
|
| 440 |
"cell_type": "code",
|
| 441 |
-
"execution_count":
|
| 442 |
"metadata": {},
|
| 443 |
"outputs": [],
|
| 444 |
"source": [
|
|
@@ -453,7 +454,7 @@
|
|
| 453 |
},
|
| 454 |
{
|
| 455 |
"cell_type": "code",
|
| 456 |
-
"execution_count":
|
| 457 |
"metadata": {},
|
| 458 |
"outputs": [],
|
| 459 |
"source": [
|
|
@@ -467,7 +468,7 @@
|
|
| 467 |
},
|
| 468 |
{
|
| 469 |
"cell_type": "code",
|
| 470 |
-
"execution_count":
|
| 471 |
"metadata": {},
|
| 472 |
"outputs": [],
|
| 473 |
"source": [
|
|
@@ -485,7 +486,7 @@
|
|
| 485 |
},
|
| 486 |
{
|
| 487 |
"cell_type": "code",
|
| 488 |
-
"execution_count":
|
| 489 |
"metadata": {},
|
| 490 |
"outputs": [],
|
| 491 |
"source": [
|
|
@@ -567,7 +568,7 @@
|
|
| 567 |
},
|
| 568 |
{
|
| 569 |
"cell_type": "code",
|
| 570 |
-
"execution_count":
|
| 571 |
"metadata": {},
|
| 572 |
"outputs": [],
|
| 573 |
"source": [
|
|
@@ -600,7 +601,7 @@
|
|
| 600 |
},
|
| 601 |
{
|
| 602 |
"cell_type": "code",
|
| 603 |
-
"execution_count":
|
| 604 |
"metadata": {},
|
| 605 |
"outputs": [],
|
| 606 |
"source": [
|
|
@@ -649,7 +650,7 @@
|
|
| 649 |
},
|
| 650 |
{
|
| 651 |
"cell_type": "code",
|
| 652 |
-
"execution_count":
|
| 653 |
"metadata": {},
|
| 654 |
"outputs": [],
|
| 655 |
"source": [
|
|
@@ -678,7 +679,7 @@
|
|
| 678 |
},
|
| 679 |
{
|
| 680 |
"cell_type": "code",
|
| 681 |
-
"execution_count":
|
| 682 |
"metadata": {},
|
| 683 |
"outputs": [],
|
| 684 |
"source": [
|
|
@@ -934,7 +935,7 @@
|
|
| 934 |
},
|
| 935 |
{
|
| 936 |
"cell_type": "code",
|
| 937 |
-
"execution_count":
|
| 938 |
"metadata": {},
|
| 939 |
"outputs": [],
|
| 940 |
"source": [
|
|
@@ -964,7 +965,7 @@
|
|
| 964 |
},
|
| 965 |
{
|
| 966 |
"cell_type": "code",
|
| 967 |
-
"execution_count":
|
| 968 |
"metadata": {},
|
| 969 |
"outputs": [],
|
| 970 |
"source": [
|
|
@@ -1030,7 +1031,7 @@
|
|
| 1030 |
},
|
| 1031 |
{
|
| 1032 |
"cell_type": "code",
|
| 1033 |
-
"execution_count":
|
| 1034 |
"metadata": {},
|
| 1035 |
"outputs": [],
|
| 1036 |
"source": [
|
|
@@ -1040,7 +1041,7 @@
|
|
| 1040 |
},
|
| 1041 |
{
|
| 1042 |
"cell_type": "code",
|
| 1043 |
-
"execution_count":
|
| 1044 |
"metadata": {},
|
| 1045 |
"outputs": [],
|
| 1046 |
"source": [
|
|
@@ -1049,7 +1050,7 @@
|
|
| 1049 |
},
|
| 1050 |
{
|
| 1051 |
"cell_type": "code",
|
| 1052 |
-
"execution_count":
|
| 1053 |
"metadata": {},
|
| 1054 |
"outputs": [],
|
| 1055 |
"source": [
|
|
@@ -1073,7 +1074,7 @@
|
|
| 1073 |
},
|
| 1074 |
{
|
| 1075 |
"cell_type": "code",
|
| 1076 |
-
"execution_count":
|
| 1077 |
"metadata": {},
|
| 1078 |
"outputs": [],
|
| 1079 |
"source": [
|
|
@@ -1271,7 +1272,7 @@
|
|
| 1271 |
},
|
| 1272 |
{
|
| 1273 |
"cell_type": "code",
|
| 1274 |
-
"execution_count":
|
| 1275 |
"metadata": {},
|
| 1276 |
"outputs": [
|
| 1277 |
{
|
|
@@ -1410,6 +1411,7 @@
|
|
| 1410 |
" del self.nn_model\n",
|
| 1411 |
" if self.config.ema:\n",
|
| 1412 |
" del self.ema_model\n",
|
|
|
|
| 1413 |
"\n",
|
| 1414 |
" def save(self, ep):\n",
|
| 1415 |
" # save model\n",
|
|
@@ -1438,7 +1440,7 @@
|
|
| 1438 |
" # n_sample = params.shape[0]\n",
|
| 1439 |
" params = params or torch.tensor([0.2,0.8]).repeat(5,1)\n",
|
| 1440 |
" assert params.dim() == 2, \"params must be a 2D torch.tensor\"\n",
|
| 1441 |
-
" print(\"params
|
| 1442 |
" # print(\"params =\", params)\n",
|
| 1443 |
" # print(\"len(params) =\", len(params))\n",
|
| 1444 |
" # model = self.ema_model if ema else self.nn_model\n",
|
|
@@ -1460,12 +1462,13 @@
|
|
| 1460 |
" # self.ema_model.load_state_dict(torch.load(os.path.join(config.output_dir, f\"{config.resume}\"))['ema_unet_state_dict'])\n",
|
| 1461 |
" # print(f\"resumed ema_model from {config.resume}\")\n",
|
| 1462 |
"\n",
|
| 1463 |
-
"
|
| 1464 |
-
"
|
| 1465 |
-
"
|
| 1466 |
-
"
|
| 1467 |
-
"
|
| 1468 |
-
"
|
|
|
|
| 1469 |
"\n",
|
| 1470 |
" np.save(os.path.join(self.config.output_dir, f\"{self.config.run_name}{'ema' if ema else None}.npy\"), x_last)\n",
|
| 1471 |
"\n",
|
|
@@ -1479,7 +1482,7 @@
|
|
| 1479 |
},
|
| 1480 |
{
|
| 1481 |
"cell_type": "code",
|
| 1482 |
-
"execution_count":
|
| 1483 |
"metadata": {},
|
| 1484 |
"outputs": [
|
| 1485 |
{
|
|
@@ -1506,14 +1509,14 @@
|
|
| 1506 |
"output_type": "stream",
|
| 1507 |
"text": [
|
| 1508 |
"params loaded: (200, 2)\n",
|
| 1509 |
-
"images rescaled to [-1.0, 1.
|
| 1510 |
-
"params rescaled to [0.0, 0.
|
| 1511 |
]
|
| 1512 |
},
|
| 1513 |
{
|
| 1514 |
"data": {
|
| 1515 |
"application/vnd.jupyter.widget-view+json": {
|
| 1516 |
-
"model_id": "
|
| 1517 |
"version_major": 2,
|
| 1518 |
"version_minor": 0
|
| 1519 |
},
|
|
@@ -1527,7 +1530,7 @@
|
|
| 1527 |
{
|
| 1528 |
"data": {
|
| 1529 |
"application/vnd.jupyter.widget-view+json": {
|
| 1530 |
-
"model_id": "
|
| 1531 |
"version_major": 2,
|
| 1532 |
"version_minor": 0
|
| 1533 |
},
|
|
@@ -1541,7 +1544,7 @@
|
|
| 1541 |
{
|
| 1542 |
"data": {
|
| 1543 |
"application/vnd.jupyter.widget-view+json": {
|
| 1544 |
-
"model_id": "
|
| 1545 |
"version_major": 2,
|
| 1546 |
"version_minor": 0
|
| 1547 |
},
|
|
@@ -1555,7 +1558,7 @@
|
|
| 1555 |
{
|
| 1556 |
"data": {
|
| 1557 |
"application/vnd.jupyter.widget-view+json": {
|
| 1558 |
-
"model_id": "
|
| 1559 |
"version_major": 2,
|
| 1560 |
"version_minor": 0
|
| 1561 |
},
|
|
@@ -1569,7 +1572,7 @@
|
|
| 1569 |
{
|
| 1570 |
"data": {
|
| 1571 |
"application/vnd.jupyter.widget-view+json": {
|
| 1572 |
-
"model_id": "
|
| 1573 |
"version_major": 2,
|
| 1574 |
"version_minor": 0
|
| 1575 |
},
|
|
@@ -1583,7 +1586,7 @@
|
|
| 1583 |
{
|
| 1584 |
"data": {
|
| 1585 |
"application/vnd.jupyter.widget-view+json": {
|
| 1586 |
-
"model_id": "
|
| 1587 |
"version_major": 2,
|
| 1588 |
"version_minor": 0
|
| 1589 |
},
|
|
@@ -1597,7 +1600,7 @@
|
|
| 1597 |
{
|
| 1598 |
"data": {
|
| 1599 |
"application/vnd.jupyter.widget-view+json": {
|
| 1600 |
-
"model_id": "
|
| 1601 |
"version_major": 2,
|
| 1602 |
"version_minor": 0
|
| 1603 |
},
|
|
@@ -1611,7 +1614,7 @@
|
|
| 1611 |
{
|
| 1612 |
"data": {
|
| 1613 |
"application/vnd.jupyter.widget-view+json": {
|
| 1614 |
-
"model_id": "
|
| 1615 |
"version_major": 2,
|
| 1616 |
"version_minor": 0
|
| 1617 |
},
|
|
@@ -1625,7 +1628,7 @@
|
|
| 1625 |
{
|
| 1626 |
"data": {
|
| 1627 |
"application/vnd.jupyter.widget-view+json": {
|
| 1628 |
-
"model_id": "
|
| 1629 |
"version_major": 2,
|
| 1630 |
"version_minor": 0
|
| 1631 |
},
|
|
@@ -1639,7 +1642,7 @@
|
|
| 1639 |
{
|
| 1640 |
"data": {
|
| 1641 |
"application/vnd.jupyter.widget-view+json": {
|
| 1642 |
-
"model_id": "
|
| 1643 |
"version_major": 2,
|
| 1644 |
"version_minor": 0
|
| 1645 |
},
|
|
@@ -1664,13 +1667,18 @@
|
|
| 1664 |
"name": "stdout",
|
| 1665 |
"output_type": "stream",
|
| 1666 |
"text": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1667 |
"nn_model resumed from ./outputs/model_state_09.pth\n"
|
| 1668 |
]
|
| 1669 |
},
|
| 1670 |
{
|
| 1671 |
"data": {
|
| 1672 |
"application/vnd.jupyter.widget-view+json": {
|
| 1673 |
-
"model_id": "
|
| 1674 |
"version_major": 2,
|
| 1675 |
"version_minor": 0
|
| 1676 |
},
|
|
@@ -1683,27 +1691,24 @@
|
|
| 1683 |
},
|
| 1684 |
{
|
| 1685 |
"ename": "RuntimeError",
|
| 1686 |
-
"evalue": "CUDA out of memory. Tried to allocate
|
| 1687 |
"output_type": "error",
|
| 1688 |
"traceback": [
|
| 1689 |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
| 1690 |
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
|
| 1691 |
-
"Cell \u001b[0;32mIn[
|
| 1692 |
-
"Cell \u001b[0;32mIn[
|
| 1693 |
-
"Cell \u001b[0;32mIn[
|
| 1694 |
-
"File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
|
| 1695 |
-
"Cell \u001b[0;32mIn[17], line 241\u001b[0m, in \u001b[0;36mContextUnet.forward\u001b[0;34m(self, x, timesteps, y)\u001b[0m\n\u001b[1;32m 237\u001b[0m \u001b[39mfor\u001b[39;00m module \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39moutput_blocks:\n\u001b[1;32m 238\u001b[0m \u001b[39m# print(\"for module in self.output_blocks, h.shape =\", h.shape)\u001b[39;00m\n\u001b[1;32m 239\u001b[0m \u001b[39m# print(\"len(hs) =\", len(hs), \", hs[-1].shape =\", hs[-1].shape)\u001b[39;00m\n\u001b[1;32m 240\u001b[0m h \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mcat([h, hs\u001b[39m.\u001b[39mpop()], dim\u001b[39m=\u001b[39m\u001b[39m1\u001b[39m)\n\u001b[0;32m--> 241\u001b[0m h \u001b[39m=\u001b[39m module(h, emb)\n\u001b[1;32m 242\u001b[0m \u001b[39m# print(\"module decoder, h.shape =\", h.shape)\u001b[39;00m\n\u001b[1;32m 244\u001b[0m h \u001b[39m=\u001b[39m h\u001b[39m.\u001b[39mtype(x\u001b[39m.\u001b[39mdtype)\n",
|
| 1696 |
"File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
|
| 1697 |
-
"Cell \u001b[0;32mIn[
|
| 1698 |
"File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
|
| 1699 |
-
"Cell \u001b[0;32mIn[13], line
|
| 1700 |
"File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
|
| 1701 |
-
"
|
| 1702 |
"File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
|
| 1703 |
-
"Cell \u001b[0;32mIn[
|
| 1704 |
-
"File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/
|
| 1705 |
-
"
|
| 1706 |
-
"\u001b[0;31mRuntimeError\u001b[0m: CUDA out of memory. Tried to allocate 16.00 MiB (GPU 0; 23.64 GiB total capacity; 22.57 GiB already allocated; 12.50 MiB free; 22.60 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF"
|
| 1707 |
]
|
| 1708 |
}
|
| 1709 |
],
|
|
@@ -1868,39 +1873,108 @@
|
|
| 1868 |
},
|
| 1869 |
{
|
| 1870 |
"cell_type": "code",
|
| 1871 |
-
"execution_count":
|
| 1872 |
"metadata": {},
|
| 1873 |
"outputs": [
|
| 1874 |
{
|
| 1875 |
-
"
|
| 1876 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1877 |
"output_type": "error",
|
| 1878 |
"traceback": [
|
| 1879 |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
| 1880 |
-
"\u001b[0;
|
| 1881 |
-
"Cell \u001b[0;32mIn[
|
| 1882 |
-
"\u001b[0;
|
|
|
|
| 1883 |
]
|
| 1884 |
}
|
| 1885 |
],
|
| 1886 |
"source": [
|
|
|
|
| 1887 |
"# ddpm = DDPMScheduler(betas=(1e-4, 0.02), num_timesteps=config.num_timesteps, device=config.device)\n",
|
|
|
|
| 1888 |
"\n",
|
| 1889 |
-
"
|
| 1890 |
-
"
|
| 1891 |
-
"
|
| 1892 |
-
"#
|
| 1893 |
-
"#
|
| 1894 |
-
"
|
| 1895 |
-
"
|
| 1896 |
"\n",
|
| 1897 |
"# n_sample = 20\n",
|
| 1898 |
-
"
|
| 1899 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1900 |
"\n",
|
| 1901 |
-
"
|
| 1902 |
]
|
| 1903 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1904 |
{
|
| 1905 |
"cell_type": "code",
|
| 1906 |
"execution_count": 21,
|
|
|
|
| 21 |
"- 已擴展爲接受不同維度的情形\n",
|
| 22 |
"- 融合cond, guide_w, drop_out這些參數\n",
|
| 23 |
"- 生成的21cm圖像該暗的地方不夠暗,似乎換成MNIST的數字圖像就沒問題\n",
|
| 24 |
+
"- 我用diffusion模型生成MNIST的數字時發現,儘管生成的數據的範圍也存在負數數值,如-0.1,但畫出來的圖像卻是理想的黑色。數據的分佈與21cm的結果的分佈沒多大差別,我現在打算把代碼退回到21cm的情形\n",
|
| 25 |
+
"- 我統一了ddpm21cm這個module,能統一實現訓練和生成樣本,但目前有個bug, sample時總是會cuda out of memory,然而單獨resume model並sample就不會。"
|
| 26 |
]
|
| 27 |
},
|
| 28 |
{
|
|
|
|
| 33 |
{
|
| 34 |
"data": {
|
| 35 |
"application/vnd.jupyter.widget-view+json": {
|
| 36 |
+
"model_id": "4f2bbf6f5e904828bc65afc7ad97df36",
|
| 37 |
"version_major": 2,
|
| 38 |
"version_minor": 0
|
| 39 |
},
|
|
|
|
| 86 |
},
|
| 87 |
{
|
| 88 |
"cell_type": "code",
|
| 89 |
+
"execution_count": 3,
|
| 90 |
"metadata": {},
|
| 91 |
"outputs": [],
|
| 92 |
"source": [
|
|
|
|
| 95 |
},
|
| 96 |
{
|
| 97 |
"cell_type": "code",
|
| 98 |
+
"execution_count": 4,
|
| 99 |
"metadata": {},
|
| 100 |
"outputs": [],
|
| 101 |
"source": [
|
|
|
|
| 192 |
},
|
| 193 |
{
|
| 194 |
"cell_type": "code",
|
| 195 |
+
"execution_count": 5,
|
| 196 |
"metadata": {},
|
| 197 |
"outputs": [],
|
| 198 |
"source": [
|
|
|
|
| 202 |
},
|
| 203 |
{
|
| 204 |
"cell_type": "code",
|
| 205 |
+
"execution_count": 6,
|
| 206 |
"metadata": {},
|
| 207 |
"outputs": [],
|
| 208 |
"source": [
|
|
|
|
| 235 |
},
|
| 236 |
{
|
| 237 |
"cell_type": "code",
|
| 238 |
+
"execution_count": 7,
|
| 239 |
"metadata": {},
|
| 240 |
"outputs": [],
|
| 241 |
"source": [
|
|
|
|
| 342 |
},
|
| 343 |
{
|
| 344 |
"cell_type": "code",
|
| 345 |
+
"execution_count": 8,
|
| 346 |
"metadata": {},
|
| 347 |
"outputs": [],
|
| 348 |
"source": [
|
|
|
|
| 384 |
},
|
| 385 |
{
|
| 386 |
"cell_type": "code",
|
| 387 |
+
"execution_count": 9,
|
| 388 |
"metadata": {},
|
| 389 |
"outputs": [],
|
| 390 |
"source": [
|
|
|
|
| 409 |
},
|
| 410 |
{
|
| 411 |
"cell_type": "code",
|
| 412 |
+
"execution_count": 10,
|
| 413 |
"metadata": {},
|
| 414 |
"outputs": [],
|
| 415 |
"source": [
|
|
|
|
| 439 |
},
|
| 440 |
{
|
| 441 |
"cell_type": "code",
|
| 442 |
+
"execution_count": 11,
|
| 443 |
"metadata": {},
|
| 444 |
"outputs": [],
|
| 445 |
"source": [
|
|
|
|
| 454 |
},
|
| 455 |
{
|
| 456 |
"cell_type": "code",
|
| 457 |
+
"execution_count": 12,
|
| 458 |
"metadata": {},
|
| 459 |
"outputs": [],
|
| 460 |
"source": [
|
|
|
|
| 468 |
},
|
| 469 |
{
|
| 470 |
"cell_type": "code",
|
| 471 |
+
"execution_count": 13,
|
| 472 |
"metadata": {},
|
| 473 |
"outputs": [],
|
| 474 |
"source": [
|
|
|
|
| 486 |
},
|
| 487 |
{
|
| 488 |
"cell_type": "code",
|
| 489 |
+
"execution_count": 14,
|
| 490 |
"metadata": {},
|
| 491 |
"outputs": [],
|
| 492 |
"source": [
|
|
|
|
| 568 |
},
|
| 569 |
{
|
| 570 |
"cell_type": "code",
|
| 571 |
+
"execution_count": 15,
|
| 572 |
"metadata": {},
|
| 573 |
"outputs": [],
|
| 574 |
"source": [
|
|
|
|
| 601 |
},
|
| 602 |
{
|
| 603 |
"cell_type": "code",
|
| 604 |
+
"execution_count": 16,
|
| 605 |
"metadata": {},
|
| 606 |
"outputs": [],
|
| 607 |
"source": [
|
|
|
|
| 650 |
},
|
| 651 |
{
|
| 652 |
"cell_type": "code",
|
| 653 |
+
"execution_count": 17,
|
| 654 |
"metadata": {},
|
| 655 |
"outputs": [],
|
| 656 |
"source": [
|
|
|
|
| 679 |
},
|
| 680 |
{
|
| 681 |
"cell_type": "code",
|
| 682 |
+
"execution_count": 18,
|
| 683 |
"metadata": {},
|
| 684 |
"outputs": [],
|
| 685 |
"source": [
|
|
|
|
| 935 |
},
|
| 936 |
{
|
| 937 |
"cell_type": "code",
|
| 938 |
+
"execution_count": 19,
|
| 939 |
"metadata": {},
|
| 940 |
"outputs": [],
|
| 941 |
"source": [
|
|
|
|
| 965 |
},
|
| 966 |
{
|
| 967 |
"cell_type": "code",
|
| 968 |
+
"execution_count": 20,
|
| 969 |
"metadata": {},
|
| 970 |
"outputs": [],
|
| 971 |
"source": [
|
|
|
|
| 1031 |
},
|
| 1032 |
{
|
| 1033 |
"cell_type": "code",
|
| 1034 |
+
"execution_count": 21,
|
| 1035 |
"metadata": {},
|
| 1036 |
"outputs": [],
|
| 1037 |
"source": [
|
|
|
|
| 1041 |
},
|
| 1042 |
{
|
| 1043 |
"cell_type": "code",
|
| 1044 |
+
"execution_count": 22,
|
| 1045 |
"metadata": {},
|
| 1046 |
"outputs": [],
|
| 1047 |
"source": [
|
|
|
|
| 1050 |
},
|
| 1051 |
{
|
| 1052 |
"cell_type": "code",
|
| 1053 |
+
"execution_count": 23,
|
| 1054 |
"metadata": {},
|
| 1055 |
"outputs": [],
|
| 1056 |
"source": [
|
|
|
|
| 1074 |
},
|
| 1075 |
{
|
| 1076 |
"cell_type": "code",
|
| 1077 |
+
"execution_count": 24,
|
| 1078 |
"metadata": {},
|
| 1079 |
"outputs": [],
|
| 1080 |
"source": [
|
|
|
|
| 1272 |
},
|
| 1273 |
{
|
| 1274 |
"cell_type": "code",
|
| 1275 |
+
"execution_count": 25,
|
| 1276 |
"metadata": {},
|
| 1277 |
"outputs": [
|
| 1278 |
{
|
|
|
|
| 1411 |
" del self.nn_model\n",
|
| 1412 |
" if self.config.ema:\n",
|
| 1413 |
" del self.ema_model\n",
|
| 1414 |
+
" torch.cuda.empty_cache()\n",
|
| 1415 |
"\n",
|
| 1416 |
" def save(self, ep):\n",
|
| 1417 |
" # save model\n",
|
|
|
|
| 1440 |
" # n_sample = params.shape[0]\n",
|
| 1441 |
" params = params or torch.tensor([0.2,0.8]).repeat(5,1)\n",
|
| 1442 |
" assert params.dim() == 2, \"params must be a 2D torch.tensor\"\n",
|
| 1443 |
+
" print(\"params =\", params)\n",
|
| 1444 |
" # print(\"params =\", params)\n",
|
| 1445 |
" # print(\"len(params) =\", len(params))\n",
|
| 1446 |
" # model = self.ema_model if ema else self.nn_model\n",
|
|
|
|
| 1462 |
" # self.ema_model.load_state_dict(torch.load(os.path.join(config.output_dir, f\"{config.resume}\"))['ema_unet_state_dict'])\n",
|
| 1463 |
" # print(f\"resumed ema_model from {config.resume}\")\n",
|
| 1464 |
"\n",
|
| 1465 |
+
" with torch.no_grad():\n",
|
| 1466 |
+
" x_last, x_entire = self.ddpm.sample(\n",
|
| 1467 |
+
" nn_model=nn_model, \n",
|
| 1468 |
+
" params=params.to(self.config.device), \n",
|
| 1469 |
+
" device=self.config.device, \n",
|
| 1470 |
+
" guide_w=self.config.guide_w\n",
|
| 1471 |
+
" )\n",
|
| 1472 |
"\n",
|
| 1473 |
" np.save(os.path.join(self.config.output_dir, f\"{self.config.run_name}{'ema' if ema else None}.npy\"), x_last)\n",
|
| 1474 |
"\n",
|
|
|
|
| 1482 |
},
|
| 1483 |
{
|
| 1484 |
"cell_type": "code",
|
| 1485 |
+
"execution_count": 26,
|
| 1486 |
"metadata": {},
|
| 1487 |
"outputs": [
|
| 1488 |
{
|
|
|
|
| 1509 |
"output_type": "stream",
|
| 1510 |
"text": [
|
| 1511 |
"params loaded: (200, 2)\n",
|
| 1512 |
+
"images rescaled to [-1.0, 1.064338207244873]\n",
|
| 1513 |
+
"params rescaled to [0.0, 0.9988593502151616]\n"
|
| 1514 |
]
|
| 1515 |
},
|
| 1516 |
{
|
| 1517 |
"data": {
|
| 1518 |
"application/vnd.jupyter.widget-view+json": {
|
| 1519 |
+
"model_id": "2e0b629831714bc2b32e25d44a72f4b3",
|
| 1520 |
"version_major": 2,
|
| 1521 |
"version_minor": 0
|
| 1522 |
},
|
|
|
|
| 1530 |
{
|
| 1531 |
"data": {
|
| 1532 |
"application/vnd.jupyter.widget-view+json": {
|
| 1533 |
+
"model_id": "c634a180ede04f3cb09ab74daf0401c6",
|
| 1534 |
"version_major": 2,
|
| 1535 |
"version_minor": 0
|
| 1536 |
},
|
|
|
|
| 1544 |
{
|
| 1545 |
"data": {
|
| 1546 |
"application/vnd.jupyter.widget-view+json": {
|
| 1547 |
+
"model_id": "6f3a0791c42b4d7e958f2a9d57f64de8",
|
| 1548 |
"version_major": 2,
|
| 1549 |
"version_minor": 0
|
| 1550 |
},
|
|
|
|
| 1558 |
{
|
| 1559 |
"data": {
|
| 1560 |
"application/vnd.jupyter.widget-view+json": {
|
| 1561 |
+
"model_id": "9dce2de3e8a14aee83e2b182dc06608f",
|
| 1562 |
"version_major": 2,
|
| 1563 |
"version_minor": 0
|
| 1564 |
},
|
|
|
|
| 1572 |
{
|
| 1573 |
"data": {
|
| 1574 |
"application/vnd.jupyter.widget-view+json": {
|
| 1575 |
+
"model_id": "d4596bdc71cc4d4cb780442b97849883",
|
| 1576 |
"version_major": 2,
|
| 1577 |
"version_minor": 0
|
| 1578 |
},
|
|
|
|
| 1586 |
{
|
| 1587 |
"data": {
|
| 1588 |
"application/vnd.jupyter.widget-view+json": {
|
| 1589 |
+
"model_id": "6e68847216504241b81ebcb71c48f687",
|
| 1590 |
"version_major": 2,
|
| 1591 |
"version_minor": 0
|
| 1592 |
},
|
|
|
|
| 1600 |
{
|
| 1601 |
"data": {
|
| 1602 |
"application/vnd.jupyter.widget-view+json": {
|
| 1603 |
+
"model_id": "830c25eb902a47e7997dcdb40099c5a4",
|
| 1604 |
"version_major": 2,
|
| 1605 |
"version_minor": 0
|
| 1606 |
},
|
|
|
|
| 1614 |
{
|
| 1615 |
"data": {
|
| 1616 |
"application/vnd.jupyter.widget-view+json": {
|
| 1617 |
+
"model_id": "87fdac7b595c4d0ea7258ee8bb35de17",
|
| 1618 |
"version_major": 2,
|
| 1619 |
"version_minor": 0
|
| 1620 |
},
|
|
|
|
| 1628 |
{
|
| 1629 |
"data": {
|
| 1630 |
"application/vnd.jupyter.widget-view+json": {
|
| 1631 |
+
"model_id": "b9f6be95f4bd403d85f6df34756e7b8d",
|
| 1632 |
"version_major": 2,
|
| 1633 |
"version_minor": 0
|
| 1634 |
},
|
|
|
|
| 1642 |
{
|
| 1643 |
"data": {
|
| 1644 |
"application/vnd.jupyter.widget-view+json": {
|
| 1645 |
+
"model_id": "28ec5d881b37440ba5f4c863fc552c17",
|
| 1646 |
"version_major": 2,
|
| 1647 |
"version_minor": 0
|
| 1648 |
},
|
|
|
|
| 1667 |
"name": "stdout",
|
| 1668 |
"output_type": "stream",
|
| 1669 |
"text": [
|
| 1670 |
+
"params = tensor([[0.2000, 0.8000],\n",
|
| 1671 |
+
" [0.2000, 0.8000],\n",
|
| 1672 |
+
" [0.2000, 0.8000],\n",
|
| 1673 |
+
" [0.2000, 0.8000],\n",
|
| 1674 |
+
" [0.2000, 0.8000]])\n",
|
| 1675 |
"nn_model resumed from ./outputs/model_state_09.pth\n"
|
| 1676 |
]
|
| 1677 |
},
|
| 1678 |
{
|
| 1679 |
"data": {
|
| 1680 |
"application/vnd.jupyter.widget-view+json": {
|
| 1681 |
+
"model_id": "58944c3b1e4f42bb8771f776c35a90a7",
|
| 1682 |
"version_major": 2,
|
| 1683 |
"version_minor": 0
|
| 1684 |
},
|
|
|
|
| 1691 |
},
|
| 1692 |
{
|
| 1693 |
"ename": "RuntimeError",
|
| 1694 |
+
"evalue": "CUDA out of memory. Tried to allocate 640.00 MiB (GPU 0; 23.64 GiB total capacity; 21.65 GiB already allocated; 432.50 MiB free; 22.23 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF",
|
| 1695 |
"output_type": "error",
|
| 1696 |
"traceback": [
|
| 1697 |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
| 1698 |
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
|
| 1699 |
+
"Cell \u001b[0;32mIn[26], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m ddpm21cm\u001b[39m.\u001b[39;49msample(\u001b[39m\"\u001b[39;49m\u001b[39m./outputs/model_state_09.pth\u001b[39;49m\u001b[39m\"\u001b[39;49m)\n",
|
| 1700 |
+
"Cell \u001b[0;32mIn[25], line 177\u001b[0m, in \u001b[0;36mDDPM21CM.sample\u001b[0;34m(self, file, params, ema, entire)\u001b[0m\n\u001b[1;32m 171\u001b[0m nn_model\u001b[39m.\u001b[39meval()\n\u001b[1;32m 173\u001b[0m \u001b[39m# self.ema_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM, dim=config.dim, stride=config.stride).to(config.device)\u001b[39;00m\n\u001b[1;32m 174\u001b[0m \u001b[39m# self.ema_model.load_state_dict(torch.load(os.path.join(config.output_dir, f\"{config.resume}\"))['ema_unet_state_dict'])\u001b[39;00m\n\u001b[1;32m 175\u001b[0m \u001b[39m# print(f\"resumed ema_model from {config.resume}\")\u001b[39;00m\n\u001b[0;32m--> 177\u001b[0m x_last, x_entire \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mddpm\u001b[39m.\u001b[39;49msample(\n\u001b[1;32m 178\u001b[0m nn_model\u001b[39m=\u001b[39;49mnn_model, \n\u001b[1;32m 179\u001b[0m params\u001b[39m=\u001b[39;49mparams\u001b[39m.\u001b[39;49mto(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mconfig\u001b[39m.\u001b[39;49mdevice), \n\u001b[1;32m 180\u001b[0m device\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mconfig\u001b[39m.\u001b[39;49mdevice, \n\u001b[1;32m 181\u001b[0m guide_w\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mconfig\u001b[39m.\u001b[39;49mguide_w\n\u001b[1;32m 182\u001b[0m )\n\u001b[1;32m 184\u001b[0m np\u001b[39m.\u001b[39msave(os\u001b[39m.\u001b[39mpath\u001b[39m.\u001b[39mjoin(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconfig\u001b[39m.\u001b[39moutput_dir, \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m{\u001b[39;00m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconfig\u001b[39m.\u001b[39mrun_name\u001b[39m}\u001b[39;00m\u001b[39m{\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39mema\u001b[39m\u001b[39m'\u001b[39m\u001b[39m \u001b[39m\u001b[39mif\u001b[39;00m\u001b[39m \u001b[39mema\u001b[39m \u001b[39m\u001b[39melse\u001b[39;00m\u001b[39m \u001b[39m\u001b[39mNone\u001b[39;00m\u001b[39m}\u001b[39;00m\u001b[39m.npy\u001b[39m\u001b[39m\"\u001b[39m), x_last)\n\u001b[1;32m 186\u001b[0m \u001b[39mif\u001b[39;00m entire:\n",
|
| 1701 |
+
"Cell \u001b[0;32mIn[7], line 75\u001b[0m, in \u001b[0;36mDDPMScheduler.sample\u001b[0;34m(self, nn_model, params, device, guide_w)\u001b[0m\n\u001b[1;32m 71\u001b[0m t_is \u001b[39m=\u001b[39m t_is\u001b[39m.\u001b[39mrepeat(\u001b[39m2\u001b[39m)\n\u001b[1;32m 73\u001b[0m \u001b[39m# split predictions and compute weighting\u001b[39;00m\n\u001b[1;32m 74\u001b[0m \u001b[39m# print(\"nn_model input shape\", x_i.shape, t_is.shape, c_i.shape)\u001b[39;00m\n\u001b[0;32m---> 75\u001b[0m eps \u001b[39m=\u001b[39m nn_model(x_i, t_is, c_i)\n\u001b[1;32m 76\u001b[0m eps1 \u001b[39m=\u001b[39m eps[:n_sample]\n\u001b[1;32m 77\u001b[0m eps2 \u001b[39m=\u001b[39m eps[n_sample:]\n",
|
|
|
|
|
|
|
| 1702 |
"File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
|
| 1703 |
+
"Cell \u001b[0;32mIn[18], line 241\u001b[0m, in \u001b[0;36mContextUnet.forward\u001b[0;34m(self, x, timesteps, y)\u001b[0m\n\u001b[1;32m 237\u001b[0m \u001b[39mfor\u001b[39;00m module \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39moutput_blocks:\n\u001b[1;32m 238\u001b[0m \u001b[39m# print(\"for module in self.output_blocks, h.shape =\", h.shape)\u001b[39;00m\n\u001b[1;32m 239\u001b[0m \u001b[39m# print(\"len(hs) =\", len(hs), \", hs[-1].shape =\", hs[-1].shape)\u001b[39;00m\n\u001b[1;32m 240\u001b[0m h \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mcat([h, hs\u001b[39m.\u001b[39mpop()], dim\u001b[39m=\u001b[39m\u001b[39m1\u001b[39m)\n\u001b[0;32m--> 241\u001b[0m h \u001b[39m=\u001b[39m module(h, emb)\n\u001b[1;32m 242\u001b[0m \u001b[39m# print(\"module decoder, h.shape =\", h.shape)\u001b[39;00m\n\u001b[1;32m 244\u001b[0m h \u001b[39m=\u001b[39m h\u001b[39m.\u001b[39mtype(x\u001b[39m.\u001b[39mdtype)\n",
|
| 1704 |
"File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
|
| 1705 |
+
"Cell \u001b[0;32mIn[13], line 7\u001b[0m, in \u001b[0;36mTimestepEmbedSequential.forward\u001b[0;34m(self, x, emb, encoder_out)\u001b[0m\n\u001b[1;32m 5\u001b[0m x \u001b[39m=\u001b[39m layer(x, emb)\n\u001b[1;32m 6\u001b[0m \u001b[39melif\u001b[39;00m \u001b[39misinstance\u001b[39m(layer, AttentionBlock):\n\u001b[0;32m----> 7\u001b[0m x \u001b[39m=\u001b[39m layer(x, encoder_out)\n\u001b[1;32m 8\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 9\u001b[0m x \u001b[39m=\u001b[39m layer(x)\n",
|
| 1706 |
"File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
|
| 1707 |
+
"Cell \u001b[0;32mIn[16], line 37\u001b[0m, in \u001b[0;36mAttentionBlock.forward\u001b[0;34m(self, x, encoder_out)\u001b[0m\n\u001b[1;32m 35\u001b[0m h \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mattention(qkv, encoder_out)\n\u001b[1;32m 36\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m---> 37\u001b[0m h \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mattention(qkv)\n\u001b[1;32m 38\u001b[0m \u001b[39m# print(\"AttentionBlock, before proj_out, torch.unique(h).shape =\", torch.unique(h).shape)\u001b[39;00m\n\u001b[1;32m 39\u001b[0m h \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mproj_out(h)\n",
|
| 1708 |
"File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
|
| 1709 |
+
"Cell \u001b[0;32mIn[15], line 21\u001b[0m, in \u001b[0;36mQKVAttention.forward\u001b[0;34m(self, qkv, encoder_kv)\u001b[0m\n\u001b[1;32m 18\u001b[0m v \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mcat([ev,v], dim\u001b[39m=\u001b[39m\u001b[39m-\u001b[39m\u001b[39m1\u001b[39m)\n\u001b[1;32m 20\u001b[0m scale \u001b[39m=\u001b[39m \u001b[39m1\u001b[39m \u001b[39m/\u001b[39m math\u001b[39m.\u001b[39msqrt(math\u001b[39m.\u001b[39msqrt(ch))\n\u001b[0;32m---> 21\u001b[0m weight \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39;49meinsum(\u001b[39m\"\u001b[39;49m\u001b[39mbct,bcs->bts\u001b[39;49m\u001b[39m\"\u001b[39;49m, q\u001b[39m*\u001b[39;49mscale, k\u001b[39m*\u001b[39;49mscale)\n\u001b[1;32m 22\u001b[0m weight \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39msoftmax(weight\u001b[39m.\u001b[39mfloat(), dim\u001b[39m=\u001b[39m\u001b[39m-\u001b[39m\u001b[39m1\u001b[39m)\u001b[39m.\u001b[39mtype(weight\u001b[39m.\u001b[39mdtype)\n\u001b[1;32m 24\u001b[0m a \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39meinsum(\u001b[39m\"\u001b[39m\u001b[39mbts,bcs->bct\u001b[39m\u001b[39m\"\u001b[39m, weight, v)\n",
|
| 1710 |
+
"File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/functional.py:360\u001b[0m, in \u001b[0;36meinsum\u001b[0;34m(*args)\u001b[0m\n\u001b[1;32m 356\u001b[0m \u001b[39m# recurse incase operands contains value that has torch function\u001b[39;00m\n\u001b[1;32m 357\u001b[0m \u001b[39m# in the original implementation this line is omitted\u001b[39;00m\n\u001b[1;32m 358\u001b[0m \u001b[39mreturn\u001b[39;00m einsum(equation, \u001b[39m*\u001b[39m_operands)\n\u001b[0;32m--> 360\u001b[0m \u001b[39mreturn\u001b[39;00m _VF\u001b[39m.\u001b[39;49meinsum(equation, operands)\n",
|
| 1711 |
+
"\u001b[0;31mRuntimeError\u001b[0m: CUDA out of memory. Tried to allocate 640.00 MiB (GPU 0; 23.64 GiB total capacity; 21.65 GiB already allocated; 432.50 MiB free; 22.23 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF"
|
|
|
|
| 1712 |
]
|
| 1713 |
}
|
| 1714 |
],
|
|
|
|
| 1873 |
},
|
| 1874 |
{
|
| 1875 |
"cell_type": "code",
|
| 1876 |
+
"execution_count": 36,
|
| 1877 |
"metadata": {},
|
| 1878 |
"outputs": [
|
| 1879 |
{
|
| 1880 |
+
"name": "stdout",
|
| 1881 |
+
"output_type": "stream",
|
| 1882 |
+
"text": [
|
| 1883 |
+
"resuming nn_model\n"
|
| 1884 |
+
]
|
| 1885 |
+
},
|
| 1886 |
+
{
|
| 1887 |
+
"data": {
|
| 1888 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 1889 |
+
"model_id": "0e5bf6a8bad7403cab54e0b75464142a",
|
| 1890 |
+
"version_major": 2,
|
| 1891 |
+
"version_minor": 0
|
| 1892 |
+
},
|
| 1893 |
+
"text/plain": [
|
| 1894 |
+
" 0%| | 0/1000 [00:00<?, ?it/s]"
|
| 1895 |
+
]
|
| 1896 |
+
},
|
| 1897 |
+
"metadata": {},
|
| 1898 |
+
"output_type": "display_data"
|
| 1899 |
+
},
|
| 1900 |
+
{
|
| 1901 |
+
"ename": "KeyboardInterrupt",
|
| 1902 |
+
"evalue": "",
|
| 1903 |
"output_type": "error",
|
| 1904 |
"traceback": [
|
| 1905 |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
| 1906 |
+
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
| 1907 |
+
"Cell \u001b[0;32mIn[36], line 17\u001b[0m\n\u001b[1;32m 14\u001b[0m params \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mtensor((\u001b[39m0.2\u001b[39m,\u001b[39m0.8\u001b[39m))\u001b[39m.\u001b[39mrepeat(\u001b[39m10\u001b[39m,\u001b[39m1\u001b[39m)\n\u001b[1;32m 15\u001b[0m \u001b[39mwith\u001b[39;00m torch\u001b[39m.\u001b[39mno_grad():\n\u001b[1;32m 16\u001b[0m \u001b[39m# x_last_ema, x_ema_entire = ddpm.sample(nn_model, n_sample, (1,config.HII_DIM, config.num_redshift), config.device, params = torch.tile(config.params_single,(n_sample,1)).to(config.device), guide_w=config.guide_w)\u001b[39;00m\n\u001b[0;32m---> 17\u001b[0m x_last, x_entire \u001b[39m=\u001b[39m ddpm\u001b[39m.\u001b[39;49msample(\n\u001b[1;32m 18\u001b[0m nn_model\u001b[39m=\u001b[39;49mnn_model, \n\u001b[1;32m 19\u001b[0m params\u001b[39m=\u001b[39;49mparams\u001b[39m.\u001b[39;49mto(config\u001b[39m.\u001b[39;49mdevice), \n\u001b[1;32m 20\u001b[0m device\u001b[39m=\u001b[39;49mconfig\u001b[39m.\u001b[39;49mdevice, \n\u001b[1;32m 21\u001b[0m guide_w\u001b[39m=\u001b[39;49mconfig\u001b[39m.\u001b[39;49mguide_w\n\u001b[1;32m 22\u001b[0m )\n\u001b[1;32m 24\u001b[0m np\u001b[39m.\u001b[39msave(os\u001b[39m.\u001b[39mpath\u001b[39m.\u001b[39mjoin(config\u001b[39m.\u001b[39moutput_dir, \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m{\u001b[39;00mconfig\u001b[39m.\u001b[39mrun_name\u001b[39m}\u001b[39;00m\u001b[39m_ema.npy\u001b[39m\u001b[39m\"\u001b[39m), x_last)\n",
|
| 1908 |
+
"Cell \u001b[0;32mIn[6], line 59\u001b[0m, in \u001b[0;36mDDPMScheduler.sample\u001b[0;34m(self, nn_model, params, device, guide_w)\u001b[0m\n\u001b[1;32m 56\u001b[0m pbar_sample\u001b[39m.\u001b[39mset_description(\u001b[39m\"\u001b[39m\u001b[39mSampling\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 57\u001b[0m \u001b[39mfor\u001b[39;00m i \u001b[39min\u001b[39;00m \u001b[39mreversed\u001b[39m(\u001b[39mrange\u001b[39m(\u001b[39m0\u001b[39m, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mnum_timesteps)):\n\u001b[1;32m 58\u001b[0m \u001b[39m# print(f'sampling timestep {i:4d}',end='\\r')\u001b[39;00m\n\u001b[0;32m---> 59\u001b[0m t_is \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39;49mtensor([i])\u001b[39m.\u001b[39;49mto(device)\n\u001b[1;32m 60\u001b[0m t_is \u001b[39m=\u001b[39m t_is\u001b[39m.\u001b[39mrepeat(n_sample)\n\u001b[1;32m 62\u001b[0m z \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mrandn(n_sample, \u001b[39m*\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mimg_shape)\u001b[39m.\u001b[39mto(device) \u001b[39mif\u001b[39;00m i \u001b[39m>\u001b[39m \u001b[39m0\u001b[39m \u001b[39melse\u001b[39;00m \u001b[39m0\u001b[39m\n",
|
| 1909 |
+
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
|
| 1910 |
]
|
| 1911 |
}
|
| 1912 |
],
|
| 1913 |
"source": [
|
| 1914 |
+
"config = TrainConfig()\n",
|
| 1915 |
"# ddpm = DDPMScheduler(betas=(1e-4, 0.02), num_timesteps=config.num_timesteps, device=config.device)\n",
|
| 1916 |
+
"ddpm = DDPMScheduler(betas=(1e-4, 0.02), num_timesteps=config.num_timesteps, img_shape=config.img_shape, device=config.device)\n",
|
| 1917 |
"\n",
|
| 1918 |
+
"nn_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM)\n",
|
| 1919 |
+
"print(\"resuming nn_model\")\n",
|
| 1920 |
+
"nn_model.load_state_dict(torch.load(os.path.join(config.output_dir, f\"model_state.pth\"))['ema_unet_state_dict'])\n",
|
| 1921 |
+
"# nn_model = ContextUnet(n_param=1, image_size=28)\n",
|
| 1922 |
+
"# nn_model.train()\n",
|
| 1923 |
+
"nn_model.to(ddpm.device)\n",
|
| 1924 |
+
"nn_model.eval()\n",
|
| 1925 |
"\n",
|
| 1926 |
"# n_sample = 20\n",
|
| 1927 |
+
"params = torch.tensor((0.2,0.8)).repeat(10,1)\n",
|
| 1928 |
+
"with torch.no_grad():\n",
|
| 1929 |
+
" # x_last_ema, x_ema_entire = ddpm.sample(nn_model, n_sample, (1,config.HII_DIM, config.num_redshift), config.device, params = torch.tile(config.params_single,(n_sample,1)).to(config.device), guide_w=config.guide_w)\n",
|
| 1930 |
+
" x_last, x_entire = ddpm.sample(\n",
|
| 1931 |
+
" nn_model=nn_model, \n",
|
| 1932 |
+
" params=params.to(config.device), \n",
|
| 1933 |
+
" device=config.device, \n",
|
| 1934 |
+
" guide_w=config.guide_w\n",
|
| 1935 |
+
" )\n",
|
| 1936 |
"\n",
|
| 1937 |
+
"np.save(os.path.join(config.output_dir, f\"{config.run_name}_ema.npy\"), x_last)"
|
| 1938 |
]
|
| 1939 |
},
|
| 1940 |
+
{
|
| 1941 |
+
"cell_type": "code",
|
| 1942 |
+
"execution_count": 32,
|
| 1943 |
+
"metadata": {},
|
| 1944 |
+
"outputs": [
|
| 1945 |
+
{
|
| 1946 |
+
"data": {
|
| 1947 |
+
"text/plain": [
|
| 1948 |
+
"(4, 1, 64, 512)"
|
| 1949 |
+
]
|
| 1950 |
+
},
|
| 1951 |
+
"execution_count": 32,
|
| 1952 |
+
"metadata": {},
|
| 1953 |
+
"output_type": "execute_result"
|
| 1954 |
+
}
|
| 1955 |
+
],
|
| 1956 |
+
"source": [
|
| 1957 |
+
"x_last.shape"
|
| 1958 |
+
]
|
| 1959 |
+
},
|
| 1960 |
+
{
|
| 1961 |
+
"cell_type": "code",
|
| 1962 |
+
"execution_count": 35,
|
| 1963 |
+
"metadata": {},
|
| 1964 |
+
"outputs": [
|
| 1965 |
+
{
|
| 1966 |
+
"data": {
|
| 1967 |
+
"text/plain": [
|
| 1968 |
+
"'cuda'"
|
| 1969 |
+
]
|
| 1970 |
+
},
|
| 1971 |
+
"execution_count": 35,
|
| 1972 |
+
"metadata": {},
|
| 1973 |
+
"output_type": "execute_result"
|
| 1974 |
+
}
|
| 1975 |
+
],
|
| 1976 |
+
"source": []
|
| 1977 |
+
},
|
| 1978 |
{
|
| 1979 |
"cell_type": "code",
|
| 1980 |
"execution_count": 21,
|
load_h5.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
import h5py
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from torch.utils.data import DataLoader, Dataset
|
| 6 |
+
# from datasets import Dataset
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
import numpy as np
|
| 9 |
+
import random
|
| 10 |
+
from abc import ABC, abstractmethod
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
import math
|
| 13 |
+
from PIL import Image
|
| 14 |
+
import os
|
| 15 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 16 |
+
import copy
|
| 17 |
+
from tqdm.auto import tqdm
|
| 18 |
+
# from torchvision import transforms
|
| 19 |
+
# from diffusers import UNet2DModel#, UNet3DConditionModel
|
| 20 |
+
# from diffusers import DDPMScheduler
|
| 21 |
+
from diffusers.utils import make_image_grid
|
| 22 |
+
import datetime
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
from diffusers.optimization import get_cosine_schedule_with_warmup
|
| 25 |
+
from accelerate import notebook_launcher, Accelerator
|
| 26 |
+
from huggingface_hub import create_repo, upload_folder
|
| 27 |
+
|