Xsmos commited on
Commit
5ef4a5c
·
verified ·
1 Parent(s): 443fbc2
Files changed (1) hide show
  1. diffusion.ipynb +78 -73
diffusion.ipynb CHANGED
@@ -32,7 +32,7 @@
32
  {
33
  "data": {
34
  "application/vnd.jupyter.widget-view+json": {
35
- "model_id": "8be92a01e78a47b792d93b35d557885d",
36
  "version_major": 2,
37
  "version_minor": 0
38
  },
@@ -962,7 +962,7 @@
962
  },
963
  {
964
  "cell_type": "code",
965
- "execution_count": 20,
966
  "metadata": {},
967
  "outputs": [],
968
  "source": [
@@ -987,7 +987,7 @@
987
  "\n",
988
  " n_epoch = 10#2#5#25 # 120\n",
989
  " num_timesteps = 1000#1000 # 1000, 500; DDPM time steps\n",
990
- " train_batch_size = 10#10#20#2#100 # 10\n",
991
  " # n_sample = 24 # 64, the number of samples in sampling process\n",
992
  " n_param = 2\n",
993
  " guide_w = 0#-1#0#-1#0.1#[0,0.1] #[0,0.5,2] strength of generative guidance\n",
@@ -1032,7 +1032,7 @@
1032
  "outputs": [],
1033
  "source": [
1034
  "# dataset = Dataset4h5(config.dataset_name, num_image=config.num_image, HII_DIM=config.HII_DIM, num_redshift=config.num_redshift, drop_prob=config.drop_prob, dim=config.dim)\n",
1035
- "# dataloader = DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True)"
1036
  ]
1037
  },
1038
  {
@@ -1083,10 +1083,10 @@
1083
  "# # HII_DIM = 64\n",
1084
  "# # num_redshift = 64#512#128\n",
1085
  "# # dataset = Dataset4h5(config.dataset_name, num_image=config.num_image, HII_DIM=config.HII_DIM, num_redshift=config.num_redshift, drop_prob=config.drop_prob)\n",
1086
- "# # dataloader = DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True)\n",
1087
  "# # tf = transforms.Compose([transforms.ToTensor()]) # mnist is already normalised 0 to 1\n",
1088
  "# # dataset = MNIST(\"./data\", train=True, download=True, transform=tf)\n",
1089
- "# # dataloader = DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True, num_workers=5) # Initialize accelerator and tensorboard logging\n",
1090
  "# accelerator = Accelerator(\n",
1091
  "# mixed_precision=config.mixed_precision,\n",
1092
  "# gradient_accumulation_steps=config.gradient_accumulation_steps,\n",
@@ -1268,37 +1268,15 @@
1268
  },
1269
  {
1270
  "cell_type": "code",
1271
- "execution_count": 25,
1272
  "metadata": {},
1273
  "outputs": [
1274
  {
1275
  "name": "stdout",
1276
  "output_type": "stream",
1277
  "text": [
1278
- "dataset content: <KeysViewHDF5 ['brightness_temp', 'density', 'kwargs', 'params', 'redshifts_distances', 'seeds', 'xH_box']>\n",
1279
- "51200 images can be loaded\n",
1280
- "field.shape = (64, 64, 514)\n",
1281
- "params keys = [b'ION_Tvir_MIN', b'HII_EFF_FACTOR']\n",
1282
- "loading 200 images randomly\n",
1283
- "images loaded: (200, 1, 64, 512)\n",
1284
- "params loaded: (200, 2)\n",
1285
- "images rescaled to [-1.0, 1.082756519317627]\n",
1286
- "params rescaled to [0.0, 0.9938162632551855]\n",
1287
  "resumed nn_model from model_state.pth\n",
1288
- "Number of parameters for nn_model: 111048705\n"
1289
- ]
1290
- },
1291
- {
1292
- "name": "stderr",
1293
- "output_type": "stream",
1294
- "text": [
1295
- "Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n"
1296
- ]
1297
- },
1298
- {
1299
- "name": "stdout",
1300
- "output_type": "stream",
1301
- "text": [
1302
  "resumed ema_model from model_state.pth\n"
1303
  ]
1304
  }
@@ -1306,15 +1284,14 @@
1306
  "source": [
1307
  "# @dataclass\n",
1308
  "class DDPM21CM:\n",
1309
- " def __init__(self, config):\n",
 
1310
  " self.config = config\n",
1311
- "\n",
1312
- " dataset = Dataset4h5(config.dataset_name, num_image=config.num_image, HII_DIM=config.HII_DIM, num_redshift=config.num_redshift, drop_prob=config.drop_prob, dim=config.dim)\n",
1313
- " # self.shape_loaded = dataset.images.shape\n",
1314
- " # print(\"shape_loaded =\", self.shape_loaded)\n",
1315
- " self.dataloader = DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True)\n",
1316
- " del dataset\n",
1317
- "\n",
1318
  " self.ddpm = DDPMScheduler(betas=(1e-4, 0.02), num_timesteps=config.num_timesteps, img_shape=config.img_shape, device=config.device)\n",
1319
  "\n",
1320
  " # initialize the unet\n",
@@ -1345,10 +1322,17 @@
1345
  " self.lr_scheduler = get_cosine_schedule_with_warmup(\n",
1346
  " optimizer=self.optimizer,\n",
1347
  " num_warmup_steps=config.lr_warmup_steps,\n",
1348
- " num_training_steps=(len(self.dataloader) * config.n_epoch),\n",
 
1349
  " )\n",
1350
  "\n",
1351
- " self.accelerate(config)\n",
 
 
 
 
 
 
1352
  "\n",
1353
  " def accelerate(self, config):\n",
1354
  " self.accelerator = Accelerator(\n",
@@ -1376,6 +1360,7 @@
1376
  " ## training loop ##\n",
1377
  " ###################\n",
1378
  " # plot_unet = True\n",
 
1379
  " global_step = 0\n",
1380
  " for ep in range(self.config.n_epoch):\n",
1381
  " self.ddpm.train()\n",
@@ -1446,26 +1431,54 @@
1446
  " model = self.ema_model if ema else self.nn_model\n",
1447
  " # params = torch.tile(params, (n_sample,1)).to(device)\n",
1448
  "\n",
1449
- " x_last, x_entire = self.ddpm.sample(model, params=params, shape=shape, device=self.config.device, guide_w=self.config.guide_w)\n",
1450
  "\n",
1451
  " np.save(os.path.join(self.config.output_dir, f\"{self.config.run_name}{'ema' if ema else None}.npy\"), x_last)\n",
1452
  " if entire:\n",
1453
  " np.save(os.path.join(self.config.output_dir, f\"{self.config.run_name}{'ema' if ema else None}_entire.npy\"), x_last)\n",
1454
  "\n",
1455
  "\n",
1456
- "ddpm21cm = DDPM21CM(TrainConfig())\n",
1457
  "# print(\"device =\", config.device)"
1458
  ]
1459
  },
1460
  {
1461
  "cell_type": "code",
1462
- "execution_count": 26,
1463
  "metadata": {},
1464
  "outputs": [
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1465
  {
1466
  "data": {
1467
  "application/vnd.jupyter.widget-view+json": {
1468
- "model_id": "7a0b627f28ef409f8504113bc3af36e3",
1469
  "version_major": 2,
1470
  "version_minor": 0
1471
  },
@@ -1479,7 +1492,7 @@
1479
  {
1480
  "data": {
1481
  "application/vnd.jupyter.widget-view+json": {
1482
- "model_id": "62f09cd440a84841b336ab15e76e2fe6",
1483
  "version_major": 2,
1484
  "version_minor": 0
1485
  },
@@ -1493,7 +1506,7 @@
1493
  {
1494
  "data": {
1495
  "application/vnd.jupyter.widget-view+json": {
1496
- "model_id": "9db24e29de0c47328f1aba68db61bbae",
1497
  "version_major": 2,
1498
  "version_minor": 0
1499
  },
@@ -1507,7 +1520,7 @@
1507
  {
1508
  "data": {
1509
  "application/vnd.jupyter.widget-view+json": {
1510
- "model_id": "ee59d1a664d04a2b90a7a448a816ed10",
1511
  "version_major": 2,
1512
  "version_minor": 0
1513
  },
@@ -1521,7 +1534,7 @@
1521
  {
1522
  "data": {
1523
  "application/vnd.jupyter.widget-view+json": {
1524
- "model_id": "8690c736f7eb4a23925b450c05659575",
1525
  "version_major": 2,
1526
  "version_minor": 0
1527
  },
@@ -1535,7 +1548,7 @@
1535
  {
1536
  "data": {
1537
  "application/vnd.jupyter.widget-view+json": {
1538
- "model_id": "7dc014a33bfd43408e0aafc208bb403e",
1539
  "version_major": 2,
1540
  "version_minor": 0
1541
  },
@@ -1549,7 +1562,7 @@
1549
  {
1550
  "data": {
1551
  "application/vnd.jupyter.widget-view+json": {
1552
- "model_id": "6715e5cccc6d480397f76bcea34f94e5",
1553
  "version_major": 2,
1554
  "version_minor": 0
1555
  },
@@ -1563,7 +1576,7 @@
1563
  {
1564
  "data": {
1565
  "application/vnd.jupyter.widget-view+json": {
1566
- "model_id": "b7410efd4a5d4efdb9b8be38ba1c2fcb",
1567
  "version_major": 2,
1568
  "version_minor": 0
1569
  },
@@ -1577,7 +1590,7 @@
1577
  {
1578
  "data": {
1579
  "application/vnd.jupyter.widget-view+json": {
1580
- "model_id": "3b6c0478c9ff4a99b7f79ba4422dbd7d",
1581
  "version_major": 2,
1582
  "version_minor": 0
1583
  },
@@ -1591,7 +1604,7 @@
1591
  {
1592
  "data": {
1593
  "application/vnd.jupyter.widget-view+json": {
1594
- "model_id": "d26f49f5a9804d84b6b6a531a56eb03a",
1595
  "version_major": 2,
1596
  "version_minor": 0
1597
  },
@@ -1609,28 +1622,20 @@
1609
  },
1610
  {
1611
  "cell_type": "code",
1612
- "execution_count": null,
1613
  "metadata": {},
1614
  "outputs": [
1615
  {
1616
- "name": "stdout",
1617
- "output_type": "stream",
1618
- "text": [
1619
- "params.shape[0], len(params) 2 2\n"
1620
- ]
1621
- },
1622
- {
1623
- "ename": "AttributeError",
1624
- "evalue": "'DDPMScheduler' object has no attribute 'shape'",
1625
  "output_type": "error",
1626
  "traceback": [
1627
  "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
1628
- "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
1629
- "Cell \u001b[0;32mIn[116], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m ddpm21cm\u001b[39m.\u001b[39;49msample(\u001b[39m\"\u001b[39;49m\u001b[39m./outputs/model_state_09.pth\u001b[39;49m\u001b[39m\"\u001b[39;49m)\n",
1630
- "Cell \u001b[0;32mIn[115], line 143\u001b[0m, in \u001b[0;36mDDPM21CM.sample\u001b[0;34m(self, file, params, ema, entire)\u001b[0m\n\u001b[1;32m 140\u001b[0m model \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mema_model \u001b[39mif\u001b[39;00m ema \u001b[39melse\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mnn_model\n\u001b[1;32m 141\u001b[0m \u001b[39m# params = torch.tile(params, (n_sample,1)).to(device)\u001b[39;00m\n\u001b[0;32m--> 143\u001b[0m x_last, x_entire \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mddpm\u001b[39m.\u001b[39;49msample(model, params\u001b[39m=\u001b[39;49mparams, device\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mconfig\u001b[39m.\u001b[39;49mdevice, guide_w\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mconfig\u001b[39m.\u001b[39;49mguide_w)\n\u001b[1;32m 145\u001b[0m np\u001b[39m.\u001b[39msave(os\u001b[39m.\u001b[39mpath\u001b[39m.\u001b[39mjoin(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconfig\u001b[39m.\u001b[39moutput_dir, \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m{\u001b[39;00m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconfig\u001b[39m.\u001b[39mrun_name\u001b[39m}\u001b[39;00m\u001b[39m{\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39mema\u001b[39m\u001b[39m'\u001b[39m\u001b[39m \u001b[39m\u001b[39mif\u001b[39;00m\u001b[39m \u001b[39mema\u001b[39m \u001b[39m\u001b[39melse\u001b[39;00m\u001b[39m \u001b[39m\u001b[39mNone\u001b[39;00m\u001b[39m}\u001b[39;00m\u001b[39m.npy\u001b[39m\u001b[39m\"\u001b[39m), x_last)\n\u001b[1;32m 146\u001b[0m \u001b[39mif\u001b[39;00m entire:\n",
1631
- "Cell \u001b[0;32mIn[90], line 40\u001b[0m, in \u001b[0;36mDDPMScheduler.sample\u001b[0;34m(self, nn_model, params, device, guide_w)\u001b[0m\n\u001b[1;32m 38\u001b[0m n_sample \u001b[39m=\u001b[39m params\u001b[39m.\u001b[39mshape[\u001b[39m0\u001b[39m]\n\u001b[1;32m 39\u001b[0m \u001b[39mprint\u001b[39m(\u001b[39m\"\u001b[39m\u001b[39mparams.shape[0], len(params)\u001b[39m\u001b[39m\"\u001b[39m, params\u001b[39m.\u001b[39mshape[\u001b[39m0\u001b[39m], \u001b[39mlen\u001b[39m(params))\n\u001b[0;32m---> 40\u001b[0m x_i \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mrandn(n_sample, \u001b[39m*\u001b[39m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mshape[\u001b[39m1\u001b[39m:])\u001b[39m.\u001b[39mto(device)\n\u001b[1;32m 41\u001b[0m \u001b[39m# print(\"x_i.shape =\", x_i.shape)\u001b[39;00m\n\u001b[1;32m 42\u001b[0m \u001b[39mif\u001b[39;00m guide_w \u001b[39m!=\u001b[39m \u001b[39m-\u001b[39m\u001b[39m1\u001b[39m:\n",
1632
- "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/module.py:1207\u001b[0m, in \u001b[0;36mModule.__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m 1205\u001b[0m \u001b[39mif\u001b[39;00m name \u001b[39min\u001b[39;00m modules:\n\u001b[1;32m 1206\u001b[0m \u001b[39mreturn\u001b[39;00m modules[name]\n\u001b[0;32m-> 1207\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mAttributeError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39m'\u001b[39m\u001b[39m{}\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39m object has no attribute \u001b[39m\u001b[39m'\u001b[39m\u001b[39m{}\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m.\u001b[39mformat(\n\u001b[1;32m 1208\u001b[0m \u001b[39mtype\u001b[39m(\u001b[39mself\u001b[39m)\u001b[39m.\u001b[39m\u001b[39m__name__\u001b[39m, name))\n",
1633
- "\u001b[0;31mAttributeError\u001b[0m: 'DDPMScheduler' object has no attribute 'shape'"
1634
  ]
1635
  }
1636
  ],
@@ -2030,7 +2035,7 @@
2030
  "@dataclass\n",
2031
  "class TrainingConfig:\n",
2032
  " image_size = 128\n",
2033
- " train_batch_size = 16\n",
2034
  " eval_batch_size = 16\n",
2035
  " num_epochs = 50\n",
2036
  " gradient_accumulation_steps = 1\n",
@@ -2067,7 +2072,7 @@
2067
  " return {\"images\": images}\n",
2068
  "\n",
2069
  "dataset.set_transform(transform)\n",
2070
- "dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True)\n",
2071
  "\n",
2072
  "model = UNet2DModel(\n",
2073
  " sample_size = config.image_size,\n",
@@ -2268,7 +2273,7 @@
2268
  "class TrainingConfig:\n",
2269
  " num_images = 600\n",
2270
  " image_size = [64,512]\n",
2271
- " train_batch_size = 10\n",
2272
  " eval_batch_size = 24\n",
2273
  " num_epochs = 20\n",
2274
  " gradient_accumulation_steps = 1\n",
@@ -2547,7 +2552,7 @@
2547
  "outputs": [],
2548
  "source": [
2549
  "# dataset.set_transform(transform)\n",
2550
- "dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True)"
2551
  ]
2552
  },
2553
  {
 
32
  {
33
  "data": {
34
  "application/vnd.jupyter.widget-view+json": {
35
+ "model_id": "f925cb378800455fb1216e84a2900e58",
36
  "version_major": 2,
37
  "version_minor": 0
38
  },
 
962
  },
963
  {
964
  "cell_type": "code",
965
+ "execution_count": 29,
966
  "metadata": {},
967
  "outputs": [],
968
  "source": [
 
987
  "\n",
988
  " n_epoch = 10#2#5#25 # 120\n",
989
  " num_timesteps = 1000#1000 # 1000, 500; DDPM time steps\n",
990
+ " batch_size = 10#10#20#2#100 # 10\n",
991
  " # n_sample = 24 # 64, the number of samples in sampling process\n",
992
  " n_param = 2\n",
993
  " guide_w = 0#-1#0#-1#0.1#[0,0.1] #[0,0.5,2] strength of generative guidance\n",
 
1032
  "outputs": [],
1033
  "source": [
1034
  "# dataset = Dataset4h5(config.dataset_name, num_image=config.num_image, HII_DIM=config.HII_DIM, num_redshift=config.num_redshift, drop_prob=config.drop_prob, dim=config.dim)\n",
1035
+ "# dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True)"
1036
  ]
1037
  },
1038
  {
 
1083
  "# # HII_DIM = 64\n",
1084
  "# # num_redshift = 64#512#128\n",
1085
  "# # dataset = Dataset4h5(config.dataset_name, num_image=config.num_image, HII_DIM=config.HII_DIM, num_redshift=config.num_redshift, drop_prob=config.drop_prob)\n",
1086
+ "# # dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True)\n",
1087
  "# # tf = transforms.Compose([transforms.ToTensor()]) # mnist is already normalised 0 to 1\n",
1088
  "# # dataset = MNIST(\"./data\", train=True, download=True, transform=tf)\n",
1089
+ "# # dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True, num_workers=5) # Initialize accelerator and tensorboard logging\n",
1090
  "# accelerator = Accelerator(\n",
1091
  "# mixed_precision=config.mixed_precision,\n",
1092
  "# gradient_accumulation_steps=config.gradient_accumulation_steps,\n",
 
1268
  },
1269
  {
1270
  "cell_type": "code",
1271
+ "execution_count": 31,
1272
  "metadata": {},
1273
  "outputs": [
1274
  {
1275
  "name": "stdout",
1276
  "output_type": "stream",
1277
  "text": [
 
 
 
 
 
 
 
 
 
1278
  "resumed nn_model from model_state.pth\n",
1279
+ "Number of parameters for nn_model: 111048705\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
1280
  "resumed ema_model from model_state.pth\n"
1281
  ]
1282
  }
 
1284
  "source": [
1285
  "# @dataclass\n",
1286
  "class DDPM21CM:\n",
1287
+ " def __init__(self):\n",
1288
+ " config = TrainConfig()\n",
1289
  " self.config = config\n",
1290
+ " # dataset = Dataset4h5(config.dataset_name, num_image=config.num_image, HII_DIM=config.HII_DIM, num_redshift=config.num_redshift, drop_prob=config.drop_prob, dim=config.dim)\n",
1291
+ " # # self.shape_loaded = dataset.images.shape\n",
1292
+ " # # print(\"shape_loaded =\", self.shape_loaded)\n",
1293
+ " # self.dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True)\n",
1294
+ " # del dataset\n",
 
 
1295
  " self.ddpm = DDPMScheduler(betas=(1e-4, 0.02), num_timesteps=config.num_timesteps, img_shape=config.img_shape, device=config.device)\n",
1296
  "\n",
1297
  " # initialize the unet\n",
 
1322
  " self.lr_scheduler = get_cosine_schedule_with_warmup(\n",
1323
  " optimizer=self.optimizer,\n",
1324
  " num_warmup_steps=config.lr_warmup_steps,\n",
1325
+ " num_training_steps=(int(config.num_image/config.batch_size) * config.n_epoch),\n",
1326
+ " # num_training_steps=(len(self.dataloader) * config.n_epoch),\n",
1327
  " )\n",
1328
  "\n",
1329
+ " def load(self):\n",
1330
+ " dataset = Dataset4h5(self.config.dataset_name, num_image=self.config.num_image, HII_DIM=self.config.HII_DIM, num_redshift=self.config.num_redshift, drop_prob=self.config.drop_prob, dim=self.config.dim)\n",
1331
+ " # self.shape_loaded = dataset.images.shape\n",
1332
+ " # print(\"shape_loaded =\", self.shape_loaded)\n",
1333
+ " self.dataloader = DataLoader(dataset, batch_size=self.config.batch_size, shuffle=True)\n",
1334
+ " # del dataset\n",
1335
+ " self.accelerate(self.config)\n",
1336
  "\n",
1337
  " def accelerate(self, config):\n",
1338
  " self.accelerator = Accelerator(\n",
 
1360
  " ## training loop ##\n",
1361
  " ###################\n",
1362
  " # plot_unet = True\n",
1363
+ " self.load()\n",
1364
  " global_step = 0\n",
1365
  " for ep in range(self.config.n_epoch):\n",
1366
  " self.ddpm.train()\n",
 
1431
  " model = self.ema_model if ema else self.nn_model\n",
1432
  " # params = torch.tile(params, (n_sample,1)).to(device)\n",
1433
  "\n",
1434
+ " x_last, x_entire = self.ddpm.sample(model, params=params, device=self.config.device, guide_w=self.config.guide_w)\n",
1435
  "\n",
1436
  " np.save(os.path.join(self.config.output_dir, f\"{self.config.run_name}{'ema' if ema else None}.npy\"), x_last)\n",
1437
  " if entire:\n",
1438
  " np.save(os.path.join(self.config.output_dir, f\"{self.config.run_name}{'ema' if ema else None}_entire.npy\"), x_last)\n",
1439
  "\n",
1440
  "\n",
1441
+ "ddpm21cm = DDPM21CM()\n",
1442
  "# print(\"device =\", config.device)"
1443
  ]
1444
  },
1445
  {
1446
  "cell_type": "code",
1447
+ "execution_count": 32,
1448
  "metadata": {},
1449
  "outputs": [
1450
+ {
1451
+ "name": "stdout",
1452
+ "output_type": "stream",
1453
+ "text": [
1454
+ "dataset content: <KeysViewHDF5 ['brightness_temp', 'density', 'kwargs', 'params', 'redshifts_distances', 'seeds', 'xH_box']>\n",
1455
+ "51200 images can be loaded\n",
1456
+ "field.shape = (64, 64, 514)\n",
1457
+ "params keys = [b'ION_Tvir_MIN', b'HII_EFF_FACTOR']\n",
1458
+ "loading 200 images randomly\n",
1459
+ "images loaded: (200, 1, 64, 512)\n"
1460
+ ]
1461
+ },
1462
+ {
1463
+ "name": "stderr",
1464
+ "output_type": "stream",
1465
+ "text": [
1466
+ "Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n"
1467
+ ]
1468
+ },
1469
+ {
1470
+ "name": "stdout",
1471
+ "output_type": "stream",
1472
+ "text": [
1473
+ "params loaded: (200, 2)\n",
1474
+ "images rescaled to [-1.0, 1.095154047012329]\n",
1475
+ "params rescaled to [0.0, 0.997810682944812]\n"
1476
+ ]
1477
+ },
1478
  {
1479
  "data": {
1480
  "application/vnd.jupyter.widget-view+json": {
1481
+ "model_id": "e6947686dc1344b4b446910cca5326dc",
1482
  "version_major": 2,
1483
  "version_minor": 0
1484
  },
 
1492
  {
1493
  "data": {
1494
  "application/vnd.jupyter.widget-view+json": {
1495
+ "model_id": "90757bce414c44778e816eb75fbeb106",
1496
  "version_major": 2,
1497
  "version_minor": 0
1498
  },
 
1506
  {
1507
  "data": {
1508
  "application/vnd.jupyter.widget-view+json": {
1509
+ "model_id": "cc51568c1fb94c82a96726973b103b36",
1510
  "version_major": 2,
1511
  "version_minor": 0
1512
  },
 
1520
  {
1521
  "data": {
1522
  "application/vnd.jupyter.widget-view+json": {
1523
+ "model_id": "62c75108d8344ec69b34873181e3d308",
1524
  "version_major": 2,
1525
  "version_minor": 0
1526
  },
 
1534
  {
1535
  "data": {
1536
  "application/vnd.jupyter.widget-view+json": {
1537
+ "model_id": "6bece94eb3cd45c7933bae67ad774f06",
1538
  "version_major": 2,
1539
  "version_minor": 0
1540
  },
 
1548
  {
1549
  "data": {
1550
  "application/vnd.jupyter.widget-view+json": {
1551
+ "model_id": "d3c093174fab4863af65f477f8baee8f",
1552
  "version_major": 2,
1553
  "version_minor": 0
1554
  },
 
1562
  {
1563
  "data": {
1564
  "application/vnd.jupyter.widget-view+json": {
1565
+ "model_id": "f2565e70177a402ba90215f4d80f717e",
1566
  "version_major": 2,
1567
  "version_minor": 0
1568
  },
 
1576
  {
1577
  "data": {
1578
  "application/vnd.jupyter.widget-view+json": {
1579
+ "model_id": "f11d5a63fb1040c2a4366207e53476b7",
1580
  "version_major": 2,
1581
  "version_minor": 0
1582
  },
 
1590
  {
1591
  "data": {
1592
  "application/vnd.jupyter.widget-view+json": {
1593
+ "model_id": "4641851ffed541edb4bfd99d156bf000",
1594
  "version_major": 2,
1595
  "version_minor": 0
1596
  },
 
1604
  {
1605
  "data": {
1606
  "application/vnd.jupyter.widget-view+json": {
1607
+ "model_id": "9338bea79f5d49c4b79a2327b9833954",
1608
  "version_major": 2,
1609
  "version_minor": 0
1610
  },
 
1622
  },
1623
  {
1624
  "cell_type": "code",
1625
+ "execution_count": 29,
1626
  "metadata": {},
1627
  "outputs": [
1628
  {
1629
+ "ename": "IndexError",
1630
+ "evalue": "tuple index out of range",
 
 
 
 
 
 
 
1631
  "output_type": "error",
1632
  "traceback": [
1633
  "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
1634
+ "\u001b[0;31mIndexError\u001b[0m Traceback (most recent call last)",
1635
+ "Cell \u001b[0;32mIn[29], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m ddpm21cm\u001b[39m.\u001b[39;49msample(\u001b[39m\"\u001b[39;49m\u001b[39m./outputs/model_state_09.pth\u001b[39;49m\u001b[39m\"\u001b[39;49m)\n",
1636
+ "Cell \u001b[0;32mIn[28], line 143\u001b[0m, in \u001b[0;36mDDPM21CM.sample\u001b[0;34m(self, file, params, ema, entire)\u001b[0m\n\u001b[1;32m 140\u001b[0m model \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mema_model \u001b[39mif\u001b[39;00m ema \u001b[39melse\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mnn_model\n\u001b[1;32m 141\u001b[0m \u001b[39m# params = torch.tile(params, (n_sample,1)).to(device)\u001b[39;00m\n\u001b[0;32m--> 143\u001b[0m x_last, x_entire \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mddpm\u001b[39m.\u001b[39;49msample(model, params\u001b[39m=\u001b[39;49mparams, device\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mconfig\u001b[39m.\u001b[39;49mdevice, guide_w\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mconfig\u001b[39m.\u001b[39;49mguide_w)\n\u001b[1;32m 145\u001b[0m np\u001b[39m.\u001b[39msave(os\u001b[39m.\u001b[39mpath\u001b[39m.\u001b[39mjoin(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconfig\u001b[39m.\u001b[39moutput_dir, \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m{\u001b[39;00m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconfig\u001b[39m.\u001b[39mrun_name\u001b[39m}\u001b[39;00m\u001b[39m{\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39mema\u001b[39m\u001b[39m'\u001b[39m\u001b[39m \u001b[39m\u001b[39mif\u001b[39;00m\u001b[39m \u001b[39mema\u001b[39m \u001b[39m\u001b[39melse\u001b[39;00m\u001b[39m \u001b[39m\u001b[39mNone\u001b[39;00m\u001b[39m}\u001b[39;00m\u001b[39m.npy\u001b[39m\u001b[39m\"\u001b[39m), x_last)\n\u001b[1;32m 146\u001b[0m \u001b[39mif\u001b[39;00m entire:\n",
1637
+ "Cell \u001b[0;32mIn[7], line 45\u001b[0m, in \u001b[0;36mDDPMScheduler.sample\u001b[0;34m(self, nn_model, params, device, guide_w)\u001b[0m\n\u001b[1;32m 43\u001b[0m \u001b[39mif\u001b[39;00m guide_w \u001b[39m!=\u001b[39m \u001b[39m-\u001b[39m\u001b[39m1\u001b[39m:\n\u001b[1;32m 44\u001b[0m c_i \u001b[39m=\u001b[39m params\n\u001b[0;32m---> 45\u001b[0m uncond_tokens \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mzeros(\u001b[39mint\u001b[39m(n_sample), params\u001b[39m.\u001b[39;49mshape[\u001b[39m1\u001b[39;49m])\u001b[39m.\u001b[39mto(device)\n\u001b[1;32m 46\u001b[0m \u001b[39m# uncond_tokens = torch.tensor(np.float32(np.array([0,0]))).to(device)\u001b[39;00m\n\u001b[1;32m 47\u001b[0m \u001b[39m# uncond_tokens = uncond_tokens.repeat(int(n_sample),1)\u001b[39;00m\n\u001b[1;32m 48\u001b[0m c_i \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mcat((c_i, uncond_tokens), \u001b[39m0\u001b[39m)\n",
1638
+ "\u001b[0;31mIndexError\u001b[0m: tuple index out of range"
 
1639
  ]
1640
  }
1641
  ],
 
2035
  "@dataclass\n",
2036
  "class TrainingConfig:\n",
2037
  " image_size = 128\n",
2038
+ " batch_size = 16\n",
2039
  " eval_batch_size = 16\n",
2040
  " num_epochs = 50\n",
2041
  " gradient_accumulation_steps = 1\n",
 
2072
  " return {\"images\": images}\n",
2073
  "\n",
2074
  "dataset.set_transform(transform)\n",
2075
+ "dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.batch_size, shuffle=True)\n",
2076
  "\n",
2077
  "model = UNet2DModel(\n",
2078
  " sample_size = config.image_size,\n",
 
2273
  "class TrainingConfig:\n",
2274
  " num_images = 600\n",
2275
  " image_size = [64,512]\n",
2276
+ " batch_size = 10\n",
2277
  " eval_batch_size = 24\n",
2278
  " num_epochs = 20\n",
2279
  " gradient_accumulation_steps = 1\n",
 
2552
  "outputs": [],
2553
  "source": [
2554
  "# dataset.set_transform(transform)\n",
2555
+ "dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.batch_size, shuffle=True)"
2556
  ]
2557
  },
2558
  {