0521-2013
Browse files- diffusion.ipynb +105 -88
diffusion.ipynb
CHANGED
|
@@ -32,7 +32,7 @@
|
|
| 32 |
{
|
| 33 |
"data": {
|
| 34 |
"application/vnd.jupyter.widget-view+json": {
|
| 35 |
-
"model_id": "
|
| 36 |
"version_major": 2,
|
| 37 |
"version_minor": 0
|
| 38 |
},
|
|
@@ -234,18 +234,19 @@
|
|
| 234 |
},
|
| 235 |
{
|
| 236 |
"cell_type": "code",
|
| 237 |
-
"execution_count":
|
| 238 |
"metadata": {},
|
| 239 |
"outputs": [],
|
| 240 |
"source": [
|
| 241 |
"class DDPMScheduler(nn.Module):\n",
|
| 242 |
-
" def __init__(self, betas: tuple, num_timesteps: int, device='cpu'):\n",
|
| 243 |
" super().__init__()\n",
|
| 244 |
" \n",
|
| 245 |
" beta_1, beta_T = betas\n",
|
| 246 |
" assert 0 < beta_1 <= beta_T <= 1, \"ensure 0 < beta_1 <= beta_T <= 1\"\n",
|
| 247 |
" self.device = device\n",
|
| 248 |
" self.num_timesteps = num_timesteps\n",
|
|
|
|
| 249 |
" self.beta_t = torch.linspace(beta_1, beta_T, self.num_timesteps) #* (beta_T-beta_1) + beta_1\n",
|
| 250 |
" self.beta_t = self.beta_t.to(self.device)\n",
|
| 251 |
"\n",
|
|
@@ -277,7 +278,7 @@
|
|
| 277 |
" def sample(self, nn_model, params, device, guide_w = 0):\n",
|
| 278 |
" n_sample = len(params) #params.shape[0]\n",
|
| 279 |
" # print(\"params.shape[0], len(params)\", params.shape[0], len(params))\n",
|
| 280 |
-
" x_i = torch.randn(n_sample, *self.
|
| 281 |
" # print(\"x_i.shape =\", x_i.shape)\n",
|
| 282 |
" if guide_w != -1:\n",
|
| 283 |
" c_i = params\n",
|
|
@@ -297,7 +298,7 @@
|
|
| 297 |
" t_is = torch.tensor([i]).to(device)\n",
|
| 298 |
" t_is = t_is.repeat(n_sample)\n",
|
| 299 |
"\n",
|
| 300 |
-
" z = torch.randn(n_sample, *self.
|
| 301 |
"\n",
|
| 302 |
" if guide_w == -1:\n",
|
| 303 |
" # eps = nn_model(x_i, t_is, return_dict=False)[0]\n",
|
|
@@ -305,7 +306,7 @@
|
|
| 305 |
" # x_i = 1/torch.sqrt(self.alpha_t[i])*(x_i-eps*self.beta_t[i]/torch.sqrt(1-self.bar_alpha_t[i])) + torch.sqrt(self.beta_t[i])*z\n",
|
| 306 |
" else:\n",
|
| 307 |
" # double batch\n",
|
| 308 |
-
" x_i = x_i.repeat(2, *torch.ones(len(self.
|
| 309 |
" t_is = t_is.repeat(2)\n",
|
| 310 |
"\n",
|
| 311 |
" # split predictions and compute weighting\n",
|
|
@@ -338,7 +339,7 @@
|
|
| 338 |
},
|
| 339 |
{
|
| 340 |
"cell_type": "code",
|
| 341 |
-
"execution_count":
|
| 342 |
"metadata": {},
|
| 343 |
"outputs": [],
|
| 344 |
"source": [
|
|
@@ -380,23 +381,23 @@
|
|
| 380 |
},
|
| 381 |
{
|
| 382 |
"cell_type": "code",
|
| 383 |
-
"execution_count":
|
| 384 |
"metadata": {},
|
| 385 |
"outputs": [],
|
| 386 |
"source": [
|
| 387 |
"class Downsample(nn.Module):\n",
|
| 388 |
-
" def __init__(self, channels, use_conv, out_channels=None):\n",
|
| 389 |
" super().__init__()\n",
|
| 390 |
" self.channels = channels\n",
|
| 391 |
" self.out_channels = out_channels or channels\n",
|
| 392 |
-
" stride = config.stride\n",
|
| 393 |
" if use_conv:\n",
|
| 394 |
" # print(\"conv\")\n",
|
| 395 |
-
" self.op = Conv[
|
| 396 |
" else:\n",
|
| 397 |
" # print(\"pool\")\n",
|
| 398 |
" assert channels == self.out_channels\n",
|
| 399 |
-
" self.op = AvgPool[
|
| 400 |
"\n",
|
| 401 |
" def forward(self, x):\n",
|
| 402 |
" assert x.shape[1] == self.channels\n",
|
|
@@ -405,25 +406,26 @@
|
|
| 405 |
},
|
| 406 |
{
|
| 407 |
"cell_type": "code",
|
| 408 |
-
"execution_count":
|
| 409 |
"metadata": {},
|
| 410 |
"outputs": [],
|
| 411 |
"source": [
|
| 412 |
"class Upsample(nn.Module):\n",
|
| 413 |
-
" def __init__(self, channels, use_conv, out_channels=None):\n",
|
| 414 |
" super().__init__()\n",
|
| 415 |
" self.channels = channels\n",
|
| 416 |
" self.out_channels = out_channels\n",
|
| 417 |
" self.use_conv = use_conv\n",
|
|
|
|
| 418 |
" if self.use_conv:\n",
|
| 419 |
-
" self.conv = Conv[
|
| 420 |
"\n",
|
| 421 |
" def forward(self, x):\n",
|
| 422 |
" assert x.shape[1] == self.channels\n",
|
| 423 |
-
" stride = config.stride\n",
|
| 424 |
" # print(torch.tensor(x.shape[2:]))\n",
|
| 425 |
" # print(torch.tensor(stride))\n",
|
| 426 |
-
" shape = torch.tensor(x.shape[2:]) * torch.tensor(stride)\n",
|
| 427 |
" shape = tuple(shape.detach().numpy())\n",
|
| 428 |
" # print(shape)\n",
|
| 429 |
" x = F.interpolate(x, shape, mode='nearest')\n",
|
|
@@ -434,7 +436,7 @@
|
|
| 434 |
},
|
| 435 |
{
|
| 436 |
"cell_type": "code",
|
| 437 |
-
"execution_count":
|
| 438 |
"metadata": {},
|
| 439 |
"outputs": [],
|
| 440 |
"source": [
|
|
@@ -449,7 +451,7 @@
|
|
| 449 |
},
|
| 450 |
{
|
| 451 |
"cell_type": "code",
|
| 452 |
-
"execution_count":
|
| 453 |
"metadata": {},
|
| 454 |
"outputs": [],
|
| 455 |
"source": [
|
|
@@ -463,7 +465,7 @@
|
|
| 463 |
},
|
| 464 |
{
|
| 465 |
"cell_type": "code",
|
| 466 |
-
"execution_count":
|
| 467 |
"metadata": {},
|
| 468 |
"outputs": [],
|
| 469 |
"source": [
|
|
@@ -481,32 +483,33 @@
|
|
| 481 |
},
|
| 482 |
{
|
| 483 |
"cell_type": "code",
|
| 484 |
-
"execution_count":
|
| 485 |
"metadata": {},
|
| 486 |
"outputs": [],
|
| 487 |
"source": [
|
| 488 |
"class ResBlock(TimestepBlock):\n",
|
| 489 |
" def __init__(\n",
|
| 490 |
-
" self, channels, emb_channels, dropout, out_channels=None, use_conv=False, use_checkpoint=False, use_scale_shift_norm=False, up=False, down=False,\n",
|
| 491 |
" ):\n",
|
| 492 |
" super().__init__()\n",
|
| 493 |
" self.out_channels = out_channels or channels\n",
|
| 494 |
" self.use_scale_shift_norm = use_scale_shift_norm\n",
|
|
|
|
| 495 |
"\n",
|
| 496 |
" self.in_layers = nn.Sequential(\n",
|
| 497 |
" # nn.BatchNorm2d(channels), # normalize to standard gaussian\n",
|
| 498 |
" normalization(channels, swish=1.0),\n",
|
| 499 |
" nn.Identity(),\n",
|
| 500 |
-
" Conv[
|
| 501 |
" )\n",
|
| 502 |
"\n",
|
| 503 |
" self.updown = up or down\n",
|
| 504 |
" if up:\n",
|
| 505 |
-
" self.h_updown = Upsample(channels, False)\n",
|
| 506 |
-
" self.x_updown = Upsample(channels, False)\n",
|
| 507 |
" elif down:\n",
|
| 508 |
-
" self.h_updown = Downsample(channels, False)\n",
|
| 509 |
-
" self.x_updown = Downsample(channels, False)\n",
|
| 510 |
" else:\n",
|
| 511 |
" self.h_updown = self.x_updown = nn.Identity()\n",
|
| 512 |
"\n",
|
|
@@ -523,15 +526,15 @@
|
|
| 523 |
" normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0),\n",
|
| 524 |
" nn.SiLU() if use_scale_shift_norm else nn.Identity(),\n",
|
| 525 |
" nn.Dropout(p=dropout),\n",
|
| 526 |
-
" zero_module(Conv[
|
| 527 |
" )\n",
|
| 528 |
"\n",
|
| 529 |
" if self.out_channels == channels:\n",
|
| 530 |
" self.skip_connection = nn.Identity()\n",
|
| 531 |
" elif use_conv:\n",
|
| 532 |
-
" self.skip_connection = Conv[
|
| 533 |
" else:\n",
|
| 534 |
-
" self.skip_connection = Conv[
|
| 535 |
" \n",
|
| 536 |
"\n",
|
| 537 |
" def forward(self, x, emb):\n",
|
|
@@ -562,7 +565,7 @@
|
|
| 562 |
},
|
| 563 |
{
|
| 564 |
"cell_type": "code",
|
| 565 |
-
"execution_count":
|
| 566 |
"metadata": {},
|
| 567 |
"outputs": [],
|
| 568 |
"source": [
|
|
@@ -595,7 +598,7 @@
|
|
| 595 |
},
|
| 596 |
{
|
| 597 |
"cell_type": "code",
|
| 598 |
-
"execution_count":
|
| 599 |
"metadata": {},
|
| 600 |
"outputs": [],
|
| 601 |
"source": [
|
|
@@ -644,7 +647,7 @@
|
|
| 644 |
},
|
| 645 |
{
|
| 646 |
"cell_type": "code",
|
| 647 |
-
"execution_count":
|
| 648 |
"metadata": {},
|
| 649 |
"outputs": [],
|
| 650 |
"source": [
|
|
@@ -673,7 +676,7 @@
|
|
| 673 |
},
|
| 674 |
{
|
| 675 |
"cell_type": "code",
|
| 676 |
-
"execution_count":
|
| 677 |
"metadata": {},
|
| 678 |
"outputs": [],
|
| 679 |
"source": [
|
|
@@ -697,6 +700,8 @@
|
|
| 697 |
" resblock_updown = False,\n",
|
| 698 |
" conv_resample = True,\n",
|
| 699 |
" encoder_channels = None,\n",
|
|
|
|
|
|
|
| 700 |
" ):\n",
|
| 701 |
" super().__init__()\n",
|
| 702 |
"\n",
|
|
@@ -742,7 +747,7 @@
|
|
| 742 |
"\n",
|
| 743 |
" ###################### input_blocks ######################\n",
|
| 744 |
" self.input_blocks = nn.ModuleList(\n",
|
| 745 |
-
" [TimestepEmbedSequential(Conv[
|
| 746 |
" )\n",
|
| 747 |
" self._feature_size = ch\n",
|
| 748 |
" input_block_chans = [ch]\n",
|
|
@@ -758,6 +763,8 @@
|
|
| 758 |
" out_channels = int(mult * model_channels),\n",
|
| 759 |
" use_checkpoint = use_checkpoint,\n",
|
| 760 |
" use_scale_shift_norm = use_scale_shift_norm,\n",
|
|
|
|
|
|
|
| 761 |
" )\n",
|
| 762 |
" ]\n",
|
| 763 |
" ch = int(mult * model_channels)\n",
|
|
@@ -788,9 +795,11 @@
|
|
| 788 |
" use_checkpoint=use_checkpoint,\n",
|
| 789 |
" use_scale_shift_norm=use_scale_shift_norm,\n",
|
| 790 |
" down=True,\n",
|
|
|
|
|
|
|
| 791 |
" )\n",
|
| 792 |
" if resblock_updown\n",
|
| 793 |
-
" else Downsample(ch, conv_resample, out_channels=out_ch)\n",
|
| 794 |
" )\n",
|
| 795 |
" )\n",
|
| 796 |
" ch = out_ch\n",
|
|
@@ -807,6 +816,8 @@
|
|
| 807 |
" dropout,\n",
|
| 808 |
" use_checkpoint=use_checkpoint,\n",
|
| 809 |
" use_scale_shift_norm=use_scale_shift_norm,\n",
|
|
|
|
|
|
|
| 810 |
" ),\n",
|
| 811 |
" AttentionBlock(\n",
|
| 812 |
" ch,\n",
|
|
@@ -821,6 +832,8 @@
|
|
| 821 |
" dropout,\n",
|
| 822 |
" use_checkpoint=use_checkpoint,\n",
|
| 823 |
" use_scale_shift_norm=use_scale_shift_norm,\n",
|
|
|
|
|
|
|
| 824 |
" ),\n",
|
| 825 |
" )\n",
|
| 826 |
" self._feature_size += ch\n",
|
|
@@ -840,6 +853,8 @@
|
|
| 840 |
" # dims=dims,\n",
|
| 841 |
" use_checkpoint=use_checkpoint,\n",
|
| 842 |
" use_scale_shift_norm=use_scale_shift_norm,\n",
|
|
|
|
|
|
|
| 843 |
" )\n",
|
| 844 |
" ]\n",
|
| 845 |
" ch = int(model_channels * mult)\n",
|
|
@@ -866,9 +881,11 @@
|
|
| 866 |
" use_checkpoint=use_checkpoint,\n",
|
| 867 |
" use_scale_shift_norm=use_scale_shift_norm,\n",
|
| 868 |
" up=True,\n",
|
|
|
|
|
|
|
| 869 |
" )\n",
|
| 870 |
" if resblock_updown\n",
|
| 871 |
-
" else Upsample(ch, conv_resample, out_channels=out_ch)\n",
|
| 872 |
" )\n",
|
| 873 |
" ds //= 2\n",
|
| 874 |
" self.output_blocks.append(TimestepEmbedSequential(*layers))\n",
|
|
@@ -878,7 +895,7 @@
|
|
| 878 |
" # nn.BatchNorm2d(ch),\n",
|
| 879 |
" normalization(ch, swish=1.0),\n",
|
| 880 |
" nn.Identity(),\n",
|
| 881 |
-
" zero_module(Conv[
|
| 882 |
" )\n",
|
| 883 |
" # self.use_fp16 = use_fp16\n",
|
| 884 |
"\n",
|
|
@@ -915,7 +932,7 @@
|
|
| 915 |
},
|
| 916 |
{
|
| 917 |
"cell_type": "code",
|
| 918 |
-
"execution_count":
|
| 919 |
"metadata": {},
|
| 920 |
"outputs": [],
|
| 921 |
"source": [
|
|
@@ -945,7 +962,7 @@
|
|
| 945 |
},
|
| 946 |
{
|
| 947 |
"cell_type": "code",
|
| 948 |
-
"execution_count":
|
| 949 |
"metadata": {},
|
| 950 |
"outputs": [],
|
| 951 |
"source": [
|
|
@@ -963,7 +980,7 @@
|
|
| 963 |
" # dim = 2\n",
|
| 964 |
" dim = 2\n",
|
| 965 |
" stride = (2,2) if dim == 2 else (2,2,4)\n",
|
| 966 |
-
" num_image =
|
| 967 |
" HII_DIM = 64\n",
|
| 968 |
" num_redshift = 512#256#256#64#512#128\n",
|
| 969 |
" img_shape = (HII_DIM, num_redshift) if dim == 2 else (HII_DIM, HII_DIM, num_redshift)\n",
|
|
@@ -1010,7 +1027,7 @@
|
|
| 1010 |
},
|
| 1011 |
{
|
| 1012 |
"cell_type": "code",
|
| 1013 |
-
"execution_count":
|
| 1014 |
"metadata": {},
|
| 1015 |
"outputs": [],
|
| 1016 |
"source": [
|
|
@@ -1020,7 +1037,7 @@
|
|
| 1020 |
},
|
| 1021 |
{
|
| 1022 |
"cell_type": "code",
|
| 1023 |
-
"execution_count":
|
| 1024 |
"metadata": {},
|
| 1025 |
"outputs": [],
|
| 1026 |
"source": [
|
|
@@ -1029,7 +1046,7 @@
|
|
| 1029 |
},
|
| 1030 |
{
|
| 1031 |
"cell_type": "code",
|
| 1032 |
-
"execution_count":
|
| 1033 |
"metadata": {},
|
| 1034 |
"outputs": [],
|
| 1035 |
"source": [
|
|
@@ -1053,7 +1070,7 @@
|
|
| 1053 |
},
|
| 1054 |
{
|
| 1055 |
"cell_type": "code",
|
| 1056 |
-
"execution_count":
|
| 1057 |
"metadata": {},
|
| 1058 |
"outputs": [],
|
| 1059 |
"source": [
|
|
@@ -1251,7 +1268,7 @@
|
|
| 1251 |
},
|
| 1252 |
{
|
| 1253 |
"cell_type": "code",
|
| 1254 |
-
"execution_count":
|
| 1255 |
"metadata": {},
|
| 1256 |
"outputs": [
|
| 1257 |
{
|
|
@@ -1262,11 +1279,11 @@
|
|
| 1262 |
"51200 images can be loaded\n",
|
| 1263 |
"field.shape = (64, 64, 514)\n",
|
| 1264 |
"params keys = [b'ION_Tvir_MIN', b'HII_EFF_FACTOR']\n",
|
| 1265 |
-
"loading
|
| 1266 |
-
"images loaded: (
|
| 1267 |
-
"params loaded: (
|
| 1268 |
-
"images rescaled to [-1.0, 1.
|
| 1269 |
-
"params rescaled to [0.0, 0.
|
| 1270 |
"resumed nn_model from model_state.pth\n",
|
| 1271 |
"Number of parameters for nn_model: 111048705\n"
|
| 1272 |
]
|
|
@@ -1298,10 +1315,10 @@
|
|
| 1298 |
" self.dataloader = DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True)\n",
|
| 1299 |
" del dataset\n",
|
| 1300 |
"\n",
|
| 1301 |
-
" self.ddpm = DDPMScheduler(betas=(1e-4, 0.02), num_timesteps=config.num_timesteps, device=config.device)\n",
|
| 1302 |
"\n",
|
| 1303 |
" # initialize the unet\n",
|
| 1304 |
-
" self.nn_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM)\n",
|
| 1305 |
"\n",
|
| 1306 |
" if config.resume:\n",
|
| 1307 |
" self.nn_model.load_state_dict(torch.load(os.path.join(config.output_dir, f\"{config.resume}\"))['unet_state_dict'])\n",
|
|
@@ -1318,7 +1335,7 @@
|
|
| 1318 |
" if config.ema:\n",
|
| 1319 |
" self.ema = EMA(config.ema_rate)\n",
|
| 1320 |
" if config.resume:\n",
|
| 1321 |
-
" self.ema_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM).to(config.device)\n",
|
| 1322 |
" self.ema_model.load_state_dict(torch.load(os.path.join(config.output_dir, f\"{config.resume}\"))['ema_unet_state_dict'])\n",
|
| 1323 |
" print(f\"resumed ema_model from {config.resume}\")\n",
|
| 1324 |
" else:\n",
|
|
@@ -1442,18 +1459,18 @@
|
|
| 1442 |
},
|
| 1443 |
{
|
| 1444 |
"cell_type": "code",
|
| 1445 |
-
"execution_count":
|
| 1446 |
"metadata": {},
|
| 1447 |
"outputs": [
|
| 1448 |
{
|
| 1449 |
"data": {
|
| 1450 |
"application/vnd.jupyter.widget-view+json": {
|
| 1451 |
-
"model_id": "
|
| 1452 |
"version_major": 2,
|
| 1453 |
"version_minor": 0
|
| 1454 |
},
|
| 1455 |
"text/plain": [
|
| 1456 |
-
" 0%| | 0/
|
| 1457 |
]
|
| 1458 |
},
|
| 1459 |
"metadata": {},
|
|
@@ -1462,12 +1479,12 @@
|
|
| 1462 |
{
|
| 1463 |
"data": {
|
| 1464 |
"application/vnd.jupyter.widget-view+json": {
|
| 1465 |
-
"model_id": "
|
| 1466 |
"version_major": 2,
|
| 1467 |
"version_minor": 0
|
| 1468 |
},
|
| 1469 |
"text/plain": [
|
| 1470 |
-
" 0%| | 0/
|
| 1471 |
]
|
| 1472 |
},
|
| 1473 |
"metadata": {},
|
|
@@ -1476,12 +1493,12 @@
|
|
| 1476 |
{
|
| 1477 |
"data": {
|
| 1478 |
"application/vnd.jupyter.widget-view+json": {
|
| 1479 |
-
"model_id": "
|
| 1480 |
"version_major": 2,
|
| 1481 |
"version_minor": 0
|
| 1482 |
},
|
| 1483 |
"text/plain": [
|
| 1484 |
-
" 0%| | 0/
|
| 1485 |
]
|
| 1486 |
},
|
| 1487 |
"metadata": {},
|
|
@@ -1490,12 +1507,12 @@
|
|
| 1490 |
{
|
| 1491 |
"data": {
|
| 1492 |
"application/vnd.jupyter.widget-view+json": {
|
| 1493 |
-
"model_id": "
|
| 1494 |
"version_major": 2,
|
| 1495 |
"version_minor": 0
|
| 1496 |
},
|
| 1497 |
"text/plain": [
|
| 1498 |
-
" 0%| | 0/
|
| 1499 |
]
|
| 1500 |
},
|
| 1501 |
"metadata": {},
|
|
@@ -1504,12 +1521,12 @@
|
|
| 1504 |
{
|
| 1505 |
"data": {
|
| 1506 |
"application/vnd.jupyter.widget-view+json": {
|
| 1507 |
-
"model_id": "
|
| 1508 |
"version_major": 2,
|
| 1509 |
"version_minor": 0
|
| 1510 |
},
|
| 1511 |
"text/plain": [
|
| 1512 |
-
" 0%| | 0/
|
| 1513 |
]
|
| 1514 |
},
|
| 1515 |
"metadata": {},
|
|
@@ -1518,12 +1535,12 @@
|
|
| 1518 |
{
|
| 1519 |
"data": {
|
| 1520 |
"application/vnd.jupyter.widget-view+json": {
|
| 1521 |
-
"model_id": "
|
| 1522 |
"version_major": 2,
|
| 1523 |
"version_minor": 0
|
| 1524 |
},
|
| 1525 |
"text/plain": [
|
| 1526 |
-
" 0%| | 0/
|
| 1527 |
]
|
| 1528 |
},
|
| 1529 |
"metadata": {},
|
|
@@ -1532,12 +1549,12 @@
|
|
| 1532 |
{
|
| 1533 |
"data": {
|
| 1534 |
"application/vnd.jupyter.widget-view+json": {
|
| 1535 |
-
"model_id": "
|
| 1536 |
"version_major": 2,
|
| 1537 |
"version_minor": 0
|
| 1538 |
},
|
| 1539 |
"text/plain": [
|
| 1540 |
-
" 0%| | 0/
|
| 1541 |
]
|
| 1542 |
},
|
| 1543 |
"metadata": {},
|
|
@@ -1546,12 +1563,12 @@
|
|
| 1546 |
{
|
| 1547 |
"data": {
|
| 1548 |
"application/vnd.jupyter.widget-view+json": {
|
| 1549 |
-
"model_id": "
|
| 1550 |
"version_major": 2,
|
| 1551 |
"version_minor": 0
|
| 1552 |
},
|
| 1553 |
"text/plain": [
|
| 1554 |
-
" 0%| | 0/
|
| 1555 |
]
|
| 1556 |
},
|
| 1557 |
"metadata": {},
|
|
@@ -1560,12 +1577,12 @@
|
|
| 1560 |
{
|
| 1561 |
"data": {
|
| 1562 |
"application/vnd.jupyter.widget-view+json": {
|
| 1563 |
-
"model_id": "
|
| 1564 |
"version_major": 2,
|
| 1565 |
"version_minor": 0
|
| 1566 |
},
|
| 1567 |
"text/plain": [
|
| 1568 |
-
" 0%| | 0/
|
| 1569 |
]
|
| 1570 |
},
|
| 1571 |
"metadata": {},
|
|
@@ -1574,12 +1591,12 @@
|
|
| 1574 |
{
|
| 1575 |
"data": {
|
| 1576 |
"application/vnd.jupyter.widget-view+json": {
|
| 1577 |
-
"model_id": "
|
| 1578 |
"version_major": 2,
|
| 1579 |
"version_minor": 0
|
| 1580 |
},
|
| 1581 |
"text/plain": [
|
| 1582 |
-
" 0%| | 0/
|
| 1583 |
]
|
| 1584 |
},
|
| 1585 |
"metadata": {},
|
|
@@ -1592,7 +1609,7 @@
|
|
| 1592 |
},
|
| 1593 |
{
|
| 1594 |
"cell_type": "code",
|
| 1595 |
-
"execution_count":
|
| 1596 |
"metadata": {},
|
| 1597 |
"outputs": [
|
| 1598 |
{
|
|
@@ -1804,21 +1821,21 @@
|
|
| 1804 |
}
|
| 1805 |
],
|
| 1806 |
"source": [
|
| 1807 |
-
"ddpm = DDPMScheduler(betas=(1e-4, 0.02), num_timesteps=config.num_timesteps, device=config.device)\n",
|
| 1808 |
"\n",
|
| 1809 |
-
"nn_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM)\n",
|
| 1810 |
-
"print(\"resuming nn_model\")\n",
|
| 1811 |
-
"nn_model.load_state_dict(torch.load(os.path.join(config.output_dir, f\"model_state.pth\"))['ema_unet_state_dict'])\n",
|
| 1812 |
-
"# nn_model = ContextUnet(n_param=1, image_size=28)\n",
|
| 1813 |
-
"# nn_model.train()\n",
|
| 1814 |
-
"nn_model.to(ddpm.device)\n",
|
| 1815 |
-
"nn_model.eval()\n",
|
| 1816 |
"\n",
|
| 1817 |
-
"n_sample = 20\n",
|
| 1818 |
-
"with torch.no_grad():\n",
|
| 1819 |
-
"
|
| 1820 |
"\n",
|
| 1821 |
-
"np.save(os.path.join(config.output_dir, f\"{config.run_name}_ema.npy\"), x_last_ema)"
|
| 1822 |
]
|
| 1823 |
},
|
| 1824 |
{
|
|
|
|
| 32 |
{
|
| 33 |
"data": {
|
| 34 |
"application/vnd.jupyter.widget-view+json": {
|
| 35 |
+
"model_id": "8be92a01e78a47b792d93b35d557885d",
|
| 36 |
"version_major": 2,
|
| 37 |
"version_minor": 0
|
| 38 |
},
|
|
|
|
| 234 |
},
|
| 235 |
{
|
| 236 |
"cell_type": "code",
|
| 237 |
+
"execution_count": 7,
|
| 238 |
"metadata": {},
|
| 239 |
"outputs": [],
|
| 240 |
"source": [
|
| 241 |
"class DDPMScheduler(nn.Module):\n",
|
| 242 |
+
" def __init__(self, betas: tuple, num_timesteps: int, img_shape: list, device='cpu'):\n",
|
| 243 |
" super().__init__()\n",
|
| 244 |
" \n",
|
| 245 |
" beta_1, beta_T = betas\n",
|
| 246 |
" assert 0 < beta_1 <= beta_T <= 1, \"ensure 0 < beta_1 <= beta_T <= 1\"\n",
|
| 247 |
" self.device = device\n",
|
| 248 |
" self.num_timesteps = num_timesteps\n",
|
| 249 |
+
" self.img_shape = img_shape\n",
|
| 250 |
" self.beta_t = torch.linspace(beta_1, beta_T, self.num_timesteps) #* (beta_T-beta_1) + beta_1\n",
|
| 251 |
" self.beta_t = self.beta_t.to(self.device)\n",
|
| 252 |
"\n",
|
|
|
|
| 278 |
" def sample(self, nn_model, params, device, guide_w = 0):\n",
|
| 279 |
" n_sample = len(params) #params.shape[0]\n",
|
| 280 |
" # print(\"params.shape[0], len(params)\", params.shape[0], len(params))\n",
|
| 281 |
+
" x_i = torch.randn(n_sample, *self.img_shape[1:]).to(device)\n",
|
| 282 |
" # print(\"x_i.shape =\", x_i.shape)\n",
|
| 283 |
" if guide_w != -1:\n",
|
| 284 |
" c_i = params\n",
|
|
|
|
| 298 |
" t_is = torch.tensor([i]).to(device)\n",
|
| 299 |
" t_is = t_is.repeat(n_sample)\n",
|
| 300 |
"\n",
|
| 301 |
+
" z = torch.randn(n_sample, *self.img_shape[1:]).to(device) if i > 0 else 0\n",
|
| 302 |
"\n",
|
| 303 |
" if guide_w == -1:\n",
|
| 304 |
" # eps = nn_model(x_i, t_is, return_dict=False)[0]\n",
|
|
|
|
| 306 |
" # x_i = 1/torch.sqrt(self.alpha_t[i])*(x_i-eps*self.beta_t[i]/torch.sqrt(1-self.bar_alpha_t[i])) + torch.sqrt(self.beta_t[i])*z\n",
|
| 307 |
" else:\n",
|
| 308 |
" # double batch\n",
|
| 309 |
+
" x_i = x_i.repeat(2, *torch.ones(len(self.img_shape[1:]), dtype=int).tolist())\n",
|
| 310 |
" t_is = t_is.repeat(2)\n",
|
| 311 |
"\n",
|
| 312 |
" # split predictions and compute weighting\n",
|
|
|
|
| 339 |
},
|
| 340 |
{
|
| 341 |
"cell_type": "code",
|
| 342 |
+
"execution_count": 8,
|
| 343 |
"metadata": {},
|
| 344 |
"outputs": [],
|
| 345 |
"source": [
|
|
|
|
| 381 |
},
|
| 382 |
{
|
| 383 |
"cell_type": "code",
|
| 384 |
+
"execution_count": 9,
|
| 385 |
"metadata": {},
|
| 386 |
"outputs": [],
|
| 387 |
"source": [
|
| 388 |
"class Downsample(nn.Module):\n",
|
| 389 |
+
" def __init__(self, channels, use_conv, out_channels=None, dim=2, stride=(2,2)):\n",
|
| 390 |
" super().__init__()\n",
|
| 391 |
" self.channels = channels\n",
|
| 392 |
" self.out_channels = out_channels or channels\n",
|
| 393 |
+
" # stride = config.stride\n",
|
| 394 |
" if use_conv:\n",
|
| 395 |
" # print(\"conv\")\n",
|
| 396 |
+
" self.op = Conv[dim](channels, self.out_channels, 3, stride=stride, padding=1)\n",
|
| 397 |
" else:\n",
|
| 398 |
" # print(\"pool\")\n",
|
| 399 |
" assert channels == self.out_channels\n",
|
| 400 |
+
" self.op = AvgPool[dim](kernel_size=stride, stride=stride)\n",
|
| 401 |
"\n",
|
| 402 |
" def forward(self, x):\n",
|
| 403 |
" assert x.shape[1] == self.channels\n",
|
|
|
|
| 406 |
},
|
| 407 |
{
|
| 408 |
"cell_type": "code",
|
| 409 |
+
"execution_count": 10,
|
| 410 |
"metadata": {},
|
| 411 |
"outputs": [],
|
| 412 |
"source": [
|
| 413 |
"class Upsample(nn.Module):\n",
|
| 414 |
+
" def __init__(self, channels, use_conv, out_channels=None, dim=2, stride=(2,2)):\n",
|
| 415 |
" super().__init__()\n",
|
| 416 |
" self.channels = channels\n",
|
| 417 |
" self.out_channels = out_channels\n",
|
| 418 |
" self.use_conv = use_conv\n",
|
| 419 |
+
" self.stride = stride\n",
|
| 420 |
" if self.use_conv:\n",
|
| 421 |
+
" self.conv = Conv[dim](self.channels, self.out_channels, 3, padding=1)\n",
|
| 422 |
"\n",
|
| 423 |
" def forward(self, x):\n",
|
| 424 |
" assert x.shape[1] == self.channels\n",
|
| 425 |
+
" # stride = config.stride\n",
|
| 426 |
" # print(torch.tensor(x.shape[2:]))\n",
|
| 427 |
" # print(torch.tensor(stride))\n",
|
| 428 |
+
" shape = torch.tensor(x.shape[2:]) * torch.tensor(self.stride)\n",
|
| 429 |
" shape = tuple(shape.detach().numpy())\n",
|
| 430 |
" # print(shape)\n",
|
| 431 |
" x = F.interpolate(x, shape, mode='nearest')\n",
|
|
|
|
| 436 |
},
|
| 437 |
{
|
| 438 |
"cell_type": "code",
|
| 439 |
+
"execution_count": 11,
|
| 440 |
"metadata": {},
|
| 441 |
"outputs": [],
|
| 442 |
"source": [
|
|
|
|
| 451 |
},
|
| 452 |
{
|
| 453 |
"cell_type": "code",
|
| 454 |
+
"execution_count": 12,
|
| 455 |
"metadata": {},
|
| 456 |
"outputs": [],
|
| 457 |
"source": [
|
|
|
|
| 465 |
},
|
| 466 |
{
|
| 467 |
"cell_type": "code",
|
| 468 |
+
"execution_count": 13,
|
| 469 |
"metadata": {},
|
| 470 |
"outputs": [],
|
| 471 |
"source": [
|
|
|
|
| 483 |
},
|
| 484 |
{
|
| 485 |
"cell_type": "code",
|
| 486 |
+
"execution_count": 14,
|
| 487 |
"metadata": {},
|
| 488 |
"outputs": [],
|
| 489 |
"source": [
|
| 490 |
"class ResBlock(TimestepBlock):\n",
|
| 491 |
" def __init__(\n",
|
| 492 |
+
" self, channels, emb_channels, dropout, out_channels=None, use_conv=False, use_checkpoint=False, use_scale_shift_norm=False, up=False, down=False, dim=2, stride=(2,2),\n",
|
| 493 |
" ):\n",
|
| 494 |
" super().__init__()\n",
|
| 495 |
" self.out_channels = out_channels or channels\n",
|
| 496 |
" self.use_scale_shift_norm = use_scale_shift_norm\n",
|
| 497 |
+
" self.stride = stride\n",
|
| 498 |
"\n",
|
| 499 |
" self.in_layers = nn.Sequential(\n",
|
| 500 |
" # nn.BatchNorm2d(channels), # normalize to standard gaussian\n",
|
| 501 |
" normalization(channels, swish=1.0),\n",
|
| 502 |
" nn.Identity(),\n",
|
| 503 |
+
" Conv[dim](channels, self.out_channels, 3, padding=1),\n",
|
| 504 |
" )\n",
|
| 505 |
"\n",
|
| 506 |
" self.updown = up or down\n",
|
| 507 |
" if up:\n",
|
| 508 |
+
" self.h_updown = Upsample(channels, False, dim=dim, stride=stride)\n",
|
| 509 |
+
" self.x_updown = Upsample(channels, False, dim=dim, stride=stride)\n",
|
| 510 |
" elif down:\n",
|
| 511 |
+
" self.h_updown = Downsample(channels, False, dim=dim, stride=stride)\n",
|
| 512 |
+
" self.x_updown = Downsample(channels, False, dim=dim, stride=stride)\n",
|
| 513 |
" else:\n",
|
| 514 |
" self.h_updown = self.x_updown = nn.Identity()\n",
|
| 515 |
"\n",
|
|
|
|
| 526 |
" normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0),\n",
|
| 527 |
" nn.SiLU() if use_scale_shift_norm else nn.Identity(),\n",
|
| 528 |
" nn.Dropout(p=dropout),\n",
|
| 529 |
+
" zero_module(Conv[dim](self.out_channels, self.out_channels, 3, padding=1)),\n",
|
| 530 |
" )\n",
|
| 531 |
"\n",
|
| 532 |
" if self.out_channels == channels:\n",
|
| 533 |
" self.skip_connection = nn.Identity()\n",
|
| 534 |
" elif use_conv:\n",
|
| 535 |
+
" self.skip_connection = Conv[dim](channels, self.out_channels, 3, padding=1)\n",
|
| 536 |
" else:\n",
|
| 537 |
+
" self.skip_connection = Conv[dim](channels, self.out_channels, 1)\n",
|
| 538 |
" \n",
|
| 539 |
"\n",
|
| 540 |
" def forward(self, x, emb):\n",
|
|
|
|
| 565 |
},
|
| 566 |
{
|
| 567 |
"cell_type": "code",
|
| 568 |
+
"execution_count": 15,
|
| 569 |
"metadata": {},
|
| 570 |
"outputs": [],
|
| 571 |
"source": [
|
|
|
|
| 598 |
},
|
| 599 |
{
|
| 600 |
"cell_type": "code",
|
| 601 |
+
"execution_count": 16,
|
| 602 |
"metadata": {},
|
| 603 |
"outputs": [],
|
| 604 |
"source": [
|
|
|
|
| 647 |
},
|
| 648 |
{
|
| 649 |
"cell_type": "code",
|
| 650 |
+
"execution_count": 17,
|
| 651 |
"metadata": {},
|
| 652 |
"outputs": [],
|
| 653 |
"source": [
|
|
|
|
| 676 |
},
|
| 677 |
{
|
| 678 |
"cell_type": "code",
|
| 679 |
+
"execution_count": 18,
|
| 680 |
"metadata": {},
|
| 681 |
"outputs": [],
|
| 682 |
"source": [
|
|
|
|
| 700 |
" resblock_updown = False,\n",
|
| 701 |
" conv_resample = True,\n",
|
| 702 |
" encoder_channels = None,\n",
|
| 703 |
+
" dim = 2,\n",
|
| 704 |
+
" stride = (2,2)\n",
|
| 705 |
" ):\n",
|
| 706 |
" super().__init__()\n",
|
| 707 |
"\n",
|
|
|
|
| 747 |
"\n",
|
| 748 |
" ###################### input_blocks ######################\n",
|
| 749 |
" self.input_blocks = nn.ModuleList(\n",
|
| 750 |
+
" [TimestepEmbedSequential(Conv[dim](in_channels, ch, 3, padding=1))]\n",
|
| 751 |
" )\n",
|
| 752 |
" self._feature_size = ch\n",
|
| 753 |
" input_block_chans = [ch]\n",
|
|
|
|
| 763 |
" out_channels = int(mult * model_channels),\n",
|
| 764 |
" use_checkpoint = use_checkpoint,\n",
|
| 765 |
" use_scale_shift_norm = use_scale_shift_norm,\n",
|
| 766 |
+
" dim = dim,\n",
|
| 767 |
+
" stride = stride,\n",
|
| 768 |
" )\n",
|
| 769 |
" ]\n",
|
| 770 |
" ch = int(mult * model_channels)\n",
|
|
|
|
| 795 |
" use_checkpoint=use_checkpoint,\n",
|
| 796 |
" use_scale_shift_norm=use_scale_shift_norm,\n",
|
| 797 |
" down=True,\n",
|
| 798 |
+
" dim = dim,\n",
|
| 799 |
+
" stride = stride,\n",
|
| 800 |
" )\n",
|
| 801 |
" if resblock_updown\n",
|
| 802 |
+
" else Downsample(ch, conv_resample, out_channels=out_ch, dim=dim, stride=stride)\n",
|
| 803 |
" )\n",
|
| 804 |
" )\n",
|
| 805 |
" ch = out_ch\n",
|
|
|
|
| 816 |
" dropout,\n",
|
| 817 |
" use_checkpoint=use_checkpoint,\n",
|
| 818 |
" use_scale_shift_norm=use_scale_shift_norm,\n",
|
| 819 |
+
" dim = dim,\n",
|
| 820 |
+
" stride = stride,\n",
|
| 821 |
" ),\n",
|
| 822 |
" AttentionBlock(\n",
|
| 823 |
" ch,\n",
|
|
|
|
| 832 |
" dropout,\n",
|
| 833 |
" use_checkpoint=use_checkpoint,\n",
|
| 834 |
" use_scale_shift_norm=use_scale_shift_norm,\n",
|
| 835 |
+
" dim = dim,\n",
|
| 836 |
+
" stride = stride,\n",
|
| 837 |
" ),\n",
|
| 838 |
" )\n",
|
| 839 |
" self._feature_size += ch\n",
|
|
|
|
| 853 |
" # dims=dims,\n",
|
| 854 |
" use_checkpoint=use_checkpoint,\n",
|
| 855 |
" use_scale_shift_norm=use_scale_shift_norm,\n",
|
| 856 |
+
" dim = dim,\n",
|
| 857 |
+
" stride = stride,\n",
|
| 858 |
" )\n",
|
| 859 |
" ]\n",
|
| 860 |
" ch = int(model_channels * mult)\n",
|
|
|
|
| 881 |
" use_checkpoint=use_checkpoint,\n",
|
| 882 |
" use_scale_shift_norm=use_scale_shift_norm,\n",
|
| 883 |
" up=True,\n",
|
| 884 |
+
" dim = dim,\n",
|
| 885 |
+
" stride = stride,\n",
|
| 886 |
" )\n",
|
| 887 |
" if resblock_updown\n",
|
| 888 |
+
" else Upsample(ch, conv_resample, out_channels=out_ch, dim=dim, stride=stride)\n",
|
| 889 |
" )\n",
|
| 890 |
" ds //= 2\n",
|
| 891 |
" self.output_blocks.append(TimestepEmbedSequential(*layers))\n",
|
|
|
|
| 895 |
" # nn.BatchNorm2d(ch),\n",
|
| 896 |
" normalization(ch, swish=1.0),\n",
|
| 897 |
" nn.Identity(),\n",
|
| 898 |
+
" zero_module(Conv[dim](input_ch, out_channels, 3, padding=1)),\n",
|
| 899 |
" )\n",
|
| 900 |
" # self.use_fp16 = use_fp16\n",
|
| 901 |
"\n",
|
|
|
|
| 932 |
},
|
| 933 |
{
|
| 934 |
"cell_type": "code",
|
| 935 |
+
"execution_count": 19,
|
| 936 |
"metadata": {},
|
| 937 |
"outputs": [],
|
| 938 |
"source": [
|
|
|
|
| 962 |
},
|
| 963 |
{
|
| 964 |
"cell_type": "code",
|
| 965 |
+
"execution_count": 20,
|
| 966 |
"metadata": {},
|
| 967 |
"outputs": [],
|
| 968 |
"source": [
|
|
|
|
| 980 |
" # dim = 2\n",
|
| 981 |
" dim = 2\n",
|
| 982 |
" stride = (2,2) if dim == 2 else (2,2,4)\n",
|
| 983 |
+
" num_image = 200 # 2400\n",
|
| 984 |
" HII_DIM = 64\n",
|
| 985 |
" num_redshift = 512#256#256#64#512#128\n",
|
| 986 |
" img_shape = (HII_DIM, num_redshift) if dim == 2 else (HII_DIM, HII_DIM, num_redshift)\n",
|
|
|
|
| 1027 |
},
|
| 1028 |
{
|
| 1029 |
"cell_type": "code",
|
| 1030 |
+
"execution_count": 21,
|
| 1031 |
"metadata": {},
|
| 1032 |
"outputs": [],
|
| 1033 |
"source": [
|
|
|
|
| 1037 |
},
|
| 1038 |
{
|
| 1039 |
"cell_type": "code",
|
| 1040 |
+
"execution_count": 22,
|
| 1041 |
"metadata": {},
|
| 1042 |
"outputs": [],
|
| 1043 |
"source": [
|
|
|
|
| 1046 |
},
|
| 1047 |
{
|
| 1048 |
"cell_type": "code",
|
| 1049 |
+
"execution_count": 23,
|
| 1050 |
"metadata": {},
|
| 1051 |
"outputs": [],
|
| 1052 |
"source": [
|
|
|
|
| 1070 |
},
|
| 1071 |
{
|
| 1072 |
"cell_type": "code",
|
| 1073 |
+
"execution_count": 24,
|
| 1074 |
"metadata": {},
|
| 1075 |
"outputs": [],
|
| 1076 |
"source": [
|
|
|
|
| 1268 |
},
|
| 1269 |
{
|
| 1270 |
"cell_type": "code",
|
| 1271 |
+
"execution_count": 25,
|
| 1272 |
"metadata": {},
|
| 1273 |
"outputs": [
|
| 1274 |
{
|
|
|
|
| 1279 |
"51200 images can be loaded\n",
|
| 1280 |
"field.shape = (64, 64, 514)\n",
|
| 1281 |
"params keys = [b'ION_Tvir_MIN', b'HII_EFF_FACTOR']\n",
|
| 1282 |
+
"loading 200 images randomly\n",
|
| 1283 |
+
"images loaded: (200, 1, 64, 512)\n",
|
| 1284 |
+
"params loaded: (200, 2)\n",
|
| 1285 |
+
"images rescaled to [-1.0, 1.082756519317627]\n",
|
| 1286 |
+
"params rescaled to [0.0, 0.9938162632551855]\n",
|
| 1287 |
"resumed nn_model from model_state.pth\n",
|
| 1288 |
"Number of parameters for nn_model: 111048705\n"
|
| 1289 |
]
|
|
|
|
| 1315 |
" self.dataloader = DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True)\n",
|
| 1316 |
" del dataset\n",
|
| 1317 |
"\n",
|
| 1318 |
+
" self.ddpm = DDPMScheduler(betas=(1e-4, 0.02), num_timesteps=config.num_timesteps, img_shape=config.img_shape, device=config.device)\n",
|
| 1319 |
"\n",
|
| 1320 |
" # initialize the unet\n",
|
| 1321 |
+
" self.nn_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM, dim=config.dim, stride=config.stride)\n",
|
| 1322 |
"\n",
|
| 1323 |
" if config.resume:\n",
|
| 1324 |
" self.nn_model.load_state_dict(torch.load(os.path.join(config.output_dir, f\"{config.resume}\"))['unet_state_dict'])\n",
|
|
|
|
| 1335 |
" if config.ema:\n",
|
| 1336 |
" self.ema = EMA(config.ema_rate)\n",
|
| 1337 |
" if config.resume:\n",
|
| 1338 |
+
" self.ema_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM, dim=config.dim, stride=config.stride).to(config.device)\n",
|
| 1339 |
" self.ema_model.load_state_dict(torch.load(os.path.join(config.output_dir, f\"{config.resume}\"))['ema_unet_state_dict'])\n",
|
| 1340 |
" print(f\"resumed ema_model from {config.resume}\")\n",
|
| 1341 |
" else:\n",
|
|
|
|
| 1459 |
},
|
| 1460 |
{
|
| 1461 |
"cell_type": "code",
|
| 1462 |
+
"execution_count": 26,
|
| 1463 |
"metadata": {},
|
| 1464 |
"outputs": [
|
| 1465 |
{
|
| 1466 |
"data": {
|
| 1467 |
"application/vnd.jupyter.widget-view+json": {
|
| 1468 |
+
"model_id": "7a0b627f28ef409f8504113bc3af36e3",
|
| 1469 |
"version_major": 2,
|
| 1470 |
"version_minor": 0
|
| 1471 |
},
|
| 1472 |
"text/plain": [
|
| 1473 |
+
" 0%| | 0/20 [00:00<?, ?it/s]"
|
| 1474 |
]
|
| 1475 |
},
|
| 1476 |
"metadata": {},
|
|
|
|
| 1479 |
{
|
| 1480 |
"data": {
|
| 1481 |
"application/vnd.jupyter.widget-view+json": {
|
| 1482 |
+
"model_id": "62f09cd440a84841b336ab15e76e2fe6",
|
| 1483 |
"version_major": 2,
|
| 1484 |
"version_minor": 0
|
| 1485 |
},
|
| 1486 |
"text/plain": [
|
| 1487 |
+
" 0%| | 0/20 [00:00<?, ?it/s]"
|
| 1488 |
]
|
| 1489 |
},
|
| 1490 |
"metadata": {},
|
|
|
|
| 1493 |
{
|
| 1494 |
"data": {
|
| 1495 |
"application/vnd.jupyter.widget-view+json": {
|
| 1496 |
+
"model_id": "9db24e29de0c47328f1aba68db61bbae",
|
| 1497 |
"version_major": 2,
|
| 1498 |
"version_minor": 0
|
| 1499 |
},
|
| 1500 |
"text/plain": [
|
| 1501 |
+
" 0%| | 0/20 [00:00<?, ?it/s]"
|
| 1502 |
]
|
| 1503 |
},
|
| 1504 |
"metadata": {},
|
|
|
|
| 1507 |
{
|
| 1508 |
"data": {
|
| 1509 |
"application/vnd.jupyter.widget-view+json": {
|
| 1510 |
+
"model_id": "ee59d1a664d04a2b90a7a448a816ed10",
|
| 1511 |
"version_major": 2,
|
| 1512 |
"version_minor": 0
|
| 1513 |
},
|
| 1514 |
"text/plain": [
|
| 1515 |
+
" 0%| | 0/20 [00:00<?, ?it/s]"
|
| 1516 |
]
|
| 1517 |
},
|
| 1518 |
"metadata": {},
|
|
|
|
| 1521 |
{
|
| 1522 |
"data": {
|
| 1523 |
"application/vnd.jupyter.widget-view+json": {
|
| 1524 |
+
"model_id": "8690c736f7eb4a23925b450c05659575",
|
| 1525 |
"version_major": 2,
|
| 1526 |
"version_minor": 0
|
| 1527 |
},
|
| 1528 |
"text/plain": [
|
| 1529 |
+
" 0%| | 0/20 [00:00<?, ?it/s]"
|
| 1530 |
]
|
| 1531 |
},
|
| 1532 |
"metadata": {},
|
|
|
|
| 1535 |
{
|
| 1536 |
"data": {
|
| 1537 |
"application/vnd.jupyter.widget-view+json": {
|
| 1538 |
+
"model_id": "7dc014a33bfd43408e0aafc208bb403e",
|
| 1539 |
"version_major": 2,
|
| 1540 |
"version_minor": 0
|
| 1541 |
},
|
| 1542 |
"text/plain": [
|
| 1543 |
+
" 0%| | 0/20 [00:00<?, ?it/s]"
|
| 1544 |
]
|
| 1545 |
},
|
| 1546 |
"metadata": {},
|
|
|
|
| 1549 |
{
|
| 1550 |
"data": {
|
| 1551 |
"application/vnd.jupyter.widget-view+json": {
|
| 1552 |
+
"model_id": "6715e5cccc6d480397f76bcea34f94e5",
|
| 1553 |
"version_major": 2,
|
| 1554 |
"version_minor": 0
|
| 1555 |
},
|
| 1556 |
"text/plain": [
|
| 1557 |
+
" 0%| | 0/20 [00:00<?, ?it/s]"
|
| 1558 |
]
|
| 1559 |
},
|
| 1560 |
"metadata": {},
|
|
|
|
| 1563 |
{
|
| 1564 |
"data": {
|
| 1565 |
"application/vnd.jupyter.widget-view+json": {
|
| 1566 |
+
"model_id": "b7410efd4a5d4efdb9b8be38ba1c2fcb",
|
| 1567 |
"version_major": 2,
|
| 1568 |
"version_minor": 0
|
| 1569 |
},
|
| 1570 |
"text/plain": [
|
| 1571 |
+
" 0%| | 0/20 [00:00<?, ?it/s]"
|
| 1572 |
]
|
| 1573 |
},
|
| 1574 |
"metadata": {},
|
|
|
|
| 1577 |
{
|
| 1578 |
"data": {
|
| 1579 |
"application/vnd.jupyter.widget-view+json": {
|
| 1580 |
+
"model_id": "3b6c0478c9ff4a99b7f79ba4422dbd7d",
|
| 1581 |
"version_major": 2,
|
| 1582 |
"version_minor": 0
|
| 1583 |
},
|
| 1584 |
"text/plain": [
|
| 1585 |
+
" 0%| | 0/20 [00:00<?, ?it/s]"
|
| 1586 |
]
|
| 1587 |
},
|
| 1588 |
"metadata": {},
|
|
|
|
| 1591 |
{
|
| 1592 |
"data": {
|
| 1593 |
"application/vnd.jupyter.widget-view+json": {
|
| 1594 |
+
"model_id": "d26f49f5a9804d84b6b6a531a56eb03a",
|
| 1595 |
"version_major": 2,
|
| 1596 |
"version_minor": 0
|
| 1597 |
},
|
| 1598 |
"text/plain": [
|
| 1599 |
+
" 0%| | 0/20 [00:00<?, ?it/s]"
|
| 1600 |
]
|
| 1601 |
},
|
| 1602 |
"metadata": {},
|
|
|
|
| 1609 |
},
|
| 1610 |
{
|
| 1611 |
"cell_type": "code",
|
| 1612 |
+
"execution_count": null,
|
| 1613 |
"metadata": {},
|
| 1614 |
"outputs": [
|
| 1615 |
{
|
|
|
|
| 1821 |
}
|
| 1822 |
],
|
| 1823 |
"source": [
|
| 1824 |
+
"# ddpm = DDPMScheduler(betas=(1e-4, 0.02), num_timesteps=config.num_timesteps, device=config.device)\n",
|
| 1825 |
"\n",
|
| 1826 |
+
"# nn_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM)\n",
|
| 1827 |
+
"# print(\"resuming nn_model\")\n",
|
| 1828 |
+
"# nn_model.load_state_dict(torch.load(os.path.join(config.output_dir, f\"model_state.pth\"))['ema_unet_state_dict'])\n",
|
| 1829 |
+
"# # nn_model = ContextUnet(n_param=1, image_size=28)\n",
|
| 1830 |
+
"# # nn_model.train()\n",
|
| 1831 |
+
"# nn_model.to(ddpm.device)\n",
|
| 1832 |
+
"# nn_model.eval()\n",
|
| 1833 |
"\n",
|
| 1834 |
+
"# n_sample = 20\n",
|
| 1835 |
+
"# with torch.no_grad():\n",
|
| 1836 |
+
"# x_last_ema, x_ema_entire = ddpm.sample(nn_model, n_sample, (1,config.HII_DIM, config.num_redshift), config.device, params = torch.tile(config.params_single,(n_sample,1)).to(config.device), guide_w=config.guide_w)\n",
|
| 1837 |
"\n",
|
| 1838 |
+
"# np.save(os.path.join(config.output_dir, f\"{config.run_name}_ema.npy\"), x_last_ema)"
|
| 1839 |
]
|
| 1840 |
},
|
| 1841 |
{
|