Xsmos commited on
Commit
4192c13
·
verified ·
1 Parent(s): ac2c546
Files changed (1) hide show
  1. diffusion.ipynb +55 -49
diffusion.ipynb CHANGED
@@ -32,7 +32,7 @@
32
  {
33
  "data": {
34
  "application/vnd.jupyter.widget-view+json": {
35
- "model_id": "f925cb378800455fb1216e84a2900e58",
36
  "version_major": 2,
37
  "version_minor": 0
38
  },
@@ -234,7 +234,7 @@
234
  },
235
  {
236
  "cell_type": "code",
237
- "execution_count": 84,
238
  "metadata": {},
239
  "outputs": [],
240
  "source": [
@@ -278,7 +278,7 @@
278
  " def sample(self, nn_model, params, device, guide_w = 0):\n",
279
  " n_sample = len(params) #params.shape[0]\n",
280
  " # print(\"params.shape[0], len(params)\", params.shape[0], len(params))\n",
281
- " x_i = torch.randn(n_sample, *self.img_shape[1:]).to(device)\n",
282
  " print(\"x_i.shape =\", x_i.shape)\n",
283
  " # print(\"x_i.shape =\", x_i.shape)\n",
284
  " if guide_w != -1:\n",
@@ -299,7 +299,7 @@
299
  " t_is = torch.tensor([i]).to(device)\n",
300
  " t_is = t_is.repeat(n_sample)\n",
301
  "\n",
302
- " z = torch.randn(n_sample, *self.img_shape[1:]).to(device) if i > 0 else 0\n",
303
  "\n",
304
  " if guide_w == -1:\n",
305
  " # eps = nn_model(x_i, t_is, return_dict=False)[0]\n",
@@ -307,7 +307,7 @@
307
  " # x_i = 1/torch.sqrt(self.alpha_t[i])*(x_i-eps*self.beta_t[i]/torch.sqrt(1-self.bar_alpha_t[i])) + torch.sqrt(self.beta_t[i])*z\n",
308
  " else:\n",
309
  " # double batch\n",
310
- " x_i = x_i.repeat(2, *torch.ones(len(self.img_shape[1:]), dtype=int).tolist())\n",
311
  " t_is = t_is.repeat(2)\n",
312
  "\n",
313
  " # split predictions and compute weighting\n",
@@ -341,7 +341,7 @@
341
  },
342
  {
343
  "cell_type": "code",
344
- "execution_count": 85,
345
  "metadata": {},
346
  "outputs": [],
347
  "source": [
@@ -383,7 +383,7 @@
383
  },
384
  {
385
  "cell_type": "code",
386
- "execution_count": 86,
387
  "metadata": {},
388
  "outputs": [],
389
  "source": [
@@ -408,7 +408,7 @@
408
  },
409
  {
410
  "cell_type": "code",
411
- "execution_count": 87,
412
  "metadata": {},
413
  "outputs": [],
414
  "source": [
@@ -438,7 +438,7 @@
438
  },
439
  {
440
  "cell_type": "code",
441
- "execution_count": 88,
442
  "metadata": {},
443
  "outputs": [],
444
  "source": [
@@ -453,7 +453,7 @@
453
  },
454
  {
455
  "cell_type": "code",
456
- "execution_count": 89,
457
  "metadata": {},
458
  "outputs": [],
459
  "source": [
@@ -467,7 +467,7 @@
467
  },
468
  {
469
  "cell_type": "code",
470
- "execution_count": 90,
471
  "metadata": {},
472
  "outputs": [],
473
  "source": [
@@ -485,7 +485,7 @@
485
  },
486
  {
487
  "cell_type": "code",
488
- "execution_count": 91,
489
  "metadata": {},
490
  "outputs": [],
491
  "source": [
@@ -567,7 +567,7 @@
567
  },
568
  {
569
  "cell_type": "code",
570
- "execution_count": 92,
571
  "metadata": {},
572
  "outputs": [],
573
  "source": [
@@ -600,7 +600,7 @@
600
  },
601
  {
602
  "cell_type": "code",
603
- "execution_count": 93,
604
  "metadata": {},
605
  "outputs": [],
606
  "source": [
@@ -649,7 +649,7 @@
649
  },
650
  {
651
  "cell_type": "code",
652
- "execution_count": 94,
653
  "metadata": {},
654
  "outputs": [],
655
  "source": [
@@ -678,7 +678,7 @@
678
  },
679
  {
680
  "cell_type": "code",
681
- "execution_count": 95,
682
  "metadata": {},
683
  "outputs": [],
684
  "source": [
@@ -934,7 +934,7 @@
934
  },
935
  {
936
  "cell_type": "code",
937
- "execution_count": 96,
938
  "metadata": {},
939
  "outputs": [],
940
  "source": [
@@ -964,7 +964,7 @@
964
  },
965
  {
966
  "cell_type": "code",
967
- "execution_count": 100,
968
  "metadata": {},
969
  "outputs": [],
970
  "source": [
@@ -1030,7 +1030,7 @@
1030
  },
1031
  {
1032
  "cell_type": "code",
1033
- "execution_count": 101,
1034
  "metadata": {},
1035
  "outputs": [],
1036
  "source": [
@@ -1040,7 +1040,7 @@
1040
  },
1041
  {
1042
  "cell_type": "code",
1043
- "execution_count": 102,
1044
  "metadata": {},
1045
  "outputs": [],
1046
  "source": [
@@ -1049,7 +1049,7 @@
1049
  },
1050
  {
1051
  "cell_type": "code",
1052
- "execution_count": 103,
1053
  "metadata": {},
1054
  "outputs": [],
1055
  "source": [
@@ -1073,7 +1073,7 @@
1073
  },
1074
  {
1075
  "cell_type": "code",
1076
- "execution_count": 104,
1077
  "metadata": {},
1078
  "outputs": [],
1079
  "source": [
@@ -1271,7 +1271,7 @@
1271
  },
1272
  {
1273
  "cell_type": "code",
1274
- "execution_count": 105,
1275
  "metadata": {},
1276
  "outputs": [
1277
  {
@@ -1429,10 +1429,11 @@
1429
  " print('saved model at ' + self.config.output_dir + f\"model_state_{ep:02d}.pth\")\n",
1430
  " # print('saved model at ' + config.save_dir + f\"model_epoch_{ep}_test_{config.run_name}.pth\")\n",
1431
  "\n",
1432
- " def sample(self, file, params=torch.tensor((0.2,0.8)).view(1,2), ema=False, entire=False):\n",
1433
  " # n_sample = params.shape[0]\n",
 
1434
  " assert params.dim() == 2, \"params must be a 2D torch.tensor\"\n",
1435
- " # print(\"params.shape =\", params.shape)\n",
1436
  " # print(\"len(params) =\", len(params))\n",
1437
  " model = self.ema_model if ema else self.nn_model\n",
1438
  " # params = torch.tile(params, (n_sample,1)).to(device)\n",
@@ -1450,7 +1451,7 @@
1450
  },
1451
  {
1452
  "cell_type": "code",
1453
- "execution_count": 106,
1454
  "metadata": {},
1455
  "outputs": [
1456
  {
@@ -1477,14 +1478,14 @@
1477
  "output_type": "stream",
1478
  "text": [
1479
  "params loaded: (200, 2)\n",
1480
- "images rescaled to [-1.0, 1.1032476425170898]\n",
1481
- "params rescaled to [0.0, 0.9962284381407488]\n"
1482
  ]
1483
  },
1484
  {
1485
  "data": {
1486
  "application/vnd.jupyter.widget-view+json": {
1487
- "model_id": "986e997d5f8c4c24b45d221987296db6",
1488
  "version_major": 2,
1489
  "version_minor": 0
1490
  },
@@ -1498,7 +1499,7 @@
1498
  {
1499
  "data": {
1500
  "application/vnd.jupyter.widget-view+json": {
1501
- "model_id": "2aab12179dd24a8788d8ce3dfcce6a34",
1502
  "version_major": 2,
1503
  "version_minor": 0
1504
  },
@@ -1512,7 +1513,7 @@
1512
  {
1513
  "data": {
1514
  "application/vnd.jupyter.widget-view+json": {
1515
- "model_id": "1560051f3a3f460984173d8564715be7",
1516
  "version_major": 2,
1517
  "version_minor": 0
1518
  },
@@ -1526,7 +1527,7 @@
1526
  {
1527
  "data": {
1528
  "application/vnd.jupyter.widget-view+json": {
1529
- "model_id": "d782dec46b654831a9bdd35482be746a",
1530
  "version_major": 2,
1531
  "version_minor": 0
1532
  },
@@ -1540,7 +1541,7 @@
1540
  {
1541
  "data": {
1542
  "application/vnd.jupyter.widget-view+json": {
1543
- "model_id": "f633f6d189934559844de4a3bacfa8d4",
1544
  "version_major": 2,
1545
  "version_minor": 0
1546
  },
@@ -1554,7 +1555,7 @@
1554
  {
1555
  "data": {
1556
  "application/vnd.jupyter.widget-view+json": {
1557
- "model_id": "e4035b6e33df421e8c844839387f574b",
1558
  "version_major": 2,
1559
  "version_minor": 0
1560
  },
@@ -1568,7 +1569,7 @@
1568
  {
1569
  "data": {
1570
  "application/vnd.jupyter.widget-view+json": {
1571
- "model_id": "069d8cc887b04a4cbb2119342a0bf44e",
1572
  "version_major": 2,
1573
  "version_minor": 0
1574
  },
@@ -1582,7 +1583,7 @@
1582
  {
1583
  "data": {
1584
  "application/vnd.jupyter.widget-view+json": {
1585
- "model_id": "1c58007af7184304a8bb2774236f9729",
1586
  "version_major": 2,
1587
  "version_minor": 0
1588
  },
@@ -1596,7 +1597,7 @@
1596
  {
1597
  "data": {
1598
  "application/vnd.jupyter.widget-view+json": {
1599
- "model_id": "b4af0c5b1460456488bea9de6e159a51",
1600
  "version_major": 2,
1601
  "version_minor": 0
1602
  },
@@ -1610,7 +1611,7 @@
1610
  {
1611
  "data": {
1612
  "application/vnd.jupyter.widget-view+json": {
1613
- "model_id": "0483ca6f03fe4be094639c26fc660925",
1614
  "version_major": 2,
1615
  "version_minor": 0
1616
  },
@@ -1628,20 +1629,21 @@
1628
  },
1629
  {
1630
  "cell_type": "code",
1631
- "execution_count": null,
1632
  "metadata": {},
1633
  "outputs": [
1634
  {
1635
  "name": "stdout",
1636
  "output_type": "stream",
1637
  "text": [
1638
- "x_i.shape = torch.Size([1, 512])\n"
 
1639
  ]
1640
  },
1641
  {
1642
  "data": {
1643
  "application/vnd.jupyter.widget-view+json": {
1644
- "model_id": "cbef2a1163e74da5918a66fafc43caaa",
1645
  "version_major": 2,
1646
  "version_minor": 0
1647
  },
@@ -1656,27 +1658,31 @@
1656
  "name": "stdout",
1657
  "output_type": "stream",
1658
  "text": [
1659
- "nn_model input shape torch.Size([2, 512]) torch.Size([2]) torch.Size([2, 2])\n"
1660
  ]
1661
  },
1662
  {
1663
  "ename": "RuntimeError",
1664
- "evalue": "Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [2, 512]",
1665
  "output_type": "error",
1666
  "traceback": [
1667
  "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
1668
  "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
1669
- "Cell \u001b[0;32mIn[99], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m ddpm21cm\u001b[39m.\u001b[39;49msample(\u001b[39m\"\u001b[39;49m\u001b[39m./outputs/model_state_09.pth\u001b[39;49m\u001b[39m\"\u001b[39;49m)\n",
1670
- "Cell \u001b[0;32mIn[98], line 153\u001b[0m, in \u001b[0;36mDDPM21CM.sample\u001b[0;34m(self, file, params, ema, entire)\u001b[0m\n\u001b[1;32m 150\u001b[0m model \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mema_model \u001b[39mif\u001b[39;00m ema \u001b[39melse\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mnn_model\n\u001b[1;32m 151\u001b[0m \u001b[39m# params = torch.tile(params, (n_sample,1)).to(device)\u001b[39;00m\n\u001b[0;32m--> 153\u001b[0m x_last, x_entire \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mddpm\u001b[39m.\u001b[39;49msample(model, params\u001b[39m=\u001b[39;49mparams\u001b[39m.\u001b[39;49mto(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mconfig\u001b[39m.\u001b[39;49mdevice), device\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mconfig\u001b[39m.\u001b[39;49mdevice, guide_w\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mconfig\u001b[39m.\u001b[39;49mguide_w)\n\u001b[1;32m 155\u001b[0m np\u001b[39m.\u001b[39msave(os\u001b[39m.\u001b[39mpath\u001b[39m.\u001b[39mjoin(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconfig\u001b[39m.\u001b[39moutput_dir, \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m{\u001b[39;00m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconfig\u001b[39m.\u001b[39mrun_name\u001b[39m}\u001b[39;00m\u001b[39m{\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39mema\u001b[39m\u001b[39m'\u001b[39m\u001b[39m \u001b[39m\u001b[39mif\u001b[39;00m\u001b[39m \u001b[39mema\u001b[39m \u001b[39m\u001b[39melse\u001b[39;00m\u001b[39m \u001b[39m\u001b[39mNone\u001b[39;00m\u001b[39m}\u001b[39;00m\u001b[39m.npy\u001b[39m\u001b[39m\"\u001b[39m), x_last)\n\u001b[1;32m 156\u001b[0m \u001b[39mif\u001b[39;00m entire:\n",
1671
- "Cell \u001b[0;32mIn[84], line 75\u001b[0m, in \u001b[0;36mDDPMScheduler.sample\u001b[0;34m(self, nn_model, params, device, guide_w)\u001b[0m\n\u001b[1;32m 73\u001b[0m \u001b[39m# split predictions and compute weighting\u001b[39;00m\n\u001b[1;32m 74\u001b[0m \u001b[39mprint\u001b[39m(\u001b[39m\"\u001b[39m\u001b[39mnn_model input shape\u001b[39m\u001b[39m\"\u001b[39m, x_i\u001b[39m.\u001b[39mshape, t_is\u001b[39m.\u001b[39mshape, c_i\u001b[39m.\u001b[39mshape)\n\u001b[0;32m---> 75\u001b[0m eps \u001b[39m=\u001b[39m nn_model(x_i, t_is, c_i)\n\u001b[1;32m 76\u001b[0m eps1 \u001b[39m=\u001b[39m eps[:n_sample]\n\u001b[1;32m 77\u001b[0m eps2 \u001b[39m=\u001b[39m eps[n_sample:]\n",
 
 
 
 
1672
  "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
1673
- "Cell \u001b[0;32mIn[95], line 230\u001b[0m, in \u001b[0;36mContextUnet.forward\u001b[0;34m(self, x, timesteps, y)\u001b[0m\n\u001b[1;32m 228\u001b[0m \u001b[39m# print(\"0,h.shape =\", h.shape)\u001b[39;00m\n\u001b[1;32m 229\u001b[0m \u001b[39mfor\u001b[39;00m module \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39minput_blocks:\n\u001b[0;32m--> 230\u001b[0m h \u001b[39m=\u001b[39m module(h, emb)\n\u001b[1;32m 231\u001b[0m hs\u001b[39m.\u001b[39mappend(h)\n\u001b[1;32m 232\u001b[0m \u001b[39m# print(\"module encoder, h.shape =\", h.shape)\u001b[39;00m\n\u001b[1;32m 233\u001b[0m \u001b[39m# print(\"2,h.shape =\", h.shape)\u001b[39;00m\n",
1674
  "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
1675
- "Cell \u001b[0;32mIn[90], line 9\u001b[0m, in \u001b[0;36mTimestepEmbedSequential.forward\u001b[0;34m(self, x, emb, encoder_out)\u001b[0m\n\u001b[1;32m 7\u001b[0m x \u001b[39m=\u001b[39m layer(x, encoder_out)\n\u001b[1;32m 8\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m----> 9\u001b[0m x \u001b[39m=\u001b[39m layer(x)\n\u001b[1;32m 10\u001b[0m \u001b[39mreturn\u001b[39;00m x\n",
1676
  "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
1677
  "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/conv.py:457\u001b[0m, in \u001b[0;36mConv2d.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 456\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mforward\u001b[39m(\u001b[39mself\u001b[39m, \u001b[39minput\u001b[39m: Tensor) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Tensor:\n\u001b[0;32m--> 457\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_conv_forward(\u001b[39minput\u001b[39;49m, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mweight, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mbias)\n",
1678
  "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/conv.py:453\u001b[0m, in \u001b[0;36mConv2d._conv_forward\u001b[0;34m(self, input, weight, bias)\u001b[0m\n\u001b[1;32m 449\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mpadding_mode \u001b[39m!=\u001b[39m \u001b[39m'\u001b[39m\u001b[39mzeros\u001b[39m\u001b[39m'\u001b[39m:\n\u001b[1;32m 450\u001b[0m \u001b[39mreturn\u001b[39;00m F\u001b[39m.\u001b[39mconv2d(F\u001b[39m.\u001b[39mpad(\u001b[39minput\u001b[39m, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_reversed_padding_repeated_twice, mode\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mpadding_mode),\n\u001b[1;32m 451\u001b[0m weight, bias, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mstride,\n\u001b[1;32m 452\u001b[0m _pair(\u001b[39m0\u001b[39m), \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdilation, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mgroups)\n\u001b[0;32m--> 453\u001b[0m \u001b[39mreturn\u001b[39;00m F\u001b[39m.\u001b[39;49mconv2d(\u001b[39minput\u001b[39;49m, weight, bias, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mstride,\n\u001b[1;32m 454\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mpadding, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mdilation, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mgroups)\n",
1679
- "\u001b[0;31mRuntimeError\u001b[0m: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [2, 512]"
1680
  ]
1681
  }
1682
  ],
 
32
  {
33
  "data": {
34
  "application/vnd.jupyter.widget-view+json": {
35
+ "model_id": "b9a1289a64b14b6e9aebf8148c38c9a8",
36
  "version_major": 2,
37
  "version_minor": 0
38
  },
 
234
  },
235
  {
236
  "cell_type": "code",
237
+ "execution_count": 7,
238
  "metadata": {},
239
  "outputs": [],
240
  "source": [
 
278
  " def sample(self, nn_model, params, device, guide_w = 0):\n",
279
  " n_sample = len(params) #params.shape[0]\n",
280
  " # print(\"params.shape[0], len(params)\", params.shape[0], len(params))\n",
281
+ " x_i = torch.randn(n_sample, *self.img_shape).to(device)\n",
282
  " print(\"x_i.shape =\", x_i.shape)\n",
283
  " # print(\"x_i.shape =\", x_i.shape)\n",
284
  " if guide_w != -1:\n",
 
299
  " t_is = torch.tensor([i]).to(device)\n",
300
  " t_is = t_is.repeat(n_sample)\n",
301
  "\n",
302
+ " z = torch.randn(n_sample, *self.img_shape).to(device) if i > 0 else 0\n",
303
  "\n",
304
  " if guide_w == -1:\n",
305
  " # eps = nn_model(x_i, t_is, return_dict=False)[0]\n",
 
307
  " # x_i = 1/torch.sqrt(self.alpha_t[i])*(x_i-eps*self.beta_t[i]/torch.sqrt(1-self.bar_alpha_t[i])) + torch.sqrt(self.beta_t[i])*z\n",
308
  " else:\n",
309
  " # double batch\n",
310
+ " x_i = x_i.repeat(2, *torch.ones(len(self.img_shape), dtype=int).tolist())\n",
311
  " t_is = t_is.repeat(2)\n",
312
  "\n",
313
  " # split predictions and compute weighting\n",
 
341
  },
342
  {
343
  "cell_type": "code",
344
+ "execution_count": 8,
345
  "metadata": {},
346
  "outputs": [],
347
  "source": [
 
383
  },
384
  {
385
  "cell_type": "code",
386
+ "execution_count": 9,
387
  "metadata": {},
388
  "outputs": [],
389
  "source": [
 
408
  },
409
  {
410
  "cell_type": "code",
411
+ "execution_count": 10,
412
  "metadata": {},
413
  "outputs": [],
414
  "source": [
 
438
  },
439
  {
440
  "cell_type": "code",
441
+ "execution_count": 11,
442
  "metadata": {},
443
  "outputs": [],
444
  "source": [
 
453
  },
454
  {
455
  "cell_type": "code",
456
+ "execution_count": 12,
457
  "metadata": {},
458
  "outputs": [],
459
  "source": [
 
467
  },
468
  {
469
  "cell_type": "code",
470
+ "execution_count": 13,
471
  "metadata": {},
472
  "outputs": [],
473
  "source": [
 
485
  },
486
  {
487
  "cell_type": "code",
488
+ "execution_count": 14,
489
  "metadata": {},
490
  "outputs": [],
491
  "source": [
 
567
  },
568
  {
569
  "cell_type": "code",
570
+ "execution_count": 15,
571
  "metadata": {},
572
  "outputs": [],
573
  "source": [
 
600
  },
601
  {
602
  "cell_type": "code",
603
+ "execution_count": 16,
604
  "metadata": {},
605
  "outputs": [],
606
  "source": [
 
649
  },
650
  {
651
  "cell_type": "code",
652
+ "execution_count": 17,
653
  "metadata": {},
654
  "outputs": [],
655
  "source": [
 
678
  },
679
  {
680
  "cell_type": "code",
681
+ "execution_count": 18,
682
  "metadata": {},
683
  "outputs": [],
684
  "source": [
 
934
  },
935
  {
936
  "cell_type": "code",
937
+ "execution_count": 19,
938
  "metadata": {},
939
  "outputs": [],
940
  "source": [
 
964
  },
965
  {
966
  "cell_type": "code",
967
+ "execution_count": 20,
968
  "metadata": {},
969
  "outputs": [],
970
  "source": [
 
1030
  },
1031
  {
1032
  "cell_type": "code",
1033
+ "execution_count": 21,
1034
  "metadata": {},
1035
  "outputs": [],
1036
  "source": [
 
1040
  },
1041
  {
1042
  "cell_type": "code",
1043
+ "execution_count": 22,
1044
  "metadata": {},
1045
  "outputs": [],
1046
  "source": [
 
1049
  },
1050
  {
1051
  "cell_type": "code",
1052
+ "execution_count": 23,
1053
  "metadata": {},
1054
  "outputs": [],
1055
  "source": [
 
1073
  },
1074
  {
1075
  "cell_type": "code",
1076
+ "execution_count": 24,
1077
  "metadata": {},
1078
  "outputs": [],
1079
  "source": [
 
1271
  },
1272
  {
1273
  "cell_type": "code",
1274
+ "execution_count": 25,
1275
  "metadata": {},
1276
  "outputs": [
1277
  {
 
1429
  " print('saved model at ' + self.config.output_dir + f\"model_state_{ep:02d}.pth\")\n",
1430
  " # print('saved model at ' + config.save_dir + f\"model_epoch_{ep}_test_{config.run_name}.pth\")\n",
1431
  "\n",
1432
+ " def sample(self, file, params=None, ema=False, entire=False):\n",
1433
  " # n_sample = params.shape[0]\n",
1434
+ " params = params or torch.tensor([0.2,0.8]).repeat(10,1)\n",
1435
  " assert params.dim() == 2, \"params must be a 2D torch.tensor\"\n",
1436
+ " print(\"params.shape =\", params.shape)\n",
1437
  " # print(\"len(params) =\", len(params))\n",
1438
  " model = self.ema_model if ema else self.nn_model\n",
1439
  " # params = torch.tile(params, (n_sample,1)).to(device)\n",
 
1451
  },
1452
  {
1453
  "cell_type": "code",
1454
+ "execution_count": 26,
1455
  "metadata": {},
1456
  "outputs": [
1457
  {
 
1478
  "output_type": "stream",
1479
  "text": [
1480
  "params loaded: (200, 2)\n",
1481
+ "images rescaled to [-1.0, 1.0288152694702148]\n",
1482
+ "params rescaled to [0.0, 0.9958380936251952]\n"
1483
  ]
1484
  },
1485
  {
1486
  "data": {
1487
  "application/vnd.jupyter.widget-view+json": {
1488
+ "model_id": "fa014381477545e4a0fb6976d0cd3e0e",
1489
  "version_major": 2,
1490
  "version_minor": 0
1491
  },
 
1499
  {
1500
  "data": {
1501
  "application/vnd.jupyter.widget-view+json": {
1502
+ "model_id": "8bdd69fdc66c422c99b59032ef36b1f4",
1503
  "version_major": 2,
1504
  "version_minor": 0
1505
  },
 
1513
  {
1514
  "data": {
1515
  "application/vnd.jupyter.widget-view+json": {
1516
+ "model_id": "3cea3f68bad54e9bb4a63a897bea12dd",
1517
  "version_major": 2,
1518
  "version_minor": 0
1519
  },
 
1527
  {
1528
  "data": {
1529
  "application/vnd.jupyter.widget-view+json": {
1530
+ "model_id": "26f29f7e73fd493fa62d52e6441589f1",
1531
  "version_major": 2,
1532
  "version_minor": 0
1533
  },
 
1541
  {
1542
  "data": {
1543
  "application/vnd.jupyter.widget-view+json": {
1544
+ "model_id": "439448549c8840c0a970861a66e83264",
1545
  "version_major": 2,
1546
  "version_minor": 0
1547
  },
 
1555
  {
1556
  "data": {
1557
  "application/vnd.jupyter.widget-view+json": {
1558
+ "model_id": "acb1905d443a451c937cc0c114e558c9",
1559
  "version_major": 2,
1560
  "version_minor": 0
1561
  },
 
1569
  {
1570
  "data": {
1571
  "application/vnd.jupyter.widget-view+json": {
1572
+ "model_id": "4f81671b7e1b43218b5aa6620e59503b",
1573
  "version_major": 2,
1574
  "version_minor": 0
1575
  },
 
1583
  {
1584
  "data": {
1585
  "application/vnd.jupyter.widget-view+json": {
1586
+ "model_id": "66673bbcaffc412796022e1ed372b6cf",
1587
  "version_major": 2,
1588
  "version_minor": 0
1589
  },
 
1597
  {
1598
  "data": {
1599
  "application/vnd.jupyter.widget-view+json": {
1600
+ "model_id": "2e04dfb1e9da4895a6f83516684bec03",
1601
  "version_major": 2,
1602
  "version_minor": 0
1603
  },
 
1611
  {
1612
  "data": {
1613
  "application/vnd.jupyter.widget-view+json": {
1614
+ "model_id": "1dc4006116ca46e0a2b74a74858fa57a",
1615
  "version_major": 2,
1616
  "version_minor": 0
1617
  },
 
1629
  },
1630
  {
1631
  "cell_type": "code",
1632
+ "execution_count": 128,
1633
  "metadata": {},
1634
  "outputs": [
1635
  {
1636
  "name": "stdout",
1637
  "output_type": "stream",
1638
  "text": [
1639
+ "params.shape = torch.Size([10, 2])\n",
1640
+ "x_i.shape = torch.Size([10, 1, 64, 512])\n"
1641
  ]
1642
  },
1643
  {
1644
  "data": {
1645
  "application/vnd.jupyter.widget-view+json": {
1646
+ "model_id": "44830b578bb94dc5bf65eafaff19b1f5",
1647
  "version_major": 2,
1648
  "version_minor": 0
1649
  },
 
1658
  "name": "stdout",
1659
  "output_type": "stream",
1660
  "text": [
1661
+ "nn_model input shape torch.Size([20, 1, 64, 512]) torch.Size([20]) torch.Size([20, 2])\n"
1662
  ]
1663
  },
1664
  {
1665
  "ename": "RuntimeError",
1666
+ "evalue": "CUDA out of memory. Tried to allocate 320.00 MiB (GPU 0; 23.64 GiB total capacity; 21.71 GiB already allocated; 170.50 MiB free; 22.45 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF",
1667
  "output_type": "error",
1668
  "traceback": [
1669
  "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
1670
  "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
1671
+ "Cell \u001b[0;32mIn[128], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m ddpm21cm\u001b[39m.\u001b[39;49msample(\u001b[39m\"\u001b[39;49m\u001b[39m./outputs/model_state_09.pth\u001b[39;49m\u001b[39m\"\u001b[39;49m)\n",
1672
+ "Cell \u001b[0;32mIn[127], line 154\u001b[0m, in \u001b[0;36mDDPM21CM.sample\u001b[0;34m(self, file, params, ema, entire)\u001b[0m\n\u001b[1;32m 151\u001b[0m model \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mema_model \u001b[39mif\u001b[39;00m ema \u001b[39melse\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mnn_model\n\u001b[1;32m 152\u001b[0m \u001b[39m# params = torch.tile(params, (n_sample,1)).to(device)\u001b[39;00m\n\u001b[0;32m--> 154\u001b[0m x_last, x_entire \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mddpm\u001b[39m.\u001b[39;49msample(model, params\u001b[39m=\u001b[39;49mparams\u001b[39m.\u001b[39;49mto(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mconfig\u001b[39m.\u001b[39;49mdevice), device\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mconfig\u001b[39m.\u001b[39;49mdevice, guide_w\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mconfig\u001b[39m.\u001b[39;49mguide_w)\n\u001b[1;32m 156\u001b[0m np\u001b[39m.\u001b[39msave(os\u001b[39m.\u001b[39mpath\u001b[39m.\u001b[39mjoin(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconfig\u001b[39m.\u001b[39moutput_dir, \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m{\u001b[39;00m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconfig\u001b[39m.\u001b[39mrun_name\u001b[39m}\u001b[39;00m\u001b[39m{\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39mema\u001b[39m\u001b[39m'\u001b[39m\u001b[39m \u001b[39m\u001b[39mif\u001b[39;00m\u001b[39m \u001b[39mema\u001b[39m \u001b[39m\u001b[39melse\u001b[39;00m\u001b[39m \u001b[39m\u001b[39mNone\u001b[39;00m\u001b[39m}\u001b[39;00m\u001b[39m.npy\u001b[39m\u001b[39m\"\u001b[39m), x_last)\n\u001b[1;32m 157\u001b[0m \u001b[39mif\u001b[39;00m entire:\n",
1673
+ "Cell \u001b[0;32mIn[111], line 75\u001b[0m, in \u001b[0;36mDDPMScheduler.sample\u001b[0;34m(self, nn_model, params, device, guide_w)\u001b[0m\n\u001b[1;32m 73\u001b[0m \u001b[39m# split predictions and compute weighting\u001b[39;00m\n\u001b[1;32m 74\u001b[0m \u001b[39mprint\u001b[39m(\u001b[39m\"\u001b[39m\u001b[39mnn_model input shape\u001b[39m\u001b[39m\"\u001b[39m, x_i\u001b[39m.\u001b[39mshape, t_is\u001b[39m.\u001b[39mshape, c_i\u001b[39m.\u001b[39mshape)\n\u001b[0;32m---> 75\u001b[0m eps \u001b[39m=\u001b[39m nn_model(x_i, t_is, c_i)\n\u001b[1;32m 76\u001b[0m eps1 \u001b[39m=\u001b[39m eps[:n_sample]\n\u001b[1;32m 77\u001b[0m eps2 \u001b[39m=\u001b[39m eps[n_sample:]\n",
1674
+ "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
1675
+ "Cell \u001b[0;32mIn[122], line 230\u001b[0m, in \u001b[0;36mContextUnet.forward\u001b[0;34m(self, x, timesteps, y)\u001b[0m\n\u001b[1;32m 228\u001b[0m \u001b[39m# print(\"0,h.shape =\", h.shape)\u001b[39;00m\n\u001b[1;32m 229\u001b[0m \u001b[39mfor\u001b[39;00m module \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39minput_blocks:\n\u001b[0;32m--> 230\u001b[0m h \u001b[39m=\u001b[39m module(h, emb)\n\u001b[1;32m 231\u001b[0m hs\u001b[39m.\u001b[39mappend(h)\n\u001b[1;32m 232\u001b[0m \u001b[39m# print(\"module encoder, h.shape =\", h.shape)\u001b[39;00m\n\u001b[1;32m 233\u001b[0m \u001b[39m# print(\"2,h.shape =\", h.shape)\u001b[39;00m\n",
1676
+ "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
1677
+ "Cell \u001b[0;32mIn[117], line 5\u001b[0m, in \u001b[0;36mTimestepEmbedSequential.forward\u001b[0;34m(self, x, emb, encoder_out)\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[39mfor\u001b[39;00m layer \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m:\n\u001b[1;32m 4\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(layer, TimestepBlock):\n\u001b[0;32m----> 5\u001b[0m x \u001b[39m=\u001b[39m layer(x, emb)\n\u001b[1;32m 6\u001b[0m \u001b[39melif\u001b[39;00m \u001b[39misinstance\u001b[39m(layer, AttentionBlock):\n\u001b[1;32m 7\u001b[0m x \u001b[39m=\u001b[39m layer(x, encoder_out)\n",
1678
  "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
1679
+ "Cell \u001b[0;32mIn[118], line 72\u001b[0m, in \u001b[0;36mResBlock.forward\u001b[0;34m(self, x, emb)\u001b[0m\n\u001b[1;32m 70\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 71\u001b[0m h \u001b[39m+\u001b[39m\u001b[39m=\u001b[39m emb_out\n\u001b[0;32m---> 72\u001b[0m h \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mout_layers(h)\n\u001b[1;32m 73\u001b[0m \u001b[39m# print(\"ResBlock, torch.unique(h).shape =\", torch.unique(h).shape)\u001b[39;00m\n\u001b[1;32m 74\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mskip_connection(x) \u001b[39m+\u001b[39m h\n",
1680
  "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
1681
+ "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/container.py:139\u001b[0m, in \u001b[0;36mSequential.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 137\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mforward\u001b[39m(\u001b[39mself\u001b[39m, \u001b[39minput\u001b[39m):\n\u001b[1;32m 138\u001b[0m \u001b[39mfor\u001b[39;00m module \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m:\n\u001b[0;32m--> 139\u001b[0m \u001b[39minput\u001b[39m \u001b[39m=\u001b[39m module(\u001b[39minput\u001b[39;49m)\n\u001b[1;32m 140\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39minput\u001b[39m\n",
1682
  "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
1683
  "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/conv.py:457\u001b[0m, in \u001b[0;36mConv2d.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 456\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mforward\u001b[39m(\u001b[39mself\u001b[39m, \u001b[39minput\u001b[39m: Tensor) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Tensor:\n\u001b[0;32m--> 457\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_conv_forward(\u001b[39minput\u001b[39;49m, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mweight, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mbias)\n",
1684
  "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/conv.py:453\u001b[0m, in \u001b[0;36mConv2d._conv_forward\u001b[0;34m(self, input, weight, bias)\u001b[0m\n\u001b[1;32m 449\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mpadding_mode \u001b[39m!=\u001b[39m \u001b[39m'\u001b[39m\u001b[39mzeros\u001b[39m\u001b[39m'\u001b[39m:\n\u001b[1;32m 450\u001b[0m \u001b[39mreturn\u001b[39;00m F\u001b[39m.\u001b[39mconv2d(F\u001b[39m.\u001b[39mpad(\u001b[39minput\u001b[39m, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_reversed_padding_repeated_twice, mode\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mpadding_mode),\n\u001b[1;32m 451\u001b[0m weight, bias, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mstride,\n\u001b[1;32m 452\u001b[0m _pair(\u001b[39m0\u001b[39m), \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdilation, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mgroups)\n\u001b[0;32m--> 453\u001b[0m \u001b[39mreturn\u001b[39;00m F\u001b[39m.\u001b[39;49mconv2d(\u001b[39minput\u001b[39;49m, weight, bias, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mstride,\n\u001b[1;32m 454\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mpadding, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mdilation, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mgroups)\n",
1685
+ "\u001b[0;31mRuntimeError\u001b[0m: CUDA out of memory. Tried to allocate 320.00 MiB (GPU 0; 23.64 GiB total capacity; 21.71 GiB already allocated; 170.50 MiB free; 22.45 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF"
1686
  ]
1687
  }
1688
  ],