Xsmos commited on
Commit
713c506
·
verified ·
1 Parent(s): 7b07cfd
Files changed (2) hide show
  1. diffusion.ipynb +149 -75
  2. load_h5.py +27 -0
diffusion.ipynb CHANGED
@@ -21,7 +21,8 @@
21
  "- 已擴展爲接受不同維度的情形\n",
22
  "- 融合cond, guide_w, drop_out這些參數\n",
23
  "- 生成的21cm圖像該暗的地方不夠暗,似乎換成MNIST的數字圖像就沒問題\n",
24
- "- 我用diffusion模型生成MNIST的數字時發現,儘管生成的數據的範圍也存在負數數值,如-0.1,但畫出來的圖像卻是理想的黑色。數據的分佈與21cm的結果的分佈沒多大差別,我現在打算把代碼退回到21cm的情形"
 
25
  ]
26
  },
27
  {
@@ -32,7 +33,7 @@
32
  {
33
  "data": {
34
  "application/vnd.jupyter.widget-view+json": {
35
- "model_id": "611fecae78ae4f2dadb03a78fc815f12",
36
  "version_major": 2,
37
  "version_minor": 0
38
  },
@@ -85,7 +86,7 @@
85
  },
86
  {
87
  "cell_type": "code",
88
- "execution_count": 2,
89
  "metadata": {},
90
  "outputs": [],
91
  "source": [
@@ -94,7 +95,7 @@
94
  },
95
  {
96
  "cell_type": "code",
97
- "execution_count": 3,
98
  "metadata": {},
99
  "outputs": [],
100
  "source": [
@@ -191,7 +192,7 @@
191
  },
192
  {
193
  "cell_type": "code",
194
- "execution_count": 4,
195
  "metadata": {},
196
  "outputs": [],
197
  "source": [
@@ -201,7 +202,7 @@
201
  },
202
  {
203
  "cell_type": "code",
204
- "execution_count": 5,
205
  "metadata": {},
206
  "outputs": [],
207
  "source": [
@@ -234,7 +235,7 @@
234
  },
235
  {
236
  "cell_type": "code",
237
- "execution_count": 6,
238
  "metadata": {},
239
  "outputs": [],
240
  "source": [
@@ -341,7 +342,7 @@
341
  },
342
  {
343
  "cell_type": "code",
344
- "execution_count": 7,
345
  "metadata": {},
346
  "outputs": [],
347
  "source": [
@@ -383,7 +384,7 @@
383
  },
384
  {
385
  "cell_type": "code",
386
- "execution_count": 8,
387
  "metadata": {},
388
  "outputs": [],
389
  "source": [
@@ -408,7 +409,7 @@
408
  },
409
  {
410
  "cell_type": "code",
411
- "execution_count": 9,
412
  "metadata": {},
413
  "outputs": [],
414
  "source": [
@@ -438,7 +439,7 @@
438
  },
439
  {
440
  "cell_type": "code",
441
- "execution_count": 10,
442
  "metadata": {},
443
  "outputs": [],
444
  "source": [
@@ -453,7 +454,7 @@
453
  },
454
  {
455
  "cell_type": "code",
456
- "execution_count": 11,
457
  "metadata": {},
458
  "outputs": [],
459
  "source": [
@@ -467,7 +468,7 @@
467
  },
468
  {
469
  "cell_type": "code",
470
- "execution_count": 12,
471
  "metadata": {},
472
  "outputs": [],
473
  "source": [
@@ -485,7 +486,7 @@
485
  },
486
  {
487
  "cell_type": "code",
488
- "execution_count": 13,
489
  "metadata": {},
490
  "outputs": [],
491
  "source": [
@@ -567,7 +568,7 @@
567
  },
568
  {
569
  "cell_type": "code",
570
- "execution_count": 14,
571
  "metadata": {},
572
  "outputs": [],
573
  "source": [
@@ -600,7 +601,7 @@
600
  },
601
  {
602
  "cell_type": "code",
603
- "execution_count": 15,
604
  "metadata": {},
605
  "outputs": [],
606
  "source": [
@@ -649,7 +650,7 @@
649
  },
650
  {
651
  "cell_type": "code",
652
- "execution_count": 16,
653
  "metadata": {},
654
  "outputs": [],
655
  "source": [
@@ -678,7 +679,7 @@
678
  },
679
  {
680
  "cell_type": "code",
681
- "execution_count": 17,
682
  "metadata": {},
683
  "outputs": [],
684
  "source": [
@@ -934,7 +935,7 @@
934
  },
935
  {
936
  "cell_type": "code",
937
- "execution_count": 18,
938
  "metadata": {},
939
  "outputs": [],
940
  "source": [
@@ -964,7 +965,7 @@
964
  },
965
  {
966
  "cell_type": "code",
967
- "execution_count": 19,
968
  "metadata": {},
969
  "outputs": [],
970
  "source": [
@@ -1030,7 +1031,7 @@
1030
  },
1031
  {
1032
  "cell_type": "code",
1033
- "execution_count": 20,
1034
  "metadata": {},
1035
  "outputs": [],
1036
  "source": [
@@ -1040,7 +1041,7 @@
1040
  },
1041
  {
1042
  "cell_type": "code",
1043
- "execution_count": 21,
1044
  "metadata": {},
1045
  "outputs": [],
1046
  "source": [
@@ -1049,7 +1050,7 @@
1049
  },
1050
  {
1051
  "cell_type": "code",
1052
- "execution_count": 22,
1053
  "metadata": {},
1054
  "outputs": [],
1055
  "source": [
@@ -1073,7 +1074,7 @@
1073
  },
1074
  {
1075
  "cell_type": "code",
1076
- "execution_count": 23,
1077
  "metadata": {},
1078
  "outputs": [],
1079
  "source": [
@@ -1271,7 +1272,7 @@
1271
  },
1272
  {
1273
  "cell_type": "code",
1274
- "execution_count": 24,
1275
  "metadata": {},
1276
  "outputs": [
1277
  {
@@ -1410,6 +1411,7 @@
1410
  " del self.nn_model\n",
1411
  " if self.config.ema:\n",
1412
  " del self.ema_model\n",
 
1413
  "\n",
1414
  " def save(self, ep):\n",
1415
  " # save model\n",
@@ -1438,7 +1440,7 @@
1438
  " # n_sample = params.shape[0]\n",
1439
  " params = params or torch.tensor([0.2,0.8]).repeat(5,1)\n",
1440
  " assert params.dim() == 2, \"params must be a 2D torch.tensor\"\n",
1441
- " print(\"params.shape =\", params.shape)\n",
1442
  " # print(\"params =\", params)\n",
1443
  " # print(\"len(params) =\", len(params))\n",
1444
  " # model = self.ema_model if ema else self.nn_model\n",
@@ -1460,12 +1462,13 @@
1460
  " # self.ema_model.load_state_dict(torch.load(os.path.join(config.output_dir, f\"{config.resume}\"))['ema_unet_state_dict'])\n",
1461
  " # print(f\"resumed ema_model from {config.resume}\")\n",
1462
  "\n",
1463
- " x_last, x_entire = self.ddpm.sample(\n",
1464
- " nn_model=nn_model, \n",
1465
- " params=params.to(self.config.device), \n",
1466
- " device=self.config.device, \n",
1467
- " guide_w=self.config.guide_w\n",
1468
- " )\n",
 
1469
  "\n",
1470
  " np.save(os.path.join(self.config.output_dir, f\"{self.config.run_name}{'ema' if ema else None}.npy\"), x_last)\n",
1471
  "\n",
@@ -1479,7 +1482,7 @@
1479
  },
1480
  {
1481
  "cell_type": "code",
1482
- "execution_count": 25,
1483
  "metadata": {},
1484
  "outputs": [
1485
  {
@@ -1506,14 +1509,14 @@
1506
  "output_type": "stream",
1507
  "text": [
1508
  "params loaded: (200, 2)\n",
1509
- "images rescaled to [-1.0, 1.082669734954834]\n",
1510
- "params rescaled to [0.0, 0.9967326920094951]\n"
1511
  ]
1512
  },
1513
  {
1514
  "data": {
1515
  "application/vnd.jupyter.widget-view+json": {
1516
- "model_id": "942983559b1a4599a837802a5661819e",
1517
  "version_major": 2,
1518
  "version_minor": 0
1519
  },
@@ -1527,7 +1530,7 @@
1527
  {
1528
  "data": {
1529
  "application/vnd.jupyter.widget-view+json": {
1530
- "model_id": "0fb002cefad94cf1bcd9e5051f476a9a",
1531
  "version_major": 2,
1532
  "version_minor": 0
1533
  },
@@ -1541,7 +1544,7 @@
1541
  {
1542
  "data": {
1543
  "application/vnd.jupyter.widget-view+json": {
1544
- "model_id": "299c836b34294f0f87150b8ae853f3ce",
1545
  "version_major": 2,
1546
  "version_minor": 0
1547
  },
@@ -1555,7 +1558,7 @@
1555
  {
1556
  "data": {
1557
  "application/vnd.jupyter.widget-view+json": {
1558
- "model_id": "6c837426e91448c792a13b793cea49b2",
1559
  "version_major": 2,
1560
  "version_minor": 0
1561
  },
@@ -1569,7 +1572,7 @@
1569
  {
1570
  "data": {
1571
  "application/vnd.jupyter.widget-view+json": {
1572
- "model_id": "56d4f8cb22ee47cd96eb54bd3190c604",
1573
  "version_major": 2,
1574
  "version_minor": 0
1575
  },
@@ -1583,7 +1586,7 @@
1583
  {
1584
  "data": {
1585
  "application/vnd.jupyter.widget-view+json": {
1586
- "model_id": "7655aad351544b97846605477be63089",
1587
  "version_major": 2,
1588
  "version_minor": 0
1589
  },
@@ -1597,7 +1600,7 @@
1597
  {
1598
  "data": {
1599
  "application/vnd.jupyter.widget-view+json": {
1600
- "model_id": "d869a0412b1c4e7eb986b28c319fe8a0",
1601
  "version_major": 2,
1602
  "version_minor": 0
1603
  },
@@ -1611,7 +1614,7 @@
1611
  {
1612
  "data": {
1613
  "application/vnd.jupyter.widget-view+json": {
1614
- "model_id": "16148281b53f4c09a1fbcf4704ba1cf0",
1615
  "version_major": 2,
1616
  "version_minor": 0
1617
  },
@@ -1625,7 +1628,7 @@
1625
  {
1626
  "data": {
1627
  "application/vnd.jupyter.widget-view+json": {
1628
- "model_id": "628e0253ac4045578d22ea29648021c3",
1629
  "version_major": 2,
1630
  "version_minor": 0
1631
  },
@@ -1639,7 +1642,7 @@
1639
  {
1640
  "data": {
1641
  "application/vnd.jupyter.widget-view+json": {
1642
- "model_id": "0c3a68177127449ab185502b7b02245f",
1643
  "version_major": 2,
1644
  "version_minor": 0
1645
  },
@@ -1664,13 +1667,18 @@
1664
  "name": "stdout",
1665
  "output_type": "stream",
1666
  "text": [
 
 
 
 
 
1667
  "nn_model resumed from ./outputs/model_state_09.pth\n"
1668
  ]
1669
  },
1670
  {
1671
  "data": {
1672
  "application/vnd.jupyter.widget-view+json": {
1673
- "model_id": "f0b11cda17be4a77b3d926968350f23f",
1674
  "version_major": 2,
1675
  "version_minor": 0
1676
  },
@@ -1683,27 +1691,24 @@
1683
  },
1684
  {
1685
  "ename": "RuntimeError",
1686
- "evalue": "CUDA out of memory. Tried to allocate 16.00 MiB (GPU 0; 23.64 GiB total capacity; 22.57 GiB already allocated; 12.50 MiB free; 22.60 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF",
1687
  "output_type": "error",
1688
  "traceback": [
1689
  "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
1690
  "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
1691
- "Cell \u001b[0;32mIn[33], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m ddpm21cm\u001b[39m.\u001b[39;49msample(\u001b[39m\"\u001b[39;49m\u001b[39m./outputs/model_state_09.pth\u001b[39;49m\u001b[39m\"\u001b[39;49m)\n",
1692
- "Cell \u001b[0;32mIn[31], line 176\u001b[0m, in \u001b[0;36mDDPM21CM.sample\u001b[0;34m(self, file, params, ema, entire)\u001b[0m\n\u001b[1;32m 170\u001b[0m nn_model\u001b[39m.\u001b[39meval()\n\u001b[1;32m 172\u001b[0m \u001b[39m# self.ema_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM, dim=config.dim, stride=config.stride).to(config.device)\u001b[39;00m\n\u001b[1;32m 173\u001b[0m \u001b[39m# self.ema_model.load_state_dict(torch.load(os.path.join(config.output_dir, f\"{config.resume}\"))['ema_unet_state_dict'])\u001b[39;00m\n\u001b[1;32m 174\u001b[0m \u001b[39m# print(f\"resumed ema_model from {config.resume}\")\u001b[39;00m\n\u001b[0;32m--> 176\u001b[0m x_last, x_entire \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mddpm\u001b[39m.\u001b[39;49msample(\n\u001b[1;32m 177\u001b[0m nn_model\u001b[39m=\u001b[39;49mnn_model, \n\u001b[1;32m 178\u001b[0m params\u001b[39m=\u001b[39;49mparams\u001b[39m.\u001b[39;49mto(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mconfig\u001b[39m.\u001b[39;49mdevice), \n\u001b[1;32m 179\u001b[0m device\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mconfig\u001b[39m.\u001b[39;49mdevice, \n\u001b[1;32m 180\u001b[0m guide_w\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mconfig\u001b[39m.\u001b[39;49mguide_w\n\u001b[1;32m 181\u001b[0m )\n\u001b[1;32m 183\u001b[0m np\u001b[39m.\u001b[39msave(os\u001b[39m.\u001b[39mpath\u001b[39m.\u001b[39mjoin(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconfig\u001b[39m.\u001b[39moutput_dir, \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m{\u001b[39;00m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconfig\u001b[39m.\u001b[39mrun_name\u001b[39m}\u001b[39;00m\u001b[39m{\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39mema\u001b[39m\u001b[39m'\u001b[39m\u001b[39m \u001b[39m\u001b[39mif\u001b[39;00m\u001b[39m \u001b[39mema\u001b[39m \u001b[39m\u001b[39melse\u001b[39;00m\u001b[39m \u001b[39m\u001b[39mNone\u001b[39;00m\u001b[39m}\u001b[39;00m\u001b[39m.npy\u001b[39m\u001b[39m\"\u001b[39m), x_last)\n\u001b[1;32m 185\u001b[0m \u001b[39mif\u001b[39;00m entire:\n",
1693
- "Cell \u001b[0;32mIn[6], line 75\u001b[0m, in \u001b[0;36mDDPMScheduler.sample\u001b[0;34m(self, nn_model, params, device, guide_w)\u001b[0m\n\u001b[1;32m 71\u001b[0m t_is \u001b[39m=\u001b[39m t_is\u001b[39m.\u001b[39mrepeat(\u001b[39m2\u001b[39m)\n\u001b[1;32m 73\u001b[0m \u001b[39m# split predictions and compute weighting\u001b[39;00m\n\u001b[1;32m 74\u001b[0m \u001b[39m# print(\"nn_model input shape\", x_i.shape, t_is.shape, c_i.shape)\u001b[39;00m\n\u001b[0;32m---> 75\u001b[0m eps \u001b[39m=\u001b[39m nn_model(x_i, t_is, c_i)\n\u001b[1;32m 76\u001b[0m eps1 \u001b[39m=\u001b[39m eps[:n_sample]\n\u001b[1;32m 77\u001b[0m eps2 \u001b[39m=\u001b[39m eps[n_sample:]\n",
1694
- "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
1695
- "Cell \u001b[0;32mIn[17], line 241\u001b[0m, in \u001b[0;36mContextUnet.forward\u001b[0;34m(self, x, timesteps, y)\u001b[0m\n\u001b[1;32m 237\u001b[0m \u001b[39mfor\u001b[39;00m module \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39moutput_blocks:\n\u001b[1;32m 238\u001b[0m \u001b[39m# print(\"for module in self.output_blocks, h.shape =\", h.shape)\u001b[39;00m\n\u001b[1;32m 239\u001b[0m \u001b[39m# print(\"len(hs) =\", len(hs), \", hs[-1].shape =\", hs[-1].shape)\u001b[39;00m\n\u001b[1;32m 240\u001b[0m h \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mcat([h, hs\u001b[39m.\u001b[39mpop()], dim\u001b[39m=\u001b[39m\u001b[39m1\u001b[39m)\n\u001b[0;32m--> 241\u001b[0m h \u001b[39m=\u001b[39m module(h, emb)\n\u001b[1;32m 242\u001b[0m \u001b[39m# print(\"module decoder, h.shape =\", h.shape)\u001b[39;00m\n\u001b[1;32m 244\u001b[0m h \u001b[39m=\u001b[39m h\u001b[39m.\u001b[39mtype(x\u001b[39m.\u001b[39mdtype)\n",
1696
  "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
1697
- "Cell \u001b[0;32mIn[12], line 5\u001b[0m, in \u001b[0;36mTimestepEmbedSequential.forward\u001b[0;34m(self, x, emb, encoder_out)\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[39mfor\u001b[39;00m layer \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m:\n\u001b[1;32m 4\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(layer, TimestepBlock):\n\u001b[0;32m----> 5\u001b[0m x \u001b[39m=\u001b[39m layer(x, emb)\n\u001b[1;32m 6\u001b[0m \u001b[39melif\u001b[39;00m \u001b[39misinstance\u001b[39m(layer, AttentionBlock):\n\u001b[1;32m 7\u001b[0m x \u001b[39m=\u001b[39m layer(x, encoder_out)\n",
1698
  "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
1699
- "Cell \u001b[0;32mIn[13], line 72\u001b[0m, in \u001b[0;36mResBlock.forward\u001b[0;34m(self, x, emb)\u001b[0m\n\u001b[1;32m 70\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 71\u001b[0m h \u001b[39m+\u001b[39m\u001b[39m=\u001b[39m emb_out\n\u001b[0;32m---> 72\u001b[0m h \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mout_layers(h)\n\u001b[1;32m 73\u001b[0m \u001b[39m# print(\"ResBlock, torch.unique(h).shape =\", torch.unique(h).shape)\u001b[39;00m\n\u001b[1;32m 74\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mskip_connection(x) \u001b[39m+\u001b[39m h\n",
1700
  "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
1701
- "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/container.py:139\u001b[0m, in \u001b[0;36mSequential.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 137\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mforward\u001b[39m(\u001b[39mself\u001b[39m, \u001b[39minput\u001b[39m):\n\u001b[1;32m 138\u001b[0m \u001b[39mfor\u001b[39;00m module \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m:\n\u001b[0;32m--> 139\u001b[0m \u001b[39minput\u001b[39m \u001b[39m=\u001b[39m module(\u001b[39minput\u001b[39;49m)\n\u001b[1;32m 140\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39minput\u001b[39m\n",
1702
  "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
1703
- "Cell \u001b[0;32mIn[7], line 7\u001b[0m, in \u001b[0;36mGroupNorm32.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mforward\u001b[39m(\u001b[39mself\u001b[39m, x):\n\u001b[0;32m----> 7\u001b[0m y \u001b[39m=\u001b[39m \u001b[39msuper\u001b[39;49m()\u001b[39m.\u001b[39;49mforward(x\u001b[39m.\u001b[39;49mfloat())\u001b[39m.\u001b[39mto(x\u001b[39m.\u001b[39mdtype)\n\u001b[1;32m 8\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mswish \u001b[39m==\u001b[39m \u001b[39m1.0\u001b[39m:\n\u001b[1;32m 9\u001b[0m y \u001b[39m=\u001b[39m F\u001b[39m.\u001b[39msilu(y)\n",
1704
- "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/normalization.py:272\u001b[0m, in \u001b[0;36mGroupNorm.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 271\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mforward\u001b[39m(\u001b[39mself\u001b[39m, \u001b[39minput\u001b[39m: Tensor) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Tensor:\n\u001b[0;32m--> 272\u001b[0m \u001b[39mreturn\u001b[39;00m F\u001b[39m.\u001b[39;49mgroup_norm(\n\u001b[1;32m 273\u001b[0m \u001b[39minput\u001b[39;49m, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mnum_groups, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mweight, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mbias, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49meps)\n",
1705
- "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/functional.py:2516\u001b[0m, in \u001b[0;36mgroup_norm\u001b[0;34m(input, num_groups, weight, bias, eps)\u001b[0m\n\u001b[1;32m 2514\u001b[0m \u001b[39mreturn\u001b[39;00m handle_torch_function(group_norm, (\u001b[39minput\u001b[39m, weight, bias,), \u001b[39minput\u001b[39m, num_groups, weight\u001b[39m=\u001b[39mweight, bias\u001b[39m=\u001b[39mbias, eps\u001b[39m=\u001b[39meps)\n\u001b[1;32m 2515\u001b[0m _verify_batch_size([\u001b[39minput\u001b[39m\u001b[39m.\u001b[39msize(\u001b[39m0\u001b[39m) \u001b[39m*\u001b[39m \u001b[39minput\u001b[39m\u001b[39m.\u001b[39msize(\u001b[39m1\u001b[39m) \u001b[39m/\u001b[39m\u001b[39m/\u001b[39m num_groups, num_groups] \u001b[39m+\u001b[39m \u001b[39mlist\u001b[39m(\u001b[39minput\u001b[39m\u001b[39m.\u001b[39msize()[\u001b[39m2\u001b[39m:]))\n\u001b[0;32m-> 2516\u001b[0m \u001b[39mreturn\u001b[39;00m torch\u001b[39m.\u001b[39;49mgroup_norm(\u001b[39minput\u001b[39;49m, num_groups, weight, bias, eps, torch\u001b[39m.\u001b[39;49mbackends\u001b[39m.\u001b[39;49mcudnn\u001b[39m.\u001b[39;49menabled)\n",
1706
- "\u001b[0;31mRuntimeError\u001b[0m: CUDA out of memory. Tried to allocate 16.00 MiB (GPU 0; 23.64 GiB total capacity; 22.57 GiB already allocated; 12.50 MiB free; 22.60 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF"
1707
  ]
1708
  }
1709
  ],
@@ -1868,39 +1873,108 @@
1868
  },
1869
  {
1870
  "cell_type": "code",
1871
- "execution_count": 25,
1872
  "metadata": {},
1873
  "outputs": [
1874
  {
1875
- "ename": "NameError",
1876
- "evalue": "name 'config' is not defined",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1877
  "output_type": "error",
1878
  "traceback": [
1879
  "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
1880
- "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
1881
- "Cell \u001b[0;32mIn[25], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m ddpm \u001b[39m=\u001b[39m DDPMScheduler(betas\u001b[39m=\u001b[39m(\u001b[39m1e-4\u001b[39m, \u001b[39m0.02\u001b[39m), num_timesteps\u001b[39m=\u001b[39mconfig\u001b[39m.\u001b[39mnum_timesteps, device\u001b[39m=\u001b[39mconfig\u001b[39m.\u001b[39mdevice)\n\u001b[1;32m 3\u001b[0m nn_model \u001b[39m=\u001b[39m ContextUnet(n_param\u001b[39m=\u001b[39mconfig\u001b[39m.\u001b[39mn_param, image_size\u001b[39m=\u001b[39mconfig\u001b[39m.\u001b[39mHII_DIM)\n\u001b[1;32m 4\u001b[0m \u001b[39mprint\u001b[39m(\u001b[39m\"\u001b[39m\u001b[39mresuming nn_model\u001b[39m\u001b[39m\"\u001b[39m)\n",
1882
- "\u001b[0;31mNameError\u001b[0m: name 'config' is not defined"
 
1883
  ]
1884
  }
1885
  ],
1886
  "source": [
 
1887
  "# ddpm = DDPMScheduler(betas=(1e-4, 0.02), num_timesteps=config.num_timesteps, device=config.device)\n",
 
1888
  "\n",
1889
- "# nn_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM)\n",
1890
- "# print(\"resuming nn_model\")\n",
1891
- "# nn_model.load_state_dict(torch.load(os.path.join(config.output_dir, f\"model_state.pth\"))['ema_unet_state_dict'])\n",
1892
- "# # nn_model = ContextUnet(n_param=1, image_size=28)\n",
1893
- "# # nn_model.train()\n",
1894
- "# nn_model.to(ddpm.device)\n",
1895
- "# nn_model.eval()\n",
1896
  "\n",
1897
  "# n_sample = 20\n",
1898
- "# with torch.no_grad():\n",
1899
- "# x_last_ema, x_ema_entire = ddpm.sample(nn_model, n_sample, (1,config.HII_DIM, config.num_redshift), config.device, params = torch.tile(config.params_single,(n_sample,1)).to(config.device), guide_w=config.guide_w)\n",
 
 
 
 
 
 
 
1900
  "\n",
1901
- "# np.save(os.path.join(config.output_dir, f\"{config.run_name}_ema.npy\"), x_last_ema)"
1902
  ]
1903
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1904
  {
1905
  "cell_type": "code",
1906
  "execution_count": 21,
 
21
  "- 已擴展爲接受不同維度的情形\n",
22
  "- 融合cond, guide_w, drop_out這些參數\n",
23
  "- 生成的21cm圖像該暗的地方不夠暗,似乎換成MNIST的數字圖像就沒問題\n",
24
+ "- 我用diffusion模型生成MNIST的數字時發現,儘管生成的數據的範圍也存在負數數值,如-0.1,但畫出來的圖像卻是理想的黑色。數據的分佈與21cm的結果的分佈沒多大差別,我現在打算把代碼退回到21cm的情形\n",
25
+ "- 我統一了ddpm21cm這個module,能統一實現訓練和生成樣本,但目前有個bug, sample時總是會cuda out of memory,然而單獨resume model並sample就不會。"
26
  ]
27
  },
28
  {
 
33
  {
34
  "data": {
35
  "application/vnd.jupyter.widget-view+json": {
36
+ "model_id": "4f2bbf6f5e904828bc65afc7ad97df36",
37
  "version_major": 2,
38
  "version_minor": 0
39
  },
 
86
  },
87
  {
88
  "cell_type": "code",
89
+ "execution_count": 3,
90
  "metadata": {},
91
  "outputs": [],
92
  "source": [
 
95
  },
96
  {
97
  "cell_type": "code",
98
+ "execution_count": 4,
99
  "metadata": {},
100
  "outputs": [],
101
  "source": [
 
192
  },
193
  {
194
  "cell_type": "code",
195
+ "execution_count": 5,
196
  "metadata": {},
197
  "outputs": [],
198
  "source": [
 
202
  },
203
  {
204
  "cell_type": "code",
205
+ "execution_count": 6,
206
  "metadata": {},
207
  "outputs": [],
208
  "source": [
 
235
  },
236
  {
237
  "cell_type": "code",
238
+ "execution_count": 7,
239
  "metadata": {},
240
  "outputs": [],
241
  "source": [
 
342
  },
343
  {
344
  "cell_type": "code",
345
+ "execution_count": 8,
346
  "metadata": {},
347
  "outputs": [],
348
  "source": [
 
384
  },
385
  {
386
  "cell_type": "code",
387
+ "execution_count": 9,
388
  "metadata": {},
389
  "outputs": [],
390
  "source": [
 
409
  },
410
  {
411
  "cell_type": "code",
412
+ "execution_count": 10,
413
  "metadata": {},
414
  "outputs": [],
415
  "source": [
 
439
  },
440
  {
441
  "cell_type": "code",
442
+ "execution_count": 11,
443
  "metadata": {},
444
  "outputs": [],
445
  "source": [
 
454
  },
455
  {
456
  "cell_type": "code",
457
+ "execution_count": 12,
458
  "metadata": {},
459
  "outputs": [],
460
  "source": [
 
468
  },
469
  {
470
  "cell_type": "code",
471
+ "execution_count": 13,
472
  "metadata": {},
473
  "outputs": [],
474
  "source": [
 
486
  },
487
  {
488
  "cell_type": "code",
489
+ "execution_count": 14,
490
  "metadata": {},
491
  "outputs": [],
492
  "source": [
 
568
  },
569
  {
570
  "cell_type": "code",
571
+ "execution_count": 15,
572
  "metadata": {},
573
  "outputs": [],
574
  "source": [
 
601
  },
602
  {
603
  "cell_type": "code",
604
+ "execution_count": 16,
605
  "metadata": {},
606
  "outputs": [],
607
  "source": [
 
650
  },
651
  {
652
  "cell_type": "code",
653
+ "execution_count": 17,
654
  "metadata": {},
655
  "outputs": [],
656
  "source": [
 
679
  },
680
  {
681
  "cell_type": "code",
682
+ "execution_count": 18,
683
  "metadata": {},
684
  "outputs": [],
685
  "source": [
 
935
  },
936
  {
937
  "cell_type": "code",
938
+ "execution_count": 19,
939
  "metadata": {},
940
  "outputs": [],
941
  "source": [
 
965
  },
966
  {
967
  "cell_type": "code",
968
+ "execution_count": 20,
969
  "metadata": {},
970
  "outputs": [],
971
  "source": [
 
1031
  },
1032
  {
1033
  "cell_type": "code",
1034
+ "execution_count": 21,
1035
  "metadata": {},
1036
  "outputs": [],
1037
  "source": [
 
1041
  },
1042
  {
1043
  "cell_type": "code",
1044
+ "execution_count": 22,
1045
  "metadata": {},
1046
  "outputs": [],
1047
  "source": [
 
1050
  },
1051
  {
1052
  "cell_type": "code",
1053
+ "execution_count": 23,
1054
  "metadata": {},
1055
  "outputs": [],
1056
  "source": [
 
1074
  },
1075
  {
1076
  "cell_type": "code",
1077
+ "execution_count": 24,
1078
  "metadata": {},
1079
  "outputs": [],
1080
  "source": [
 
1272
  },
1273
  {
1274
  "cell_type": "code",
1275
+ "execution_count": 25,
1276
  "metadata": {},
1277
  "outputs": [
1278
  {
 
1411
  " del self.nn_model\n",
1412
  " if self.config.ema:\n",
1413
  " del self.ema_model\n",
1414
+ " torch.cuda.empty_cache()\n",
1415
  "\n",
1416
  " def save(self, ep):\n",
1417
  " # save model\n",
 
1440
  " # n_sample = params.shape[0]\n",
1441
  " params = params or torch.tensor([0.2,0.8]).repeat(5,1)\n",
1442
  " assert params.dim() == 2, \"params must be a 2D torch.tensor\"\n",
1443
+ " print(\"params =\", params)\n",
1444
  " # print(\"params =\", params)\n",
1445
  " # print(\"len(params) =\", len(params))\n",
1446
  " # model = self.ema_model if ema else self.nn_model\n",
 
1462
  " # self.ema_model.load_state_dict(torch.load(os.path.join(config.output_dir, f\"{config.resume}\"))['ema_unet_state_dict'])\n",
1463
  " # print(f\"resumed ema_model from {config.resume}\")\n",
1464
  "\n",
1465
+ " with torch.no_grad():\n",
1466
+ " x_last, x_entire = self.ddpm.sample(\n",
1467
+ " nn_model=nn_model, \n",
1468
+ " params=params.to(self.config.device), \n",
1469
+ " device=self.config.device, \n",
1470
+ " guide_w=self.config.guide_w\n",
1471
+ " )\n",
1472
  "\n",
1473
  " np.save(os.path.join(self.config.output_dir, f\"{self.config.run_name}{'ema' if ema else None}.npy\"), x_last)\n",
1474
  "\n",
 
1482
  },
1483
  {
1484
  "cell_type": "code",
1485
+ "execution_count": 26,
1486
  "metadata": {},
1487
  "outputs": [
1488
  {
 
1509
  "output_type": "stream",
1510
  "text": [
1511
  "params loaded: (200, 2)\n",
1512
+ "images rescaled to [-1.0, 1.064338207244873]\n",
1513
+ "params rescaled to [0.0, 0.9988593502151616]\n"
1514
  ]
1515
  },
1516
  {
1517
  "data": {
1518
  "application/vnd.jupyter.widget-view+json": {
1519
+ "model_id": "2e0b629831714bc2b32e25d44a72f4b3",
1520
  "version_major": 2,
1521
  "version_minor": 0
1522
  },
 
1530
  {
1531
  "data": {
1532
  "application/vnd.jupyter.widget-view+json": {
1533
+ "model_id": "c634a180ede04f3cb09ab74daf0401c6",
1534
  "version_major": 2,
1535
  "version_minor": 0
1536
  },
 
1544
  {
1545
  "data": {
1546
  "application/vnd.jupyter.widget-view+json": {
1547
+ "model_id": "6f3a0791c42b4d7e958f2a9d57f64de8",
1548
  "version_major": 2,
1549
  "version_minor": 0
1550
  },
 
1558
  {
1559
  "data": {
1560
  "application/vnd.jupyter.widget-view+json": {
1561
+ "model_id": "9dce2de3e8a14aee83e2b182dc06608f",
1562
  "version_major": 2,
1563
  "version_minor": 0
1564
  },
 
1572
  {
1573
  "data": {
1574
  "application/vnd.jupyter.widget-view+json": {
1575
+ "model_id": "d4596bdc71cc4d4cb780442b97849883",
1576
  "version_major": 2,
1577
  "version_minor": 0
1578
  },
 
1586
  {
1587
  "data": {
1588
  "application/vnd.jupyter.widget-view+json": {
1589
+ "model_id": "6e68847216504241b81ebcb71c48f687",
1590
  "version_major": 2,
1591
  "version_minor": 0
1592
  },
 
1600
  {
1601
  "data": {
1602
  "application/vnd.jupyter.widget-view+json": {
1603
+ "model_id": "830c25eb902a47e7997dcdb40099c5a4",
1604
  "version_major": 2,
1605
  "version_minor": 0
1606
  },
 
1614
  {
1615
  "data": {
1616
  "application/vnd.jupyter.widget-view+json": {
1617
+ "model_id": "87fdac7b595c4d0ea7258ee8bb35de17",
1618
  "version_major": 2,
1619
  "version_minor": 0
1620
  },
 
1628
  {
1629
  "data": {
1630
  "application/vnd.jupyter.widget-view+json": {
1631
+ "model_id": "b9f6be95f4bd403d85f6df34756e7b8d",
1632
  "version_major": 2,
1633
  "version_minor": 0
1634
  },
 
1642
  {
1643
  "data": {
1644
  "application/vnd.jupyter.widget-view+json": {
1645
+ "model_id": "28ec5d881b37440ba5f4c863fc552c17",
1646
  "version_major": 2,
1647
  "version_minor": 0
1648
  },
 
1667
  "name": "stdout",
1668
  "output_type": "stream",
1669
  "text": [
1670
+ "params = tensor([[0.2000, 0.8000],\n",
1671
+ " [0.2000, 0.8000],\n",
1672
+ " [0.2000, 0.8000],\n",
1673
+ " [0.2000, 0.8000],\n",
1674
+ " [0.2000, 0.8000]])\n",
1675
  "nn_model resumed from ./outputs/model_state_09.pth\n"
1676
  ]
1677
  },
1678
  {
1679
  "data": {
1680
  "application/vnd.jupyter.widget-view+json": {
1681
+ "model_id": "58944c3b1e4f42bb8771f776c35a90a7",
1682
  "version_major": 2,
1683
  "version_minor": 0
1684
  },
 
1691
  },
1692
  {
1693
  "ename": "RuntimeError",
1694
+ "evalue": "CUDA out of memory. Tried to allocate 640.00 MiB (GPU 0; 23.64 GiB total capacity; 21.65 GiB already allocated; 432.50 MiB free; 22.23 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF",
1695
  "output_type": "error",
1696
  "traceback": [
1697
  "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
1698
  "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
1699
+ "Cell \u001b[0;32mIn[26], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m ddpm21cm\u001b[39m.\u001b[39;49msample(\u001b[39m\"\u001b[39;49m\u001b[39m./outputs/model_state_09.pth\u001b[39;49m\u001b[39m\"\u001b[39;49m)\n",
1700
+ "Cell \u001b[0;32mIn[25], line 177\u001b[0m, in \u001b[0;36mDDPM21CM.sample\u001b[0;34m(self, file, params, ema, entire)\u001b[0m\n\u001b[1;32m 171\u001b[0m nn_model\u001b[39m.\u001b[39meval()\n\u001b[1;32m 173\u001b[0m \u001b[39m# self.ema_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM, dim=config.dim, stride=config.stride).to(config.device)\u001b[39;00m\n\u001b[1;32m 174\u001b[0m \u001b[39m# self.ema_model.load_state_dict(torch.load(os.path.join(config.output_dir, f\"{config.resume}\"))['ema_unet_state_dict'])\u001b[39;00m\n\u001b[1;32m 175\u001b[0m \u001b[39m# print(f\"resumed ema_model from {config.resume}\")\u001b[39;00m\n\u001b[0;32m--> 177\u001b[0m x_last, x_entire \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mddpm\u001b[39m.\u001b[39;49msample(\n\u001b[1;32m 178\u001b[0m nn_model\u001b[39m=\u001b[39;49mnn_model, \n\u001b[1;32m 179\u001b[0m params\u001b[39m=\u001b[39;49mparams\u001b[39m.\u001b[39;49mto(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mconfig\u001b[39m.\u001b[39;49mdevice), \n\u001b[1;32m 180\u001b[0m device\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mconfig\u001b[39m.\u001b[39;49mdevice, \n\u001b[1;32m 181\u001b[0m guide_w\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mconfig\u001b[39m.\u001b[39;49mguide_w\n\u001b[1;32m 182\u001b[0m )\n\u001b[1;32m 184\u001b[0m np\u001b[39m.\u001b[39msave(os\u001b[39m.\u001b[39mpath\u001b[39m.\u001b[39mjoin(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconfig\u001b[39m.\u001b[39moutput_dir, \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m{\u001b[39;00m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconfig\u001b[39m.\u001b[39mrun_name\u001b[39m}\u001b[39;00m\u001b[39m{\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39mema\u001b[39m\u001b[39m'\u001b[39m\u001b[39m \u001b[39m\u001b[39mif\u001b[39;00m\u001b[39m \u001b[39mema\u001b[39m \u001b[39m\u001b[39melse\u001b[39;00m\u001b[39m \u001b[39m\u001b[39mNone\u001b[39;00m\u001b[39m}\u001b[39;00m\u001b[39m.npy\u001b[39m\u001b[39m\"\u001b[39m), x_last)\n\u001b[1;32m 186\u001b[0m \u001b[39mif\u001b[39;00m entire:\n",
1701
+ "Cell \u001b[0;32mIn[7], line 75\u001b[0m, in \u001b[0;36mDDPMScheduler.sample\u001b[0;34m(self, nn_model, params, device, guide_w)\u001b[0m\n\u001b[1;32m 71\u001b[0m t_is \u001b[39m=\u001b[39m t_is\u001b[39m.\u001b[39mrepeat(\u001b[39m2\u001b[39m)\n\u001b[1;32m 73\u001b[0m \u001b[39m# split predictions and compute weighting\u001b[39;00m\n\u001b[1;32m 74\u001b[0m \u001b[39m# print(\"nn_model input shape\", x_i.shape, t_is.shape, c_i.shape)\u001b[39;00m\n\u001b[0;32m---> 75\u001b[0m eps \u001b[39m=\u001b[39m nn_model(x_i, t_is, c_i)\n\u001b[1;32m 76\u001b[0m eps1 \u001b[39m=\u001b[39m eps[:n_sample]\n\u001b[1;32m 77\u001b[0m eps2 \u001b[39m=\u001b[39m eps[n_sample:]\n",
 
 
1702
  "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
1703
+ "Cell \u001b[0;32mIn[18], line 241\u001b[0m, in \u001b[0;36mContextUnet.forward\u001b[0;34m(self, x, timesteps, y)\u001b[0m\n\u001b[1;32m 237\u001b[0m \u001b[39mfor\u001b[39;00m module \u001b[39min\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39moutput_blocks:\n\u001b[1;32m 238\u001b[0m \u001b[39m# print(\"for module in self.output_blocks, h.shape =\", h.shape)\u001b[39;00m\n\u001b[1;32m 239\u001b[0m \u001b[39m# print(\"len(hs) =\", len(hs), \", hs[-1].shape =\", hs[-1].shape)\u001b[39;00m\n\u001b[1;32m 240\u001b[0m h \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mcat([h, hs\u001b[39m.\u001b[39mpop()], dim\u001b[39m=\u001b[39m\u001b[39m1\u001b[39m)\n\u001b[0;32m--> 241\u001b[0m h \u001b[39m=\u001b[39m module(h, emb)\n\u001b[1;32m 242\u001b[0m \u001b[39m# print(\"module decoder, h.shape =\", h.shape)\u001b[39;00m\n\u001b[1;32m 244\u001b[0m h \u001b[39m=\u001b[39m h\u001b[39m.\u001b[39mtype(x\u001b[39m.\u001b[39mdtype)\n",
1704
  "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
1705
+ "Cell \u001b[0;32mIn[13], line 7\u001b[0m, in \u001b[0;36mTimestepEmbedSequential.forward\u001b[0;34m(self, x, emb, encoder_out)\u001b[0m\n\u001b[1;32m 5\u001b[0m x \u001b[39m=\u001b[39m layer(x, emb)\n\u001b[1;32m 6\u001b[0m \u001b[39melif\u001b[39;00m \u001b[39misinstance\u001b[39m(layer, AttentionBlock):\n\u001b[0;32m----> 7\u001b[0m x \u001b[39m=\u001b[39m layer(x, encoder_out)\n\u001b[1;32m 8\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 9\u001b[0m x \u001b[39m=\u001b[39m layer(x)\n",
1706
  "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
1707
+ "Cell \u001b[0;32mIn[16], line 37\u001b[0m, in \u001b[0;36mAttentionBlock.forward\u001b[0;34m(self, x, encoder_out)\u001b[0m\n\u001b[1;32m 35\u001b[0m h \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mattention(qkv, encoder_out)\n\u001b[1;32m 36\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[0;32m---> 37\u001b[0m h \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mattention(qkv)\n\u001b[1;32m 38\u001b[0m \u001b[39m# print(\"AttentionBlock, before proj_out, torch.unique(h).shape =\", torch.unique(h).shape)\u001b[39;00m\n\u001b[1;32m 39\u001b[0m h \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mproj_out(h)\n",
1708
  "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
1709
+ "Cell \u001b[0;32mIn[15], line 21\u001b[0m, in \u001b[0;36mQKVAttention.forward\u001b[0;34m(self, qkv, encoder_kv)\u001b[0m\n\u001b[1;32m 18\u001b[0m v \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mcat([ev,v], dim\u001b[39m=\u001b[39m\u001b[39m-\u001b[39m\u001b[39m1\u001b[39m)\n\u001b[1;32m 20\u001b[0m scale \u001b[39m=\u001b[39m \u001b[39m1\u001b[39m \u001b[39m/\u001b[39m math\u001b[39m.\u001b[39msqrt(math\u001b[39m.\u001b[39msqrt(ch))\n\u001b[0;32m---> 21\u001b[0m weight \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39;49meinsum(\u001b[39m\"\u001b[39;49m\u001b[39mbct,bcs->bts\u001b[39;49m\u001b[39m\"\u001b[39;49m, q\u001b[39m*\u001b[39;49mscale, k\u001b[39m*\u001b[39;49mscale)\n\u001b[1;32m 22\u001b[0m weight \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39msoftmax(weight\u001b[39m.\u001b[39mfloat(), dim\u001b[39m=\u001b[39m\u001b[39m-\u001b[39m\u001b[39m1\u001b[39m)\u001b[39m.\u001b[39mtype(weight\u001b[39m.\u001b[39mdtype)\n\u001b[1;32m 24\u001b[0m a \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39meinsum(\u001b[39m\"\u001b[39m\u001b[39mbts,bcs->bct\u001b[39m\u001b[39m\"\u001b[39m, weight, v)\n",
1710
+ "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/functional.py:360\u001b[0m, in \u001b[0;36meinsum\u001b[0;34m(*args)\u001b[0m\n\u001b[1;32m 356\u001b[0m \u001b[39m# recurse incase operands contains value that has torch function\u001b[39;00m\n\u001b[1;32m 357\u001b[0m \u001b[39m# in the original implementation this line is omitted\u001b[39;00m\n\u001b[1;32m 358\u001b[0m \u001b[39mreturn\u001b[39;00m einsum(equation, \u001b[39m*\u001b[39m_operands)\n\u001b[0;32m--> 360\u001b[0m \u001b[39mreturn\u001b[39;00m _VF\u001b[39m.\u001b[39;49meinsum(equation, operands)\n",
1711
+ "\u001b[0;31mRuntimeError\u001b[0m: CUDA out of memory. Tried to allocate 640.00 MiB (GPU 0; 23.64 GiB total capacity; 21.65 GiB already allocated; 432.50 MiB free; 22.23 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF"
 
1712
  ]
1713
  }
1714
  ],
 
1873
  },
1874
  {
1875
  "cell_type": "code",
1876
+ "execution_count": 36,
1877
  "metadata": {},
1878
  "outputs": [
1879
  {
1880
+ "name": "stdout",
1881
+ "output_type": "stream",
1882
+ "text": [
1883
+ "resuming nn_model\n"
1884
+ ]
1885
+ },
1886
+ {
1887
+ "data": {
1888
+ "application/vnd.jupyter.widget-view+json": {
1889
+ "model_id": "0e5bf6a8bad7403cab54e0b75464142a",
1890
+ "version_major": 2,
1891
+ "version_minor": 0
1892
+ },
1893
+ "text/plain": [
1894
+ " 0%| | 0/1000 [00:00<?, ?it/s]"
1895
+ ]
1896
+ },
1897
+ "metadata": {},
1898
+ "output_type": "display_data"
1899
+ },
1900
+ {
1901
+ "ename": "KeyboardInterrupt",
1902
+ "evalue": "",
1903
  "output_type": "error",
1904
  "traceback": [
1905
  "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
1906
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
1907
+ "Cell \u001b[0;32mIn[36], line 17\u001b[0m\n\u001b[1;32m 14\u001b[0m params \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mtensor((\u001b[39m0.2\u001b[39m,\u001b[39m0.8\u001b[39m))\u001b[39m.\u001b[39mrepeat(\u001b[39m10\u001b[39m,\u001b[39m1\u001b[39m)\n\u001b[1;32m 15\u001b[0m \u001b[39mwith\u001b[39;00m torch\u001b[39m.\u001b[39mno_grad():\n\u001b[1;32m 16\u001b[0m \u001b[39m# x_last_ema, x_ema_entire = ddpm.sample(nn_model, n_sample, (1,config.HII_DIM, config.num_redshift), config.device, params = torch.tile(config.params_single,(n_sample,1)).to(config.device), guide_w=config.guide_w)\u001b[39;00m\n\u001b[0;32m---> 17\u001b[0m x_last, x_entire \u001b[39m=\u001b[39m ddpm\u001b[39m.\u001b[39;49msample(\n\u001b[1;32m 18\u001b[0m nn_model\u001b[39m=\u001b[39;49mnn_model, \n\u001b[1;32m 19\u001b[0m params\u001b[39m=\u001b[39;49mparams\u001b[39m.\u001b[39;49mto(config\u001b[39m.\u001b[39;49mdevice), \n\u001b[1;32m 20\u001b[0m device\u001b[39m=\u001b[39;49mconfig\u001b[39m.\u001b[39;49mdevice, \n\u001b[1;32m 21\u001b[0m guide_w\u001b[39m=\u001b[39;49mconfig\u001b[39m.\u001b[39;49mguide_w\n\u001b[1;32m 22\u001b[0m )\n\u001b[1;32m 24\u001b[0m np\u001b[39m.\u001b[39msave(os\u001b[39m.\u001b[39mpath\u001b[39m.\u001b[39mjoin(config\u001b[39m.\u001b[39moutput_dir, \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m{\u001b[39;00mconfig\u001b[39m.\u001b[39mrun_name\u001b[39m}\u001b[39;00m\u001b[39m_ema.npy\u001b[39m\u001b[39m\"\u001b[39m), x_last)\n",
1908
+ "Cell \u001b[0;32mIn[6], line 59\u001b[0m, in \u001b[0;36mDDPMScheduler.sample\u001b[0;34m(self, nn_model, params, device, guide_w)\u001b[0m\n\u001b[1;32m 56\u001b[0m pbar_sample\u001b[39m.\u001b[39mset_description(\u001b[39m\"\u001b[39m\u001b[39mSampling\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 57\u001b[0m \u001b[39mfor\u001b[39;00m i \u001b[39min\u001b[39;00m \u001b[39mreversed\u001b[39m(\u001b[39mrange\u001b[39m(\u001b[39m0\u001b[39m, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mnum_timesteps)):\n\u001b[1;32m 58\u001b[0m \u001b[39m# print(f'sampling timestep {i:4d}',end='\\r')\u001b[39;00m\n\u001b[0;32m---> 59\u001b[0m t_is \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39;49mtensor([i])\u001b[39m.\u001b[39;49mto(device)\n\u001b[1;32m 60\u001b[0m t_is \u001b[39m=\u001b[39m t_is\u001b[39m.\u001b[39mrepeat(n_sample)\n\u001b[1;32m 62\u001b[0m z \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mrandn(n_sample, \u001b[39m*\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mimg_shape)\u001b[39m.\u001b[39mto(device) \u001b[39mif\u001b[39;00m i \u001b[39m>\u001b[39m \u001b[39m0\u001b[39m \u001b[39melse\u001b[39;00m \u001b[39m0\u001b[39m\n",
1909
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
1910
  ]
1911
  }
1912
  ],
1913
  "source": [
1914
+ "config = TrainConfig()\n",
1915
  "# ddpm = DDPMScheduler(betas=(1e-4, 0.02), num_timesteps=config.num_timesteps, device=config.device)\n",
1916
+ "ddpm = DDPMScheduler(betas=(1e-4, 0.02), num_timesteps=config.num_timesteps, img_shape=config.img_shape, device=config.device)\n",
1917
  "\n",
1918
+ "nn_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM)\n",
1919
+ "print(\"resuming nn_model\")\n",
1920
+ "nn_model.load_state_dict(torch.load(os.path.join(config.output_dir, f\"model_state.pth\"))['ema_unet_state_dict'])\n",
1921
+ "# nn_model = ContextUnet(n_param=1, image_size=28)\n",
1922
+ "# nn_model.train()\n",
1923
+ "nn_model.to(ddpm.device)\n",
1924
+ "nn_model.eval()\n",
1925
  "\n",
1926
  "# n_sample = 20\n",
1927
+ "params = torch.tensor((0.2,0.8)).repeat(10,1)\n",
1928
+ "with torch.no_grad():\n",
1929
+ " # x_last_ema, x_ema_entire = ddpm.sample(nn_model, n_sample, (1,config.HII_DIM, config.num_redshift), config.device, params = torch.tile(config.params_single,(n_sample,1)).to(config.device), guide_w=config.guide_w)\n",
1930
+ " x_last, x_entire = ddpm.sample(\n",
1931
+ " nn_model=nn_model, \n",
1932
+ " params=params.to(config.device), \n",
1933
+ " device=config.device, \n",
1934
+ " guide_w=config.guide_w\n",
1935
+ " )\n",
1936
  "\n",
1937
+ "np.save(os.path.join(config.output_dir, f\"{config.run_name}_ema.npy\"), x_last)"
1938
  ]
1939
  },
1940
+ {
1941
+ "cell_type": "code",
1942
+ "execution_count": 32,
1943
+ "metadata": {},
1944
+ "outputs": [
1945
+ {
1946
+ "data": {
1947
+ "text/plain": [
1948
+ "(4, 1, 64, 512)"
1949
+ ]
1950
+ },
1951
+ "execution_count": 32,
1952
+ "metadata": {},
1953
+ "output_type": "execute_result"
1954
+ }
1955
+ ],
1956
+ "source": [
1957
+ "x_last.shape"
1958
+ ]
1959
+ },
1960
+ {
1961
+ "cell_type": "code",
1962
+ "execution_count": 35,
1963
+ "metadata": {},
1964
+ "outputs": [
1965
+ {
1966
+ "data": {
1967
+ "text/plain": [
1968
+ "'cuda'"
1969
+ ]
1970
+ },
1971
+ "execution_count": 35,
1972
+ "metadata": {},
1973
+ "output_type": "execute_result"
1974
+ }
1975
+ ],
1976
+ "source": []
1977
+ },
1978
  {
1979
  "cell_type": "code",
1980
  "execution_count": 21,
load_h5.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ import h5py
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.utils.data import DataLoader, Dataset
6
+ # from datasets import Dataset
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ import random
10
+ from abc import ABC, abstractmethod
11
+ import torch.nn.functional as F
12
+ import math
13
+ from PIL import Image
14
+ import os
15
+ from torch.utils.tensorboard import SummaryWriter
16
+ import copy
17
+ from tqdm.auto import tqdm
18
+ # from torchvision import transforms
19
+ # from diffusers import UNet2DModel#, UNet3DConditionModel
20
+ # from diffusers import DDPMScheduler
21
+ from diffusers.utils import make_image_grid
22
+ import datetime
23
+ from pathlib import Path
24
+ from diffusers.optimization import get_cosine_schedule_with_warmup
25
+ from accelerate import notebook_launcher, Accelerator
26
+ from huggingface_hub import create_repo, upload_folder
27
+