diff --git "a/diffusion.ipynb" "b/diffusion.ipynb" --- "a/diffusion.ipynb" +++ "b/diffusion.ipynb" @@ -67,6 +67,30 @@ "from huggingface_hub import notebook_login" ] }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "b7a51a73994a43178f1baa379210c892", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "VBox(children=(HTML(value='
\n", "51200 images can be loaded\n", "field.shape = (64, 64, 514)\n", "params keys = [b'ION_Tvir_MIN', b'HII_EFF_FACTOR']\n", - "loading 1600 images randomly\n", - "images loaded: (1600, 1, 64, 64)\n" + "loading 4 images randomly\n" ] }, { @@ -594,20 +593,21 @@ "name": "stdout", "output_type": "stream", "text": [ - "params loaded: (1600, 2)\n", - "images rescaled to [-1.0, 1.119088888168335]\n", - "params rescaled to [0.0001848356100438764, 0.9999958400767995]\n" + "images loaded: (4, 1, 64, 64, 64)\n", + "params loaded: (4, 2)\n", + "images rescaled to [-1.0, 1.049072027206421]\n", + "params rescaled to [0.02179058530500466, 0.9278468466439764]\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "fe5b9721d5d743e4a7628ad9232bf72d", + "model_id": "2ecbe26474ef446498cae235911aeb71", "version_major": 2, "version_minor": 0 }, "text/plain": [ - " 0%| | 0/32 [00:00\n", - "51200 images can be loaded\n", - "field.shape = (64, 64, 514)\n", - "params keys = [b'ION_Tvir_MIN', b'HII_EFF_FACTOR']\n", - "loading 3200 images randomly\n", - "images loaded: (3200, 1, 64, 64)\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n" + "total 4.4G\n", + "-rw-r--r-- 1 bxia34 13M Jul 2 21:45 Tvir4.800000190734863-zeta131.34100341796875-N1600.npy\n", + "-rw-r--r-- 1 bxia34 13M Jul 2 21:26 Tvir5.4770002365112305-zeta200.0-N1600.npy\n", + "-rw-r--r-- 1 bxia34 13M Jul 2 21:08 Tvir4.698999881744385-zeta30.0-N1600.npy\n", + "-rw-r--r-- 1 bxia34 13M Jul 2 20:49 Tvir5.599999904632568-zeta19.03700065612793-N1600.npy\n", + "-rw-r--r-- 1 bxia34 13M Jul 2 20:31 Tvir4.400000095367432-zeta131.34100341796875-N1600.npy\n", + "-rw-r--r-- 1 bxia34 848M Jul 2 20:13 model_state-N25600\n", + "drwxr-xr-x 15 bxia34 4.0K Jul 2 19:09 \u001b[0m\u001b[01;34mlogs\u001b[0m/\n", + "-rw-r--r-- 1 bxia34 848M Jul 2 18:45 model_state-N12800\n", + "-rw-r--r-- 1 bxia34 848M Jul 2 18:01 model_state-N6400\n", + "-rw-r--r-- 1 bxia34 848M Jul 2 17:37 model_state-N3200\n", + "-rw-r--r-- 1 bxia34 848M Jul 2 17:25 model_state-N1600\n", + "-rw-r--r-- 1 bxia34 3.1M Jul 1 12:31 Tvir4.800000190734863-zeta131.34100341796875-N2000.npy\n", + "-rw-r--r-- 1 bxia34 3.1M Jul 1 12:12 Tvir5.4770002365112305-zeta200.0-N2000.npy\n", + "-rw-r--r-- 1 bxia34 3.1M Jul 1 11:54 Tvir4.698999881744385-zeta30.0-N2000.npy\n", + "-rw-r--r-- 1 bxia34 3.1M Jul 1 11:35 Tvir5.599999904632568-zeta19.03700065612793-N2000.npy\n", + "-rw-r--r-- 1 bxia34 3.1M Jul 1 11:17 Tvir4.400000095367432-zeta131.34100341796875-N2000.npy\n", + "-rw-r--r-- 1 bxia34 3.1M Jul 1 10:37 Tvir4.800000190734863-zeta131.34100341796875-N32000.npy\n", + "-rw-r--r-- 1 bxia34 3.1M Jul 1 04:47 Tvir5.4770002365112305-zeta200.0-N32000.npy\n", + "-rw-r--r-- 1 bxia34 3.1M Jul 1 04:29 Tvir4.698999881744385-zeta30.0-N32000.npy\n", + "-rw-r--r-- 1 bxia34 3.1M Jul 1 04:11 Tvir5.599999904632568-zeta19.03700065612793-N32000.npy\n", + "-rw-r--r-- 1 bxia34 3.1M Jul 1 03:53 Tvir4.400000095367432-zeta131.34100341796875-N32000.npy\n", + "-rw-r--r-- 1 bxia34 3.1M Jul 1 00:23 Tvir4.800000190734863-zeta131.34100341796875-N20000.npy\n", + "-rw-r--r-- 1 bxia34 3.1M Jul 1 00:05 Tvir5.4770002365112305-zeta200.0-N20000.npy\n", + "-rw-r--r-- 1 bxia34 3.1M Jun 30 23:47 Tvir4.698999881744385-zeta30.0-N20000.npy\n", + "-rw-r--r-- 1 bxia34 3.1M Jun 30 23:29 Tvir5.599999904632568-zeta19.03700065612793-N20000.npy\n", + "-rw-r--r-- 1 bxia34 3.1M Jun 30 23:11 Tvir4.400000095367432-zeta131.34100341796875-N20000.npy\n", + "-rw-r--r-- 1 bxia34 3.1M Jun 30 20:08 Tvir4.800000190734863-zeta131.34100341796875-N15000.npy\n", + "-rw-r--r-- 1 bxia34 3.1M Jun 30 19:50 Tvir5.4770002365112305-zeta200.0-N15000.npy\n", + "-rw-r--r-- 1 bxia34 3.1M Jun 30 19:32 Tvir4.698999881744385-zeta30.0-N15000.npy\n", + "-rw-r--r-- 1 bxia34 3.1M Jun 30 19:14 Tvir5.599999904632568-zeta19.03700065612793-N15000.npy\n", + "-rw-r--r-- 1 bxia34 3.1M Jun 30 18:57 Tvir4.400000095367432-zeta131.34100341796875-N15000.npy\n", + "-rw-r--r-- 1 bxia34 3.1M Jun 29 12:41 Tvir4.800000190734863-zeta131.34100341796875-N7000.npy\n", + "-rw-r--r-- 1 bxia34 3.1M Jun 29 12:23 Tvir5.4770002365112305-zeta200.0-N7000.npy\n", + "-rw-r--r-- 1 bxia34 3.1M Jun 29 12:06 Tvir4.698999881744385-zeta30.0-N7000.npy\n", + "-rw-r--r-- 1 bxia34 3.1M Jun 29 11:48 Tvir5.599999904632568-zeta19.03700065612793-N7000.npy\n", + "-rw-r--r-- 1 bxia34 3.1M Jun 29 11:30 Tvir4.400000095367432-zeta131.34100341796875-N7000.npy\n", + "-rw-r--r-- 1 bxia34 3.1M Jun 29 04:56 Tvir4.800000190734863-zeta131.34100341796875-N25600.npy\n", + "-rw-r--r-- 1 bxia34 3.1M Jun 29 04:38 Tvir5.4770002365112305-zeta200.0-N25600.npy\n", + "-rw-r--r-- 1 bxia34 3.1M Jun 29 04:21 Tvir4.698999881744385-zeta30.0-N25600.npy\n", + "-rw-r--r-- 1 bxia34 3.1M Jun 29 04:03 Tvir5.599999904632568-zeta19.03700065612793-N25600.npy\n", + "-rw-r--r-- 1 bxia34 3.1M Jun 29 03:45 Tvir4.400000095367432-zeta131.34100341796875-N25600.npy\n", + "-rw-r--r-- 1 bxia34 3.1M Jun 29 00:35 Tvir4.800000190734863-zeta131.34100341796875-N3000.npy\n", + "-rw-r--r-- 1 bxia34 3.1M Jun 29 00:17 Tvir5.4770002365112305-zeta200.0-N3000.npy\n", + "-rw-r--r-- 1 bxia34 3.1M Jun 28 23:59 Tvir4.698999881744385-zeta30.0-N3000.npy\n", + "-rw-r--r-- 1 bxia34 3.1M Jun 28 23:42 Tvir5.599999904632568-zeta19.03700065612793-N3000.npy\n", + "-rw-r--r-- 1 bxia34 3.1M Jun 28 23:20 Tvir4.400000095367432-zeta131.34100341796875-N3000.npy\n", + "-rw-r--r-- 1 bxia34 3.1M Jun 28 21:06 Tvir4.800000190734863-zeta131.34100341796875-N10000.npy\n", + "-rw-r--r-- 1 bxia34 3.1M Jun 28 20:49 Tvir5.4770002365112305-zeta200.0-N10000.npy\n", + "-rw-r--r-- 1 bxia34 3.1M Jun 28 20:31 Tvir4.698999881744385-zeta30.0-N10000.npy\n", + "-rw-r--r-- 1 bxia34 3.1M Jun 28 20:13 Tvir5.599999904632568-zeta19.03700065612793-N10000.npy\n", + "-rw-r--r-- 1 bxia34 3.1M Jun 28 19:56 Tvir4.400000095367432-zeta131.34100341796875-N10000.npy\n", + "-rw-r--r-- 1 bxia34 3.1M Jun 28 18:30 Tvir4.800000190734863-zeta131.34100341796875-N1000.npy\n", + "-rw-r--r-- 1 bxia34 3.1M Jun 28 18:13 Tvir5.4770002365112305-zeta200.0-N1000.npy\n", + "-rw-r--r-- 1 bxia34 3.1M Jun 28 17:55 Tvir4.698999881744385-zeta30.0-N1000.npy\n", + "-rw-r--r-- 1 bxia34 3.1M Jun 28 17:37 Tvir5.599999904632568-zeta19.03700065612793-N1000.npy\n", + "-rw-r--r-- 1 bxia34 3.1M Jun 28 17:20 Tvir4.400000095367432-zeta131.34100341796875-N1000.npy\n", + "-rw-r--r-- 1 bxia34 3.1M Jun 28 14:03 Tvir4.400000095367432-zeta131.34100341796875-N5000.npy\n", + "-rw-r--r-- 1 bxia34 3.1M Jun 10 18:58 Tvir4.800000190734863-zeta131.34100341796875-N5000.npy\n", + "-rw-r--r-- 1 bxia34 3.1M Jun 10 18:40 Tvir5.4770002365112305-zeta200.0-N5000.npy\n", + "-rw-r--r-- 1 bxia34 3.1M Jun 10 18:22 Tvir4.698999881744385-zeta30.0-N5000.npy\n", + "-rw-r--r-- 1 bxia34 3.1M Jun 10 18:05 Tvir5.599999904632568-zeta19.03700065612793-N5000.npy\n" ] - }, + } + ], + "source": [ + "ll -lth outputs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "params loaded: (3200, 2)\n", - "images rescaled to [-1.0, 1.1910929679870605]\n", - "params rescaled to [0.0002077208516410245, 0.999968750248265]\n" + "Number of parameters for nn_model: 111048705\n", + "sampling 800 images with normalized params = tensor([[0.2000, 0.5056]])\n", + "nn_model resumed from ./outputs/model_state-N3200\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "356d54e499ff46548303ff8d0a0b9520", + "model_id": "2d3427d677774c9785ee081b3b3b5542", "version_major": 2, "version_minor": 0 }, "text/plain": [ - " 0%| | 0/64 [00:00\n", - "51200 images can be loaded\n", - "field.shape = (64, 64, 514)\n", - "params keys = [b'ION_Tvir_MIN', b'HII_EFF_FACTOR']\n", - "loading 6400 images randomly\n", - "images loaded: (6400, 1, 64, 64)\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "params loaded: (6400, 2)\n", - "images rescaled to [-1.0, 1.1910929679870605]\n", - "params rescaled to [0.0001702067256199591, 0.9998461462686923]\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "7dcbe3c1221443a0910da2952ba6b58d", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/128 [00:00\n", - "51200 images can be loaded\n", - "field.shape = (64, 64, 514)\n", - "params keys = [b'ION_Tvir_MIN', b'HII_EFF_FACTOR']\n", - "loading 12800 images randomly\n", - "images loaded: (12800, 1, 64, 64)\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "params loaded: (12800, 2)\n", - "images rescaled to [-1.0, 1.2380504608154297]\n", - "params rescaled to [3.8339325512049e-05, 0.9999958400767995]\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "7a44e73529d34fcdbc790b622ce0af6a", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/256 [00:00\n", - "51200 images can be loaded\n", - "field.shape = (64, 64, 514)\n", - "params keys = [b'ION_Tvir_MIN', b'HII_EFF_FACTOR']\n", - "loading 25600 images randomly\n", - "images loaded: (25600, 1, 64, 64)\n", - "params loaded: (25600, 2)\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "images rescaled to [-1.0, 1.2380504608154297]\n", - "params rescaled to [3.349517406547875e-06, 0.9999922179553216]\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "488a6a64228f4ca487199a34fe14f334", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/512 [00:00 9\u001b[0m ddpm21cm\u001b[39m.\u001b[39;49msample(\u001b[39mf\u001b[39;49m\u001b[39m\"\u001b[39;49m\u001b[39m./outputs/model_state-N\u001b[39;49m\u001b[39m{\u001b[39;49;00mnum_image\u001b[39m}\u001b[39;49;00m\u001b[39m\"\u001b[39;49m, params\u001b[39m=\u001b[39;49mtorch\u001b[39m.\u001b[39;49mtensor([\u001b[39m4.4\u001b[39;49m, \u001b[39m131.341\u001b[39;49m]), repeat\u001b[39m=\u001b[39;49mrepeat)\n\u001b[1;32m 11\u001b[0m ddpm21cm\u001b[39m.\u001b[39msample(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m./outputs/model_state-N\u001b[39m\u001b[39m{\u001b[39;00mnum_image\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m, params\u001b[39m=\u001b[39mtorch\u001b[39m.\u001b[39mtensor((\u001b[39m5.6\u001b[39m, \u001b[39m19.037\u001b[39m)), repeat\u001b[39m=\u001b[39mrepeat)\n\u001b[1;32m 13\u001b[0m \u001b[39m# ddpm21cm.sample(f\"./outputs/model_state-N{num_image}\", params=torch.tensor((4.699, 30)), repeat=repeat)\u001b[39;00m\n\u001b[1;32m 14\u001b[0m \n\u001b[1;32m 15\u001b[0m \u001b[39m# ddpm21cm.sample(f\"./outputs/model_state-N{num_image}\", params=torch.tensor((5.477, 200)), repeat=repeat)\u001b[39;00m\n\u001b[1;32m 16\u001b[0m \n\u001b[1;32m 17\u001b[0m \u001b[39m# ddpm21cm.sample(f\"./outputs/model_state-N{num_image}\", params=torch.tensor((4.8, 131.341)), repeat=repeat)\u001b[39;00m\n", - "Cell \u001b[0;32mIn[5], line 207\u001b[0m, in \u001b[0;36mDDPM21CM.sample\u001b[0;34m(self, file, params, repeat, ema, entire)\u001b[0m\n\u001b[1;32m 202\u001b[0m \u001b[39m# self.ema_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM, dim=config.dim, stride=config.stride).to(config.device)\u001b[39;00m\n\u001b[1;32m 203\u001b[0m \u001b[39m# self.ema_model.load_state_dict(torch.load(os.path.join(config.output_dir, f\"{config.resume}\"))['ema_unet_state_dict'])\u001b[39;00m\n\u001b[1;32m 204\u001b[0m \u001b[39m# print(f\"resumed ema_model from {config.resume}\")\u001b[39;00m\n\u001b[1;32m 206\u001b[0m \u001b[39mwith\u001b[39;00m torch\u001b[39m.\u001b[39mno_grad():\n\u001b[0;32m--> 207\u001b[0m x_last, x_entire \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mddpm\u001b[39m.\u001b[39;49msample(\n\u001b[1;32m 208\u001b[0m nn_model\u001b[39m=\u001b[39;49mnn_model, \n\u001b[1;32m 209\u001b[0m params\u001b[39m=\u001b[39;49mparams\u001b[39m.\u001b[39;49mto(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mconfig\u001b[39m.\u001b[39;49mdevice), \n\u001b[1;32m 210\u001b[0m device\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mconfig\u001b[39m.\u001b[39;49mdevice, \n\u001b[1;32m 211\u001b[0m guide_w\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mconfig\u001b[39m.\u001b[39;49mguide_w\n\u001b[1;32m 212\u001b[0m )\n\u001b[1;32m 214\u001b[0m \u001b[39m# np.save(os.path.join(self.config.output_dir, f\"{self.config.run_name}{'ema' if ema else ''}.npy\"), x_last)\u001b[39;00m\n\u001b[1;32m 215\u001b[0m np\u001b[39m.\u001b[39msave(os\u001b[39m.\u001b[39mpath\u001b[39m.\u001b[39mjoin(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconfig\u001b[39m.\u001b[39moutput_dir, \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mTvir\u001b[39m\u001b[39m{\u001b[39;00mparams_backup[\u001b[39m0\u001b[39m]\u001b[39m}\u001b[39;00m\u001b[39m-zeta\u001b[39m\u001b[39m{\u001b[39;00mparams_backup[\u001b[39m1\u001b[39m]\u001b[39m}\u001b[39;00m\u001b[39m-N\u001b[39m\u001b[39m{\u001b[39;00m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconfig\u001b[39m.\u001b[39mnum_image\u001b[39m}\u001b[39;00m\u001b[39m{\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39mema\u001b[39m\u001b[39m'\u001b[39m\u001b[39m \u001b[39m\u001b[39mif\u001b[39;00m\u001b[39m \u001b[39mema\u001b[39m \u001b[39m\u001b[39melse\u001b[39;00m\u001b[39m \u001b[39m\u001b[39m'\u001b[39m\u001b[39m'\u001b[39m\u001b[39m}\u001b[39;00m\u001b[39m.npy\u001b[39m\u001b[39m\"\u001b[39m), x_last)\n", - "Cell \u001b[0;32mIn[2], line 59\u001b[0m, in \u001b[0;36mDDPMScheduler.sample\u001b[0;34m(self, nn_model, params, device, guide_w)\u001b[0m\n\u001b[1;32m 56\u001b[0m pbar_sample\u001b[39m.\u001b[39mset_description(\u001b[39m\"\u001b[39m\u001b[39mSampling\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 57\u001b[0m \u001b[39mfor\u001b[39;00m i \u001b[39min\u001b[39;00m \u001b[39mreversed\u001b[39m(\u001b[39mrange\u001b[39m(\u001b[39m0\u001b[39m, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mnum_timesteps)):\n\u001b[1;32m 58\u001b[0m \u001b[39m# print(f'sampling timestep {i:4d}',end='\\r')\u001b[39;00m\n\u001b[0;32m---> 59\u001b[0m t_is \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39;49mtensor([i])\u001b[39m.\u001b[39;49mto(device)\n\u001b[1;32m 60\u001b[0m t_is \u001b[39m=\u001b[39m t_is\u001b[39m.\u001b[39mrepeat(n_sample)\n\u001b[1;32m 62\u001b[0m z \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mrandn(n_sample, \u001b[39m*\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mimg_shape)\u001b[39m.\u001b[39mto(device) \u001b[39mif\u001b[39;00m i \u001b[39m>\u001b[39m \u001b[39m0\u001b[39m \u001b[39melse\u001b[39;00m \u001b[39m0\u001b[39m\n", - "\u001b[0;31mKeyboardInterrupt\u001b[0m: " - ] } ], "source": [ "if __name__ == \"__main__\":\n", " num_image_list = [1600,3200,6400,12800,25600]\n", + " # num_image_list = [3200,6400,12800,25600]\n", " # args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)\n", " repeat = 800\n", " config = TrainConfig()\n",