Xsmos commited on
Commit
58b2929
·
verified ·
1 Parent(s): 0757b47
Files changed (1) hide show
  1. diffusion.ipynb +25 -31
diffusion.ipynb CHANGED
@@ -121,14 +121,14 @@
121
  " self.images = torch.from_numpy(self.images)\n",
122
  " print(f\"images rescaled to [{self.images.min()}, {self.images.max()}]\")\n",
123
  "\n",
124
- " cond_filter = torch.bernoulli(torch.ones(self.params.shape[0],1)-self.drop_prob).repeat(1,self.params.shape[1]).numpy()\n",
125
  " self.params = torch.from_numpy(self.params*cond_filter)\n",
126
  " print(f\"params rescaled to [{self.params.min()}, {self.params.max()}]\")\n",
127
  "\n",
128
  " def load_h5(self):\n",
129
  " with h5py.File(self.dir_name, 'r') as f:\n",
130
  " print(f\"dataset content: {f.keys()}\")\n",
131
- " max_num_image = f['brightness_temp'].shape[0]\n",
132
  " print(f\"{max_num_image} images can be loaded\")\n",
133
  " field_shape = f['brightness_temp'].shape[1:]\n",
134
  " print(f\"field.shape = {field_shape}\")\n",
@@ -275,8 +275,8 @@
275
  " return noisy_images, noise, ts\n",
276
  "\n",
277
  " def sample(self, nn_model, params, device, guide_w = 0):\n",
278
- " n_sample = params.shape[0]\n",
279
- " print(\"params.shape[0], len(params)\", params.shape[0], len(params))\n",
280
  " x_i = torch.randn(n_sample, *self.shape[1:]).to(device)\n",
281
  " # print(\"x_i.shape =\", x_i.shape)\n",
282
  " if guide_w != -1:\n",
@@ -945,7 +945,7 @@
945
  },
946
  {
947
  "cell_type": "code",
948
- "execution_count": 103,
949
  "metadata": {},
950
  "outputs": [],
951
  "source": [
@@ -966,6 +966,8 @@
966
  " num_image = 20 # 2400\n",
967
  " HII_DIM = 64\n",
968
  " num_redshift = 512#256#256#64#512#128\n",
 
 
969
  " n_epoch = 10#2#5#25 # 120\n",
970
  " num_timesteps = 1000#1000 # 1000, 500; DDPM time steps\n",
971
  " train_batch_size = 10#10#20#2#100 # 10\n",
@@ -1008,7 +1010,7 @@
1008
  },
1009
  {
1010
  "cell_type": "code",
1011
- "execution_count": 104,
1012
  "metadata": {},
1013
  "outputs": [],
1014
  "source": [
@@ -1018,7 +1020,7 @@
1018
  },
1019
  {
1020
  "cell_type": "code",
1021
- "execution_count": 105,
1022
  "metadata": {},
1023
  "outputs": [],
1024
  "source": [
@@ -1027,7 +1029,7 @@
1027
  },
1028
  {
1029
  "cell_type": "code",
1030
- "execution_count": 106,
1031
  "metadata": {},
1032
  "outputs": [],
1033
  "source": [
@@ -1051,7 +1053,7 @@
1051
  },
1052
  {
1053
  "cell_type": "code",
1054
- "execution_count": 107,
1055
  "metadata": {},
1056
  "outputs": [],
1057
  "source": [
@@ -1249,7 +1251,7 @@
1249
  },
1250
  {
1251
  "cell_type": "code",
1252
- "execution_count": 117,
1253
  "metadata": {},
1254
  "outputs": [
1255
  {
@@ -1263,8 +1265,8 @@
1263
  "loading 20 images randomly\n",
1264
  "images loaded: (20, 1, 64, 512)\n",
1265
  "params loaded: (20, 2)\n",
1266
- "images rescaled to [-1.0, 0.8818775415420532]\n",
1267
- "params rescaled to [0.0, 0.9965706900954632]\n",
1268
  "resumed nn_model from model_state.pth\n",
1269
  "Number of parameters for nn_model: 111048705\n"
1270
  ]
@@ -1424,7 +1426,6 @@
1424
  "\n",
1425
  " def sample(self, file, params=torch.tensor((0.2,0.8)), ema=False, entire=False):\n",
1426
  " # n_sample = params.shape[0]\n",
1427
- " shape = (self.config.HII, self.config.num_redshift) if self.config.dim == 2 else (self.config.HII, self.config.HII, self.config.num_redshift)\n",
1428
  " model = self.ema_model if ema else self.nn_model\n",
1429
  " # params = torch.tile(params, (n_sample,1)).to(device)\n",
1430
  "\n",
@@ -1441,13 +1442,13 @@
1441
  },
1442
  {
1443
  "cell_type": "code",
1444
- "execution_count": 110,
1445
  "metadata": {},
1446
  "outputs": [
1447
  {
1448
  "data": {
1449
  "application/vnd.jupyter.widget-view+json": {
1450
- "model_id": "7b26c7444a1144728f0299db5d5683b1",
1451
  "version_major": 2,
1452
  "version_minor": 0
1453
  },
@@ -1461,7 +1462,7 @@
1461
  {
1462
  "data": {
1463
  "application/vnd.jupyter.widget-view+json": {
1464
- "model_id": "e3b8c1a18460443d986282f284bf0b42",
1465
  "version_major": 2,
1466
  "version_minor": 0
1467
  },
@@ -1475,7 +1476,7 @@
1475
  {
1476
  "data": {
1477
  "application/vnd.jupyter.widget-view+json": {
1478
- "model_id": "84d558fcea2f4b998f311da7452a2c93",
1479
  "version_major": 2,
1480
  "version_minor": 0
1481
  },
@@ -1489,7 +1490,7 @@
1489
  {
1490
  "data": {
1491
  "application/vnd.jupyter.widget-view+json": {
1492
- "model_id": "364208bd1fa04859baf233ed30451638",
1493
  "version_major": 2,
1494
  "version_minor": 0
1495
  },
@@ -1503,7 +1504,7 @@
1503
  {
1504
  "data": {
1505
  "application/vnd.jupyter.widget-view+json": {
1506
- "model_id": "e8aa693ca9b444d7b6adb77531f8b718",
1507
  "version_major": 2,
1508
  "version_minor": 0
1509
  },
@@ -1517,7 +1518,7 @@
1517
  {
1518
  "data": {
1519
  "application/vnd.jupyter.widget-view+json": {
1520
- "model_id": "aab44a74df8c4a7a8baf975b8ec50b8b",
1521
  "version_major": 2,
1522
  "version_minor": 0
1523
  },
@@ -1531,7 +1532,7 @@
1531
  {
1532
  "data": {
1533
  "application/vnd.jupyter.widget-view+json": {
1534
- "model_id": "461da88dafc8437a858b6eed2042c709",
1535
  "version_major": 2,
1536
  "version_minor": 0
1537
  },
@@ -1545,7 +1546,7 @@
1545
  {
1546
  "data": {
1547
  "application/vnd.jupyter.widget-view+json": {
1548
- "model_id": "a85d41b905884ec6947eb1fb94f2f934",
1549
  "version_major": 2,
1550
  "version_minor": 0
1551
  },
@@ -1559,7 +1560,7 @@
1559
  {
1560
  "data": {
1561
  "application/vnd.jupyter.widget-view+json": {
1562
- "model_id": "f9677e0a3b6049d1a9877eec1689f506",
1563
  "version_major": 2,
1564
  "version_minor": 0
1565
  },
@@ -1573,7 +1574,7 @@
1573
  {
1574
  "data": {
1575
  "application/vnd.jupyter.widget-view+json": {
1576
- "model_id": "5bd75ea4440b4acbb54b55c14fec272b",
1577
  "version_major": 2,
1578
  "version_minor": 0
1579
  },
@@ -1583,13 +1584,6 @@
1583
  },
1584
  "metadata": {},
1585
  "output_type": "display_data"
1586
- },
1587
- {
1588
- "name": "stdout",
1589
- "output_type": "stream",
1590
- "text": [
1591
- "saved model at ./outputs/model_state_09.pth\n"
1592
- ]
1593
  }
1594
  ],
1595
  "source": [
 
121
  " self.images = torch.from_numpy(self.images)\n",
122
  " print(f\"images rescaled to [{self.images.min()}, {self.images.max()}]\")\n",
123
  "\n",
124
+ " cond_filter = torch.bernoulli(torch.ones(len(self.params),1)-self.drop_prob).repeat(1,self.params.shape[1]).numpy()\n",
125
  " self.params = torch.from_numpy(self.params*cond_filter)\n",
126
  " print(f\"params rescaled to [{self.params.min()}, {self.params.max()}]\")\n",
127
  "\n",
128
  " def load_h5(self):\n",
129
  " with h5py.File(self.dir_name, 'r') as f:\n",
130
  " print(f\"dataset content: {f.keys()}\")\n",
131
+ " max_num_image = len(f['brightness_temp'])#.shape[0]\n",
132
  " print(f\"{max_num_image} images can be loaded\")\n",
133
  " field_shape = f['brightness_temp'].shape[1:]\n",
134
  " print(f\"field.shape = {field_shape}\")\n",
 
275
  " return noisy_images, noise, ts\n",
276
  "\n",
277
  " def sample(self, nn_model, params, device, guide_w = 0):\n",
278
+ " n_sample = len(params) #params.shape[0]\n",
279
+ " # print(\"params.shape[0], len(params)\", params.shape[0], len(params))\n",
280
  " x_i = torch.randn(n_sample, *self.shape[1:]).to(device)\n",
281
  " # print(\"x_i.shape =\", x_i.shape)\n",
282
  " if guide_w != -1:\n",
 
945
  },
946
  {
947
  "cell_type": "code",
948
+ "execution_count": 123,
949
  "metadata": {},
950
  "outputs": [],
951
  "source": [
 
966
  " num_image = 20 # 2400\n",
967
  " HII_DIM = 64\n",
968
  " num_redshift = 512#256#256#64#512#128\n",
969
+ " img_shape = (HII_DIM, num_redshift) if dim == 2 else (HII_DIM, HII_DIM, num_redshift)\n",
970
+ "\n",
971
  " n_epoch = 10#2#5#25 # 120\n",
972
  " num_timesteps = 1000#1000 # 1000, 500; DDPM time steps\n",
973
  " train_batch_size = 10#10#20#2#100 # 10\n",
 
1010
  },
1011
  {
1012
  "cell_type": "code",
1013
+ "execution_count": 124,
1014
  "metadata": {},
1015
  "outputs": [],
1016
  "source": [
 
1020
  },
1021
  {
1022
  "cell_type": "code",
1023
+ "execution_count": 125,
1024
  "metadata": {},
1025
  "outputs": [],
1026
  "source": [
 
1029
  },
1030
  {
1031
  "cell_type": "code",
1032
+ "execution_count": 126,
1033
  "metadata": {},
1034
  "outputs": [],
1035
  "source": [
 
1053
  },
1054
  {
1055
  "cell_type": "code",
1056
+ "execution_count": 127,
1057
  "metadata": {},
1058
  "outputs": [],
1059
  "source": [
 
1251
  },
1252
  {
1253
  "cell_type": "code",
1254
+ "execution_count": 128,
1255
  "metadata": {},
1256
  "outputs": [
1257
  {
 
1265
  "loading 20 images randomly\n",
1266
  "images loaded: (20, 1, 64, 512)\n",
1267
  "params loaded: (20, 2)\n",
1268
+ "images rescaled to [-1.0, 1.038496732711792]\n",
1269
+ "params rescaled to [0.0, 0.9816321951033768]\n",
1270
  "resumed nn_model from model_state.pth\n",
1271
  "Number of parameters for nn_model: 111048705\n"
1272
  ]
 
1426
  "\n",
1427
  " def sample(self, file, params=torch.tensor((0.2,0.8)), ema=False, entire=False):\n",
1428
  " # n_sample = params.shape[0]\n",
 
1429
  " model = self.ema_model if ema else self.nn_model\n",
1430
  " # params = torch.tile(params, (n_sample,1)).to(device)\n",
1431
  "\n",
 
1442
  },
1443
  {
1444
  "cell_type": "code",
1445
+ "execution_count": 129,
1446
  "metadata": {},
1447
  "outputs": [
1448
  {
1449
  "data": {
1450
  "application/vnd.jupyter.widget-view+json": {
1451
+ "model_id": "067df92056c8456aa796e3416bac122a",
1452
  "version_major": 2,
1453
  "version_minor": 0
1454
  },
 
1462
  {
1463
  "data": {
1464
  "application/vnd.jupyter.widget-view+json": {
1465
+ "model_id": "8e77d9787ee049e5896b1be75d34bf05",
1466
  "version_major": 2,
1467
  "version_minor": 0
1468
  },
 
1476
  {
1477
  "data": {
1478
  "application/vnd.jupyter.widget-view+json": {
1479
+ "model_id": "8211e85e22354d7da06f66786ff33d4a",
1480
  "version_major": 2,
1481
  "version_minor": 0
1482
  },
 
1490
  {
1491
  "data": {
1492
  "application/vnd.jupyter.widget-view+json": {
1493
+ "model_id": "31d068ad21c642468bb2a90c7af57c83",
1494
  "version_major": 2,
1495
  "version_minor": 0
1496
  },
 
1504
  {
1505
  "data": {
1506
  "application/vnd.jupyter.widget-view+json": {
1507
+ "model_id": "2ca6304e757f4c8696bacfc36692e791",
1508
  "version_major": 2,
1509
  "version_minor": 0
1510
  },
 
1518
  {
1519
  "data": {
1520
  "application/vnd.jupyter.widget-view+json": {
1521
+ "model_id": "7cc536030a784596995ec5130b7638c5",
1522
  "version_major": 2,
1523
  "version_minor": 0
1524
  },
 
1532
  {
1533
  "data": {
1534
  "application/vnd.jupyter.widget-view+json": {
1535
+ "model_id": "b415a15a942046f08e3e2c92404c14ad",
1536
  "version_major": 2,
1537
  "version_minor": 0
1538
  },
 
1546
  {
1547
  "data": {
1548
  "application/vnd.jupyter.widget-view+json": {
1549
+ "model_id": "2de1a814b7d34998b63eec43c1d43c12",
1550
  "version_major": 2,
1551
  "version_minor": 0
1552
  },
 
1560
  {
1561
  "data": {
1562
  "application/vnd.jupyter.widget-view+json": {
1563
+ "model_id": "2ae161b79b0d4e688b12432455a6c065",
1564
  "version_major": 2,
1565
  "version_minor": 0
1566
  },
 
1574
  {
1575
  "data": {
1576
  "application/vnd.jupyter.widget-view+json": {
1577
+ "model_id": "7497a93eb57a40e281141126947f78ae",
1578
  "version_major": 2,
1579
  "version_minor": 0
1580
  },
 
1584
  },
1585
  "metadata": {},
1586
  "output_type": "display_data"
 
 
 
 
 
 
 
1587
  }
1588
  ],
1589
  "source": [