Xsmos commited on
Commit
443fbc2
·
verified ·
1 Parent(s): 58b2929

0521-2013

Browse files
Files changed (1) hide show
  1. diffusion.ipynb +105 -88
diffusion.ipynb CHANGED
@@ -32,7 +32,7 @@
32
  {
33
  "data": {
34
  "application/vnd.jupyter.widget-view+json": {
35
- "model_id": "c9806d87a1f2404fb462189f2912d675",
36
  "version_major": 2,
37
  "version_minor": 0
38
  },
@@ -234,18 +234,19 @@
234
  },
235
  {
236
  "cell_type": "code",
237
- "execution_count": 90,
238
  "metadata": {},
239
  "outputs": [],
240
  "source": [
241
  "class DDPMScheduler(nn.Module):\n",
242
- " def __init__(self, betas: tuple, num_timesteps: int, device='cpu'):\n",
243
  " super().__init__()\n",
244
  " \n",
245
  " beta_1, beta_T = betas\n",
246
  " assert 0 < beta_1 <= beta_T <= 1, \"ensure 0 < beta_1 <= beta_T <= 1\"\n",
247
  " self.device = device\n",
248
  " self.num_timesteps = num_timesteps\n",
 
249
  " self.beta_t = torch.linspace(beta_1, beta_T, self.num_timesteps) #* (beta_T-beta_1) + beta_1\n",
250
  " self.beta_t = self.beta_t.to(self.device)\n",
251
  "\n",
@@ -277,7 +278,7 @@
277
  " def sample(self, nn_model, params, device, guide_w = 0):\n",
278
  " n_sample = len(params) #params.shape[0]\n",
279
  " # print(\"params.shape[0], len(params)\", params.shape[0], len(params))\n",
280
- " x_i = torch.randn(n_sample, *self.shape[1:]).to(device)\n",
281
  " # print(\"x_i.shape =\", x_i.shape)\n",
282
  " if guide_w != -1:\n",
283
  " c_i = params\n",
@@ -297,7 +298,7 @@
297
  " t_is = torch.tensor([i]).to(device)\n",
298
  " t_is = t_is.repeat(n_sample)\n",
299
  "\n",
300
- " z = torch.randn(n_sample, *self.shape[1:]).to(device) if i > 0 else 0\n",
301
  "\n",
302
  " if guide_w == -1:\n",
303
  " # eps = nn_model(x_i, t_is, return_dict=False)[0]\n",
@@ -305,7 +306,7 @@
305
  " # x_i = 1/torch.sqrt(self.alpha_t[i])*(x_i-eps*self.beta_t[i]/torch.sqrt(1-self.bar_alpha_t[i])) + torch.sqrt(self.beta_t[i])*z\n",
306
  " else:\n",
307
  " # double batch\n",
308
- " x_i = x_i.repeat(2, *torch.ones(len(self.shape[1:]), dtype=int).tolist())\n",
309
  " t_is = t_is.repeat(2)\n",
310
  "\n",
311
  " # split predictions and compute weighting\n",
@@ -338,7 +339,7 @@
338
  },
339
  {
340
  "cell_type": "code",
341
- "execution_count": 91,
342
  "metadata": {},
343
  "outputs": [],
344
  "source": [
@@ -380,23 +381,23 @@
380
  },
381
  {
382
  "cell_type": "code",
383
- "execution_count": 92,
384
  "metadata": {},
385
  "outputs": [],
386
  "source": [
387
  "class Downsample(nn.Module):\n",
388
- " def __init__(self, channels, use_conv, out_channels=None):\n",
389
  " super().__init__()\n",
390
  " self.channels = channels\n",
391
  " self.out_channels = out_channels or channels\n",
392
- " stride = config.stride\n",
393
  " if use_conv:\n",
394
  " # print(\"conv\")\n",
395
- " self.op = Conv[config.dim](channels, self.out_channels, 3, stride=stride, padding=1)\n",
396
  " else:\n",
397
  " # print(\"pool\")\n",
398
  " assert channels == self.out_channels\n",
399
- " self.op = AvgPool[config.dim](kernel_size=stride, stride=stride)\n",
400
  "\n",
401
  " def forward(self, x):\n",
402
  " assert x.shape[1] == self.channels\n",
@@ -405,25 +406,26 @@
405
  },
406
  {
407
  "cell_type": "code",
408
- "execution_count": 93,
409
  "metadata": {},
410
  "outputs": [],
411
  "source": [
412
  "class Upsample(nn.Module):\n",
413
- " def __init__(self, channels, use_conv, out_channels=None):\n",
414
  " super().__init__()\n",
415
  " self.channels = channels\n",
416
  " self.out_channels = out_channels\n",
417
  " self.use_conv = use_conv\n",
 
418
  " if self.use_conv:\n",
419
- " self.conv = Conv[config.dim](self.channels, self.out_channels, 3, padding=1)\n",
420
  "\n",
421
  " def forward(self, x):\n",
422
  " assert x.shape[1] == self.channels\n",
423
- " stride = config.stride\n",
424
  " # print(torch.tensor(x.shape[2:]))\n",
425
  " # print(torch.tensor(stride))\n",
426
- " shape = torch.tensor(x.shape[2:]) * torch.tensor(stride)\n",
427
  " shape = tuple(shape.detach().numpy())\n",
428
  " # print(shape)\n",
429
  " x = F.interpolate(x, shape, mode='nearest')\n",
@@ -434,7 +436,7 @@
434
  },
435
  {
436
  "cell_type": "code",
437
- "execution_count": 94,
438
  "metadata": {},
439
  "outputs": [],
440
  "source": [
@@ -449,7 +451,7 @@
449
  },
450
  {
451
  "cell_type": "code",
452
- "execution_count": 95,
453
  "metadata": {},
454
  "outputs": [],
455
  "source": [
@@ -463,7 +465,7 @@
463
  },
464
  {
465
  "cell_type": "code",
466
- "execution_count": 96,
467
  "metadata": {},
468
  "outputs": [],
469
  "source": [
@@ -481,32 +483,33 @@
481
  },
482
  {
483
  "cell_type": "code",
484
- "execution_count": 97,
485
  "metadata": {},
486
  "outputs": [],
487
  "source": [
488
  "class ResBlock(TimestepBlock):\n",
489
  " def __init__(\n",
490
- " self, channels, emb_channels, dropout, out_channels=None, use_conv=False, use_checkpoint=False, use_scale_shift_norm=False, up=False, down=False,\n",
491
  " ):\n",
492
  " super().__init__()\n",
493
  " self.out_channels = out_channels or channels\n",
494
  " self.use_scale_shift_norm = use_scale_shift_norm\n",
 
495
  "\n",
496
  " self.in_layers = nn.Sequential(\n",
497
  " # nn.BatchNorm2d(channels), # normalize to standard gaussian\n",
498
  " normalization(channels, swish=1.0),\n",
499
  " nn.Identity(),\n",
500
- " Conv[config.dim](channels, self.out_channels, 3, padding=1),\n",
501
  " )\n",
502
  "\n",
503
  " self.updown = up or down\n",
504
  " if up:\n",
505
- " self.h_updown = Upsample(channels, False)\n",
506
- " self.x_updown = Upsample(channels, False)\n",
507
  " elif down:\n",
508
- " self.h_updown = Downsample(channels, False)\n",
509
- " self.x_updown = Downsample(channels, False)\n",
510
  " else:\n",
511
  " self.h_updown = self.x_updown = nn.Identity()\n",
512
  "\n",
@@ -523,15 +526,15 @@
523
  " normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0),\n",
524
  " nn.SiLU() if use_scale_shift_norm else nn.Identity(),\n",
525
  " nn.Dropout(p=dropout),\n",
526
- " zero_module(Conv[config.dim](self.out_channels, self.out_channels, 3, padding=1)),\n",
527
  " )\n",
528
  "\n",
529
  " if self.out_channels == channels:\n",
530
  " self.skip_connection = nn.Identity()\n",
531
  " elif use_conv:\n",
532
- " self.skip_connection = Conv[config.dim](channels, self.out_channels, 3, padding=1)\n",
533
  " else:\n",
534
- " self.skip_connection = Conv[config.dim](channels, self.out_channels, 1)\n",
535
  " \n",
536
  "\n",
537
  " def forward(self, x, emb):\n",
@@ -562,7 +565,7 @@
562
  },
563
  {
564
  "cell_type": "code",
565
- "execution_count": 98,
566
  "metadata": {},
567
  "outputs": [],
568
  "source": [
@@ -595,7 +598,7 @@
595
  },
596
  {
597
  "cell_type": "code",
598
- "execution_count": 99,
599
  "metadata": {},
600
  "outputs": [],
601
  "source": [
@@ -644,7 +647,7 @@
644
  },
645
  {
646
  "cell_type": "code",
647
- "execution_count": 100,
648
  "metadata": {},
649
  "outputs": [],
650
  "source": [
@@ -673,7 +676,7 @@
673
  },
674
  {
675
  "cell_type": "code",
676
- "execution_count": 101,
677
  "metadata": {},
678
  "outputs": [],
679
  "source": [
@@ -697,6 +700,8 @@
697
  " resblock_updown = False,\n",
698
  " conv_resample = True,\n",
699
  " encoder_channels = None,\n",
 
 
700
  " ):\n",
701
  " super().__init__()\n",
702
  "\n",
@@ -742,7 +747,7 @@
742
  "\n",
743
  " ###################### input_blocks ######################\n",
744
  " self.input_blocks = nn.ModuleList(\n",
745
- " [TimestepEmbedSequential(Conv[config.dim](in_channels, ch, 3, padding=1))]\n",
746
  " )\n",
747
  " self._feature_size = ch\n",
748
  " input_block_chans = [ch]\n",
@@ -758,6 +763,8 @@
758
  " out_channels = int(mult * model_channels),\n",
759
  " use_checkpoint = use_checkpoint,\n",
760
  " use_scale_shift_norm = use_scale_shift_norm,\n",
 
 
761
  " )\n",
762
  " ]\n",
763
  " ch = int(mult * model_channels)\n",
@@ -788,9 +795,11 @@
788
  " use_checkpoint=use_checkpoint,\n",
789
  " use_scale_shift_norm=use_scale_shift_norm,\n",
790
  " down=True,\n",
 
 
791
  " )\n",
792
  " if resblock_updown\n",
793
- " else Downsample(ch, conv_resample, out_channels=out_ch)\n",
794
  " )\n",
795
  " )\n",
796
  " ch = out_ch\n",
@@ -807,6 +816,8 @@
807
  " dropout,\n",
808
  " use_checkpoint=use_checkpoint,\n",
809
  " use_scale_shift_norm=use_scale_shift_norm,\n",
 
 
810
  " ),\n",
811
  " AttentionBlock(\n",
812
  " ch,\n",
@@ -821,6 +832,8 @@
821
  " dropout,\n",
822
  " use_checkpoint=use_checkpoint,\n",
823
  " use_scale_shift_norm=use_scale_shift_norm,\n",
 
 
824
  " ),\n",
825
  " )\n",
826
  " self._feature_size += ch\n",
@@ -840,6 +853,8 @@
840
  " # dims=dims,\n",
841
  " use_checkpoint=use_checkpoint,\n",
842
  " use_scale_shift_norm=use_scale_shift_norm,\n",
 
 
843
  " )\n",
844
  " ]\n",
845
  " ch = int(model_channels * mult)\n",
@@ -866,9 +881,11 @@
866
  " use_checkpoint=use_checkpoint,\n",
867
  " use_scale_shift_norm=use_scale_shift_norm,\n",
868
  " up=True,\n",
 
 
869
  " )\n",
870
  " if resblock_updown\n",
871
- " else Upsample(ch, conv_resample, out_channels=out_ch)\n",
872
  " )\n",
873
  " ds //= 2\n",
874
  " self.output_blocks.append(TimestepEmbedSequential(*layers))\n",
@@ -878,7 +895,7 @@
878
  " # nn.BatchNorm2d(ch),\n",
879
  " normalization(ch, swish=1.0),\n",
880
  " nn.Identity(),\n",
881
- " zero_module(Conv[config.dim](input_ch, out_channels, 3, padding=1)),\n",
882
  " )\n",
883
  " # self.use_fp16 = use_fp16\n",
884
  "\n",
@@ -915,7 +932,7 @@
915
  },
916
  {
917
  "cell_type": "code",
918
- "execution_count": 102,
919
  "metadata": {},
920
  "outputs": [],
921
  "source": [
@@ -945,7 +962,7 @@
945
  },
946
  {
947
  "cell_type": "code",
948
- "execution_count": 123,
949
  "metadata": {},
950
  "outputs": [],
951
  "source": [
@@ -963,7 +980,7 @@
963
  " # dim = 2\n",
964
  " dim = 2\n",
965
  " stride = (2,2) if dim == 2 else (2,2,4)\n",
966
- " num_image = 20 # 2400\n",
967
  " HII_DIM = 64\n",
968
  " num_redshift = 512#256#256#64#512#128\n",
969
  " img_shape = (HII_DIM, num_redshift) if dim == 2 else (HII_DIM, HII_DIM, num_redshift)\n",
@@ -1010,7 +1027,7 @@
1010
  },
1011
  {
1012
  "cell_type": "code",
1013
- "execution_count": 124,
1014
  "metadata": {},
1015
  "outputs": [],
1016
  "source": [
@@ -1020,7 +1037,7 @@
1020
  },
1021
  {
1022
  "cell_type": "code",
1023
- "execution_count": 125,
1024
  "metadata": {},
1025
  "outputs": [],
1026
  "source": [
@@ -1029,7 +1046,7 @@
1029
  },
1030
  {
1031
  "cell_type": "code",
1032
- "execution_count": 126,
1033
  "metadata": {},
1034
  "outputs": [],
1035
  "source": [
@@ -1053,7 +1070,7 @@
1053
  },
1054
  {
1055
  "cell_type": "code",
1056
- "execution_count": 127,
1057
  "metadata": {},
1058
  "outputs": [],
1059
  "source": [
@@ -1251,7 +1268,7 @@
1251
  },
1252
  {
1253
  "cell_type": "code",
1254
- "execution_count": 128,
1255
  "metadata": {},
1256
  "outputs": [
1257
  {
@@ -1262,11 +1279,11 @@
1262
  "51200 images can be loaded\n",
1263
  "field.shape = (64, 64, 514)\n",
1264
  "params keys = [b'ION_Tvir_MIN', b'HII_EFF_FACTOR']\n",
1265
- "loading 20 images randomly\n",
1266
- "images loaded: (20, 1, 64, 512)\n",
1267
- "params loaded: (20, 2)\n",
1268
- "images rescaled to [-1.0, 1.038496732711792]\n",
1269
- "params rescaled to [0.0, 0.9816321951033768]\n",
1270
  "resumed nn_model from model_state.pth\n",
1271
  "Number of parameters for nn_model: 111048705\n"
1272
  ]
@@ -1298,10 +1315,10 @@
1298
  " self.dataloader = DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True)\n",
1299
  " del dataset\n",
1300
  "\n",
1301
- " self.ddpm = DDPMScheduler(betas=(1e-4, 0.02), num_timesteps=config.num_timesteps, device=config.device)\n",
1302
  "\n",
1303
  " # initialize the unet\n",
1304
- " self.nn_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM)\n",
1305
  "\n",
1306
  " if config.resume:\n",
1307
  " self.nn_model.load_state_dict(torch.load(os.path.join(config.output_dir, f\"{config.resume}\"))['unet_state_dict'])\n",
@@ -1318,7 +1335,7 @@
1318
  " if config.ema:\n",
1319
  " self.ema = EMA(config.ema_rate)\n",
1320
  " if config.resume:\n",
1321
- " self.ema_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM).to(config.device)\n",
1322
  " self.ema_model.load_state_dict(torch.load(os.path.join(config.output_dir, f\"{config.resume}\"))['ema_unet_state_dict'])\n",
1323
  " print(f\"resumed ema_model from {config.resume}\")\n",
1324
  " else:\n",
@@ -1442,18 +1459,18 @@
1442
  },
1443
  {
1444
  "cell_type": "code",
1445
- "execution_count": 129,
1446
  "metadata": {},
1447
  "outputs": [
1448
  {
1449
  "data": {
1450
  "application/vnd.jupyter.widget-view+json": {
1451
- "model_id": "067df92056c8456aa796e3416bac122a",
1452
  "version_major": 2,
1453
  "version_minor": 0
1454
  },
1455
  "text/plain": [
1456
- " 0%| | 0/2 [00:00<?, ?it/s]"
1457
  ]
1458
  },
1459
  "metadata": {},
@@ -1462,12 +1479,12 @@
1462
  {
1463
  "data": {
1464
  "application/vnd.jupyter.widget-view+json": {
1465
- "model_id": "8e77d9787ee049e5896b1be75d34bf05",
1466
  "version_major": 2,
1467
  "version_minor": 0
1468
  },
1469
  "text/plain": [
1470
- " 0%| | 0/2 [00:00<?, ?it/s]"
1471
  ]
1472
  },
1473
  "metadata": {},
@@ -1476,12 +1493,12 @@
1476
  {
1477
  "data": {
1478
  "application/vnd.jupyter.widget-view+json": {
1479
- "model_id": "8211e85e22354d7da06f66786ff33d4a",
1480
  "version_major": 2,
1481
  "version_minor": 0
1482
  },
1483
  "text/plain": [
1484
- " 0%| | 0/2 [00:00<?, ?it/s]"
1485
  ]
1486
  },
1487
  "metadata": {},
@@ -1490,12 +1507,12 @@
1490
  {
1491
  "data": {
1492
  "application/vnd.jupyter.widget-view+json": {
1493
- "model_id": "31d068ad21c642468bb2a90c7af57c83",
1494
  "version_major": 2,
1495
  "version_minor": 0
1496
  },
1497
  "text/plain": [
1498
- " 0%| | 0/2 [00:00<?, ?it/s]"
1499
  ]
1500
  },
1501
  "metadata": {},
@@ -1504,12 +1521,12 @@
1504
  {
1505
  "data": {
1506
  "application/vnd.jupyter.widget-view+json": {
1507
- "model_id": "2ca6304e757f4c8696bacfc36692e791",
1508
  "version_major": 2,
1509
  "version_minor": 0
1510
  },
1511
  "text/plain": [
1512
- " 0%| | 0/2 [00:00<?, ?it/s]"
1513
  ]
1514
  },
1515
  "metadata": {},
@@ -1518,12 +1535,12 @@
1518
  {
1519
  "data": {
1520
  "application/vnd.jupyter.widget-view+json": {
1521
- "model_id": "7cc536030a784596995ec5130b7638c5",
1522
  "version_major": 2,
1523
  "version_minor": 0
1524
  },
1525
  "text/plain": [
1526
- " 0%| | 0/2 [00:00<?, ?it/s]"
1527
  ]
1528
  },
1529
  "metadata": {},
@@ -1532,12 +1549,12 @@
1532
  {
1533
  "data": {
1534
  "application/vnd.jupyter.widget-view+json": {
1535
- "model_id": "b415a15a942046f08e3e2c92404c14ad",
1536
  "version_major": 2,
1537
  "version_minor": 0
1538
  },
1539
  "text/plain": [
1540
- " 0%| | 0/2 [00:00<?, ?it/s]"
1541
  ]
1542
  },
1543
  "metadata": {},
@@ -1546,12 +1563,12 @@
1546
  {
1547
  "data": {
1548
  "application/vnd.jupyter.widget-view+json": {
1549
- "model_id": "2de1a814b7d34998b63eec43c1d43c12",
1550
  "version_major": 2,
1551
  "version_minor": 0
1552
  },
1553
  "text/plain": [
1554
- " 0%| | 0/2 [00:00<?, ?it/s]"
1555
  ]
1556
  },
1557
  "metadata": {},
@@ -1560,12 +1577,12 @@
1560
  {
1561
  "data": {
1562
  "application/vnd.jupyter.widget-view+json": {
1563
- "model_id": "2ae161b79b0d4e688b12432455a6c065",
1564
  "version_major": 2,
1565
  "version_minor": 0
1566
  },
1567
  "text/plain": [
1568
- " 0%| | 0/2 [00:00<?, ?it/s]"
1569
  ]
1570
  },
1571
  "metadata": {},
@@ -1574,12 +1591,12 @@
1574
  {
1575
  "data": {
1576
  "application/vnd.jupyter.widget-view+json": {
1577
- "model_id": "7497a93eb57a40e281141126947f78ae",
1578
  "version_major": 2,
1579
  "version_minor": 0
1580
  },
1581
  "text/plain": [
1582
- " 0%| | 0/2 [00:00<?, ?it/s]"
1583
  ]
1584
  },
1585
  "metadata": {},
@@ -1592,7 +1609,7 @@
1592
  },
1593
  {
1594
  "cell_type": "code",
1595
- "execution_count": 116,
1596
  "metadata": {},
1597
  "outputs": [
1598
  {
@@ -1804,21 +1821,21 @@
1804
  }
1805
  ],
1806
  "source": [
1807
- "ddpm = DDPMScheduler(betas=(1e-4, 0.02), num_timesteps=config.num_timesteps, device=config.device)\n",
1808
  "\n",
1809
- "nn_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM)\n",
1810
- "print(\"resuming nn_model\")\n",
1811
- "nn_model.load_state_dict(torch.load(os.path.join(config.output_dir, f\"model_state.pth\"))['ema_unet_state_dict'])\n",
1812
- "# nn_model = ContextUnet(n_param=1, image_size=28)\n",
1813
- "# nn_model.train()\n",
1814
- "nn_model.to(ddpm.device)\n",
1815
- "nn_model.eval()\n",
1816
  "\n",
1817
- "n_sample = 20\n",
1818
- "with torch.no_grad():\n",
1819
- " x_last_ema, x_ema_entire = ddpm.sample(nn_model, n_sample, (1,config.HII_DIM, config.num_redshift), config.device, params = torch.tile(config.params_single,(n_sample,1)).to(config.device), guide_w=config.guide_w)\n",
1820
  "\n",
1821
- "np.save(os.path.join(config.output_dir, f\"{config.run_name}_ema.npy\"), x_last_ema)"
1822
  ]
1823
  },
1824
  {
 
32
  {
33
  "data": {
34
  "application/vnd.jupyter.widget-view+json": {
35
+ "model_id": "8be92a01e78a47b792d93b35d557885d",
36
  "version_major": 2,
37
  "version_minor": 0
38
  },
 
234
  },
235
  {
236
  "cell_type": "code",
237
+ "execution_count": 7,
238
  "metadata": {},
239
  "outputs": [],
240
  "source": [
241
  "class DDPMScheduler(nn.Module):\n",
242
+ " def __init__(self, betas: tuple, num_timesteps: int, img_shape: list, device='cpu'):\n",
243
  " super().__init__()\n",
244
  " \n",
245
  " beta_1, beta_T = betas\n",
246
  " assert 0 < beta_1 <= beta_T <= 1, \"ensure 0 < beta_1 <= beta_T <= 1\"\n",
247
  " self.device = device\n",
248
  " self.num_timesteps = num_timesteps\n",
249
+ " self.img_shape = img_shape\n",
250
  " self.beta_t = torch.linspace(beta_1, beta_T, self.num_timesteps) #* (beta_T-beta_1) + beta_1\n",
251
  " self.beta_t = self.beta_t.to(self.device)\n",
252
  "\n",
 
278
  " def sample(self, nn_model, params, device, guide_w = 0):\n",
279
  " n_sample = len(params) #params.shape[0]\n",
280
  " # print(\"params.shape[0], len(params)\", params.shape[0], len(params))\n",
281
+ " x_i = torch.randn(n_sample, *self.img_shape[1:]).to(device)\n",
282
  " # print(\"x_i.shape =\", x_i.shape)\n",
283
  " if guide_w != -1:\n",
284
  " c_i = params\n",
 
298
  " t_is = torch.tensor([i]).to(device)\n",
299
  " t_is = t_is.repeat(n_sample)\n",
300
  "\n",
301
+ " z = torch.randn(n_sample, *self.img_shape[1:]).to(device) if i > 0 else 0\n",
302
  "\n",
303
  " if guide_w == -1:\n",
304
  " # eps = nn_model(x_i, t_is, return_dict=False)[0]\n",
 
306
  " # x_i = 1/torch.sqrt(self.alpha_t[i])*(x_i-eps*self.beta_t[i]/torch.sqrt(1-self.bar_alpha_t[i])) + torch.sqrt(self.beta_t[i])*z\n",
307
  " else:\n",
308
  " # double batch\n",
309
+ " x_i = x_i.repeat(2, *torch.ones(len(self.img_shape[1:]), dtype=int).tolist())\n",
310
  " t_is = t_is.repeat(2)\n",
311
  "\n",
312
  " # split predictions and compute weighting\n",
 
339
  },
340
  {
341
  "cell_type": "code",
342
+ "execution_count": 8,
343
  "metadata": {},
344
  "outputs": [],
345
  "source": [
 
381
  },
382
  {
383
  "cell_type": "code",
384
+ "execution_count": 9,
385
  "metadata": {},
386
  "outputs": [],
387
  "source": [
388
  "class Downsample(nn.Module):\n",
389
+ " def __init__(self, channels, use_conv, out_channels=None, dim=2, stride=(2,2)):\n",
390
  " super().__init__()\n",
391
  " self.channels = channels\n",
392
  " self.out_channels = out_channels or channels\n",
393
+ " # stride = config.stride\n",
394
  " if use_conv:\n",
395
  " # print(\"conv\")\n",
396
+ " self.op = Conv[dim](channels, self.out_channels, 3, stride=stride, padding=1)\n",
397
  " else:\n",
398
  " # print(\"pool\")\n",
399
  " assert channels == self.out_channels\n",
400
+ " self.op = AvgPool[dim](kernel_size=stride, stride=stride)\n",
401
  "\n",
402
  " def forward(self, x):\n",
403
  " assert x.shape[1] == self.channels\n",
 
406
  },
407
  {
408
  "cell_type": "code",
409
+ "execution_count": 10,
410
  "metadata": {},
411
  "outputs": [],
412
  "source": [
413
  "class Upsample(nn.Module):\n",
414
+ " def __init__(self, channels, use_conv, out_channels=None, dim=2, stride=(2,2)):\n",
415
  " super().__init__()\n",
416
  " self.channels = channels\n",
417
  " self.out_channels = out_channels\n",
418
  " self.use_conv = use_conv\n",
419
+ " self.stride = stride\n",
420
  " if self.use_conv:\n",
421
+ " self.conv = Conv[dim](self.channels, self.out_channels, 3, padding=1)\n",
422
  "\n",
423
  " def forward(self, x):\n",
424
  " assert x.shape[1] == self.channels\n",
425
+ " # stride = config.stride\n",
426
  " # print(torch.tensor(x.shape[2:]))\n",
427
  " # print(torch.tensor(stride))\n",
428
+ " shape = torch.tensor(x.shape[2:]) * torch.tensor(self.stride)\n",
429
  " shape = tuple(shape.detach().numpy())\n",
430
  " # print(shape)\n",
431
  " x = F.interpolate(x, shape, mode='nearest')\n",
 
436
  },
437
  {
438
  "cell_type": "code",
439
+ "execution_count": 11,
440
  "metadata": {},
441
  "outputs": [],
442
  "source": [
 
451
  },
452
  {
453
  "cell_type": "code",
454
+ "execution_count": 12,
455
  "metadata": {},
456
  "outputs": [],
457
  "source": [
 
465
  },
466
  {
467
  "cell_type": "code",
468
+ "execution_count": 13,
469
  "metadata": {},
470
  "outputs": [],
471
  "source": [
 
483
  },
484
  {
485
  "cell_type": "code",
486
+ "execution_count": 14,
487
  "metadata": {},
488
  "outputs": [],
489
  "source": [
490
  "class ResBlock(TimestepBlock):\n",
491
  " def __init__(\n",
492
+ " self, channels, emb_channels, dropout, out_channels=None, use_conv=False, use_checkpoint=False, use_scale_shift_norm=False, up=False, down=False, dim=2, stride=(2,2),\n",
493
  " ):\n",
494
  " super().__init__()\n",
495
  " self.out_channels = out_channels or channels\n",
496
  " self.use_scale_shift_norm = use_scale_shift_norm\n",
497
+ " self.stride = stride\n",
498
  "\n",
499
  " self.in_layers = nn.Sequential(\n",
500
  " # nn.BatchNorm2d(channels), # normalize to standard gaussian\n",
501
  " normalization(channels, swish=1.0),\n",
502
  " nn.Identity(),\n",
503
+ " Conv[dim](channels, self.out_channels, 3, padding=1),\n",
504
  " )\n",
505
  "\n",
506
  " self.updown = up or down\n",
507
  " if up:\n",
508
+ " self.h_updown = Upsample(channels, False, dim=dim, stride=stride)\n",
509
+ " self.x_updown = Upsample(channels, False, dim=dim, stride=stride)\n",
510
  " elif down:\n",
511
+ " self.h_updown = Downsample(channels, False, dim=dim, stride=stride)\n",
512
+ " self.x_updown = Downsample(channels, False, dim=dim, stride=stride)\n",
513
  " else:\n",
514
  " self.h_updown = self.x_updown = nn.Identity()\n",
515
  "\n",
 
526
  " normalization(self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0),\n",
527
  " nn.SiLU() if use_scale_shift_norm else nn.Identity(),\n",
528
  " nn.Dropout(p=dropout),\n",
529
+ " zero_module(Conv[dim](self.out_channels, self.out_channels, 3, padding=1)),\n",
530
  " )\n",
531
  "\n",
532
  " if self.out_channels == channels:\n",
533
  " self.skip_connection = nn.Identity()\n",
534
  " elif use_conv:\n",
535
+ " self.skip_connection = Conv[dim](channels, self.out_channels, 3, padding=1)\n",
536
  " else:\n",
537
+ " self.skip_connection = Conv[dim](channels, self.out_channels, 1)\n",
538
  " \n",
539
  "\n",
540
  " def forward(self, x, emb):\n",
 
565
  },
566
  {
567
  "cell_type": "code",
568
+ "execution_count": 15,
569
  "metadata": {},
570
  "outputs": [],
571
  "source": [
 
598
  },
599
  {
600
  "cell_type": "code",
601
+ "execution_count": 16,
602
  "metadata": {},
603
  "outputs": [],
604
  "source": [
 
647
  },
648
  {
649
  "cell_type": "code",
650
+ "execution_count": 17,
651
  "metadata": {},
652
  "outputs": [],
653
  "source": [
 
676
  },
677
  {
678
  "cell_type": "code",
679
+ "execution_count": 18,
680
  "metadata": {},
681
  "outputs": [],
682
  "source": [
 
700
  " resblock_updown = False,\n",
701
  " conv_resample = True,\n",
702
  " encoder_channels = None,\n",
703
+ " dim = 2,\n",
704
+ " stride = (2,2)\n",
705
  " ):\n",
706
  " super().__init__()\n",
707
  "\n",
 
747
  "\n",
748
  " ###################### input_blocks ######################\n",
749
  " self.input_blocks = nn.ModuleList(\n",
750
+ " [TimestepEmbedSequential(Conv[dim](in_channels, ch, 3, padding=1))]\n",
751
  " )\n",
752
  " self._feature_size = ch\n",
753
  " input_block_chans = [ch]\n",
 
763
  " out_channels = int(mult * model_channels),\n",
764
  " use_checkpoint = use_checkpoint,\n",
765
  " use_scale_shift_norm = use_scale_shift_norm,\n",
766
+ " dim = dim,\n",
767
+ " stride = stride,\n",
768
  " )\n",
769
  " ]\n",
770
  " ch = int(mult * model_channels)\n",
 
795
  " use_checkpoint=use_checkpoint,\n",
796
  " use_scale_shift_norm=use_scale_shift_norm,\n",
797
  " down=True,\n",
798
+ " dim = dim,\n",
799
+ " stride = stride,\n",
800
  " )\n",
801
  " if resblock_updown\n",
802
+ " else Downsample(ch, conv_resample, out_channels=out_ch, dim=dim, stride=stride)\n",
803
  " )\n",
804
  " )\n",
805
  " ch = out_ch\n",
 
816
  " dropout,\n",
817
  " use_checkpoint=use_checkpoint,\n",
818
  " use_scale_shift_norm=use_scale_shift_norm,\n",
819
+ " dim = dim,\n",
820
+ " stride = stride,\n",
821
  " ),\n",
822
  " AttentionBlock(\n",
823
  " ch,\n",
 
832
  " dropout,\n",
833
  " use_checkpoint=use_checkpoint,\n",
834
  " use_scale_shift_norm=use_scale_shift_norm,\n",
835
+ " dim = dim,\n",
836
+ " stride = stride,\n",
837
  " ),\n",
838
  " )\n",
839
  " self._feature_size += ch\n",
 
853
  " # dims=dims,\n",
854
  " use_checkpoint=use_checkpoint,\n",
855
  " use_scale_shift_norm=use_scale_shift_norm,\n",
856
+ " dim = dim,\n",
857
+ " stride = stride,\n",
858
  " )\n",
859
  " ]\n",
860
  " ch = int(model_channels * mult)\n",
 
881
  " use_checkpoint=use_checkpoint,\n",
882
  " use_scale_shift_norm=use_scale_shift_norm,\n",
883
  " up=True,\n",
884
+ " dim = dim,\n",
885
+ " stride = stride,\n",
886
  " )\n",
887
  " if resblock_updown\n",
888
+ " else Upsample(ch, conv_resample, out_channels=out_ch, dim=dim, stride=stride)\n",
889
  " )\n",
890
  " ds //= 2\n",
891
  " self.output_blocks.append(TimestepEmbedSequential(*layers))\n",
 
895
  " # nn.BatchNorm2d(ch),\n",
896
  " normalization(ch, swish=1.0),\n",
897
  " nn.Identity(),\n",
898
+ " zero_module(Conv[dim](input_ch, out_channels, 3, padding=1)),\n",
899
  " )\n",
900
  " # self.use_fp16 = use_fp16\n",
901
  "\n",
 
932
  },
933
  {
934
  "cell_type": "code",
935
+ "execution_count": 19,
936
  "metadata": {},
937
  "outputs": [],
938
  "source": [
 
962
  },
963
  {
964
  "cell_type": "code",
965
+ "execution_count": 20,
966
  "metadata": {},
967
  "outputs": [],
968
  "source": [
 
980
  " # dim = 2\n",
981
  " dim = 2\n",
982
  " stride = (2,2) if dim == 2 else (2,2,4)\n",
983
+ " num_image = 200 # 2400\n",
984
  " HII_DIM = 64\n",
985
  " num_redshift = 512#256#256#64#512#128\n",
986
  " img_shape = (HII_DIM, num_redshift) if dim == 2 else (HII_DIM, HII_DIM, num_redshift)\n",
 
1027
  },
1028
  {
1029
  "cell_type": "code",
1030
+ "execution_count": 21,
1031
  "metadata": {},
1032
  "outputs": [],
1033
  "source": [
 
1037
  },
1038
  {
1039
  "cell_type": "code",
1040
+ "execution_count": 22,
1041
  "metadata": {},
1042
  "outputs": [],
1043
  "source": [
 
1046
  },
1047
  {
1048
  "cell_type": "code",
1049
+ "execution_count": 23,
1050
  "metadata": {},
1051
  "outputs": [],
1052
  "source": [
 
1070
  },
1071
  {
1072
  "cell_type": "code",
1073
+ "execution_count": 24,
1074
  "metadata": {},
1075
  "outputs": [],
1076
  "source": [
 
1268
  },
1269
  {
1270
  "cell_type": "code",
1271
+ "execution_count": 25,
1272
  "metadata": {},
1273
  "outputs": [
1274
  {
 
1279
  "51200 images can be loaded\n",
1280
  "field.shape = (64, 64, 514)\n",
1281
  "params keys = [b'ION_Tvir_MIN', b'HII_EFF_FACTOR']\n",
1282
+ "loading 200 images randomly\n",
1283
+ "images loaded: (200, 1, 64, 512)\n",
1284
+ "params loaded: (200, 2)\n",
1285
+ "images rescaled to [-1.0, 1.082756519317627]\n",
1286
+ "params rescaled to [0.0, 0.9938162632551855]\n",
1287
  "resumed nn_model from model_state.pth\n",
1288
  "Number of parameters for nn_model: 111048705\n"
1289
  ]
 
1315
  " self.dataloader = DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True)\n",
1316
  " del dataset\n",
1317
  "\n",
1318
+ " self.ddpm = DDPMScheduler(betas=(1e-4, 0.02), num_timesteps=config.num_timesteps, img_shape=config.img_shape, device=config.device)\n",
1319
  "\n",
1320
  " # initialize the unet\n",
1321
+ " self.nn_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM, dim=config.dim, stride=config.stride)\n",
1322
  "\n",
1323
  " if config.resume:\n",
1324
  " self.nn_model.load_state_dict(torch.load(os.path.join(config.output_dir, f\"{config.resume}\"))['unet_state_dict'])\n",
 
1335
  " if config.ema:\n",
1336
  " self.ema = EMA(config.ema_rate)\n",
1337
  " if config.resume:\n",
1338
+ " self.ema_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM, dim=config.dim, stride=config.stride).to(config.device)\n",
1339
  " self.ema_model.load_state_dict(torch.load(os.path.join(config.output_dir, f\"{config.resume}\"))['ema_unet_state_dict'])\n",
1340
  " print(f\"resumed ema_model from {config.resume}\")\n",
1341
  " else:\n",
 
1459
  },
1460
  {
1461
  "cell_type": "code",
1462
+ "execution_count": 26,
1463
  "metadata": {},
1464
  "outputs": [
1465
  {
1466
  "data": {
1467
  "application/vnd.jupyter.widget-view+json": {
1468
+ "model_id": "7a0b627f28ef409f8504113bc3af36e3",
1469
  "version_major": 2,
1470
  "version_minor": 0
1471
  },
1472
  "text/plain": [
1473
+ " 0%| | 0/20 [00:00<?, ?it/s]"
1474
  ]
1475
  },
1476
  "metadata": {},
 
1479
  {
1480
  "data": {
1481
  "application/vnd.jupyter.widget-view+json": {
1482
+ "model_id": "62f09cd440a84841b336ab15e76e2fe6",
1483
  "version_major": 2,
1484
  "version_minor": 0
1485
  },
1486
  "text/plain": [
1487
+ " 0%| | 0/20 [00:00<?, ?it/s]"
1488
  ]
1489
  },
1490
  "metadata": {},
 
1493
  {
1494
  "data": {
1495
  "application/vnd.jupyter.widget-view+json": {
1496
+ "model_id": "9db24e29de0c47328f1aba68db61bbae",
1497
  "version_major": 2,
1498
  "version_minor": 0
1499
  },
1500
  "text/plain": [
1501
+ " 0%| | 0/20 [00:00<?, ?it/s]"
1502
  ]
1503
  },
1504
  "metadata": {},
 
1507
  {
1508
  "data": {
1509
  "application/vnd.jupyter.widget-view+json": {
1510
+ "model_id": "ee59d1a664d04a2b90a7a448a816ed10",
1511
  "version_major": 2,
1512
  "version_minor": 0
1513
  },
1514
  "text/plain": [
1515
+ " 0%| | 0/20 [00:00<?, ?it/s]"
1516
  ]
1517
  },
1518
  "metadata": {},
 
1521
  {
1522
  "data": {
1523
  "application/vnd.jupyter.widget-view+json": {
1524
+ "model_id": "8690c736f7eb4a23925b450c05659575",
1525
  "version_major": 2,
1526
  "version_minor": 0
1527
  },
1528
  "text/plain": [
1529
+ " 0%| | 0/20 [00:00<?, ?it/s]"
1530
  ]
1531
  },
1532
  "metadata": {},
 
1535
  {
1536
  "data": {
1537
  "application/vnd.jupyter.widget-view+json": {
1538
+ "model_id": "7dc014a33bfd43408e0aafc208bb403e",
1539
  "version_major": 2,
1540
  "version_minor": 0
1541
  },
1542
  "text/plain": [
1543
+ " 0%| | 0/20 [00:00<?, ?it/s]"
1544
  ]
1545
  },
1546
  "metadata": {},
 
1549
  {
1550
  "data": {
1551
  "application/vnd.jupyter.widget-view+json": {
1552
+ "model_id": "6715e5cccc6d480397f76bcea34f94e5",
1553
  "version_major": 2,
1554
  "version_minor": 0
1555
  },
1556
  "text/plain": [
1557
+ " 0%| | 0/20 [00:00<?, ?it/s]"
1558
  ]
1559
  },
1560
  "metadata": {},
 
1563
  {
1564
  "data": {
1565
  "application/vnd.jupyter.widget-view+json": {
1566
+ "model_id": "b7410efd4a5d4efdb9b8be38ba1c2fcb",
1567
  "version_major": 2,
1568
  "version_minor": 0
1569
  },
1570
  "text/plain": [
1571
+ " 0%| | 0/20 [00:00<?, ?it/s]"
1572
  ]
1573
  },
1574
  "metadata": {},
 
1577
  {
1578
  "data": {
1579
  "application/vnd.jupyter.widget-view+json": {
1580
+ "model_id": "3b6c0478c9ff4a99b7f79ba4422dbd7d",
1581
  "version_major": 2,
1582
  "version_minor": 0
1583
  },
1584
  "text/plain": [
1585
+ " 0%| | 0/20 [00:00<?, ?it/s]"
1586
  ]
1587
  },
1588
  "metadata": {},
 
1591
  {
1592
  "data": {
1593
  "application/vnd.jupyter.widget-view+json": {
1594
+ "model_id": "d26f49f5a9804d84b6b6a531a56eb03a",
1595
  "version_major": 2,
1596
  "version_minor": 0
1597
  },
1598
  "text/plain": [
1599
+ " 0%| | 0/20 [00:00<?, ?it/s]"
1600
  ]
1601
  },
1602
  "metadata": {},
 
1609
  },
1610
  {
1611
  "cell_type": "code",
1612
+ "execution_count": null,
1613
  "metadata": {},
1614
  "outputs": [
1615
  {
 
1821
  }
1822
  ],
1823
  "source": [
1824
+ "# ddpm = DDPMScheduler(betas=(1e-4, 0.02), num_timesteps=config.num_timesteps, device=config.device)\n",
1825
  "\n",
1826
+ "# nn_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM)\n",
1827
+ "# print(\"resuming nn_model\")\n",
1828
+ "# nn_model.load_state_dict(torch.load(os.path.join(config.output_dir, f\"model_state.pth\"))['ema_unet_state_dict'])\n",
1829
+ "# # nn_model = ContextUnet(n_param=1, image_size=28)\n",
1830
+ "# # nn_model.train()\n",
1831
+ "# nn_model.to(ddpm.device)\n",
1832
+ "# nn_model.eval()\n",
1833
  "\n",
1834
+ "# n_sample = 20\n",
1835
+ "# with torch.no_grad():\n",
1836
+ "# x_last_ema, x_ema_entire = ddpm.sample(nn_model, n_sample, (1,config.HII_DIM, config.num_redshift), config.device, params = torch.tile(config.params_single,(n_sample,1)).to(config.device), guide_w=config.guide_w)\n",
1837
  "\n",
1838
+ "# np.save(os.path.join(config.output_dir, f\"{config.run_name}_ema.npy\"), x_last_ema)"
1839
  ]
1840
  },
1841
  {