0522-0050
Browse files- diffusion.ipynb +61 -66
diffusion.ipynb
CHANGED
|
@@ -32,7 +32,7 @@
|
|
| 32 |
{
|
| 33 |
"data": {
|
| 34 |
"application/vnd.jupyter.widget-view+json": {
|
| 35 |
-
"model_id": "
|
| 36 |
"version_major": 2,
|
| 37 |
"version_minor": 0
|
| 38 |
},
|
|
@@ -51,7 +51,7 @@
|
|
| 51 |
},
|
| 52 |
{
|
| 53 |
"cell_type": "code",
|
| 54 |
-
"execution_count":
|
| 55 |
"metadata": {},
|
| 56 |
"outputs": [],
|
| 57 |
"source": [
|
|
@@ -85,7 +85,7 @@
|
|
| 85 |
},
|
| 86 |
{
|
| 87 |
"cell_type": "code",
|
| 88 |
-
"execution_count":
|
| 89 |
"metadata": {},
|
| 90 |
"outputs": [],
|
| 91 |
"source": [
|
|
@@ -94,7 +94,7 @@
|
|
| 94 |
},
|
| 95 |
{
|
| 96 |
"cell_type": "code",
|
| 97 |
-
"execution_count":
|
| 98 |
"metadata": {},
|
| 99 |
"outputs": [],
|
| 100 |
"source": [
|
|
@@ -191,7 +191,7 @@
|
|
| 191 |
},
|
| 192 |
{
|
| 193 |
"cell_type": "code",
|
| 194 |
-
"execution_count":
|
| 195 |
"metadata": {},
|
| 196 |
"outputs": [],
|
| 197 |
"source": [
|
|
@@ -201,7 +201,7 @@
|
|
| 201 |
},
|
| 202 |
{
|
| 203 |
"cell_type": "code",
|
| 204 |
-
"execution_count":
|
| 205 |
"metadata": {},
|
| 206 |
"outputs": [],
|
| 207 |
"source": [
|
|
@@ -234,7 +234,7 @@
|
|
| 234 |
},
|
| 235 |
{
|
| 236 |
"cell_type": "code",
|
| 237 |
-
"execution_count":
|
| 238 |
"metadata": {},
|
| 239 |
"outputs": [],
|
| 240 |
"source": [
|
|
@@ -279,7 +279,7 @@
|
|
| 279 |
" n_sample = len(params) #params.shape[0]\n",
|
| 280 |
" # print(\"params.shape[0], len(params)\", params.shape[0], len(params))\n",
|
| 281 |
" x_i = torch.randn(n_sample, *self.img_shape).to(device)\n",
|
| 282 |
-
" print(\"x_i.shape =\", x_i.shape)\n",
|
| 283 |
" # print(\"x_i.shape =\", x_i.shape)\n",
|
| 284 |
" if guide_w != -1:\n",
|
| 285 |
" c_i = params\n",
|
|
@@ -311,7 +311,7 @@
|
|
| 311 |
" t_is = t_is.repeat(2)\n",
|
| 312 |
"\n",
|
| 313 |
" # split predictions and compute weighting\n",
|
| 314 |
-
" print(\"nn_model input shape\", x_i.shape, t_is.shape, c_i.shape)\n",
|
| 315 |
" eps = nn_model(x_i, t_is, c_i)\n",
|
| 316 |
" eps1 = eps[:n_sample]\n",
|
| 317 |
" eps2 = eps[n_sample:]\n",
|
|
@@ -341,7 +341,7 @@
|
|
| 341 |
},
|
| 342 |
{
|
| 343 |
"cell_type": "code",
|
| 344 |
-
"execution_count":
|
| 345 |
"metadata": {},
|
| 346 |
"outputs": [],
|
| 347 |
"source": [
|
|
@@ -383,7 +383,7 @@
|
|
| 383 |
},
|
| 384 |
{
|
| 385 |
"cell_type": "code",
|
| 386 |
-
"execution_count":
|
| 387 |
"metadata": {},
|
| 388 |
"outputs": [],
|
| 389 |
"source": [
|
|
@@ -408,7 +408,7 @@
|
|
| 408 |
},
|
| 409 |
{
|
| 410 |
"cell_type": "code",
|
| 411 |
-
"execution_count":
|
| 412 |
"metadata": {},
|
| 413 |
"outputs": [],
|
| 414 |
"source": [
|
|
@@ -438,7 +438,7 @@
|
|
| 438 |
},
|
| 439 |
{
|
| 440 |
"cell_type": "code",
|
| 441 |
-
"execution_count":
|
| 442 |
"metadata": {},
|
| 443 |
"outputs": [],
|
| 444 |
"source": [
|
|
@@ -453,7 +453,7 @@
|
|
| 453 |
},
|
| 454 |
{
|
| 455 |
"cell_type": "code",
|
| 456 |
-
"execution_count":
|
| 457 |
"metadata": {},
|
| 458 |
"outputs": [],
|
| 459 |
"source": [
|
|
@@ -467,7 +467,7 @@
|
|
| 467 |
},
|
| 468 |
{
|
| 469 |
"cell_type": "code",
|
| 470 |
-
"execution_count":
|
| 471 |
"metadata": {},
|
| 472 |
"outputs": [],
|
| 473 |
"source": [
|
|
@@ -485,7 +485,7 @@
|
|
| 485 |
},
|
| 486 |
{
|
| 487 |
"cell_type": "code",
|
| 488 |
-
"execution_count":
|
| 489 |
"metadata": {},
|
| 490 |
"outputs": [],
|
| 491 |
"source": [
|
|
@@ -567,7 +567,7 @@
|
|
| 567 |
},
|
| 568 |
{
|
| 569 |
"cell_type": "code",
|
| 570 |
-
"execution_count":
|
| 571 |
"metadata": {},
|
| 572 |
"outputs": [],
|
| 573 |
"source": [
|
|
@@ -600,7 +600,7 @@
|
|
| 600 |
},
|
| 601 |
{
|
| 602 |
"cell_type": "code",
|
| 603 |
-
"execution_count":
|
| 604 |
"metadata": {},
|
| 605 |
"outputs": [],
|
| 606 |
"source": [
|
|
@@ -649,7 +649,7 @@
|
|
| 649 |
},
|
| 650 |
{
|
| 651 |
"cell_type": "code",
|
| 652 |
-
"execution_count":
|
| 653 |
"metadata": {},
|
| 654 |
"outputs": [],
|
| 655 |
"source": [
|
|
@@ -678,7 +678,7 @@
|
|
| 678 |
},
|
| 679 |
{
|
| 680 |
"cell_type": "code",
|
| 681 |
-
"execution_count":
|
| 682 |
"metadata": {},
|
| 683 |
"outputs": [],
|
| 684 |
"source": [
|
|
@@ -934,7 +934,7 @@
|
|
| 934 |
},
|
| 935 |
{
|
| 936 |
"cell_type": "code",
|
| 937 |
-
"execution_count":
|
| 938 |
"metadata": {},
|
| 939 |
"outputs": [],
|
| 940 |
"source": [
|
|
@@ -964,7 +964,7 @@
|
|
| 964 |
},
|
| 965 |
{
|
| 966 |
"cell_type": "code",
|
| 967 |
-
"execution_count":
|
| 968 |
"metadata": {},
|
| 969 |
"outputs": [],
|
| 970 |
"source": [
|
|
@@ -1030,7 +1030,7 @@
|
|
| 1030 |
},
|
| 1031 |
{
|
| 1032 |
"cell_type": "code",
|
| 1033 |
-
"execution_count":
|
| 1034 |
"metadata": {},
|
| 1035 |
"outputs": [],
|
| 1036 |
"source": [
|
|
@@ -1040,7 +1040,7 @@
|
|
| 1040 |
},
|
| 1041 |
{
|
| 1042 |
"cell_type": "code",
|
| 1043 |
-
"execution_count":
|
| 1044 |
"metadata": {},
|
| 1045 |
"outputs": [],
|
| 1046 |
"source": [
|
|
@@ -1049,7 +1049,7 @@
|
|
| 1049 |
},
|
| 1050 |
{
|
| 1051 |
"cell_type": "code",
|
| 1052 |
-
"execution_count":
|
| 1053 |
"metadata": {},
|
| 1054 |
"outputs": [],
|
| 1055 |
"source": [
|
|
@@ -1073,7 +1073,7 @@
|
|
| 1073 |
},
|
| 1074 |
{
|
| 1075 |
"cell_type": "code",
|
| 1076 |
-
"execution_count":
|
| 1077 |
"metadata": {},
|
| 1078 |
"outputs": [],
|
| 1079 |
"source": [
|
|
@@ -1271,7 +1271,7 @@
|
|
| 1271 |
},
|
| 1272 |
{
|
| 1273 |
"cell_type": "code",
|
| 1274 |
-
"execution_count":
|
| 1275 |
"metadata": {},
|
| 1276 |
"outputs": [
|
| 1277 |
{
|
|
@@ -1336,6 +1336,7 @@
|
|
| 1336 |
" self.dataloader = DataLoader(dataset, batch_size=self.config.batch_size, shuffle=True)\n",
|
| 1337 |
" # del dataset\n",
|
| 1338 |
" self.accelerate(self.config)\n",
|
|
|
|
| 1339 |
"\n",
|
| 1340 |
" def accelerate(self, config):\n",
|
| 1341 |
" self.accelerator = Accelerator(\n",
|
|
@@ -1429,11 +1430,12 @@
|
|
| 1429 |
" print('saved model at ' + self.config.output_dir + f\"model_state_{ep:02d}.pth\")\n",
|
| 1430 |
" # print('saved model at ' + config.save_dir + f\"model_epoch_{ep}_test_{config.run_name}.pth\")\n",
|
| 1431 |
"\n",
|
| 1432 |
-
" def sample(self, file, params=None, ema=False, entire=False):\n",
|
| 1433 |
" # n_sample = params.shape[0]\n",
|
| 1434 |
-
" params = params or torch.tensor([0.2,0.8]).repeat(
|
| 1435 |
" assert params.dim() == 2, \"params must be a 2D torch.tensor\"\n",
|
| 1436 |
-
" print(\"params.shape =\", params.shape)\n",
|
|
|
|
| 1437 |
" # print(\"len(params) =\", len(params))\n",
|
| 1438 |
" model = self.ema_model if ema else self.nn_model\n",
|
| 1439 |
" # params = torch.tile(params, (n_sample,1)).to(device)\n",
|
|
@@ -1451,7 +1453,7 @@
|
|
| 1451 |
},
|
| 1452 |
{
|
| 1453 |
"cell_type": "code",
|
| 1454 |
-
"execution_count":
|
| 1455 |
"metadata": {},
|
| 1456 |
"outputs": [
|
| 1457 |
{
|
|
@@ -1478,14 +1480,14 @@
|
|
| 1478 |
"output_type": "stream",
|
| 1479 |
"text": [
|
| 1480 |
"params loaded: (200, 2)\n",
|
| 1481 |
-
"images rescaled to [-1.0, 1.
|
| 1482 |
-
"params rescaled to [0.0, 0.
|
| 1483 |
]
|
| 1484 |
},
|
| 1485 |
{
|
| 1486 |
"data": {
|
| 1487 |
"application/vnd.jupyter.widget-view+json": {
|
| 1488 |
-
"model_id": "
|
| 1489 |
"version_major": 2,
|
| 1490 |
"version_minor": 0
|
| 1491 |
},
|
|
@@ -1499,7 +1501,7 @@
|
|
| 1499 |
{
|
| 1500 |
"data": {
|
| 1501 |
"application/vnd.jupyter.widget-view+json": {
|
| 1502 |
-
"model_id": "
|
| 1503 |
"version_major": 2,
|
| 1504 |
"version_minor": 0
|
| 1505 |
},
|
|
@@ -1513,7 +1515,7 @@
|
|
| 1513 |
{
|
| 1514 |
"data": {
|
| 1515 |
"application/vnd.jupyter.widget-view+json": {
|
| 1516 |
-
"model_id": "
|
| 1517 |
"version_major": 2,
|
| 1518 |
"version_minor": 0
|
| 1519 |
},
|
|
@@ -1527,7 +1529,7 @@
|
|
| 1527 |
{
|
| 1528 |
"data": {
|
| 1529 |
"application/vnd.jupyter.widget-view+json": {
|
| 1530 |
-
"model_id": "
|
| 1531 |
"version_major": 2,
|
| 1532 |
"version_minor": 0
|
| 1533 |
},
|
|
@@ -1541,7 +1543,7 @@
|
|
| 1541 |
{
|
| 1542 |
"data": {
|
| 1543 |
"application/vnd.jupyter.widget-view+json": {
|
| 1544 |
-
"model_id": "
|
| 1545 |
"version_major": 2,
|
| 1546 |
"version_minor": 0
|
| 1547 |
},
|
|
@@ -1555,7 +1557,7 @@
|
|
| 1555 |
{
|
| 1556 |
"data": {
|
| 1557 |
"application/vnd.jupyter.widget-view+json": {
|
| 1558 |
-
"model_id": "
|
| 1559 |
"version_major": 2,
|
| 1560 |
"version_minor": 0
|
| 1561 |
},
|
|
@@ -1569,7 +1571,7 @@
|
|
| 1569 |
{
|
| 1570 |
"data": {
|
| 1571 |
"application/vnd.jupyter.widget-view+json": {
|
| 1572 |
-
"model_id": "
|
| 1573 |
"version_major": 2,
|
| 1574 |
"version_minor": 0
|
| 1575 |
},
|
|
@@ -1583,7 +1585,7 @@
|
|
| 1583 |
{
|
| 1584 |
"data": {
|
| 1585 |
"application/vnd.jupyter.widget-view+json": {
|
| 1586 |
-
"model_id": "
|
| 1587 |
"version_major": 2,
|
| 1588 |
"version_minor": 0
|
| 1589 |
},
|
|
@@ -1597,7 +1599,7 @@
|
|
| 1597 |
{
|
| 1598 |
"data": {
|
| 1599 |
"application/vnd.jupyter.widget-view+json": {
|
| 1600 |
-
"model_id": "
|
| 1601 |
"version_major": 2,
|
| 1602 |
"version_minor": 0
|
| 1603 |
},
|
|
@@ -1611,7 +1613,7 @@
|
|
| 1611 |
{
|
| 1612 |
"data": {
|
| 1613 |
"application/vnd.jupyter.widget-view+json": {
|
| 1614 |
-
"model_id": "
|
| 1615 |
"version_major": 2,
|
| 1616 |
"version_minor": 0
|
| 1617 |
},
|
|
@@ -1629,21 +1631,24 @@
|
|
| 1629 |
},
|
| 1630 |
{
|
| 1631 |
"cell_type": "code",
|
| 1632 |
-
"execution_count":
|
| 1633 |
"metadata": {},
|
| 1634 |
"outputs": [
|
| 1635 |
{
|
| 1636 |
"name": "stdout",
|
| 1637 |
"output_type": "stream",
|
| 1638 |
"text": [
|
| 1639 |
-
"params
|
| 1640 |
-
"
|
|
|
|
|
|
|
|
|
|
| 1641 |
]
|
| 1642 |
},
|
| 1643 |
{
|
| 1644 |
"data": {
|
| 1645 |
"application/vnd.jupyter.widget-view+json": {
|
| 1646 |
-
"model_id": "
|
| 1647 |
"version_major": 2,
|
| 1648 |
"version_minor": 0
|
| 1649 |
},
|
|
@@ -1654,35 +1659,25 @@
|
|
| 1654 |
"metadata": {},
|
| 1655 |
"output_type": "display_data"
|
| 1656 |
},
|
| 1657 |
-
{
|
| 1658 |
-
"name": "stdout",
|
| 1659 |
-
"output_type": "stream",
|
| 1660 |
-
"text": [
|
| 1661 |
-
"nn_model input shape torch.Size([20, 1, 64, 512]) torch.Size([20]) torch.Size([20, 2])\n"
|
| 1662 |
-
]
|
| 1663 |
-
},
|
| 1664 |
{
|
| 1665 |
"ename": "RuntimeError",
|
| 1666 |
-
"evalue": "CUDA out of memory. Tried to allocate
|
| 1667 |
"output_type": "error",
|
| 1668 |
"traceback": [
|
| 1669 |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
| 1670 |
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
|
| 1671 |
-
"Cell \u001b[0;32mIn[
|
| 1672 |
-
"Cell \u001b[0;32mIn[
|
| 1673 |
-
"Cell \u001b[0;32mIn[
|
| 1674 |
-
"File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
|
| 1675 |
-
"Cell \u001b[0;32mIn[122], line 230\u001b[0m, in \u001b[0;36mContextUnet.forward\u001b[0;34m(self, x, timesteps, y)\u001b[0m\n\u001b[1;32m 228\u001b[0m \u001b[39m# print(\"0,h.shape =\", h.shape)\u001b[39;00m\n\u001b[1;32m 229\u001b[0m \u001b[39mfor\u001b[39;00m module \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39minput_blocks:\n\u001b[0;32m--> 230\u001b[0m h \u001b[39m=\u001b[39m module(h, emb)\n\u001b[1;32m 231\u001b[0m hs\u001b[39m.\u001b[39mappend(h)\n\u001b[1;32m 232\u001b[0m \u001b[39m# print(\"module encoder, h.shape =\", h.shape)\u001b[39;00m\n\u001b[1;32m 233\u001b[0m \u001b[39m# print(\"2,h.shape =\", h.shape)\u001b[39;00m\n",
|
| 1676 |
"File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
|
| 1677 |
-
"Cell \u001b[0;32mIn[
|
| 1678 |
"File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
|
| 1679 |
-
"Cell \u001b[0;32mIn[
|
| 1680 |
"File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
|
| 1681 |
-
"
|
| 1682 |
"File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
|
| 1683 |
-
"
|
| 1684 |
-
"
|
| 1685 |
-
"\u001b[0;31mRuntimeError\u001b[0m: CUDA out of memory. Tried to allocate 320.00 MiB (GPU 0; 23.64 GiB total capacity; 21.71 GiB already allocated; 170.50 MiB free; 22.45 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF"
|
| 1686 |
]
|
| 1687 |
}
|
| 1688 |
],
|
|
|
|
| 32 |
{
|
| 33 |
"data": {
|
| 34 |
"application/vnd.jupyter.widget-view+json": {
|
| 35 |
+
"model_id": "237469b68706428ea8491187899b7943",
|
| 36 |
"version_major": 2,
|
| 37 |
"version_minor": 0
|
| 38 |
},
|
|
|
|
| 51 |
},
|
| 52 |
{
|
| 53 |
"cell_type": "code",
|
| 54 |
+
"execution_count": 1,
|
| 55 |
"metadata": {},
|
| 56 |
"outputs": [],
|
| 57 |
"source": [
|
|
|
|
| 85 |
},
|
| 86 |
{
|
| 87 |
"cell_type": "code",
|
| 88 |
+
"execution_count": 2,
|
| 89 |
"metadata": {},
|
| 90 |
"outputs": [],
|
| 91 |
"source": [
|
|
|
|
| 94 |
},
|
| 95 |
{
|
| 96 |
"cell_type": "code",
|
| 97 |
+
"execution_count": 3,
|
| 98 |
"metadata": {},
|
| 99 |
"outputs": [],
|
| 100 |
"source": [
|
|
|
|
| 191 |
},
|
| 192 |
{
|
| 193 |
"cell_type": "code",
|
| 194 |
+
"execution_count": 4,
|
| 195 |
"metadata": {},
|
| 196 |
"outputs": [],
|
| 197 |
"source": [
|
|
|
|
| 201 |
},
|
| 202 |
{
|
| 203 |
"cell_type": "code",
|
| 204 |
+
"execution_count": 5,
|
| 205 |
"metadata": {},
|
| 206 |
"outputs": [],
|
| 207 |
"source": [
|
|
|
|
| 234 |
},
|
| 235 |
{
|
| 236 |
"cell_type": "code",
|
| 237 |
+
"execution_count": 6,
|
| 238 |
"metadata": {},
|
| 239 |
"outputs": [],
|
| 240 |
"source": [
|
|
|
|
| 279 |
" n_sample = len(params) #params.shape[0]\n",
|
| 280 |
" # print(\"params.shape[0], len(params)\", params.shape[0], len(params))\n",
|
| 281 |
" x_i = torch.randn(n_sample, *self.img_shape).to(device)\n",
|
| 282 |
+
" # print(\"x_i.shape =\", x_i.shape)\n",
|
| 283 |
" # print(\"x_i.shape =\", x_i.shape)\n",
|
| 284 |
" if guide_w != -1:\n",
|
| 285 |
" c_i = params\n",
|
|
|
|
| 311 |
" t_is = t_is.repeat(2)\n",
|
| 312 |
"\n",
|
| 313 |
" # split predictions and compute weighting\n",
|
| 314 |
+
" # print(\"nn_model input shape\", x_i.shape, t_is.shape, c_i.shape)\n",
|
| 315 |
" eps = nn_model(x_i, t_is, c_i)\n",
|
| 316 |
" eps1 = eps[:n_sample]\n",
|
| 317 |
" eps2 = eps[n_sample:]\n",
|
|
|
|
| 341 |
},
|
| 342 |
{
|
| 343 |
"cell_type": "code",
|
| 344 |
+
"execution_count": 7,
|
| 345 |
"metadata": {},
|
| 346 |
"outputs": [],
|
| 347 |
"source": [
|
|
|
|
| 383 |
},
|
| 384 |
{
|
| 385 |
"cell_type": "code",
|
| 386 |
+
"execution_count": 8,
|
| 387 |
"metadata": {},
|
| 388 |
"outputs": [],
|
| 389 |
"source": [
|
|
|
|
| 408 |
},
|
| 409 |
{
|
| 410 |
"cell_type": "code",
|
| 411 |
+
"execution_count": 9,
|
| 412 |
"metadata": {},
|
| 413 |
"outputs": [],
|
| 414 |
"source": [
|
|
|
|
| 438 |
},
|
| 439 |
{
|
| 440 |
"cell_type": "code",
|
| 441 |
+
"execution_count": 10,
|
| 442 |
"metadata": {},
|
| 443 |
"outputs": [],
|
| 444 |
"source": [
|
|
|
|
| 453 |
},
|
| 454 |
{
|
| 455 |
"cell_type": "code",
|
| 456 |
+
"execution_count": 11,
|
| 457 |
"metadata": {},
|
| 458 |
"outputs": [],
|
| 459 |
"source": [
|
|
|
|
| 467 |
},
|
| 468 |
{
|
| 469 |
"cell_type": "code",
|
| 470 |
+
"execution_count": 12,
|
| 471 |
"metadata": {},
|
| 472 |
"outputs": [],
|
| 473 |
"source": [
|
|
|
|
| 485 |
},
|
| 486 |
{
|
| 487 |
"cell_type": "code",
|
| 488 |
+
"execution_count": 13,
|
| 489 |
"metadata": {},
|
| 490 |
"outputs": [],
|
| 491 |
"source": [
|
|
|
|
| 567 |
},
|
| 568 |
{
|
| 569 |
"cell_type": "code",
|
| 570 |
+
"execution_count": 14,
|
| 571 |
"metadata": {},
|
| 572 |
"outputs": [],
|
| 573 |
"source": [
|
|
|
|
| 600 |
},
|
| 601 |
{
|
| 602 |
"cell_type": "code",
|
| 603 |
+
"execution_count": 15,
|
| 604 |
"metadata": {},
|
| 605 |
"outputs": [],
|
| 606 |
"source": [
|
|
|
|
| 649 |
},
|
| 650 |
{
|
| 651 |
"cell_type": "code",
|
| 652 |
+
"execution_count": 16,
|
| 653 |
"metadata": {},
|
| 654 |
"outputs": [],
|
| 655 |
"source": [
|
|
|
|
| 678 |
},
|
| 679 |
{
|
| 680 |
"cell_type": "code",
|
| 681 |
+
"execution_count": 17,
|
| 682 |
"metadata": {},
|
| 683 |
"outputs": [],
|
| 684 |
"source": [
|
|
|
|
| 934 |
},
|
| 935 |
{
|
| 936 |
"cell_type": "code",
|
| 937 |
+
"execution_count": 18,
|
| 938 |
"metadata": {},
|
| 939 |
"outputs": [],
|
| 940 |
"source": [
|
|
|
|
| 964 |
},
|
| 965 |
{
|
| 966 |
"cell_type": "code",
|
| 967 |
+
"execution_count": 19,
|
| 968 |
"metadata": {},
|
| 969 |
"outputs": [],
|
| 970 |
"source": [
|
|
|
|
| 1030 |
},
|
| 1031 |
{
|
| 1032 |
"cell_type": "code",
|
| 1033 |
+
"execution_count": 20,
|
| 1034 |
"metadata": {},
|
| 1035 |
"outputs": [],
|
| 1036 |
"source": [
|
|
|
|
| 1040 |
},
|
| 1041 |
{
|
| 1042 |
"cell_type": "code",
|
| 1043 |
+
"execution_count": 21,
|
| 1044 |
"metadata": {},
|
| 1045 |
"outputs": [],
|
| 1046 |
"source": [
|
|
|
|
| 1049 |
},
|
| 1050 |
{
|
| 1051 |
"cell_type": "code",
|
| 1052 |
+
"execution_count": 22,
|
| 1053 |
"metadata": {},
|
| 1054 |
"outputs": [],
|
| 1055 |
"source": [
|
|
|
|
| 1073 |
},
|
| 1074 |
{
|
| 1075 |
"cell_type": "code",
|
| 1076 |
+
"execution_count": 23,
|
| 1077 |
"metadata": {},
|
| 1078 |
"outputs": [],
|
| 1079 |
"source": [
|
|
|
|
| 1271 |
},
|
| 1272 |
{
|
| 1273 |
"cell_type": "code",
|
| 1274 |
+
"execution_count": 24,
|
| 1275 |
"metadata": {},
|
| 1276 |
"outputs": [
|
| 1277 |
{
|
|
|
|
| 1336 |
" self.dataloader = DataLoader(dataset, batch_size=self.config.batch_size, shuffle=True)\n",
|
| 1337 |
" # del dataset\n",
|
| 1338 |
" self.accelerate(self.config)\n",
|
| 1339 |
+
" del dataset\n",
|
| 1340 |
"\n",
|
| 1341 |
" def accelerate(self, config):\n",
|
| 1342 |
" self.accelerator = Accelerator(\n",
|
|
|
|
| 1430 |
" print('saved model at ' + self.config.output_dir + f\"model_state_{ep:02d}.pth\")\n",
|
| 1431 |
" # print('saved model at ' + config.save_dir + f\"model_epoch_{ep}_test_{config.run_name}.pth\")\n",
|
| 1432 |
"\n",
|
| 1433 |
+
" def sample(self, file, params:torch.tensor=None, ema=False, entire=False):\n",
|
| 1434 |
" # n_sample = params.shape[0]\n",
|
| 1435 |
+
" params = params or torch.tensor([0.2,0.8]).repeat(2,1)\n",
|
| 1436 |
" assert params.dim() == 2, \"params must be a 2D torch.tensor\"\n",
|
| 1437 |
+
" # print(\"params.shape =\", params.shape)\n",
|
| 1438 |
+
" print(\"params =\", params)\n",
|
| 1439 |
" # print(\"len(params) =\", len(params))\n",
|
| 1440 |
" model = self.ema_model if ema else self.nn_model\n",
|
| 1441 |
" # params = torch.tile(params, (n_sample,1)).to(device)\n",
|
|
|
|
| 1453 |
},
|
| 1454 |
{
|
| 1455 |
"cell_type": "code",
|
| 1456 |
+
"execution_count": 25,
|
| 1457 |
"metadata": {},
|
| 1458 |
"outputs": [
|
| 1459 |
{
|
|
|
|
| 1480 |
"output_type": "stream",
|
| 1481 |
"text": [
|
| 1482 |
"params loaded: (200, 2)\n",
|
| 1483 |
+
"images rescaled to [-1.0, 1.093963384628296]\n",
|
| 1484 |
+
"params rescaled to [0.0, 0.9996232858956116]\n"
|
| 1485 |
]
|
| 1486 |
},
|
| 1487 |
{
|
| 1488 |
"data": {
|
| 1489 |
"application/vnd.jupyter.widget-view+json": {
|
| 1490 |
+
"model_id": "86cf95c56e904304b70355f9cb7c9ed1",
|
| 1491 |
"version_major": 2,
|
| 1492 |
"version_minor": 0
|
| 1493 |
},
|
|
|
|
| 1501 |
{
|
| 1502 |
"data": {
|
| 1503 |
"application/vnd.jupyter.widget-view+json": {
|
| 1504 |
+
"model_id": "ed9760495855476fbd8cd1a7d8ad6128",
|
| 1505 |
"version_major": 2,
|
| 1506 |
"version_minor": 0
|
| 1507 |
},
|
|
|
|
| 1515 |
{
|
| 1516 |
"data": {
|
| 1517 |
"application/vnd.jupyter.widget-view+json": {
|
| 1518 |
+
"model_id": "f1d47d8f48e34900a36729a878892d25",
|
| 1519 |
"version_major": 2,
|
| 1520 |
"version_minor": 0
|
| 1521 |
},
|
|
|
|
| 1529 |
{
|
| 1530 |
"data": {
|
| 1531 |
"application/vnd.jupyter.widget-view+json": {
|
| 1532 |
+
"model_id": "d8d9e44be0ca4492ac1fdd0850008bed",
|
| 1533 |
"version_major": 2,
|
| 1534 |
"version_minor": 0
|
| 1535 |
},
|
|
|
|
| 1543 |
{
|
| 1544 |
"data": {
|
| 1545 |
"application/vnd.jupyter.widget-view+json": {
|
| 1546 |
+
"model_id": "8f163de4249f46ae94ce868269b3fa59",
|
| 1547 |
"version_major": 2,
|
| 1548 |
"version_minor": 0
|
| 1549 |
},
|
|
|
|
| 1557 |
{
|
| 1558 |
"data": {
|
| 1559 |
"application/vnd.jupyter.widget-view+json": {
|
| 1560 |
+
"model_id": "a9d040b034734589acdbfe235d8d22ce",
|
| 1561 |
"version_major": 2,
|
| 1562 |
"version_minor": 0
|
| 1563 |
},
|
|
|
|
| 1571 |
{
|
| 1572 |
"data": {
|
| 1573 |
"application/vnd.jupyter.widget-view+json": {
|
| 1574 |
+
"model_id": "cb7c7f4ebde5475d9b21db9838cce7d7",
|
| 1575 |
"version_major": 2,
|
| 1576 |
"version_minor": 0
|
| 1577 |
},
|
|
|
|
| 1585 |
{
|
| 1586 |
"data": {
|
| 1587 |
"application/vnd.jupyter.widget-view+json": {
|
| 1588 |
+
"model_id": "dcfb2b71bc98404db01b10c1151fcc0b",
|
| 1589 |
"version_major": 2,
|
| 1590 |
"version_minor": 0
|
| 1591 |
},
|
|
|
|
| 1599 |
{
|
| 1600 |
"data": {
|
| 1601 |
"application/vnd.jupyter.widget-view+json": {
|
| 1602 |
+
"model_id": "e46ce03a41854df081bd24b257de3b8a",
|
| 1603 |
"version_major": 2,
|
| 1604 |
"version_minor": 0
|
| 1605 |
},
|
|
|
|
| 1613 |
{
|
| 1614 |
"data": {
|
| 1615 |
"application/vnd.jupyter.widget-view+json": {
|
| 1616 |
+
"model_id": "0833d662b39d487386cec129c48548b4",
|
| 1617 |
"version_major": 2,
|
| 1618 |
"version_minor": 0
|
| 1619 |
},
|
|
|
|
| 1631 |
},
|
| 1632 |
{
|
| 1633 |
"cell_type": "code",
|
| 1634 |
+
"execution_count": null,
|
| 1635 |
"metadata": {},
|
| 1636 |
"outputs": [
|
| 1637 |
{
|
| 1638 |
"name": "stdout",
|
| 1639 |
"output_type": "stream",
|
| 1640 |
"text": [
|
| 1641 |
+
"params = tensor([[0.2000, 0.8000],\n",
|
| 1642 |
+
" [0.2000, 0.8000],\n",
|
| 1643 |
+
" [0.2000, 0.8000],\n",
|
| 1644 |
+
" [0.2000, 0.8000],\n",
|
| 1645 |
+
" [0.2000, 0.8000]])\n"
|
| 1646 |
]
|
| 1647 |
},
|
| 1648 |
{
|
| 1649 |
"data": {
|
| 1650 |
"application/vnd.jupyter.widget-view+json": {
|
| 1651 |
+
"model_id": "12c488597076479fbd2857b20d69da29",
|
| 1652 |
"version_major": 2,
|
| 1653 |
"version_minor": 0
|
| 1654 |
},
|
|
|
|
| 1659 |
"metadata": {},
|
| 1660 |
"output_type": "display_data"
|
| 1661 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1662 |
{
|
| 1663 |
"ename": "RuntimeError",
|
| 1664 |
+
"evalue": "CUDA out of memory. Tried to allocate 640.00 MiB (GPU 0; 23.64 GiB total capacity; 21.86 GiB already allocated; 222.50 MiB free; 22.44 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF",
|
| 1665 |
"output_type": "error",
|
| 1666 |
"traceback": [
|
| 1667 |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
| 1668 |
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
|
| 1669 |
+
"Cell \u001b[0;32mIn[25], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m ddpm21cm\u001b[39m.\u001b[39;49msample(\u001b[39m\"\u001b[39;49m\u001b[39m./outputs/model_state_09.pth\u001b[39;49m\u001b[39m\"\u001b[39;49m)\n",
|
| 1670 |
+
"Cell \u001b[0;32mIn[24], line 156\u001b[0m, in \u001b[0;36mDDPM21CM.sample\u001b[0;34m(self, file, params, ema, entire)\u001b[0m\n\u001b[1;32m 153\u001b[0m model \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mema_model \u001b[39mif\u001b[39;00m ema \u001b[39melse\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mnn_model\n\u001b[1;32m 154\u001b[0m \u001b[39m# params = torch.tile(params, (n_sample,1)).to(device)\u001b[39;00m\n\u001b[0;32m--> 156\u001b[0m x_last, x_entire \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mddpm\u001b[39m.\u001b[39;49msample(model, params\u001b[39m=\u001b[39;49mparams\u001b[39m.\u001b[39;49mto(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mconfig\u001b[39m.\u001b[39;49mdevice), device\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mconfig\u001b[39m.\u001b[39;49mdevice, guide_w\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mconfig\u001b[39m.\u001b[39;49mguide_w)\n\u001b[1;32m 158\u001b[0m np\u001b[39m.\u001b[39msave(os\u001b[39m.\u001b[39mpath\u001b[39m.\u001b[39mjoin(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconfig\u001b[39m.\u001b[39moutput_dir, \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m{\u001b[39;00m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconfig\u001b[39m.\u001b[39mrun_name\u001b[39m}\u001b[39;00m\u001b[39m{\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39mema\u001b[39m\u001b[39m'\u001b[39m\u001b[39m \u001b[39m\u001b[39mif\u001b[39;00m\u001b[39m \u001b[39mema\u001b[39m \u001b[39m\u001b[39melse\u001b[39;00m\u001b[39m \u001b[39m\u001b[39mNone\u001b[39;00m\u001b[39m}\u001b[39;00m\u001b[39m.npy\u001b[39m\u001b[39m\"\u001b[39m), x_last)\n\u001b[1;32m 159\u001b[0m \u001b[39mif\u001b[39;00m entire:\n",
|
| 1671 |
+
"Cell \u001b[0;32mIn[6], line 75\u001b[0m, in \u001b[0;36mDDPMScheduler.sample\u001b[0;34m(self, nn_model, params, device, guide_w)\u001b[0m\n\u001b[1;32m 71\u001b[0m t_is \u001b[39m=\u001b[39m t_is\u001b[39m.\u001b[39mrepeat(\u001b[39m2\u001b[39m)\n\u001b[1;32m 73\u001b[0m \u001b[39m# split predictions and compute weighting\u001b[39;00m\n\u001b[1;32m 74\u001b[0m \u001b[39m# print(\"nn_model input shape\", x_i.shape, t_is.shape, c_i.shape)\u001b[39;00m\n\u001b[0;32m---> 75\u001b[0m eps \u001b[39m=\u001b[39m nn_model(x_i, t_is, c_i)\n\u001b[1;32m 76\u001b[0m eps1 \u001b[39m=\u001b[39m eps[:n_sample]\n\u001b[1;32m 77\u001b[0m eps2 \u001b[39m=\u001b[39m eps[n_sample:]\n",
|
|
|
|
|
|
|
| 1672 |
"File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
|
| 1673 |
+
"Cell \u001b[0;32mIn[17], line 241\u001b[0m, in \u001b[0;36mContextUnet.forward\u001b[0;34m(self, x, timesteps, y)\u001b[0m\n\u001b[1;32m 237\u001b[0m \u001b[39mfor\u001b[39;00m module \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39moutput_blocks:\n\u001b[1;32m 238\u001b[0m \u001b[39m# print(\"for module in self.output_blocks, h.shape =\", h.shape)\u001b[39;00m\n\u001b[1;32m 239\u001b[0m \u001b[39m# print(\"len(hs) =\", len(hs), \", hs[-1].shape =\", hs[-1].shape)\u001b[39;00m\n\u001b[1;32m 240\u001b[0m h \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mcat([h, hs\u001b[39m.\u001b[39mpop()], dim\u001b[39m=\u001b[39m\u001b[39m1\u001b[39m)\n\u001b[0;32m--> 241\u001b[0m h \u001b[39m=\u001b[39m module(h, emb)\n\u001b[1;32m 242\u001b[0m \u001b[39m# print(\"module decoder, h.shape =\", h.shape)\u001b[39;00m\n\u001b[1;32m 244\u001b[0m h \u001b[39m=\u001b[39m h\u001b[39m.\u001b[39mtype(x\u001b[39m.\u001b[39mdtype)\n",
|
| 1674 |
"File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
|
| 1675 |
+
"Cell \u001b[0;32mIn[12], line 7\u001b[0m, in \u001b[0;36mTimestepEmbedSequential.forward\u001b[0;34m(self, x, emb, encoder_out)\u001b[0m\n\u001b[1;32m 5\u001b[0m x \u001b[39m=\u001b[39m layer(x, emb)\n\u001b[1;32m 6\u001b[0m \u001b[39melif\u001b[39;00m \u001b[39misinstance\u001b[39m(layer, AttentionBlock):\n\u001b[0;32m----> 7\u001b[0m x \u001b[39m=\u001b[39m layer(x, encoder_out)\n\u001b[1;32m 8\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 9\u001b[0m x \u001b[39m=\u001b[39m layer(x)\n",
|
| 1676 |
"File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
|
| 1677 |
+
"Cell \u001b[0;32mIn[15], line 37\u001b[0m, in \u001b[0;36mAttentionBlock.forward\u001b[0;34m(self, x, encoder_out)\u001b[0m\n\u001b[1;32m 35\u001b[0m h \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mattention(qkv, encoder_out)\n\u001b[1;32m 36\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m---> 37\u001b[0m h \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mattention(qkv)\n\u001b[1;32m 38\u001b[0m \u001b[39m# print(\"AttentionBlock, before proj_out, torch.unique(h).shape =\", torch.unique(h).shape)\u001b[39;00m\n\u001b[1;32m 39\u001b[0m h \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mproj_out(h)\n",
|
| 1678 |
"File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
|
| 1679 |
+
"Cell \u001b[0;32mIn[14], line 22\u001b[0m, in \u001b[0;36mQKVAttention.forward\u001b[0;34m(self, qkv, encoder_kv)\u001b[0m\n\u001b[1;32m 20\u001b[0m scale \u001b[39m=\u001b[39m \u001b[39m1\u001b[39m \u001b[39m/\u001b[39m math\u001b[39m.\u001b[39msqrt(math\u001b[39m.\u001b[39msqrt(ch))\n\u001b[1;32m 21\u001b[0m weight \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39meinsum(\u001b[39m\"\u001b[39m\u001b[39mbct,bcs->bts\u001b[39m\u001b[39m\"\u001b[39m, q\u001b[39m*\u001b[39mscale, k\u001b[39m*\u001b[39mscale)\n\u001b[0;32m---> 22\u001b[0m weight \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39;49msoftmax(weight\u001b[39m.\u001b[39;49mfloat(), dim\u001b[39m=\u001b[39;49m\u001b[39m-\u001b[39;49m\u001b[39m1\u001b[39;49m)\u001b[39m.\u001b[39mtype(weight\u001b[39m.\u001b[39mdtype)\n\u001b[1;32m 24\u001b[0m a \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39meinsum(\u001b[39m\"\u001b[39m\u001b[39mbts,bcs->bct\u001b[39m\u001b[39m\"\u001b[39m, weight, v)\n\u001b[1;32m 25\u001b[0m \u001b[39mreturn\u001b[39;00m a\u001b[39m.\u001b[39mreshape(bs, \u001b[39m-\u001b[39m\u001b[39m1\u001b[39m, length)\n",
|
| 1680 |
+
"\u001b[0;31mRuntimeError\u001b[0m: CUDA out of memory. Tried to allocate 640.00 MiB (GPU 0; 23.64 GiB total capacity; 21.86 GiB already allocated; 222.50 MiB free; 22.44 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF"
|
|
|
|
| 1681 |
]
|
| 1682 |
}
|
| 1683 |
],
|