Xsmos commited on
Commit
7b07cfd
·
verified ·
1 Parent(s): 1977ea7
Files changed (1) hide show
  1. diffusion.ipynb +41 -38
diffusion.ipynb CHANGED
@@ -32,7 +32,7 @@
32
  {
33
  "data": {
34
  "application/vnd.jupyter.widget-view+json": {
35
- "model_id": "41e7ee587bf94d41bf00f9044268b0c6",
36
  "version_major": 2,
37
  "version_minor": 0
38
  },
@@ -51,7 +51,7 @@
51
  },
52
  {
53
  "cell_type": "code",
54
- "execution_count": 1,
55
  "metadata": {},
56
  "outputs": [],
57
  "source": [
@@ -964,7 +964,7 @@
964
  },
965
  {
966
  "cell_type": "code",
967
- "execution_count": 26,
968
  "metadata": {},
969
  "outputs": [],
970
  "source": [
@@ -1030,7 +1030,7 @@
1030
  },
1031
  {
1032
  "cell_type": "code",
1033
- "execution_count": 27,
1034
  "metadata": {},
1035
  "outputs": [],
1036
  "source": [
@@ -1040,7 +1040,7 @@
1040
  },
1041
  {
1042
  "cell_type": "code",
1043
- "execution_count": 28,
1044
  "metadata": {},
1045
  "outputs": [],
1046
  "source": [
@@ -1049,7 +1049,7 @@
1049
  },
1050
  {
1051
  "cell_type": "code",
1052
- "execution_count": 29,
1053
  "metadata": {},
1054
  "outputs": [],
1055
  "source": [
@@ -1073,7 +1073,7 @@
1073
  },
1074
  {
1075
  "cell_type": "code",
1076
- "execution_count": 30,
1077
  "metadata": {},
1078
  "outputs": [],
1079
  "source": [
@@ -1271,7 +1271,7 @@
1271
  },
1272
  {
1273
  "cell_type": "code",
1274
- "execution_count": 31,
1275
  "metadata": {},
1276
  "outputs": [
1277
  {
@@ -1436,9 +1436,9 @@
1436
  "\n",
1437
  " def sample(self, file, params:torch.tensor=None, ema=False, entire=False):\n",
1438
  " # n_sample = params.shape[0]\n",
1439
- " params = params or torch.tensor([0.2,0.8]).repeat(2,1)\n",
1440
  " assert params.dim() == 2, \"params must be a 2D torch.tensor\"\n",
1441
- " # print(\"params.shape =\", params.shape)\n",
1442
  " # print(\"params =\", params)\n",
1443
  " # print(\"len(params) =\", len(params))\n",
1444
  " # model = self.ema_model if ema else self.nn_model\n",
@@ -1479,7 +1479,7 @@
1479
  },
1480
  {
1481
  "cell_type": "code",
1482
- "execution_count": 32,
1483
  "metadata": {},
1484
  "outputs": [
1485
  {
@@ -1506,14 +1506,14 @@
1506
  "output_type": "stream",
1507
  "text": [
1508
  "params loaded: (200, 2)\n",
1509
- "images rescaled to [-1.0, 1.0804238319396973]\n",
1510
- "params rescaled to [0.0, 0.9988157888295665]\n"
1511
  ]
1512
  },
1513
  {
1514
  "data": {
1515
  "application/vnd.jupyter.widget-view+json": {
1516
- "model_id": "072ec7e4f98348f9adb34ac73477c5c4",
1517
  "version_major": 2,
1518
  "version_minor": 0
1519
  },
@@ -1527,7 +1527,7 @@
1527
  {
1528
  "data": {
1529
  "application/vnd.jupyter.widget-view+json": {
1530
- "model_id": "3e06b390509e42ca8a8f760f2a8a281b",
1531
  "version_major": 2,
1532
  "version_minor": 0
1533
  },
@@ -1541,7 +1541,7 @@
1541
  {
1542
  "data": {
1543
  "application/vnd.jupyter.widget-view+json": {
1544
- "model_id": "c00939dd2faa466791f9a2eb1198be6d",
1545
  "version_major": 2,
1546
  "version_minor": 0
1547
  },
@@ -1555,7 +1555,7 @@
1555
  {
1556
  "data": {
1557
  "application/vnd.jupyter.widget-view+json": {
1558
- "model_id": "0882246dccfc499bbedb686d57359668",
1559
  "version_major": 2,
1560
  "version_minor": 0
1561
  },
@@ -1569,7 +1569,7 @@
1569
  {
1570
  "data": {
1571
  "application/vnd.jupyter.widget-view+json": {
1572
- "model_id": "a7c47614106340c6b2f8b5900bbfd0a3",
1573
  "version_major": 2,
1574
  "version_minor": 0
1575
  },
@@ -1583,7 +1583,7 @@
1583
  {
1584
  "data": {
1585
  "application/vnd.jupyter.widget-view+json": {
1586
- "model_id": "6405b0e6b1fc4dd0ac3c04b4221de066",
1587
  "version_major": 2,
1588
  "version_minor": 0
1589
  },
@@ -1597,7 +1597,7 @@
1597
  {
1598
  "data": {
1599
  "application/vnd.jupyter.widget-view+json": {
1600
- "model_id": "760f1557e66041388dbb2473a1989a39",
1601
  "version_major": 2,
1602
  "version_minor": 0
1603
  },
@@ -1611,7 +1611,7 @@
1611
  {
1612
  "data": {
1613
  "application/vnd.jupyter.widget-view+json": {
1614
- "model_id": "eee8df37c9c6448fab09a410c05eb0f9",
1615
  "version_major": 2,
1616
  "version_minor": 0
1617
  },
@@ -1625,7 +1625,7 @@
1625
  {
1626
  "data": {
1627
  "application/vnd.jupyter.widget-view+json": {
1628
- "model_id": "36fb41e7a8f046379c8df59eaacbf1db",
1629
  "version_major": 2,
1630
  "version_minor": 0
1631
  },
@@ -1639,7 +1639,7 @@
1639
  {
1640
  "data": {
1641
  "application/vnd.jupyter.widget-view+json": {
1642
- "model_id": "3ebb946b78414ee5aa04cd3221ea8e9e",
1643
  "version_major": 2,
1644
  "version_minor": 0
1645
  },
@@ -1670,7 +1670,7 @@
1670
  {
1671
  "data": {
1672
  "application/vnd.jupyter.widget-view+json": {
1673
- "model_id": "892dab290b9d4dc0b4d60cedcb2027ec",
1674
  "version_major": 2,
1675
  "version_minor": 0
1676
  },
@@ -1683,24 +1683,27 @@
1683
  },
1684
  {
1685
  "ename": "RuntimeError",
1686
- "evalue": "CUDA out of memory. Tried to allocate 256.00 MiB (GPU 0; 23.64 GiB total capacity; 22.34 GiB already allocated; 136.50 MiB free; 22.48 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF",
1687
  "output_type": "error",
1688
  "traceback": [
1689
  "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
1690
  "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
1691
- "Cell \u001b[0;32mIn[31], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m ddpm21cm\u001b[39m.\u001b[39;49msample(\u001b[39m\"\u001b[39;49m\u001b[39m./outputs/model_state_09.pth\u001b[39;49m\u001b[39m\"\u001b[39;49m)\n",
1692
- "Cell \u001b[0;32mIn[29], line 176\u001b[0m, in \u001b[0;36mDDPM21CM.sample\u001b[0;34m(self, file, params, ema, entire)\u001b[0m\n\u001b[1;32m 170\u001b[0m nn_model\u001b[39m.\u001b[39meval()\n\u001b[1;32m 172\u001b[0m \u001b[39m# self.ema_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM, dim=config.dim, stride=config.stride).to(config.device)\u001b[39;00m\n\u001b[1;32m 173\u001b[0m \u001b[39m# self.ema_model.load_state_dict(torch.load(os.path.join(config.output_dir, f\"{config.resume}\"))['ema_unet_state_dict'])\u001b[39;00m\n\u001b[1;32m 174\u001b[0m \u001b[39m# print(f\"resumed ema_model from {config.resume}\")\u001b[39;00m\n\u001b[0;32m--> 176\u001b[0m x_last, x_entire \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mddpm\u001b[39m.\u001b[39;49msample(\n\u001b[1;32m 177\u001b[0m nn_model\u001b[39m=\u001b[39;49mnn_model, \n\u001b[1;32m 178\u001b[0m params\u001b[39m=\u001b[39;49mparams\u001b[39m.\u001b[39;49mto(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mconfig\u001b[39m.\u001b[39;49mdevice), \n\u001b[1;32m 179\u001b[0m device\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mconfig\u001b[39m.\u001b[39;49mdevice, \n\u001b[1;32m 180\u001b[0m guide_w\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mconfig\u001b[39m.\u001b[39;49mguide_w\n\u001b[1;32m 181\u001b[0m )\n\u001b[1;32m 183\u001b[0m np\u001b[39m.\u001b[39msave(os\u001b[39m.\u001b[39mpath\u001b[39m.\u001b[39mjoin(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconfig\u001b[39m.\u001b[39moutput_dir, \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m{\u001b[39;00m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconfig\u001b[39m.\u001b[39mrun_name\u001b[39m}\u001b[39;00m\u001b[39m{\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39mema\u001b[39m\u001b[39m'\u001b[39m\u001b[39m \u001b[39m\u001b[39mif\u001b[39;00m\u001b[39m \u001b[39mema\u001b[39m \u001b[39m\u001b[39melse\u001b[39;00m\u001b[39m \u001b[39m\u001b[39mNone\u001b[39;00m\u001b[39m}\u001b[39;00m\u001b[39m.npy\u001b[39m\u001b[39m\"\u001b[39m), x_last)\n\u001b[1;32m 185\u001b[0m \u001b[39mif\u001b[39;00m entire:\n",
1693
- "Cell \u001b[0;32mIn[7], line 75\u001b[0m, in \u001b[0;36mDDPMScheduler.sample\u001b[0;34m(self, nn_model, params, device, guide_w)\u001b[0m\n\u001b[1;32m 71\u001b[0m t_is \u001b[39m=\u001b[39m t_is\u001b[39m.\u001b[39mrepeat(\u001b[39m2\u001b[39m)\n\u001b[1;32m 73\u001b[0m \u001b[39m# split predictions and compute weighting\u001b[39;00m\n\u001b[1;32m 74\u001b[0m \u001b[39m# print(\"nn_model input shape\", x_i.shape, t_is.shape, c_i.shape)\u001b[39;00m\n\u001b[0;32m---> 75\u001b[0m eps \u001b[39m=\u001b[39m nn_model(x_i, t_is, c_i)\n\u001b[1;32m 76\u001b[0m eps1 \u001b[39m=\u001b[39m eps[:n_sample]\n\u001b[1;32m 77\u001b[0m eps2 \u001b[39m=\u001b[39m eps[n_sample:]\n",
1694
  "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
1695
- "Cell \u001b[0;32mIn[18], line 241\u001b[0m, in \u001b[0;36mContextUnet.forward\u001b[0;34m(self, x, timesteps, y)\u001b[0m\n\u001b[1;32m 237\u001b[0m \u001b[39mfor\u001b[39;00m module \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39moutput_blocks:\n\u001b[1;32m 238\u001b[0m \u001b[39m# print(\"for module in self.output_blocks, h.shape =\", h.shape)\u001b[39;00m\n\u001b[1;32m 239\u001b[0m \u001b[39m# print(\"len(hs) =\", len(hs), \", hs[-1].shape =\", hs[-1].shape)\u001b[39;00m\n\u001b[1;32m 240\u001b[0m h \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mcat([h, hs\u001b[39m.\u001b[39mpop()], dim\u001b[39m=\u001b[39m\u001b[39m1\u001b[39m)\n\u001b[0;32m--> 241\u001b[0m h \u001b[39m=\u001b[39m module(h, emb)\n\u001b[1;32m 242\u001b[0m \u001b[39m# print(\"module decoder, h.shape =\", h.shape)\u001b[39;00m\n\u001b[1;32m 244\u001b[0m h \u001b[39m=\u001b[39m h\u001b[39m.\u001b[39mtype(x\u001b[39m.\u001b[39mdtype)\n",
1696
  "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
1697
- "Cell \u001b[0;32mIn[13], line 7\u001b[0m, in \u001b[0;36mTimestepEmbedSequential.forward\u001b[0;34m(self, x, emb, encoder_out)\u001b[0m\n\u001b[1;32m 5\u001b[0m x \u001b[39m=\u001b[39m layer(x, emb)\n\u001b[1;32m 6\u001b[0m \u001b[39melif\u001b[39;00m \u001b[39misinstance\u001b[39m(layer, AttentionBlock):\n\u001b[0;32m----> 7\u001b[0m x \u001b[39m=\u001b[39m layer(x, encoder_out)\n\u001b[1;32m 8\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 9\u001b[0m x \u001b[39m=\u001b[39m layer(x)\n",
1698
  "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
1699
- "Cell \u001b[0;32mIn[16], line 37\u001b[0m, in \u001b[0;36mAttentionBlock.forward\u001b[0;34m(self, x, encoder_out)\u001b[0m\n\u001b[1;32m 35\u001b[0m h \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mattention(qkv, encoder_out)\n\u001b[1;32m 36\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m---> 37\u001b[0m h \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mattention(qkv)\n\u001b[1;32m 38\u001b[0m \u001b[39m# print(\"AttentionBlock, before proj_out, torch.unique(h).shape =\", torch.unique(h).shape)\u001b[39;00m\n\u001b[1;32m 39\u001b[0m h \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mproj_out(h)\n",
1700
  "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
1701
- "Cell \u001b[0;32mIn[15], line 21\u001b[0m, in \u001b[0;36mQKVAttention.forward\u001b[0;34m(self, qkv, encoder_kv)\u001b[0m\n\u001b[1;32m 18\u001b[0m v \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mcat([ev,v], dim\u001b[39m=\u001b[39m\u001b[39m-\u001b[39m\u001b[39m1\u001b[39m)\n\u001b[1;32m 20\u001b[0m scale \u001b[39m=\u001b[39m \u001b[39m1\u001b[39m \u001b[39m/\u001b[39m math\u001b[39m.\u001b[39msqrt(math\u001b[39m.\u001b[39msqrt(ch))\n\u001b[0;32m---> 21\u001b[0m weight \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39;49meinsum(\u001b[39m\"\u001b[39;49m\u001b[39mbct,bcs->bts\u001b[39;49m\u001b[39m\"\u001b[39;49m, q\u001b[39m*\u001b[39;49mscale, k\u001b[39m*\u001b[39;49mscale)\n\u001b[1;32m 22\u001b[0m weight \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39msoftmax(weight\u001b[39m.\u001b[39mfloat(), dim\u001b[39m=\u001b[39m\u001b[39m-\u001b[39m\u001b[39m1\u001b[39m)\u001b[39m.\u001b[39mtype(weight\u001b[39m.\u001b[39mdtype)\n\u001b[1;32m 24\u001b[0m a \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39meinsum(\u001b[39m\"\u001b[39m\u001b[39mbts,bcs->bct\u001b[39m\u001b[39m\"\u001b[39m, weight, v)\n",
1702
- "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/functional.py:360\u001b[0m, in \u001b[0;36meinsum\u001b[0;34m(*args)\u001b[0m\n\u001b[1;32m 356\u001b[0m \u001b[39m# recurse incase operands contains value that has torch function\u001b[39;00m\n\u001b[1;32m 357\u001b[0m \u001b[39m# in the original implementation this line is omitted\u001b[39;00m\n\u001b[1;32m 358\u001b[0m \u001b[39mreturn\u001b[39;00m einsum(equation, \u001b[39m*\u001b[39m_operands)\n\u001b[0;32m--> 360\u001b[0m \u001b[39mreturn\u001b[39;00m _VF\u001b[39m.\u001b[39;49meinsum(equation, operands)\n",
1703
- "\u001b[0;31mRuntimeError\u001b[0m: CUDA out of memory. Tried to allocate 256.00 MiB (GPU 0; 23.64 GiB total capacity; 22.34 GiB already allocated; 136.50 MiB free; 22.48 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF"
 
 
 
1704
  ]
1705
  }
1706
  ],
@@ -1710,7 +1713,7 @@
1710
  },
1711
  {
1712
  "cell_type": "code",
1713
- "execution_count": 45,
1714
  "metadata": {},
1715
  "outputs": [],
1716
  "source": [
@@ -1719,7 +1722,7 @@
1719
  },
1720
  {
1721
  "cell_type": "code",
1722
- "execution_count": 46,
1723
  "metadata": {},
1724
  "outputs": [],
1725
  "source": [
@@ -1730,7 +1733,7 @@
1730
  },
1731
  {
1732
  "cell_type": "code",
1733
- "execution_count": 20,
1734
  "metadata": {},
1735
  "outputs": [
1736
  {
@@ -1758,7 +1761,7 @@
1758
  },
1759
  {
1760
  "cell_type": "code",
1761
- "execution_count": 28,
1762
  "metadata": {},
1763
  "outputs": [],
1764
  "source": [
 
32
  {
33
  "data": {
34
  "application/vnd.jupyter.widget-view+json": {
35
+ "model_id": "611fecae78ae4f2dadb03a78fc815f12",
36
  "version_major": 2,
37
  "version_minor": 0
38
  },
 
51
  },
52
  {
53
  "cell_type": "code",
54
+ "execution_count": 2,
55
  "metadata": {},
56
  "outputs": [],
57
  "source": [
 
964
  },
965
  {
966
  "cell_type": "code",
967
+ "execution_count": 19,
968
  "metadata": {},
969
  "outputs": [],
970
  "source": [
 
1030
  },
1031
  {
1032
  "cell_type": "code",
1033
+ "execution_count": 20,
1034
  "metadata": {},
1035
  "outputs": [],
1036
  "source": [
 
1040
  },
1041
  {
1042
  "cell_type": "code",
1043
+ "execution_count": 21,
1044
  "metadata": {},
1045
  "outputs": [],
1046
  "source": [
 
1049
  },
1050
  {
1051
  "cell_type": "code",
1052
+ "execution_count": 22,
1053
  "metadata": {},
1054
  "outputs": [],
1055
  "source": [
 
1073
  },
1074
  {
1075
  "cell_type": "code",
1076
+ "execution_count": 23,
1077
  "metadata": {},
1078
  "outputs": [],
1079
  "source": [
 
1271
  },
1272
  {
1273
  "cell_type": "code",
1274
+ "execution_count": 24,
1275
  "metadata": {},
1276
  "outputs": [
1277
  {
 
1436
  "\n",
1437
  " def sample(self, file, params:torch.tensor=None, ema=False, entire=False):\n",
1438
  " # n_sample = params.shape[0]\n",
1439
+ " params = params or torch.tensor([0.2,0.8]).repeat(5,1)\n",
1440
  " assert params.dim() == 2, \"params must be a 2D torch.tensor\"\n",
1441
+ " print(\"params.shape =\", params.shape)\n",
1442
  " # print(\"params =\", params)\n",
1443
  " # print(\"len(params) =\", len(params))\n",
1444
  " # model = self.ema_model if ema else self.nn_model\n",
 
1479
  },
1480
  {
1481
  "cell_type": "code",
1482
+ "execution_count": 25,
1483
  "metadata": {},
1484
  "outputs": [
1485
  {
 
1506
  "output_type": "stream",
1507
  "text": [
1508
  "params loaded: (200, 2)\n",
1509
+ "images rescaled to [-1.0, 1.082669734954834]\n",
1510
+ "params rescaled to [0.0, 0.9967326920094951]\n"
1511
  ]
1512
  },
1513
  {
1514
  "data": {
1515
  "application/vnd.jupyter.widget-view+json": {
1516
+ "model_id": "942983559b1a4599a837802a5661819e",
1517
  "version_major": 2,
1518
  "version_minor": 0
1519
  },
 
1527
  {
1528
  "data": {
1529
  "application/vnd.jupyter.widget-view+json": {
1530
+ "model_id": "0fb002cefad94cf1bcd9e5051f476a9a",
1531
  "version_major": 2,
1532
  "version_minor": 0
1533
  },
 
1541
  {
1542
  "data": {
1543
  "application/vnd.jupyter.widget-view+json": {
1544
+ "model_id": "299c836b34294f0f87150b8ae853f3ce",
1545
  "version_major": 2,
1546
  "version_minor": 0
1547
  },
 
1555
  {
1556
  "data": {
1557
  "application/vnd.jupyter.widget-view+json": {
1558
+ "model_id": "6c837426e91448c792a13b793cea49b2",
1559
  "version_major": 2,
1560
  "version_minor": 0
1561
  },
 
1569
  {
1570
  "data": {
1571
  "application/vnd.jupyter.widget-view+json": {
1572
+ "model_id": "56d4f8cb22ee47cd96eb54bd3190c604",
1573
  "version_major": 2,
1574
  "version_minor": 0
1575
  },
 
1583
  {
1584
  "data": {
1585
  "application/vnd.jupyter.widget-view+json": {
1586
+ "model_id": "7655aad351544b97846605477be63089",
1587
  "version_major": 2,
1588
  "version_minor": 0
1589
  },
 
1597
  {
1598
  "data": {
1599
  "application/vnd.jupyter.widget-view+json": {
1600
+ "model_id": "d869a0412b1c4e7eb986b28c319fe8a0",
1601
  "version_major": 2,
1602
  "version_minor": 0
1603
  },
 
1611
  {
1612
  "data": {
1613
  "application/vnd.jupyter.widget-view+json": {
1614
+ "model_id": "16148281b53f4c09a1fbcf4704ba1cf0",
1615
  "version_major": 2,
1616
  "version_minor": 0
1617
  },
 
1625
  {
1626
  "data": {
1627
  "application/vnd.jupyter.widget-view+json": {
1628
+ "model_id": "628e0253ac4045578d22ea29648021c3",
1629
  "version_major": 2,
1630
  "version_minor": 0
1631
  },
 
1639
  {
1640
  "data": {
1641
  "application/vnd.jupyter.widget-view+json": {
1642
+ "model_id": "0c3a68177127449ab185502b7b02245f",
1643
  "version_major": 2,
1644
  "version_minor": 0
1645
  },
 
1670
  {
1671
  "data": {
1672
  "application/vnd.jupyter.widget-view+json": {
1673
+ "model_id": "f0b11cda17be4a77b3d926968350f23f",
1674
  "version_major": 2,
1675
  "version_minor": 0
1676
  },
 
1683
  },
1684
  {
1685
  "ename": "RuntimeError",
1686
+ "evalue": "CUDA out of memory. Tried to allocate 16.00 MiB (GPU 0; 23.64 GiB total capacity; 22.57 GiB already allocated; 12.50 MiB free; 22.60 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF",
1687
  "output_type": "error",
1688
  "traceback": [
1689
  "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
1690
  "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
1691
+ "Cell \u001b[0;32mIn[33], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m ddpm21cm\u001b[39m.\u001b[39;49msample(\u001b[39m\"\u001b[39;49m\u001b[39m./outputs/model_state_09.pth\u001b[39;49m\u001b[39m\"\u001b[39;49m)\n",
1692
+ "Cell \u001b[0;32mIn[31], line 176\u001b[0m, in \u001b[0;36mDDPM21CM.sample\u001b[0;34m(self, file, params, ema, entire)\u001b[0m\n\u001b[1;32m 170\u001b[0m nn_model\u001b[39m.\u001b[39meval()\n\u001b[1;32m 172\u001b[0m \u001b[39m# self.ema_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM, dim=config.dim, stride=config.stride).to(config.device)\u001b[39;00m\n\u001b[1;32m 173\u001b[0m \u001b[39m# self.ema_model.load_state_dict(torch.load(os.path.join(config.output_dir, f\"{config.resume}\"))['ema_unet_state_dict'])\u001b[39;00m\n\u001b[1;32m 174\u001b[0m \u001b[39m# print(f\"resumed ema_model from {config.resume}\")\u001b[39;00m\n\u001b[0;32m--> 176\u001b[0m x_last, x_entire \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mddpm\u001b[39m.\u001b[39;49msample(\n\u001b[1;32m 177\u001b[0m nn_model\u001b[39m=\u001b[39;49mnn_model, \n\u001b[1;32m 178\u001b[0m params\u001b[39m=\u001b[39;49mparams\u001b[39m.\u001b[39;49mto(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mconfig\u001b[39m.\u001b[39;49mdevice), \n\u001b[1;32m 179\u001b[0m device\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mconfig\u001b[39m.\u001b[39;49mdevice, \n\u001b[1;32m 180\u001b[0m guide_w\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mconfig\u001b[39m.\u001b[39;49mguide_w\n\u001b[1;32m 181\u001b[0m )\n\u001b[1;32m 183\u001b[0m np\u001b[39m.\u001b[39msave(os\u001b[39m.\u001b[39mpath\u001b[39m.\u001b[39mjoin(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconfig\u001b[39m.\u001b[39moutput_dir, \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m{\u001b[39;00m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconfig\u001b[39m.\u001b[39mrun_name\u001b[39m}\u001b[39;00m\u001b[39m{\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39mema\u001b[39m\u001b[39m'\u001b[39m\u001b[39m \u001b[39m\u001b[39mif\u001b[39;00m\u001b[39m \u001b[39mema\u001b[39m \u001b[39m\u001b[39melse\u001b[39;00m\u001b[39m \u001b[39m\u001b[39mNone\u001b[39;00m\u001b[39m}\u001b[39;00m\u001b[39m.npy\u001b[39m\u001b[39m\"\u001b[39m), x_last)\n\u001b[1;32m 185\u001b[0m \u001b[39mif\u001b[39;00m entire:\n",
1693
+ "Cell \u001b[0;32mIn[6], line 75\u001b[0m, in \u001b[0;36mDDPMScheduler.sample\u001b[0;34m(self, nn_model, params, device, guide_w)\u001b[0m\n\u001b[1;32m 71\u001b[0m t_is \u001b[39m=\u001b[39m t_is\u001b[39m.\u001b[39mrepeat(\u001b[39m2\u001b[39m)\n\u001b[1;32m 73\u001b[0m \u001b[39m# split predictions and compute weighting\u001b[39;00m\n\u001b[1;32m 74\u001b[0m \u001b[39m# print(\"nn_model input shape\", x_i.shape, t_is.shape, c_i.shape)\u001b[39;00m\n\u001b[0;32m---> 75\u001b[0m eps \u001b[39m=\u001b[39m nn_model(x_i, t_is, c_i)\n\u001b[1;32m 76\u001b[0m eps1 \u001b[39m=\u001b[39m eps[:n_sample]\n\u001b[1;32m 77\u001b[0m eps2 \u001b[39m=\u001b[39m eps[n_sample:]\n",
1694
  "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
1695
+ "Cell \u001b[0;32mIn[17], line 241\u001b[0m, in \u001b[0;36mContextUnet.forward\u001b[0;34m(self, x, timesteps, y)\u001b[0m\n\u001b[1;32m 237\u001b[0m \u001b[39mfor\u001b[39;00m module \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39moutput_blocks:\n\u001b[1;32m 238\u001b[0m \u001b[39m# print(\"for module in self.output_blocks, h.shape =\", h.shape)\u001b[39;00m\n\u001b[1;32m 239\u001b[0m \u001b[39m# print(\"len(hs) =\", len(hs), \", hs[-1].shape =\", hs[-1].shape)\u001b[39;00m\n\u001b[1;32m 240\u001b[0m h \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mcat([h, hs\u001b[39m.\u001b[39mpop()], dim\u001b[39m=\u001b[39m\u001b[39m1\u001b[39m)\n\u001b[0;32m--> 241\u001b[0m h \u001b[39m=\u001b[39m module(h, emb)\n\u001b[1;32m 242\u001b[0m \u001b[39m# print(\"module decoder, h.shape =\", h.shape)\u001b[39;00m\n\u001b[1;32m 244\u001b[0m h \u001b[39m=\u001b[39m h\u001b[39m.\u001b[39mtype(x\u001b[39m.\u001b[39mdtype)\n",
1696
  "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
1697
+ "Cell \u001b[0;32mIn[12], line 5\u001b[0m, in \u001b[0;36mTimestepEmbedSequential.forward\u001b[0;34m(self, x, emb, encoder_out)\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[39mfor\u001b[39;00m layer \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m:\n\u001b[1;32m 4\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(layer, TimestepBlock):\n\u001b[0;32m----> 5\u001b[0m x \u001b[39m=\u001b[39m layer(x, emb)\n\u001b[1;32m 6\u001b[0m \u001b[39melif\u001b[39;00m \u001b[39misinstance\u001b[39m(layer, AttentionBlock):\n\u001b[1;32m 7\u001b[0m x \u001b[39m=\u001b[39m layer(x, encoder_out)\n",
1698
  "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
1699
+ "Cell \u001b[0;32mIn[13], line 72\u001b[0m, in \u001b[0;36mResBlock.forward\u001b[0;34m(self, x, emb)\u001b[0m\n\u001b[1;32m 70\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 71\u001b[0m h \u001b[39m+\u001b[39m\u001b[39m=\u001b[39m emb_out\n\u001b[0;32m---> 72\u001b[0m h \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mout_layers(h)\n\u001b[1;32m 73\u001b[0m \u001b[39m# print(\"ResBlock, torch.unique(h).shape =\", torch.unique(h).shape)\u001b[39;00m\n\u001b[1;32m 74\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mskip_connection(x) \u001b[39m+\u001b[39m h\n",
1700
  "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
1701
+ "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/container.py:139\u001b[0m, in \u001b[0;36mSequential.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 137\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mforward\u001b[39m(\u001b[39mself\u001b[39m, \u001b[39minput\u001b[39m):\n\u001b[1;32m 138\u001b[0m \u001b[39mfor\u001b[39;00m module \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m:\n\u001b[0;32m--> 139\u001b[0m \u001b[39minput\u001b[39m \u001b[39m=\u001b[39m module(\u001b[39minput\u001b[39;49m)\n\u001b[1;32m 140\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39minput\u001b[39m\n",
1702
+ "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
1703
+ "Cell \u001b[0;32mIn[7], line 7\u001b[0m, in \u001b[0;36mGroupNorm32.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mforward\u001b[39m(\u001b[39mself\u001b[39m, x):\n\u001b[0;32m----> 7\u001b[0m y \u001b[39m=\u001b[39m \u001b[39msuper\u001b[39;49m()\u001b[39m.\u001b[39;49mforward(x\u001b[39m.\u001b[39;49mfloat())\u001b[39m.\u001b[39mto(x\u001b[39m.\u001b[39mdtype)\n\u001b[1;32m 8\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mswish \u001b[39m==\u001b[39m \u001b[39m1.0\u001b[39m:\n\u001b[1;32m 9\u001b[0m y \u001b[39m=\u001b[39m F\u001b[39m.\u001b[39msilu(y)\n",
1704
+ "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/normalization.py:272\u001b[0m, in \u001b[0;36mGroupNorm.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 271\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mforward\u001b[39m(\u001b[39mself\u001b[39m, \u001b[39minput\u001b[39m: Tensor) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Tensor:\n\u001b[0;32m--> 272\u001b[0m \u001b[39mreturn\u001b[39;00m F\u001b[39m.\u001b[39;49mgroup_norm(\n\u001b[1;32m 273\u001b[0m \u001b[39minput\u001b[39;49m, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mnum_groups, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mweight, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mbias, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49meps)\n",
1705
+ "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/functional.py:2516\u001b[0m, in \u001b[0;36mgroup_norm\u001b[0;34m(input, num_groups, weight, bias, eps)\u001b[0m\n\u001b[1;32m 2514\u001b[0m \u001b[39mreturn\u001b[39;00m handle_torch_function(group_norm, (\u001b[39minput\u001b[39m, weight, bias,), \u001b[39minput\u001b[39m, num_groups, weight\u001b[39m=\u001b[39mweight, bias\u001b[39m=\u001b[39mbias, eps\u001b[39m=\u001b[39meps)\n\u001b[1;32m 2515\u001b[0m _verify_batch_size([\u001b[39minput\u001b[39m\u001b[39m.\u001b[39msize(\u001b[39m0\u001b[39m) \u001b[39m*\u001b[39m \u001b[39minput\u001b[39m\u001b[39m.\u001b[39msize(\u001b[39m1\u001b[39m) \u001b[39m/\u001b[39m\u001b[39m/\u001b[39m num_groups, num_groups] \u001b[39m+\u001b[39m \u001b[39mlist\u001b[39m(\u001b[39minput\u001b[39m\u001b[39m.\u001b[39msize()[\u001b[39m2\u001b[39m:]))\n\u001b[0;32m-> 2516\u001b[0m \u001b[39mreturn\u001b[39;00m torch\u001b[39m.\u001b[39;49mgroup_norm(\u001b[39minput\u001b[39;49m, num_groups, weight, bias, eps, torch\u001b[39m.\u001b[39;49mbackends\u001b[39m.\u001b[39;49mcudnn\u001b[39m.\u001b[39;49menabled)\n",
1706
+ "\u001b[0;31mRuntimeError\u001b[0m: CUDA out of memory. Tried to allocate 16.00 MiB (GPU 0; 23.64 GiB total capacity; 22.57 GiB already allocated; 12.50 MiB free; 22.60 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF"
1707
  ]
1708
  }
1709
  ],
 
1713
  },
1714
  {
1715
  "cell_type": "code",
1716
+ "execution_count": null,
1717
  "metadata": {},
1718
  "outputs": [],
1719
  "source": [
 
1722
  },
1723
  {
1724
  "cell_type": "code",
1725
+ "execution_count": null,
1726
  "metadata": {},
1727
  "outputs": [],
1728
  "source": [
 
1733
  },
1734
  {
1735
  "cell_type": "code",
1736
+ "execution_count": null,
1737
  "metadata": {},
1738
  "outputs": [
1739
  {
 
1761
  },
1762
  {
1763
  "cell_type": "code",
1764
+ "execution_count": null,
1765
  "metadata": {},
1766
  "outputs": [],
1767
  "source": [