Xsmos commited on
Commit
0757b47
·
verified ·
1 Parent(s): d1e5b1a
Files changed (1) hide show
  1. diffusion.ipynb +89 -36
diffusion.ipynb CHANGED
@@ -256,19 +256,19 @@
256
  " self.bar_alpha_t = torch.cumprod(self.alpha_t, dim=0)\n",
257
  "\n",
258
  " def add_noise(self, clean_images):\n",
259
- " self.shape = clean_images.shape\n",
260
- " expand = torch.ones(len(self.shape)-1, dtype=int)\n",
261
  " # ts_expand = ts.view(ts.shape[0], *expand.tolist())\n",
262
  " # expand = [1 for i in range(len(shape)-1)]\n",
263
  "\n",
264
  " noise = torch.randn_like(clean_images).to(self.device)\n",
265
- " ts = torch.randint(0, self.num_timesteps, (self.shape[0],)).to(self.device)\n",
266
  " \n",
267
  " # test_expand = test.view(test.shape[0],*expand)\n",
268
  " # extend_dim = [None for i in range(shape.dim()-1)]\n",
269
  " noisy_images = (\n",
270
- " clean_images * torch.sqrt(self.bar_alpha_t[ts]).view(self.shape[0], *expand.tolist())\n",
271
- " + noise * torch.sqrt(1-self.bar_alpha_t[ts]).view(self.shape[0], *expand.tolist())\n",
272
  " )\n",
273
  " # print(x_t.shape)\n",
274
  "\n",
@@ -1249,7 +1249,7 @@
1249
  },
1250
  {
1251
  "cell_type": "code",
1252
- "execution_count": 109,
1253
  "metadata": {},
1254
  "outputs": [
1255
  {
@@ -1263,8 +1263,8 @@
1263
  "loading 20 images randomly\n",
1264
  "images loaded: (20, 1, 64, 512)\n",
1265
  "params loaded: (20, 2)\n",
1266
- "images rescaled to [-1.0, 0.946401834487915]\n",
1267
- "params rescaled to [0.0, 0.9683106587269014]\n",
1268
  "resumed nn_model from model_state.pth\n",
1269
  "Number of parameters for nn_model: 111048705\n"
1270
  ]
@@ -1422,12 +1422,13 @@
1422
  " print('saved model at ' + self.config.output_dir + f\"model_state_{ep:02d}.pth\")\n",
1423
  " # print('saved model at ' + config.save_dir + f\"model_epoch_{ep}_test_{config.run_name}.pth\")\n",
1424
  "\n",
1425
- " def sample(self, file, params=[0.2,0.8], ema=False, entire=False):\n",
1426
- " n_sample = params.shape[0]\n",
 
1427
  " model = self.ema_model if ema else self.nn_model\n",
1428
  " # params = torch.tile(params, (n_sample,1)).to(device)\n",
1429
  "\n",
1430
- " x_last, x_entire = self.ddpm.sample(model, n_sample, self.config.device, params=params, guide_w=self.config.guide_w)\n",
1431
  "\n",
1432
  " np.save(os.path.join(self.config.output_dir, f\"{self.config.run_name}{'ema' if ema else None}.npy\"), x_last)\n",
1433
  " if entire:\n",
@@ -1526,47 +1527,99 @@
1526
  },
1527
  "metadata": {},
1528
  "output_type": "display_data"
1529
- }
1530
- ],
1531
- "source": [
1532
- "ddpm21cm.train()"
1533
- ]
1534
- },
1535
- {
1536
- "cell_type": "code",
1537
- "execution_count": 69,
1538
- "metadata": {},
1539
- "outputs": [
1540
  {
1541
  "data": {
1542
  "application/vnd.jupyter.widget-view+json": {
1543
- "model_id": "89a5be983ade43d89a1be3d977750a40",
1544
  "version_major": 2,
1545
  "version_minor": 0
1546
  },
1547
  "text/plain": [
1548
- " 0%| | 0/1000 [00:00<?, ?it/s]"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1549
  ]
1550
  },
1551
  "metadata": {},
1552
  "output_type": "display_data"
1553
  },
1554
  {
1555
- "ename": "RuntimeError",
1556
- "evalue": "The size of tensor a (24) must match the size of tensor b (36) at non-singleton dimension 0",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1557
  "output_type": "error",
1558
  "traceback": [
1559
  "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
1560
- "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
1561
- "Cell \u001b[0;32mIn[69], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m ddpm21cm\u001b[39m.\u001b[39;49msample(\u001b[39m\"\u001b[39;49m\u001b[39m./outputs/model_state_09.pth\u001b[39;49m\u001b[39m\"\u001b[39;49m)\n",
1562
- "Cell \u001b[0;32mIn[67], line 141\u001b[0m, in \u001b[0;36mDDPM21CM.sample\u001b[0;34m(self, file, n_sample, ema, entire)\u001b[0m\n\u001b[1;32m 138\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39msample\u001b[39m(\u001b[39mself\u001b[39m, file, n_sample\u001b[39m=\u001b[39m\u001b[39m12\u001b[39m, ema\u001b[39m=\u001b[39m\u001b[39mFalse\u001b[39;00m, entire\u001b[39m=\u001b[39m\u001b[39mFalse\u001b[39;00m):\n\u001b[1;32m 139\u001b[0m model \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mema_model \u001b[39mif\u001b[39;00m ema \u001b[39melse\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mnn_model\n\u001b[0;32m--> 141\u001b[0m x_last, x_entire \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mddpm\u001b[39m.\u001b[39;49msample(model, n_sample, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mshape_loaded[\u001b[39m1\u001b[39;49m:], \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mconfig\u001b[39m.\u001b[39;49mdevice, test_param\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mconfig\u001b[39m.\u001b[39;49mtest_param, guide_w\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mconfig\u001b[39m.\u001b[39;49mguide_w)\n\u001b[1;32m 143\u001b[0m np\u001b[39m.\u001b[39msave(os\u001b[39m.\u001b[39mpath\u001b[39m.\u001b[39mjoin(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconfig\u001b[39m.\u001b[39moutput_dir, \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m{\u001b[39;00m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconfig\u001b[39m.\u001b[39mrun_name\u001b[39m}\u001b[39;00m\u001b[39m{\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39mema\u001b[39m\u001b[39m'\u001b[39m\u001b[39m \u001b[39m\u001b[39mif\u001b[39;00m\u001b[39m \u001b[39mema\u001b[39m \u001b[39m\u001b[39melse\u001b[39;00m\u001b[39m \u001b[39m\u001b[39mNone\u001b[39;00m\u001b[39m}\u001b[39;00m\u001b[39m.npy\u001b[39m\u001b[39m\"\u001b[39m), x_last)\n\u001b[1;32m 144\u001b[0m \u001b[39mif\u001b[39;00m entire:\n",
1563
- "Cell \u001b[0;32mIn[7], line 70\u001b[0m, in \u001b[0;36mDDPMScheduler.sample\u001b[0;34m(self, nn_model, n_sample, shape, device, test_param, guide_w)\u001b[0m\n\u001b[1;32m 67\u001b[0m t_is \u001b[39m=\u001b[39m t_is\u001b[39m.\u001b[39mrepeat(\u001b[39m2\u001b[39m)\n\u001b[1;32m 69\u001b[0m \u001b[39m# split predictions and compute weighting\u001b[39;00m\n\u001b[0;32m---> 70\u001b[0m eps \u001b[39m=\u001b[39m nn_model(x_i, t_is, c_i)\n\u001b[1;32m 71\u001b[0m eps1 \u001b[39m=\u001b[39m eps[:n_sample]\n\u001b[1;32m 72\u001b[0m eps2 \u001b[39m=\u001b[39m eps[n_sample:]\n",
1564
- "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
1565
- "File \u001b[0;32m~/.conda/envs/diffusers/lib/python3.9/site-packages/accelerate/utils/operations.py:822\u001b[0m, in \u001b[0;36mconvert_outputs_to_fp32.<locals>.forward\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 821\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mforward\u001b[39m(\u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs):\n\u001b[0;32m--> 822\u001b[0m \u001b[39mreturn\u001b[39;00m model_forward(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n",
1566
- "File \u001b[0;32m~/.conda/envs/diffusers/lib/python3.9/site-packages/accelerate/utils/operations.py:810\u001b[0m, in \u001b[0;36mConvertOutputsToFp32.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 809\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__call__\u001b[39m(\u001b[39mself\u001b[39m, \u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs):\n\u001b[0;32m--> 810\u001b[0m \u001b[39mreturn\u001b[39;00m convert_to_fp32(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mmodel_forward(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs))\n",
1567
- "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/amp/autocast_mode.py:12\u001b[0m, in \u001b[0;36mautocast_decorator.<locals>.decorate_autocast\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[39m@functools\u001b[39m\u001b[39m.\u001b[39mwraps(func)\n\u001b[1;32m 10\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mdecorate_autocast\u001b[39m(\u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs):\n\u001b[1;32m 11\u001b[0m \u001b[39mwith\u001b[39;00m autocast_instance:\n\u001b[0;32m---> 12\u001b[0m \u001b[39mreturn\u001b[39;00m func(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n",
1568
- "Cell \u001b[0;32mIn[18], line 211\u001b[0m, in \u001b[0;36mContextUnet.forward\u001b[0;34m(self, x, timesteps, y)\u001b[0m\n\u001b[1;32m 209\u001b[0m \u001b[39mif\u001b[39;00m y \u001b[39m!=\u001b[39m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m 210\u001b[0m text_outputs \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtoken_embedding(y\u001b[39m.\u001b[39mfloat())\n\u001b[0;32m--> 211\u001b[0m emb \u001b[39m=\u001b[39m emb \u001b[39m+\u001b[39;49m text_outputs\u001b[39m.\u001b[39;49mto(emb)\n\u001b[1;32m 213\u001b[0m h \u001b[39m=\u001b[39m x\u001b[39m.\u001b[39mtype(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdtype)\n\u001b[1;32m 214\u001b[0m \u001b[39m# print(\"0,h.shape =\", h.shape)\u001b[39;00m\n",
1569
- "\u001b[0;31mRuntimeError\u001b[0m: The size of tensor a (24) must match the size of tensor b (36) at non-singleton dimension 0"
1570
  ]
1571
  }
1572
  ],
 
256
  " self.bar_alpha_t = torch.cumprod(self.alpha_t, dim=0)\n",
257
  "\n",
258
  " def add_noise(self, clean_images):\n",
259
+ " shape = clean_images.shape\n",
260
+ " expand = torch.ones(len(shape)-1, dtype=int)\n",
261
  " # ts_expand = ts.view(ts.shape[0], *expand.tolist())\n",
262
  " # expand = [1 for i in range(len(shape)-1)]\n",
263
  "\n",
264
  " noise = torch.randn_like(clean_images).to(self.device)\n",
265
+ " ts = torch.randint(0, self.num_timesteps, (shape[0],)).to(self.device)\n",
266
  " \n",
267
  " # test_expand = test.view(test.shape[0],*expand)\n",
268
  " # extend_dim = [None for i in range(shape.dim()-1)]\n",
269
  " noisy_images = (\n",
270
+ " clean_images * torch.sqrt(self.bar_alpha_t[ts]).view(shape[0], *expand.tolist())\n",
271
+ " + noise * torch.sqrt(1-self.bar_alpha_t[ts]).view(shape[0], *expand.tolist())\n",
272
  " )\n",
273
  " # print(x_t.shape)\n",
274
  "\n",
 
1249
  },
1250
  {
1251
  "cell_type": "code",
1252
+ "execution_count": 117,
1253
  "metadata": {},
1254
  "outputs": [
1255
  {
 
1263
  "loading 20 images randomly\n",
1264
  "images loaded: (20, 1, 64, 512)\n",
1265
  "params loaded: (20, 2)\n",
1266
+ "images rescaled to [-1.0, 0.8818775415420532]\n",
1267
+ "params rescaled to [0.0, 0.9965706900954632]\n",
1268
  "resumed nn_model from model_state.pth\n",
1269
  "Number of parameters for nn_model: 111048705\n"
1270
  ]
 
1422
  " print('saved model at ' + self.config.output_dir + f\"model_state_{ep:02d}.pth\")\n",
1423
  " # print('saved model at ' + config.save_dir + f\"model_epoch_{ep}_test_{config.run_name}.pth\")\n",
1424
  "\n",
1425
+ " def sample(self, file, params=torch.tensor((0.2,0.8)), ema=False, entire=False):\n",
1426
+ " # n_sample = params.shape[0]\n",
1427
+ " shape = (self.config.HII, self.config.num_redshift) if self.config.dim == 2 else (self.config.HII, self.config.HII, self.config.num_redshift)\n",
1428
  " model = self.ema_model if ema else self.nn_model\n",
1429
  " # params = torch.tile(params, (n_sample,1)).to(device)\n",
1430
  "\n",
1431
+ " x_last, x_entire = self.ddpm.sample(model, params=params, shape=shape, device=self.config.device, guide_w=self.config.guide_w)\n",
1432
  "\n",
1433
  " np.save(os.path.join(self.config.output_dir, f\"{self.config.run_name}{'ema' if ema else None}.npy\"), x_last)\n",
1434
  " if entire:\n",
 
1527
  },
1528
  "metadata": {},
1529
  "output_type": "display_data"
1530
+ },
 
 
 
 
 
 
 
 
 
 
1531
  {
1532
  "data": {
1533
  "application/vnd.jupyter.widget-view+json": {
1534
+ "model_id": "461da88dafc8437a858b6eed2042c709",
1535
  "version_major": 2,
1536
  "version_minor": 0
1537
  },
1538
  "text/plain": [
1539
+ " 0%| | 0/2 [00:00<?, ?it/s]"
1540
+ ]
1541
+ },
1542
+ "metadata": {},
1543
+ "output_type": "display_data"
1544
+ },
1545
+ {
1546
+ "data": {
1547
+ "application/vnd.jupyter.widget-view+json": {
1548
+ "model_id": "a85d41b905884ec6947eb1fb94f2f934",
1549
+ "version_major": 2,
1550
+ "version_minor": 0
1551
+ },
1552
+ "text/plain": [
1553
+ " 0%| | 0/2 [00:00<?, ?it/s]"
1554
+ ]
1555
+ },
1556
+ "metadata": {},
1557
+ "output_type": "display_data"
1558
+ },
1559
+ {
1560
+ "data": {
1561
+ "application/vnd.jupyter.widget-view+json": {
1562
+ "model_id": "f9677e0a3b6049d1a9877eec1689f506",
1563
+ "version_major": 2,
1564
+ "version_minor": 0
1565
+ },
1566
+ "text/plain": [
1567
+ " 0%| | 0/2 [00:00<?, ?it/s]"
1568
  ]
1569
  },
1570
  "metadata": {},
1571
  "output_type": "display_data"
1572
  },
1573
  {
1574
+ "data": {
1575
+ "application/vnd.jupyter.widget-view+json": {
1576
+ "model_id": "5bd75ea4440b4acbb54b55c14fec272b",
1577
+ "version_major": 2,
1578
+ "version_minor": 0
1579
+ },
1580
+ "text/plain": [
1581
+ " 0%| | 0/2 [00:00<?, ?it/s]"
1582
+ ]
1583
+ },
1584
+ "metadata": {},
1585
+ "output_type": "display_data"
1586
+ },
1587
+ {
1588
+ "name": "stdout",
1589
+ "output_type": "stream",
1590
+ "text": [
1591
+ "saved model at ./outputs/model_state_09.pth\n"
1592
+ ]
1593
+ }
1594
+ ],
1595
+ "source": [
1596
+ "ddpm21cm.train()"
1597
+ ]
1598
+ },
1599
+ {
1600
+ "cell_type": "code",
1601
+ "execution_count": 116,
1602
+ "metadata": {},
1603
+ "outputs": [
1604
+ {
1605
+ "name": "stdout",
1606
+ "output_type": "stream",
1607
+ "text": [
1608
+ "params.shape[0], len(params) 2 2\n"
1609
+ ]
1610
+ },
1611
+ {
1612
+ "ename": "AttributeError",
1613
+ "evalue": "'DDPMScheduler' object has no attribute 'shape'",
1614
  "output_type": "error",
1615
  "traceback": [
1616
  "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
1617
+ "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
1618
+ "Cell \u001b[0;32mIn[116], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m ddpm21cm\u001b[39m.\u001b[39;49msample(\u001b[39m\"\u001b[39;49m\u001b[39m./outputs/model_state_09.pth\u001b[39;49m\u001b[39m\"\u001b[39;49m)\n",
1619
+ "Cell \u001b[0;32mIn[115], line 143\u001b[0m, in \u001b[0;36mDDPM21CM.sample\u001b[0;34m(self, file, params, ema, entire)\u001b[0m\n\u001b[1;32m 140\u001b[0m model \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mema_model \u001b[39mif\u001b[39;00m ema \u001b[39melse\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mnn_model\n\u001b[1;32m 141\u001b[0m \u001b[39m# params = torch.tile(params, (n_sample,1)).to(device)\u001b[39;00m\n\u001b[0;32m--> 143\u001b[0m x_last, x_entire \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mddpm\u001b[39m.\u001b[39;49msample(model, params\u001b[39m=\u001b[39;49mparams, device\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mconfig\u001b[39m.\u001b[39;49mdevice, guide_w\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mconfig\u001b[39m.\u001b[39;49mguide_w)\n\u001b[1;32m 145\u001b[0m np\u001b[39m.\u001b[39msave(os\u001b[39m.\u001b[39mpath\u001b[39m.\u001b[39mjoin(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconfig\u001b[39m.\u001b[39moutput_dir, \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m{\u001b[39;00m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconfig\u001b[39m.\u001b[39mrun_name\u001b[39m}\u001b[39;00m\u001b[39m{\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39mema\u001b[39m\u001b[39m'\u001b[39m\u001b[39m \u001b[39m\u001b[39mif\u001b[39;00m\u001b[39m \u001b[39mema\u001b[39m \u001b[39m\u001b[39melse\u001b[39;00m\u001b[39m \u001b[39m\u001b[39mNone\u001b[39;00m\u001b[39m}\u001b[39;00m\u001b[39m.npy\u001b[39m\u001b[39m\"\u001b[39m), x_last)\n\u001b[1;32m 146\u001b[0m \u001b[39mif\u001b[39;00m entire:\n",
1620
+ "Cell \u001b[0;32mIn[90], line 40\u001b[0m, in \u001b[0;36mDDPMScheduler.sample\u001b[0;34m(self, nn_model, params, device, guide_w)\u001b[0m\n\u001b[1;32m 38\u001b[0m n_sample \u001b[39m=\u001b[39m params\u001b[39m.\u001b[39mshape[\u001b[39m0\u001b[39m]\n\u001b[1;32m 39\u001b[0m \u001b[39mprint\u001b[39m(\u001b[39m\"\u001b[39m\u001b[39mparams.shape[0], len(params)\u001b[39m\u001b[39m\"\u001b[39m, params\u001b[39m.\u001b[39mshape[\u001b[39m0\u001b[39m], \u001b[39mlen\u001b[39m(params))\n\u001b[0;32m---> 40\u001b[0m x_i \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mrandn(n_sample, \u001b[39m*\u001b[39m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mshape[\u001b[39m1\u001b[39m:])\u001b[39m.\u001b[39mto(device)\n\u001b[1;32m 41\u001b[0m \u001b[39m# print(\"x_i.shape =\", x_i.shape)\u001b[39;00m\n\u001b[1;32m 42\u001b[0m \u001b[39mif\u001b[39;00m guide_w \u001b[39m!=\u001b[39m \u001b[39m-\u001b[39m\u001b[39m1\u001b[39m:\n",
1621
+ "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/module.py:1207\u001b[0m, in \u001b[0;36mModule.__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m 1205\u001b[0m \u001b[39mif\u001b[39;00m name \u001b[39min\u001b[39;00m modules:\n\u001b[1;32m 1206\u001b[0m \u001b[39mreturn\u001b[39;00m modules[name]\n\u001b[0;32m-> 1207\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mAttributeError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39m'\u001b[39m\u001b[39m{}\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39m object has no attribute \u001b[39m\u001b[39m'\u001b[39m\u001b[39m{}\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m.\u001b[39mformat(\n\u001b[1;32m 1208\u001b[0m \u001b[39mtype\u001b[39m(\u001b[39mself\u001b[39m)\u001b[39m.\u001b[39m\u001b[39m__name__\u001b[39m, name))\n",
1622
+ "\u001b[0;31mAttributeError\u001b[0m: 'DDPMScheduler' object has no attribute 'shape'"
 
 
 
 
1623
  ]
1624
  }
1625
  ],