Xsmos commited on
Commit
e1f900a
·
verified ·
1 Parent(s): 4192c13
Files changed (1) hide show
  1. diffusion.ipynb +61 -66
diffusion.ipynb CHANGED
@@ -32,7 +32,7 @@
32
  {
33
  "data": {
34
  "application/vnd.jupyter.widget-view+json": {
35
- "model_id": "b9a1289a64b14b6e9aebf8148c38c9a8",
36
  "version_major": 2,
37
  "version_minor": 0
38
  },
@@ -51,7 +51,7 @@
51
  },
52
  {
53
  "cell_type": "code",
54
- "execution_count": 2,
55
  "metadata": {},
56
  "outputs": [],
57
  "source": [
@@ -85,7 +85,7 @@
85
  },
86
  {
87
  "cell_type": "code",
88
- "execution_count": 3,
89
  "metadata": {},
90
  "outputs": [],
91
  "source": [
@@ -94,7 +94,7 @@
94
  },
95
  {
96
  "cell_type": "code",
97
- "execution_count": 4,
98
  "metadata": {},
99
  "outputs": [],
100
  "source": [
@@ -191,7 +191,7 @@
191
  },
192
  {
193
  "cell_type": "code",
194
- "execution_count": 5,
195
  "metadata": {},
196
  "outputs": [],
197
  "source": [
@@ -201,7 +201,7 @@
201
  },
202
  {
203
  "cell_type": "code",
204
- "execution_count": 6,
205
  "metadata": {},
206
  "outputs": [],
207
  "source": [
@@ -234,7 +234,7 @@
234
  },
235
  {
236
  "cell_type": "code",
237
- "execution_count": 7,
238
  "metadata": {},
239
  "outputs": [],
240
  "source": [
@@ -279,7 +279,7 @@
279
  " n_sample = len(params) #params.shape[0]\n",
280
  " # print(\"params.shape[0], len(params)\", params.shape[0], len(params))\n",
281
  " x_i = torch.randn(n_sample, *self.img_shape).to(device)\n",
282
- " print(\"x_i.shape =\", x_i.shape)\n",
283
  " # print(\"x_i.shape =\", x_i.shape)\n",
284
  " if guide_w != -1:\n",
285
  " c_i = params\n",
@@ -311,7 +311,7 @@
311
  " t_is = t_is.repeat(2)\n",
312
  "\n",
313
  " # split predictions and compute weighting\n",
314
- " print(\"nn_model input shape\", x_i.shape, t_is.shape, c_i.shape)\n",
315
  " eps = nn_model(x_i, t_is, c_i)\n",
316
  " eps1 = eps[:n_sample]\n",
317
  " eps2 = eps[n_sample:]\n",
@@ -341,7 +341,7 @@
341
  },
342
  {
343
  "cell_type": "code",
344
- "execution_count": 8,
345
  "metadata": {},
346
  "outputs": [],
347
  "source": [
@@ -383,7 +383,7 @@
383
  },
384
  {
385
  "cell_type": "code",
386
- "execution_count": 9,
387
  "metadata": {},
388
  "outputs": [],
389
  "source": [
@@ -408,7 +408,7 @@
408
  },
409
  {
410
  "cell_type": "code",
411
- "execution_count": 10,
412
  "metadata": {},
413
  "outputs": [],
414
  "source": [
@@ -438,7 +438,7 @@
438
  },
439
  {
440
  "cell_type": "code",
441
- "execution_count": 11,
442
  "metadata": {},
443
  "outputs": [],
444
  "source": [
@@ -453,7 +453,7 @@
453
  },
454
  {
455
  "cell_type": "code",
456
- "execution_count": 12,
457
  "metadata": {},
458
  "outputs": [],
459
  "source": [
@@ -467,7 +467,7 @@
467
  },
468
  {
469
  "cell_type": "code",
470
- "execution_count": 13,
471
  "metadata": {},
472
  "outputs": [],
473
  "source": [
@@ -485,7 +485,7 @@
485
  },
486
  {
487
  "cell_type": "code",
488
- "execution_count": 14,
489
  "metadata": {},
490
  "outputs": [],
491
  "source": [
@@ -567,7 +567,7 @@
567
  },
568
  {
569
  "cell_type": "code",
570
- "execution_count": 15,
571
  "metadata": {},
572
  "outputs": [],
573
  "source": [
@@ -600,7 +600,7 @@
600
  },
601
  {
602
  "cell_type": "code",
603
- "execution_count": 16,
604
  "metadata": {},
605
  "outputs": [],
606
  "source": [
@@ -649,7 +649,7 @@
649
  },
650
  {
651
  "cell_type": "code",
652
- "execution_count": 17,
653
  "metadata": {},
654
  "outputs": [],
655
  "source": [
@@ -678,7 +678,7 @@
678
  },
679
  {
680
  "cell_type": "code",
681
- "execution_count": 18,
682
  "metadata": {},
683
  "outputs": [],
684
  "source": [
@@ -934,7 +934,7 @@
934
  },
935
  {
936
  "cell_type": "code",
937
- "execution_count": 19,
938
  "metadata": {},
939
  "outputs": [],
940
  "source": [
@@ -964,7 +964,7 @@
964
  },
965
  {
966
  "cell_type": "code",
967
- "execution_count": 20,
968
  "metadata": {},
969
  "outputs": [],
970
  "source": [
@@ -1030,7 +1030,7 @@
1030
  },
1031
  {
1032
  "cell_type": "code",
1033
- "execution_count": 21,
1034
  "metadata": {},
1035
  "outputs": [],
1036
  "source": [
@@ -1040,7 +1040,7 @@
1040
  },
1041
  {
1042
  "cell_type": "code",
1043
- "execution_count": 22,
1044
  "metadata": {},
1045
  "outputs": [],
1046
  "source": [
@@ -1049,7 +1049,7 @@
1049
  },
1050
  {
1051
  "cell_type": "code",
1052
- "execution_count": 23,
1053
  "metadata": {},
1054
  "outputs": [],
1055
  "source": [
@@ -1073,7 +1073,7 @@
1073
  },
1074
  {
1075
  "cell_type": "code",
1076
- "execution_count": 24,
1077
  "metadata": {},
1078
  "outputs": [],
1079
  "source": [
@@ -1271,7 +1271,7 @@
1271
  },
1272
  {
1273
  "cell_type": "code",
1274
- "execution_count": 25,
1275
  "metadata": {},
1276
  "outputs": [
1277
  {
@@ -1336,6 +1336,7 @@
1336
  " self.dataloader = DataLoader(dataset, batch_size=self.config.batch_size, shuffle=True)\n",
1337
  " # del dataset\n",
1338
  " self.accelerate(self.config)\n",
 
1339
  "\n",
1340
  " def accelerate(self, config):\n",
1341
  " self.accelerator = Accelerator(\n",
@@ -1429,11 +1430,12 @@
1429
  " print('saved model at ' + self.config.output_dir + f\"model_state_{ep:02d}.pth\")\n",
1430
  " # print('saved model at ' + config.save_dir + f\"model_epoch_{ep}_test_{config.run_name}.pth\")\n",
1431
  "\n",
1432
- " def sample(self, file, params=None, ema=False, entire=False):\n",
1433
  " # n_sample = params.shape[0]\n",
1434
- " params = params or torch.tensor([0.2,0.8]).repeat(10,1)\n",
1435
  " assert params.dim() == 2, \"params must be a 2D torch.tensor\"\n",
1436
- " print(\"params.shape =\", params.shape)\n",
 
1437
  " # print(\"len(params) =\", len(params))\n",
1438
  " model = self.ema_model if ema else self.nn_model\n",
1439
  " # params = torch.tile(params, (n_sample,1)).to(device)\n",
@@ -1451,7 +1453,7 @@
1451
  },
1452
  {
1453
  "cell_type": "code",
1454
- "execution_count": 26,
1455
  "metadata": {},
1456
  "outputs": [
1457
  {
@@ -1478,14 +1480,14 @@
1478
  "output_type": "stream",
1479
  "text": [
1480
  "params loaded: (200, 2)\n",
1481
- "images rescaled to [-1.0, 1.0288152694702148]\n",
1482
- "params rescaled to [0.0, 0.9958380936251952]\n"
1483
  ]
1484
  },
1485
  {
1486
  "data": {
1487
  "application/vnd.jupyter.widget-view+json": {
1488
- "model_id": "fa014381477545e4a0fb6976d0cd3e0e",
1489
  "version_major": 2,
1490
  "version_minor": 0
1491
  },
@@ -1499,7 +1501,7 @@
1499
  {
1500
  "data": {
1501
  "application/vnd.jupyter.widget-view+json": {
1502
- "model_id": "8bdd69fdc66c422c99b59032ef36b1f4",
1503
  "version_major": 2,
1504
  "version_minor": 0
1505
  },
@@ -1513,7 +1515,7 @@
1513
  {
1514
  "data": {
1515
  "application/vnd.jupyter.widget-view+json": {
1516
- "model_id": "3cea3f68bad54e9bb4a63a897bea12dd",
1517
  "version_major": 2,
1518
  "version_minor": 0
1519
  },
@@ -1527,7 +1529,7 @@
1527
  {
1528
  "data": {
1529
  "application/vnd.jupyter.widget-view+json": {
1530
- "model_id": "26f29f7e73fd493fa62d52e6441589f1",
1531
  "version_major": 2,
1532
  "version_minor": 0
1533
  },
@@ -1541,7 +1543,7 @@
1541
  {
1542
  "data": {
1543
  "application/vnd.jupyter.widget-view+json": {
1544
- "model_id": "439448549c8840c0a970861a66e83264",
1545
  "version_major": 2,
1546
  "version_minor": 0
1547
  },
@@ -1555,7 +1557,7 @@
1555
  {
1556
  "data": {
1557
  "application/vnd.jupyter.widget-view+json": {
1558
- "model_id": "acb1905d443a451c937cc0c114e558c9",
1559
  "version_major": 2,
1560
  "version_minor": 0
1561
  },
@@ -1569,7 +1571,7 @@
1569
  {
1570
  "data": {
1571
  "application/vnd.jupyter.widget-view+json": {
1572
- "model_id": "4f81671b7e1b43218b5aa6620e59503b",
1573
  "version_major": 2,
1574
  "version_minor": 0
1575
  },
@@ -1583,7 +1585,7 @@
1583
  {
1584
  "data": {
1585
  "application/vnd.jupyter.widget-view+json": {
1586
- "model_id": "66673bbcaffc412796022e1ed372b6cf",
1587
  "version_major": 2,
1588
  "version_minor": 0
1589
  },
@@ -1597,7 +1599,7 @@
1597
  {
1598
  "data": {
1599
  "application/vnd.jupyter.widget-view+json": {
1600
- "model_id": "2e04dfb1e9da4895a6f83516684bec03",
1601
  "version_major": 2,
1602
  "version_minor": 0
1603
  },
@@ -1611,7 +1613,7 @@
1611
  {
1612
  "data": {
1613
  "application/vnd.jupyter.widget-view+json": {
1614
- "model_id": "1dc4006116ca46e0a2b74a74858fa57a",
1615
  "version_major": 2,
1616
  "version_minor": 0
1617
  },
@@ -1629,21 +1631,24 @@
1629
  },
1630
  {
1631
  "cell_type": "code",
1632
- "execution_count": 128,
1633
  "metadata": {},
1634
  "outputs": [
1635
  {
1636
  "name": "stdout",
1637
  "output_type": "stream",
1638
  "text": [
1639
- "params.shape = torch.Size([10, 2])\n",
1640
- "x_i.shape = torch.Size([10, 1, 64, 512])\n"
 
 
 
1641
  ]
1642
  },
1643
  {
1644
  "data": {
1645
  "application/vnd.jupyter.widget-view+json": {
1646
- "model_id": "44830b578bb94dc5bf65eafaff19b1f5",
1647
  "version_major": 2,
1648
  "version_minor": 0
1649
  },
@@ -1654,35 +1659,25 @@
1654
  "metadata": {},
1655
  "output_type": "display_data"
1656
  },
1657
- {
1658
- "name": "stdout",
1659
- "output_type": "stream",
1660
- "text": [
1661
- "nn_model input shape torch.Size([20, 1, 64, 512]) torch.Size([20]) torch.Size([20, 2])\n"
1662
- ]
1663
- },
1664
  {
1665
  "ename": "RuntimeError",
1666
- "evalue": "CUDA out of memory. Tried to allocate 320.00 MiB (GPU 0; 23.64 GiB total capacity; 21.71 GiB already allocated; 170.50 MiB free; 22.45 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF",
1667
  "output_type": "error",
1668
  "traceback": [
1669
  "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
1670
  "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
1671
- "Cell \u001b[0;32mIn[128], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m ddpm21cm\u001b[39m.\u001b[39;49msample(\u001b[39m\"\u001b[39;49m\u001b[39m./outputs/model_state_09.pth\u001b[39;49m\u001b[39m\"\u001b[39;49m)\n",
1672
- "Cell \u001b[0;32mIn[127], line 154\u001b[0m, in \u001b[0;36mDDPM21CM.sample\u001b[0;34m(self, file, params, ema, entire)\u001b[0m\n\u001b[1;32m 151\u001b[0m model \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mema_model \u001b[39mif\u001b[39;00m ema \u001b[39melse\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mnn_model\n\u001b[1;32m 152\u001b[0m \u001b[39m# params = torch.tile(params, (n_sample,1)).to(device)\u001b[39;00m\n\u001b[0;32m--> 154\u001b[0m x_last, x_entire \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mddpm\u001b[39m.\u001b[39;49msample(model, params\u001b[39m=\u001b[39;49mparams\u001b[39m.\u001b[39;49mto(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mconfig\u001b[39m.\u001b[39;49mdevice), device\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mconfig\u001b[39m.\u001b[39;49mdevice, guide_w\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mconfig\u001b[39m.\u001b[39;49mguide_w)\n\u001b[1;32m 156\u001b[0m np\u001b[39m.\u001b[39msave(os\u001b[39m.\u001b[39mpath\u001b[39m.\u001b[39mjoin(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconfig\u001b[39m.\u001b[39moutput_dir, \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m{\u001b[39;00m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconfig\u001b[39m.\u001b[39mrun_name\u001b[39m}\u001b[39;00m\u001b[39m{\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39mema\u001b[39m\u001b[39m'\u001b[39m\u001b[39m \u001b[39m\u001b[39mif\u001b[39;00m\u001b[39m \u001b[39mema\u001b[39m \u001b[39m\u001b[39melse\u001b[39;00m\u001b[39m \u001b[39m\u001b[39mNone\u001b[39;00m\u001b[39m}\u001b[39;00m\u001b[39m.npy\u001b[39m\u001b[39m\"\u001b[39m), x_last)\n\u001b[1;32m 157\u001b[0m \u001b[39mif\u001b[39;00m entire:\n",
1673
- "Cell \u001b[0;32mIn[111], line 75\u001b[0m, in \u001b[0;36mDDPMScheduler.sample\u001b[0;34m(self, nn_model, params, device, guide_w)\u001b[0m\n\u001b[1;32m 73\u001b[0m \u001b[39m# split predictions and compute weighting\u001b[39;00m\n\u001b[1;32m 74\u001b[0m \u001b[39mprint\u001b[39m(\u001b[39m\"\u001b[39m\u001b[39mnn_model input shape\u001b[39m\u001b[39m\"\u001b[39m, x_i\u001b[39m.\u001b[39mshape, t_is\u001b[39m.\u001b[39mshape, c_i\u001b[39m.\u001b[39mshape)\n\u001b[0;32m---> 75\u001b[0m eps \u001b[39m=\u001b[39m nn_model(x_i, t_is, c_i)\n\u001b[1;32m 76\u001b[0m eps1 \u001b[39m=\u001b[39m eps[:n_sample]\n\u001b[1;32m 77\u001b[0m eps2 \u001b[39m=\u001b[39m eps[n_sample:]\n",
1674
- "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
1675
- "Cell \u001b[0;32mIn[122], line 230\u001b[0m, in \u001b[0;36mContextUnet.forward\u001b[0;34m(self, x, timesteps, y)\u001b[0m\n\u001b[1;32m 228\u001b[0m \u001b[39m# print(\"0,h.shape =\", h.shape)\u001b[39;00m\n\u001b[1;32m 229\u001b[0m \u001b[39mfor\u001b[39;00m module \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39minput_blocks:\n\u001b[0;32m--> 230\u001b[0m h \u001b[39m=\u001b[39m module(h, emb)\n\u001b[1;32m 231\u001b[0m hs\u001b[39m.\u001b[39mappend(h)\n\u001b[1;32m 232\u001b[0m \u001b[39m# print(\"module encoder, h.shape =\", h.shape)\u001b[39;00m\n\u001b[1;32m 233\u001b[0m \u001b[39m# print(\"2,h.shape =\", h.shape)\u001b[39;00m\n",
1676
  "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
1677
- "Cell \u001b[0;32mIn[117], line 5\u001b[0m, in \u001b[0;36mTimestepEmbedSequential.forward\u001b[0;34m(self, x, emb, encoder_out)\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[39mfor\u001b[39;00m layer \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m:\n\u001b[1;32m 4\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(layer, TimestepBlock):\n\u001b[0;32m----> 5\u001b[0m x \u001b[39m=\u001b[39m layer(x, emb)\n\u001b[1;32m 6\u001b[0m \u001b[39melif\u001b[39;00m \u001b[39misinstance\u001b[39m(layer, AttentionBlock):\n\u001b[1;32m 7\u001b[0m x \u001b[39m=\u001b[39m layer(x, encoder_out)\n",
1678
  "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
1679
- "Cell \u001b[0;32mIn[118], line 72\u001b[0m, in \u001b[0;36mResBlock.forward\u001b[0;34m(self, x, emb)\u001b[0m\n\u001b[1;32m 70\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 71\u001b[0m h \u001b[39m+\u001b[39m\u001b[39m=\u001b[39m emb_out\n\u001b[0;32m---> 72\u001b[0m h \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mout_layers(h)\n\u001b[1;32m 73\u001b[0m \u001b[39m# print(\"ResBlock, torch.unique(h).shape =\", torch.unique(h).shape)\u001b[39;00m\n\u001b[1;32m 74\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mskip_connection(x) \u001b[39m+\u001b[39m h\n",
1680
  "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
1681
- "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/container.py:139\u001b[0m, in \u001b[0;36mSequential.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 137\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mforward\u001b[39m(\u001b[39mself\u001b[39m, \u001b[39minput\u001b[39m):\n\u001b[1;32m 138\u001b[0m \u001b[39mfor\u001b[39;00m module \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m:\n\u001b[0;32m--> 139\u001b[0m \u001b[39minput\u001b[39m \u001b[39m=\u001b[39m module(\u001b[39minput\u001b[39;49m)\n\u001b[1;32m 140\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39minput\u001b[39m\n",
1682
  "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
1683
- "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/conv.py:457\u001b[0m, in \u001b[0;36mConv2d.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 456\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mforward\u001b[39m(\u001b[39mself\u001b[39m, \u001b[39minput\u001b[39m: Tensor) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Tensor:\n\u001b[0;32m--> 457\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_conv_forward(\u001b[39minput\u001b[39;49m, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mweight, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mbias)\n",
1684
- "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/conv.py:453\u001b[0m, in \u001b[0;36mConv2d._conv_forward\u001b[0;34m(self, input, weight, bias)\u001b[0m\n\u001b[1;32m 449\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mpadding_mode \u001b[39m!=\u001b[39m \u001b[39m'\u001b[39m\u001b[39mzeros\u001b[39m\u001b[39m'\u001b[39m:\n\u001b[1;32m 450\u001b[0m \u001b[39mreturn\u001b[39;00m F\u001b[39m.\u001b[39mconv2d(F\u001b[39m.\u001b[39mpad(\u001b[39minput\u001b[39m, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_reversed_padding_repeated_twice, mode\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mpadding_mode),\n\u001b[1;32m 451\u001b[0m weight, bias, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mstride,\n\u001b[1;32m 452\u001b[0m _pair(\u001b[39m0\u001b[39m), \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdilation, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mgroups)\n\u001b[0;32m--> 453\u001b[0m \u001b[39mreturn\u001b[39;00m F\u001b[39m.\u001b[39;49mconv2d(\u001b[39minput\u001b[39;49m, weight, bias, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mstride,\n\u001b[1;32m 454\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mpadding, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mdilation, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mgroups)\n",
1685
- "\u001b[0;31mRuntimeError\u001b[0m: CUDA out of memory. Tried to allocate 320.00 MiB (GPU 0; 23.64 GiB total capacity; 21.71 GiB already allocated; 170.50 MiB free; 22.45 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF"
1686
  ]
1687
  }
1688
  ],
 
32
  {
33
  "data": {
34
  "application/vnd.jupyter.widget-view+json": {
35
+ "model_id": "237469b68706428ea8491187899b7943",
36
  "version_major": 2,
37
  "version_minor": 0
38
  },
 
51
  },
52
  {
53
  "cell_type": "code",
54
+ "execution_count": 1,
55
  "metadata": {},
56
  "outputs": [],
57
  "source": [
 
85
  },
86
  {
87
  "cell_type": "code",
88
+ "execution_count": 2,
89
  "metadata": {},
90
  "outputs": [],
91
  "source": [
 
94
  },
95
  {
96
  "cell_type": "code",
97
+ "execution_count": 3,
98
  "metadata": {},
99
  "outputs": [],
100
  "source": [
 
191
  },
192
  {
193
  "cell_type": "code",
194
+ "execution_count": 4,
195
  "metadata": {},
196
  "outputs": [],
197
  "source": [
 
201
  },
202
  {
203
  "cell_type": "code",
204
+ "execution_count": 5,
205
  "metadata": {},
206
  "outputs": [],
207
  "source": [
 
234
  },
235
  {
236
  "cell_type": "code",
237
+ "execution_count": 6,
238
  "metadata": {},
239
  "outputs": [],
240
  "source": [
 
279
  " n_sample = len(params) #params.shape[0]\n",
280
  " # print(\"params.shape[0], len(params)\", params.shape[0], len(params))\n",
281
  " x_i = torch.randn(n_sample, *self.img_shape).to(device)\n",
282
+ " # print(\"x_i.shape =\", x_i.shape)\n",
283
  " # print(\"x_i.shape =\", x_i.shape)\n",
284
  " if guide_w != -1:\n",
285
  " c_i = params\n",
 
311
  " t_is = t_is.repeat(2)\n",
312
  "\n",
313
  " # split predictions and compute weighting\n",
314
+ " # print(\"nn_model input shape\", x_i.shape, t_is.shape, c_i.shape)\n",
315
  " eps = nn_model(x_i, t_is, c_i)\n",
316
  " eps1 = eps[:n_sample]\n",
317
  " eps2 = eps[n_sample:]\n",
 
341
  },
342
  {
343
  "cell_type": "code",
344
+ "execution_count": 7,
345
  "metadata": {},
346
  "outputs": [],
347
  "source": [
 
383
  },
384
  {
385
  "cell_type": "code",
386
+ "execution_count": 8,
387
  "metadata": {},
388
  "outputs": [],
389
  "source": [
 
408
  },
409
  {
410
  "cell_type": "code",
411
+ "execution_count": 9,
412
  "metadata": {},
413
  "outputs": [],
414
  "source": [
 
438
  },
439
  {
440
  "cell_type": "code",
441
+ "execution_count": 10,
442
  "metadata": {},
443
  "outputs": [],
444
  "source": [
 
453
  },
454
  {
455
  "cell_type": "code",
456
+ "execution_count": 11,
457
  "metadata": {},
458
  "outputs": [],
459
  "source": [
 
467
  },
468
  {
469
  "cell_type": "code",
470
+ "execution_count": 12,
471
  "metadata": {},
472
  "outputs": [],
473
  "source": [
 
485
  },
486
  {
487
  "cell_type": "code",
488
+ "execution_count": 13,
489
  "metadata": {},
490
  "outputs": [],
491
  "source": [
 
567
  },
568
  {
569
  "cell_type": "code",
570
+ "execution_count": 14,
571
  "metadata": {},
572
  "outputs": [],
573
  "source": [
 
600
  },
601
  {
602
  "cell_type": "code",
603
+ "execution_count": 15,
604
  "metadata": {},
605
  "outputs": [],
606
  "source": [
 
649
  },
650
  {
651
  "cell_type": "code",
652
+ "execution_count": 16,
653
  "metadata": {},
654
  "outputs": [],
655
  "source": [
 
678
  },
679
  {
680
  "cell_type": "code",
681
+ "execution_count": 17,
682
  "metadata": {},
683
  "outputs": [],
684
  "source": [
 
934
  },
935
  {
936
  "cell_type": "code",
937
+ "execution_count": 18,
938
  "metadata": {},
939
  "outputs": [],
940
  "source": [
 
964
  },
965
  {
966
  "cell_type": "code",
967
+ "execution_count": 19,
968
  "metadata": {},
969
  "outputs": [],
970
  "source": [
 
1030
  },
1031
  {
1032
  "cell_type": "code",
1033
+ "execution_count": 20,
1034
  "metadata": {},
1035
  "outputs": [],
1036
  "source": [
 
1040
  },
1041
  {
1042
  "cell_type": "code",
1043
+ "execution_count": 21,
1044
  "metadata": {},
1045
  "outputs": [],
1046
  "source": [
 
1049
  },
1050
  {
1051
  "cell_type": "code",
1052
+ "execution_count": 22,
1053
  "metadata": {},
1054
  "outputs": [],
1055
  "source": [
 
1073
  },
1074
  {
1075
  "cell_type": "code",
1076
+ "execution_count": 23,
1077
  "metadata": {},
1078
  "outputs": [],
1079
  "source": [
 
1271
  },
1272
  {
1273
  "cell_type": "code",
1274
+ "execution_count": 24,
1275
  "metadata": {},
1276
  "outputs": [
1277
  {
 
1336
  " self.dataloader = DataLoader(dataset, batch_size=self.config.batch_size, shuffle=True)\n",
1337
  " # del dataset\n",
1338
  " self.accelerate(self.config)\n",
1339
+ " del dataset\n",
1340
  "\n",
1341
  " def accelerate(self, config):\n",
1342
  " self.accelerator = Accelerator(\n",
 
1430
  " print('saved model at ' + self.config.output_dir + f\"model_state_{ep:02d}.pth\")\n",
1431
  " # print('saved model at ' + config.save_dir + f\"model_epoch_{ep}_test_{config.run_name}.pth\")\n",
1432
  "\n",
1433
+ " def sample(self, file, params:torch.tensor=None, ema=False, entire=False):\n",
1434
  " # n_sample = params.shape[0]\n",
1435
+ " params = params or torch.tensor([0.2,0.8]).repeat(2,1)\n",
1436
  " assert params.dim() == 2, \"params must be a 2D torch.tensor\"\n",
1437
+ " # print(\"params.shape =\", params.shape)\n",
1438
+ " print(\"params =\", params)\n",
1439
  " # print(\"len(params) =\", len(params))\n",
1440
  " model = self.ema_model if ema else self.nn_model\n",
1441
  " # params = torch.tile(params, (n_sample,1)).to(device)\n",
 
1453
  },
1454
  {
1455
  "cell_type": "code",
1456
+ "execution_count": 25,
1457
  "metadata": {},
1458
  "outputs": [
1459
  {
 
1480
  "output_type": "stream",
1481
  "text": [
1482
  "params loaded: (200, 2)\n",
1483
+ "images rescaled to [-1.0, 1.093963384628296]\n",
1484
+ "params rescaled to [0.0, 0.9996232858956116]\n"
1485
  ]
1486
  },
1487
  {
1488
  "data": {
1489
  "application/vnd.jupyter.widget-view+json": {
1490
+ "model_id": "86cf95c56e904304b70355f9cb7c9ed1",
1491
  "version_major": 2,
1492
  "version_minor": 0
1493
  },
 
1501
  {
1502
  "data": {
1503
  "application/vnd.jupyter.widget-view+json": {
1504
+ "model_id": "ed9760495855476fbd8cd1a7d8ad6128",
1505
  "version_major": 2,
1506
  "version_minor": 0
1507
  },
 
1515
  {
1516
  "data": {
1517
  "application/vnd.jupyter.widget-view+json": {
1518
+ "model_id": "f1d47d8f48e34900a36729a878892d25",
1519
  "version_major": 2,
1520
  "version_minor": 0
1521
  },
 
1529
  {
1530
  "data": {
1531
  "application/vnd.jupyter.widget-view+json": {
1532
+ "model_id": "d8d9e44be0ca4492ac1fdd0850008bed",
1533
  "version_major": 2,
1534
  "version_minor": 0
1535
  },
 
1543
  {
1544
  "data": {
1545
  "application/vnd.jupyter.widget-view+json": {
1546
+ "model_id": "8f163de4249f46ae94ce868269b3fa59",
1547
  "version_major": 2,
1548
  "version_minor": 0
1549
  },
 
1557
  {
1558
  "data": {
1559
  "application/vnd.jupyter.widget-view+json": {
1560
+ "model_id": "a9d040b034734589acdbfe235d8d22ce",
1561
  "version_major": 2,
1562
  "version_minor": 0
1563
  },
 
1571
  {
1572
  "data": {
1573
  "application/vnd.jupyter.widget-view+json": {
1574
+ "model_id": "cb7c7f4ebde5475d9b21db9838cce7d7",
1575
  "version_major": 2,
1576
  "version_minor": 0
1577
  },
 
1585
  {
1586
  "data": {
1587
  "application/vnd.jupyter.widget-view+json": {
1588
+ "model_id": "dcfb2b71bc98404db01b10c1151fcc0b",
1589
  "version_major": 2,
1590
  "version_minor": 0
1591
  },
 
1599
  {
1600
  "data": {
1601
  "application/vnd.jupyter.widget-view+json": {
1602
+ "model_id": "e46ce03a41854df081bd24b257de3b8a",
1603
  "version_major": 2,
1604
  "version_minor": 0
1605
  },
 
1613
  {
1614
  "data": {
1615
  "application/vnd.jupyter.widget-view+json": {
1616
+ "model_id": "0833d662b39d487386cec129c48548b4",
1617
  "version_major": 2,
1618
  "version_minor": 0
1619
  },
 
1631
  },
1632
  {
1633
  "cell_type": "code",
1634
+ "execution_count": null,
1635
  "metadata": {},
1636
  "outputs": [
1637
  {
1638
  "name": "stdout",
1639
  "output_type": "stream",
1640
  "text": [
1641
+ "params = tensor([[0.2000, 0.8000],\n",
1642
+ " [0.2000, 0.8000],\n",
1643
+ " [0.2000, 0.8000],\n",
1644
+ " [0.2000, 0.8000],\n",
1645
+ " [0.2000, 0.8000]])\n"
1646
  ]
1647
  },
1648
  {
1649
  "data": {
1650
  "application/vnd.jupyter.widget-view+json": {
1651
+ "model_id": "12c488597076479fbd2857b20d69da29",
1652
  "version_major": 2,
1653
  "version_minor": 0
1654
  },
 
1659
  "metadata": {},
1660
  "output_type": "display_data"
1661
  },
 
 
 
 
 
 
 
1662
  {
1663
  "ename": "RuntimeError",
1664
+ "evalue": "CUDA out of memory. Tried to allocate 640.00 MiB (GPU 0; 23.64 GiB total capacity; 21.86 GiB already allocated; 222.50 MiB free; 22.44 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF",
1665
  "output_type": "error",
1666
  "traceback": [
1667
  "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
1668
  "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
1669
+ "Cell \u001b[0;32mIn[25], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m ddpm21cm\u001b[39m.\u001b[39;49msample(\u001b[39m\"\u001b[39;49m\u001b[39m./outputs/model_state_09.pth\u001b[39;49m\u001b[39m\"\u001b[39;49m)\n",
1670
+ "Cell \u001b[0;32mIn[24], line 156\u001b[0m, in \u001b[0;36mDDPM21CM.sample\u001b[0;34m(self, file, params, ema, entire)\u001b[0m\n\u001b[1;32m 153\u001b[0m model \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mema_model \u001b[39mif\u001b[39;00m ema \u001b[39melse\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mnn_model\n\u001b[1;32m 154\u001b[0m \u001b[39m# params = torch.tile(params, (n_sample,1)).to(device)\u001b[39;00m\n\u001b[0;32m--> 156\u001b[0m x_last, x_entire \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mddpm\u001b[39m.\u001b[39;49msample(model, params\u001b[39m=\u001b[39;49mparams\u001b[39m.\u001b[39;49mto(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mconfig\u001b[39m.\u001b[39;49mdevice), device\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mconfig\u001b[39m.\u001b[39;49mdevice, guide_w\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mconfig\u001b[39m.\u001b[39;49mguide_w)\n\u001b[1;32m 158\u001b[0m np\u001b[39m.\u001b[39msave(os\u001b[39m.\u001b[39mpath\u001b[39m.\u001b[39mjoin(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconfig\u001b[39m.\u001b[39moutput_dir, \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m{\u001b[39;00m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconfig\u001b[39m.\u001b[39mrun_name\u001b[39m}\u001b[39;00m\u001b[39m{\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39mema\u001b[39m\u001b[39m'\u001b[39m\u001b[39m \u001b[39m\u001b[39mif\u001b[39;00m\u001b[39m \u001b[39mema\u001b[39m \u001b[39m\u001b[39melse\u001b[39;00m\u001b[39m \u001b[39m\u001b[39mNone\u001b[39;00m\u001b[39m}\u001b[39;00m\u001b[39m.npy\u001b[39m\u001b[39m\"\u001b[39m), x_last)\n\u001b[1;32m 159\u001b[0m \u001b[39mif\u001b[39;00m entire:\n",
1671
+ "Cell \u001b[0;32mIn[6], line 75\u001b[0m, in \u001b[0;36mDDPMScheduler.sample\u001b[0;34m(self, nn_model, params, device, guide_w)\u001b[0m\n\u001b[1;32m 71\u001b[0m t_is \u001b[39m=\u001b[39m t_is\u001b[39m.\u001b[39mrepeat(\u001b[39m2\u001b[39m)\n\u001b[1;32m 73\u001b[0m \u001b[39m# split predictions and compute weighting\u001b[39;00m\n\u001b[1;32m 74\u001b[0m \u001b[39m# print(\"nn_model input shape\", x_i.shape, t_is.shape, c_i.shape)\u001b[39;00m\n\u001b[0;32m---> 75\u001b[0m eps \u001b[39m=\u001b[39m nn_model(x_i, t_is, c_i)\n\u001b[1;32m 76\u001b[0m eps1 \u001b[39m=\u001b[39m eps[:n_sample]\n\u001b[1;32m 77\u001b[0m eps2 \u001b[39m=\u001b[39m eps[n_sample:]\n",
 
 
1672
  "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
1673
+ "Cell \u001b[0;32mIn[17], line 241\u001b[0m, in \u001b[0;36mContextUnet.forward\u001b[0;34m(self, x, timesteps, y)\u001b[0m\n\u001b[1;32m 237\u001b[0m \u001b[39mfor\u001b[39;00m module \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39moutput_blocks:\n\u001b[1;32m 238\u001b[0m \u001b[39m# print(\"for module in self.output_blocks, h.shape =\", h.shape)\u001b[39;00m\n\u001b[1;32m 239\u001b[0m \u001b[39m# print(\"len(hs) =\", len(hs), \", hs[-1].shape =\", hs[-1].shape)\u001b[39;00m\n\u001b[1;32m 240\u001b[0m h \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mcat([h, hs\u001b[39m.\u001b[39mpop()], dim\u001b[39m=\u001b[39m\u001b[39m1\u001b[39m)\n\u001b[0;32m--> 241\u001b[0m h \u001b[39m=\u001b[39m module(h, emb)\n\u001b[1;32m 242\u001b[0m \u001b[39m# print(\"module decoder, h.shape =\", h.shape)\u001b[39;00m\n\u001b[1;32m 244\u001b[0m h \u001b[39m=\u001b[39m h\u001b[39m.\u001b[39mtype(x\u001b[39m.\u001b[39mdtype)\n",
1674
  "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
1675
+ "Cell \u001b[0;32mIn[12], line 7\u001b[0m, in \u001b[0;36mTimestepEmbedSequential.forward\u001b[0;34m(self, x, emb, encoder_out)\u001b[0m\n\u001b[1;32m 5\u001b[0m x \u001b[39m=\u001b[39m layer(x, emb)\n\u001b[1;32m 6\u001b[0m \u001b[39melif\u001b[39;00m \u001b[39misinstance\u001b[39m(layer, AttentionBlock):\n\u001b[0;32m----> 7\u001b[0m x \u001b[39m=\u001b[39m layer(x, encoder_out)\n\u001b[1;32m 8\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 9\u001b[0m x \u001b[39m=\u001b[39m layer(x)\n",
1676
  "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
1677
+ "Cell \u001b[0;32mIn[15], line 37\u001b[0m, in \u001b[0;36mAttentionBlock.forward\u001b[0;34m(self, x, encoder_out)\u001b[0m\n\u001b[1;32m 35\u001b[0m h \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mattention(qkv, encoder_out)\n\u001b[1;32m 36\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m---> 37\u001b[0m h \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mattention(qkv)\n\u001b[1;32m 38\u001b[0m \u001b[39m# print(\"AttentionBlock, before proj_out, torch.unique(h).shape =\", torch.unique(h).shape)\u001b[39;00m\n\u001b[1;32m 39\u001b[0m h \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mproj_out(h)\n",
1678
  "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
1679
+ "Cell \u001b[0;32mIn[14], line 22\u001b[0m, in \u001b[0;36mQKVAttention.forward\u001b[0;34m(self, qkv, encoder_kv)\u001b[0m\n\u001b[1;32m 20\u001b[0m scale \u001b[39m=\u001b[39m \u001b[39m1\u001b[39m \u001b[39m/\u001b[39m math\u001b[39m.\u001b[39msqrt(math\u001b[39m.\u001b[39msqrt(ch))\n\u001b[1;32m 21\u001b[0m weight \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39meinsum(\u001b[39m\"\u001b[39m\u001b[39mbct,bcs->bts\u001b[39m\u001b[39m\"\u001b[39m, q\u001b[39m*\u001b[39mscale, k\u001b[39m*\u001b[39mscale)\n\u001b[0;32m---> 22\u001b[0m weight \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39;49msoftmax(weight\u001b[39m.\u001b[39;49mfloat(), dim\u001b[39m=\u001b[39;49m\u001b[39m-\u001b[39;49m\u001b[39m1\u001b[39;49m)\u001b[39m.\u001b[39mtype(weight\u001b[39m.\u001b[39mdtype)\n\u001b[1;32m 24\u001b[0m a \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39meinsum(\u001b[39m\"\u001b[39m\u001b[39mbts,bcs->bct\u001b[39m\u001b[39m\"\u001b[39m, weight, v)\n\u001b[1;32m 25\u001b[0m \u001b[39mreturn\u001b[39;00m a\u001b[39m.\u001b[39mreshape(bs, \u001b[39m-\u001b[39m\u001b[39m1\u001b[39m, length)\n",
1680
+ "\u001b[0;31mRuntimeError\u001b[0m: CUDA out of memory. Tried to allocate 640.00 MiB (GPU 0; 23.64 GiB total capacity; 21.86 GiB already allocated; 222.50 MiB free; 22.44 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF"
 
1681
  ]
1682
  }
1683
  ],