0521-1738
Browse files- diffusion.ipynb +89 -36
diffusion.ipynb
CHANGED
|
@@ -256,19 +256,19 @@
|
|
| 256 |
" self.bar_alpha_t = torch.cumprod(self.alpha_t, dim=0)\n",
|
| 257 |
"\n",
|
| 258 |
" def add_noise(self, clean_images):\n",
|
| 259 |
-
"
|
| 260 |
-
" expand = torch.ones(len(
|
| 261 |
" # ts_expand = ts.view(ts.shape[0], *expand.tolist())\n",
|
| 262 |
" # expand = [1 for i in range(len(shape)-1)]\n",
|
| 263 |
"\n",
|
| 264 |
" noise = torch.randn_like(clean_images).to(self.device)\n",
|
| 265 |
-
" ts = torch.randint(0, self.num_timesteps, (
|
| 266 |
" \n",
|
| 267 |
" # test_expand = test.view(test.shape[0],*expand)\n",
|
| 268 |
" # extend_dim = [None for i in range(shape.dim()-1)]\n",
|
| 269 |
" noisy_images = (\n",
|
| 270 |
-
" clean_images * torch.sqrt(self.bar_alpha_t[ts]).view(
|
| 271 |
-
" + noise * torch.sqrt(1-self.bar_alpha_t[ts]).view(
|
| 272 |
" )\n",
|
| 273 |
" # print(x_t.shape)\n",
|
| 274 |
"\n",
|
|
@@ -1249,7 +1249,7 @@
|
|
| 1249 |
},
|
| 1250 |
{
|
| 1251 |
"cell_type": "code",
|
| 1252 |
-
"execution_count":
|
| 1253 |
"metadata": {},
|
| 1254 |
"outputs": [
|
| 1255 |
{
|
|
@@ -1263,8 +1263,8 @@
|
|
| 1263 |
"loading 20 images randomly\n",
|
| 1264 |
"images loaded: (20, 1, 64, 512)\n",
|
| 1265 |
"params loaded: (20, 2)\n",
|
| 1266 |
-
"images rescaled to [-1.0, 0.
|
| 1267 |
-
"params rescaled to [0.0, 0.
|
| 1268 |
"resumed nn_model from model_state.pth\n",
|
| 1269 |
"Number of parameters for nn_model: 111048705\n"
|
| 1270 |
]
|
|
@@ -1422,12 +1422,13 @@
|
|
| 1422 |
" print('saved model at ' + self.config.output_dir + f\"model_state_{ep:02d}.pth\")\n",
|
| 1423 |
" # print('saved model at ' + config.save_dir + f\"model_epoch_{ep}_test_{config.run_name}.pth\")\n",
|
| 1424 |
"\n",
|
| 1425 |
-
" def sample(self, file, params=
|
| 1426 |
-
" n_sample = params.shape[0]\n",
|
|
|
|
| 1427 |
" model = self.ema_model if ema else self.nn_model\n",
|
| 1428 |
" # params = torch.tile(params, (n_sample,1)).to(device)\n",
|
| 1429 |
"\n",
|
| 1430 |
-
" x_last, x_entire = self.ddpm.sample(model,
|
| 1431 |
"\n",
|
| 1432 |
" np.save(os.path.join(self.config.output_dir, f\"{self.config.run_name}{'ema' if ema else None}.npy\"), x_last)\n",
|
| 1433 |
" if entire:\n",
|
|
@@ -1526,47 +1527,99 @@
|
|
| 1526 |
},
|
| 1527 |
"metadata": {},
|
| 1528 |
"output_type": "display_data"
|
| 1529 |
-
}
|
| 1530 |
-
],
|
| 1531 |
-
"source": [
|
| 1532 |
-
"ddpm21cm.train()"
|
| 1533 |
-
]
|
| 1534 |
-
},
|
| 1535 |
-
{
|
| 1536 |
-
"cell_type": "code",
|
| 1537 |
-
"execution_count": 69,
|
| 1538 |
-
"metadata": {},
|
| 1539 |
-
"outputs": [
|
| 1540 |
{
|
| 1541 |
"data": {
|
| 1542 |
"application/vnd.jupyter.widget-view+json": {
|
| 1543 |
-
"model_id": "
|
| 1544 |
"version_major": 2,
|
| 1545 |
"version_minor": 0
|
| 1546 |
},
|
| 1547 |
"text/plain": [
|
| 1548 |
-
" 0%| | 0/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1549 |
]
|
| 1550 |
},
|
| 1551 |
"metadata": {},
|
| 1552 |
"output_type": "display_data"
|
| 1553 |
},
|
| 1554 |
{
|
| 1555 |
-
"
|
| 1556 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1557 |
"output_type": "error",
|
| 1558 |
"traceback": [
|
| 1559 |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
| 1560 |
-
"\u001b[0;
|
| 1561 |
-
"Cell \u001b[0;32mIn[
|
| 1562 |
-
"Cell \u001b[0;32mIn[
|
| 1563 |
-
"Cell \u001b[0;32mIn[
|
| 1564 |
-
"File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/module.py:
|
| 1565 |
-
"
|
| 1566 |
-
"File \u001b[0;32m~/.conda/envs/diffusers/lib/python3.9/site-packages/accelerate/utils/operations.py:810\u001b[0m, in \u001b[0;36mConvertOutputsToFp32.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 809\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__call__\u001b[39m(\u001b[39mself\u001b[39m, \u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs):\n\u001b[0;32m--> 810\u001b[0m \u001b[39mreturn\u001b[39;00m convert_to_fp32(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mmodel_forward(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs))\n",
|
| 1567 |
-
"File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/amp/autocast_mode.py:12\u001b[0m, in \u001b[0;36mautocast_decorator.<locals>.decorate_autocast\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[39m@functools\u001b[39m\u001b[39m.\u001b[39mwraps(func)\n\u001b[1;32m 10\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mdecorate_autocast\u001b[39m(\u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs):\n\u001b[1;32m 11\u001b[0m \u001b[39mwith\u001b[39;00m autocast_instance:\n\u001b[0;32m---> 12\u001b[0m \u001b[39mreturn\u001b[39;00m func(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n",
|
| 1568 |
-
"Cell \u001b[0;32mIn[18], line 211\u001b[0m, in \u001b[0;36mContextUnet.forward\u001b[0;34m(self, x, timesteps, y)\u001b[0m\n\u001b[1;32m 209\u001b[0m \u001b[39mif\u001b[39;00m y \u001b[39m!=\u001b[39m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m 210\u001b[0m text_outputs \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtoken_embedding(y\u001b[39m.\u001b[39mfloat())\n\u001b[0;32m--> 211\u001b[0m emb \u001b[39m=\u001b[39m emb \u001b[39m+\u001b[39;49m text_outputs\u001b[39m.\u001b[39;49mto(emb)\n\u001b[1;32m 213\u001b[0m h \u001b[39m=\u001b[39m x\u001b[39m.\u001b[39mtype(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdtype)\n\u001b[1;32m 214\u001b[0m \u001b[39m# print(\"0,h.shape =\", h.shape)\u001b[39;00m\n",
|
| 1569 |
-
"\u001b[0;31mRuntimeError\u001b[0m: The size of tensor a (24) must match the size of tensor b (36) at non-singleton dimension 0"
|
| 1570 |
]
|
| 1571 |
}
|
| 1572 |
],
|
|
|
|
| 256 |
" self.bar_alpha_t = torch.cumprod(self.alpha_t, dim=0)\n",
|
| 257 |
"\n",
|
| 258 |
" def add_noise(self, clean_images):\n",
|
| 259 |
+
" shape = clean_images.shape\n",
|
| 260 |
+
" expand = torch.ones(len(shape)-1, dtype=int)\n",
|
| 261 |
" # ts_expand = ts.view(ts.shape[0], *expand.tolist())\n",
|
| 262 |
" # expand = [1 for i in range(len(shape)-1)]\n",
|
| 263 |
"\n",
|
| 264 |
" noise = torch.randn_like(clean_images).to(self.device)\n",
|
| 265 |
+
" ts = torch.randint(0, self.num_timesteps, (shape[0],)).to(self.device)\n",
|
| 266 |
" \n",
|
| 267 |
" # test_expand = test.view(test.shape[0],*expand)\n",
|
| 268 |
" # extend_dim = [None for i in range(shape.dim()-1)]\n",
|
| 269 |
" noisy_images = (\n",
|
| 270 |
+
" clean_images * torch.sqrt(self.bar_alpha_t[ts]).view(shape[0], *expand.tolist())\n",
|
| 271 |
+
" + noise * torch.sqrt(1-self.bar_alpha_t[ts]).view(shape[0], *expand.tolist())\n",
|
| 272 |
" )\n",
|
| 273 |
" # print(x_t.shape)\n",
|
| 274 |
"\n",
|
|
|
|
| 1249 |
},
|
| 1250 |
{
|
| 1251 |
"cell_type": "code",
|
| 1252 |
+
"execution_count": 117,
|
| 1253 |
"metadata": {},
|
| 1254 |
"outputs": [
|
| 1255 |
{
|
|
|
|
| 1263 |
"loading 20 images randomly\n",
|
| 1264 |
"images loaded: (20, 1, 64, 512)\n",
|
| 1265 |
"params loaded: (20, 2)\n",
|
| 1266 |
+
"images rescaled to [-1.0, 0.8818775415420532]\n",
|
| 1267 |
+
"params rescaled to [0.0, 0.9965706900954632]\n",
|
| 1268 |
"resumed nn_model from model_state.pth\n",
|
| 1269 |
"Number of parameters for nn_model: 111048705\n"
|
| 1270 |
]
|
|
|
|
| 1422 |
" print('saved model at ' + self.config.output_dir + f\"model_state_{ep:02d}.pth\")\n",
|
| 1423 |
" # print('saved model at ' + config.save_dir + f\"model_epoch_{ep}_test_{config.run_name}.pth\")\n",
|
| 1424 |
"\n",
|
| 1425 |
+
" def sample(self, file, params=torch.tensor((0.2,0.8)), ema=False, entire=False):\n",
|
| 1426 |
+
" # n_sample = params.shape[0]\n",
|
| 1427 |
+
" shape = (self.config.HII, self.config.num_redshift) if self.config.dim == 2 else (self.config.HII, self.config.HII, self.config.num_redshift)\n",
|
| 1428 |
" model = self.ema_model if ema else self.nn_model\n",
|
| 1429 |
" # params = torch.tile(params, (n_sample,1)).to(device)\n",
|
| 1430 |
"\n",
|
| 1431 |
+
" x_last, x_entire = self.ddpm.sample(model, params=params, shape=shape, device=self.config.device, guide_w=self.config.guide_w)\n",
|
| 1432 |
"\n",
|
| 1433 |
" np.save(os.path.join(self.config.output_dir, f\"{self.config.run_name}{'ema' if ema else None}.npy\"), x_last)\n",
|
| 1434 |
" if entire:\n",
|
|
|
|
| 1527 |
},
|
| 1528 |
"metadata": {},
|
| 1529 |
"output_type": "display_data"
|
| 1530 |
+
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1531 |
{
|
| 1532 |
"data": {
|
| 1533 |
"application/vnd.jupyter.widget-view+json": {
|
| 1534 |
+
"model_id": "461da88dafc8437a858b6eed2042c709",
|
| 1535 |
"version_major": 2,
|
| 1536 |
"version_minor": 0
|
| 1537 |
},
|
| 1538 |
"text/plain": [
|
| 1539 |
+
" 0%| | 0/2 [00:00<?, ?it/s]"
|
| 1540 |
+
]
|
| 1541 |
+
},
|
| 1542 |
+
"metadata": {},
|
| 1543 |
+
"output_type": "display_data"
|
| 1544 |
+
},
|
| 1545 |
+
{
|
| 1546 |
+
"data": {
|
| 1547 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 1548 |
+
"model_id": "a85d41b905884ec6947eb1fb94f2f934",
|
| 1549 |
+
"version_major": 2,
|
| 1550 |
+
"version_minor": 0
|
| 1551 |
+
},
|
| 1552 |
+
"text/plain": [
|
| 1553 |
+
" 0%| | 0/2 [00:00<?, ?it/s]"
|
| 1554 |
+
]
|
| 1555 |
+
},
|
| 1556 |
+
"metadata": {},
|
| 1557 |
+
"output_type": "display_data"
|
| 1558 |
+
},
|
| 1559 |
+
{
|
| 1560 |
+
"data": {
|
| 1561 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 1562 |
+
"model_id": "f9677e0a3b6049d1a9877eec1689f506",
|
| 1563 |
+
"version_major": 2,
|
| 1564 |
+
"version_minor": 0
|
| 1565 |
+
},
|
| 1566 |
+
"text/plain": [
|
| 1567 |
+
" 0%| | 0/2 [00:00<?, ?it/s]"
|
| 1568 |
]
|
| 1569 |
},
|
| 1570 |
"metadata": {},
|
| 1571 |
"output_type": "display_data"
|
| 1572 |
},
|
| 1573 |
{
|
| 1574 |
+
"data": {
|
| 1575 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 1576 |
+
"model_id": "5bd75ea4440b4acbb54b55c14fec272b",
|
| 1577 |
+
"version_major": 2,
|
| 1578 |
+
"version_minor": 0
|
| 1579 |
+
},
|
| 1580 |
+
"text/plain": [
|
| 1581 |
+
" 0%| | 0/2 [00:00<?, ?it/s]"
|
| 1582 |
+
]
|
| 1583 |
+
},
|
| 1584 |
+
"metadata": {},
|
| 1585 |
+
"output_type": "display_data"
|
| 1586 |
+
},
|
| 1587 |
+
{
|
| 1588 |
+
"name": "stdout",
|
| 1589 |
+
"output_type": "stream",
|
| 1590 |
+
"text": [
|
| 1591 |
+
"saved model at ./outputs/model_state_09.pth\n"
|
| 1592 |
+
]
|
| 1593 |
+
}
|
| 1594 |
+
],
|
| 1595 |
+
"source": [
|
| 1596 |
+
"ddpm21cm.train()"
|
| 1597 |
+
]
|
| 1598 |
+
},
|
| 1599 |
+
{
|
| 1600 |
+
"cell_type": "code",
|
| 1601 |
+
"execution_count": 116,
|
| 1602 |
+
"metadata": {},
|
| 1603 |
+
"outputs": [
|
| 1604 |
+
{
|
| 1605 |
+
"name": "stdout",
|
| 1606 |
+
"output_type": "stream",
|
| 1607 |
+
"text": [
|
| 1608 |
+
"params.shape[0], len(params) 2 2\n"
|
| 1609 |
+
]
|
| 1610 |
+
},
|
| 1611 |
+
{
|
| 1612 |
+
"ename": "AttributeError",
|
| 1613 |
+
"evalue": "'DDPMScheduler' object has no attribute 'shape'",
|
| 1614 |
"output_type": "error",
|
| 1615 |
"traceback": [
|
| 1616 |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
| 1617 |
+
"\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
|
| 1618 |
+
"Cell \u001b[0;32mIn[116], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m ddpm21cm\u001b[39m.\u001b[39;49msample(\u001b[39m\"\u001b[39;49m\u001b[39m./outputs/model_state_09.pth\u001b[39;49m\u001b[39m\"\u001b[39;49m)\n",
|
| 1619 |
+
"Cell \u001b[0;32mIn[115], line 143\u001b[0m, in \u001b[0;36mDDPM21CM.sample\u001b[0;34m(self, file, params, ema, entire)\u001b[0m\n\u001b[1;32m 140\u001b[0m model \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mema_model \u001b[39mif\u001b[39;00m ema \u001b[39melse\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mnn_model\n\u001b[1;32m 141\u001b[0m \u001b[39m# params = torch.tile(params, (n_sample,1)).to(device)\u001b[39;00m\n\u001b[0;32m--> 143\u001b[0m x_last, x_entire \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mddpm\u001b[39m.\u001b[39;49msample(model, params\u001b[39m=\u001b[39;49mparams, device\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mconfig\u001b[39m.\u001b[39;49mdevice, guide_w\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mconfig\u001b[39m.\u001b[39;49mguide_w)\n\u001b[1;32m 145\u001b[0m np\u001b[39m.\u001b[39msave(os\u001b[39m.\u001b[39mpath\u001b[39m.\u001b[39mjoin(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconfig\u001b[39m.\u001b[39moutput_dir, \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m{\u001b[39;00m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconfig\u001b[39m.\u001b[39mrun_name\u001b[39m}\u001b[39;00m\u001b[39m{\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39mema\u001b[39m\u001b[39m'\u001b[39m\u001b[39m \u001b[39m\u001b[39mif\u001b[39;00m\u001b[39m \u001b[39mema\u001b[39m \u001b[39m\u001b[39melse\u001b[39;00m\u001b[39m \u001b[39m\u001b[39mNone\u001b[39;00m\u001b[39m}\u001b[39;00m\u001b[39m.npy\u001b[39m\u001b[39m\"\u001b[39m), x_last)\n\u001b[1;32m 146\u001b[0m \u001b[39mif\u001b[39;00m entire:\n",
|
| 1620 |
+
"Cell \u001b[0;32mIn[90], line 40\u001b[0m, in \u001b[0;36mDDPMScheduler.sample\u001b[0;34m(self, nn_model, params, device, guide_w)\u001b[0m\n\u001b[1;32m 38\u001b[0m n_sample \u001b[39m=\u001b[39m params\u001b[39m.\u001b[39mshape[\u001b[39m0\u001b[39m]\n\u001b[1;32m 39\u001b[0m \u001b[39mprint\u001b[39m(\u001b[39m\"\u001b[39m\u001b[39mparams.shape[0], len(params)\u001b[39m\u001b[39m\"\u001b[39m, params\u001b[39m.\u001b[39mshape[\u001b[39m0\u001b[39m], \u001b[39mlen\u001b[39m(params))\n\u001b[0;32m---> 40\u001b[0m x_i \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mrandn(n_sample, \u001b[39m*\u001b[39m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mshape[\u001b[39m1\u001b[39m:])\u001b[39m.\u001b[39mto(device)\n\u001b[1;32m 41\u001b[0m \u001b[39m# print(\"x_i.shape =\", x_i.shape)\u001b[39;00m\n\u001b[1;32m 42\u001b[0m \u001b[39mif\u001b[39;00m guide_w \u001b[39m!=\u001b[39m \u001b[39m-\u001b[39m\u001b[39m1\u001b[39m:\n",
|
| 1621 |
+
"File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/module.py:1207\u001b[0m, in \u001b[0;36mModule.__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m 1205\u001b[0m \u001b[39mif\u001b[39;00m name \u001b[39min\u001b[39;00m modules:\n\u001b[1;32m 1206\u001b[0m \u001b[39mreturn\u001b[39;00m modules[name]\n\u001b[0;32m-> 1207\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mAttributeError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39m'\u001b[39m\u001b[39m{}\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39m object has no attribute \u001b[39m\u001b[39m'\u001b[39m\u001b[39m{}\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m.\u001b[39mformat(\n\u001b[1;32m 1208\u001b[0m \u001b[39mtype\u001b[39m(\u001b[39mself\u001b[39m)\u001b[39m.\u001b[39m\u001b[39m__name__\u001b[39m, name))\n",
|
| 1622 |
+
"\u001b[0;31mAttributeError\u001b[0m: 'DDPMScheduler' object has no attribute 'shape'"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1623 |
]
|
| 1624 |
}
|
| 1625 |
],
|