0521-1939
Browse files- diffusion.ipynb +25 -31
diffusion.ipynb
CHANGED
|
@@ -121,14 +121,14 @@
|
|
| 121 |
" self.images = torch.from_numpy(self.images)\n",
|
| 122 |
" print(f\"images rescaled to [{self.images.min()}, {self.images.max()}]\")\n",
|
| 123 |
"\n",
|
| 124 |
-
" cond_filter = torch.bernoulli(torch.ones(self.params
|
| 125 |
" self.params = torch.from_numpy(self.params*cond_filter)\n",
|
| 126 |
" print(f\"params rescaled to [{self.params.min()}, {self.params.max()}]\")\n",
|
| 127 |
"\n",
|
| 128 |
" def load_h5(self):\n",
|
| 129 |
" with h5py.File(self.dir_name, 'r') as f:\n",
|
| 130 |
" print(f\"dataset content: {f.keys()}\")\n",
|
| 131 |
-
" max_num_image = f['brightness_temp']
|
| 132 |
" print(f\"{max_num_image} images can be loaded\")\n",
|
| 133 |
" field_shape = f['brightness_temp'].shape[1:]\n",
|
| 134 |
" print(f\"field.shape = {field_shape}\")\n",
|
|
@@ -275,8 +275,8 @@
|
|
| 275 |
" return noisy_images, noise, ts\n",
|
| 276 |
"\n",
|
| 277 |
" def sample(self, nn_model, params, device, guide_w = 0):\n",
|
| 278 |
-
" n_sample = params.shape[0]\n",
|
| 279 |
-
" print(\"params.shape[0], len(params)\", params.shape[0], len(params))\n",
|
| 280 |
" x_i = torch.randn(n_sample, *self.shape[1:]).to(device)\n",
|
| 281 |
" # print(\"x_i.shape =\", x_i.shape)\n",
|
| 282 |
" if guide_w != -1:\n",
|
|
@@ -945,7 +945,7 @@
|
|
| 945 |
},
|
| 946 |
{
|
| 947 |
"cell_type": "code",
|
| 948 |
-
"execution_count":
|
| 949 |
"metadata": {},
|
| 950 |
"outputs": [],
|
| 951 |
"source": [
|
|
@@ -966,6 +966,8 @@
|
|
| 966 |
" num_image = 20 # 2400\n",
|
| 967 |
" HII_DIM = 64\n",
|
| 968 |
" num_redshift = 512#256#256#64#512#128\n",
|
|
|
|
|
|
|
| 969 |
" n_epoch = 10#2#5#25 # 120\n",
|
| 970 |
" num_timesteps = 1000#1000 # 1000, 500; DDPM time steps\n",
|
| 971 |
" train_batch_size = 10#10#20#2#100 # 10\n",
|
|
@@ -1008,7 +1010,7 @@
|
|
| 1008 |
},
|
| 1009 |
{
|
| 1010 |
"cell_type": "code",
|
| 1011 |
-
"execution_count":
|
| 1012 |
"metadata": {},
|
| 1013 |
"outputs": [],
|
| 1014 |
"source": [
|
|
@@ -1018,7 +1020,7 @@
|
|
| 1018 |
},
|
| 1019 |
{
|
| 1020 |
"cell_type": "code",
|
| 1021 |
-
"execution_count":
|
| 1022 |
"metadata": {},
|
| 1023 |
"outputs": [],
|
| 1024 |
"source": [
|
|
@@ -1027,7 +1029,7 @@
|
|
| 1027 |
},
|
| 1028 |
{
|
| 1029 |
"cell_type": "code",
|
| 1030 |
-
"execution_count":
|
| 1031 |
"metadata": {},
|
| 1032 |
"outputs": [],
|
| 1033 |
"source": [
|
|
@@ -1051,7 +1053,7 @@
|
|
| 1051 |
},
|
| 1052 |
{
|
| 1053 |
"cell_type": "code",
|
| 1054 |
-
"execution_count":
|
| 1055 |
"metadata": {},
|
| 1056 |
"outputs": [],
|
| 1057 |
"source": [
|
|
@@ -1249,7 +1251,7 @@
|
|
| 1249 |
},
|
| 1250 |
{
|
| 1251 |
"cell_type": "code",
|
| 1252 |
-
"execution_count":
|
| 1253 |
"metadata": {},
|
| 1254 |
"outputs": [
|
| 1255 |
{
|
|
@@ -1263,8 +1265,8 @@
|
|
| 1263 |
"loading 20 images randomly\n",
|
| 1264 |
"images loaded: (20, 1, 64, 512)\n",
|
| 1265 |
"params loaded: (20, 2)\n",
|
| 1266 |
-
"images rescaled to [-1.0,
|
| 1267 |
-
"params rescaled to [0.0, 0.
|
| 1268 |
"resumed nn_model from model_state.pth\n",
|
| 1269 |
"Number of parameters for nn_model: 111048705\n"
|
| 1270 |
]
|
|
@@ -1424,7 +1426,6 @@
|
|
| 1424 |
"\n",
|
| 1425 |
" def sample(self, file, params=torch.tensor((0.2,0.8)), ema=False, entire=False):\n",
|
| 1426 |
" # n_sample = params.shape[0]\n",
|
| 1427 |
-
" shape = (self.config.HII, self.config.num_redshift) if self.config.dim == 2 else (self.config.HII, self.config.HII, self.config.num_redshift)\n",
|
| 1428 |
" model = self.ema_model if ema else self.nn_model\n",
|
| 1429 |
" # params = torch.tile(params, (n_sample,1)).to(device)\n",
|
| 1430 |
"\n",
|
|
@@ -1441,13 +1442,13 @@
|
|
| 1441 |
},
|
| 1442 |
{
|
| 1443 |
"cell_type": "code",
|
| 1444 |
-
"execution_count":
|
| 1445 |
"metadata": {},
|
| 1446 |
"outputs": [
|
| 1447 |
{
|
| 1448 |
"data": {
|
| 1449 |
"application/vnd.jupyter.widget-view+json": {
|
| 1450 |
-
"model_id": "
|
| 1451 |
"version_major": 2,
|
| 1452 |
"version_minor": 0
|
| 1453 |
},
|
|
@@ -1461,7 +1462,7 @@
|
|
| 1461 |
{
|
| 1462 |
"data": {
|
| 1463 |
"application/vnd.jupyter.widget-view+json": {
|
| 1464 |
-
"model_id": "
|
| 1465 |
"version_major": 2,
|
| 1466 |
"version_minor": 0
|
| 1467 |
},
|
|
@@ -1475,7 +1476,7 @@
|
|
| 1475 |
{
|
| 1476 |
"data": {
|
| 1477 |
"application/vnd.jupyter.widget-view+json": {
|
| 1478 |
-
"model_id": "
|
| 1479 |
"version_major": 2,
|
| 1480 |
"version_minor": 0
|
| 1481 |
},
|
|
@@ -1489,7 +1490,7 @@
|
|
| 1489 |
{
|
| 1490 |
"data": {
|
| 1491 |
"application/vnd.jupyter.widget-view+json": {
|
| 1492 |
-
"model_id": "
|
| 1493 |
"version_major": 2,
|
| 1494 |
"version_minor": 0
|
| 1495 |
},
|
|
@@ -1503,7 +1504,7 @@
|
|
| 1503 |
{
|
| 1504 |
"data": {
|
| 1505 |
"application/vnd.jupyter.widget-view+json": {
|
| 1506 |
-
"model_id": "
|
| 1507 |
"version_major": 2,
|
| 1508 |
"version_minor": 0
|
| 1509 |
},
|
|
@@ -1517,7 +1518,7 @@
|
|
| 1517 |
{
|
| 1518 |
"data": {
|
| 1519 |
"application/vnd.jupyter.widget-view+json": {
|
| 1520 |
-
"model_id": "
|
| 1521 |
"version_major": 2,
|
| 1522 |
"version_minor": 0
|
| 1523 |
},
|
|
@@ -1531,7 +1532,7 @@
|
|
| 1531 |
{
|
| 1532 |
"data": {
|
| 1533 |
"application/vnd.jupyter.widget-view+json": {
|
| 1534 |
-
"model_id": "
|
| 1535 |
"version_major": 2,
|
| 1536 |
"version_minor": 0
|
| 1537 |
},
|
|
@@ -1545,7 +1546,7 @@
|
|
| 1545 |
{
|
| 1546 |
"data": {
|
| 1547 |
"application/vnd.jupyter.widget-view+json": {
|
| 1548 |
-
"model_id": "
|
| 1549 |
"version_major": 2,
|
| 1550 |
"version_minor": 0
|
| 1551 |
},
|
|
@@ -1559,7 +1560,7 @@
|
|
| 1559 |
{
|
| 1560 |
"data": {
|
| 1561 |
"application/vnd.jupyter.widget-view+json": {
|
| 1562 |
-
"model_id": "
|
| 1563 |
"version_major": 2,
|
| 1564 |
"version_minor": 0
|
| 1565 |
},
|
|
@@ -1573,7 +1574,7 @@
|
|
| 1573 |
{
|
| 1574 |
"data": {
|
| 1575 |
"application/vnd.jupyter.widget-view+json": {
|
| 1576 |
-
"model_id": "
|
| 1577 |
"version_major": 2,
|
| 1578 |
"version_minor": 0
|
| 1579 |
},
|
|
@@ -1583,13 +1584,6 @@
|
|
| 1583 |
},
|
| 1584 |
"metadata": {},
|
| 1585 |
"output_type": "display_data"
|
| 1586 |
-
},
|
| 1587 |
-
{
|
| 1588 |
-
"name": "stdout",
|
| 1589 |
-
"output_type": "stream",
|
| 1590 |
-
"text": [
|
| 1591 |
-
"saved model at ./outputs/model_state_09.pth\n"
|
| 1592 |
-
]
|
| 1593 |
}
|
| 1594 |
],
|
| 1595 |
"source": [
|
|
|
|
| 121 |
" self.images = torch.from_numpy(self.images)\n",
|
| 122 |
" print(f\"images rescaled to [{self.images.min()}, {self.images.max()}]\")\n",
|
| 123 |
"\n",
|
| 124 |
+
" cond_filter = torch.bernoulli(torch.ones(len(self.params),1)-self.drop_prob).repeat(1,self.params.shape[1]).numpy()\n",
|
| 125 |
" self.params = torch.from_numpy(self.params*cond_filter)\n",
|
| 126 |
" print(f\"params rescaled to [{self.params.min()}, {self.params.max()}]\")\n",
|
| 127 |
"\n",
|
| 128 |
" def load_h5(self):\n",
|
| 129 |
" with h5py.File(self.dir_name, 'r') as f:\n",
|
| 130 |
" print(f\"dataset content: {f.keys()}\")\n",
|
| 131 |
+
" max_num_image = len(f['brightness_temp'])#.shape[0]\n",
|
| 132 |
" print(f\"{max_num_image} images can be loaded\")\n",
|
| 133 |
" field_shape = f['brightness_temp'].shape[1:]\n",
|
| 134 |
" print(f\"field.shape = {field_shape}\")\n",
|
|
|
|
| 275 |
" return noisy_images, noise, ts\n",
|
| 276 |
"\n",
|
| 277 |
" def sample(self, nn_model, params, device, guide_w = 0):\n",
|
| 278 |
+
" n_sample = len(params) #params.shape[0]\n",
|
| 279 |
+
" # print(\"params.shape[0], len(params)\", params.shape[0], len(params))\n",
|
| 280 |
" x_i = torch.randn(n_sample, *self.shape[1:]).to(device)\n",
|
| 281 |
" # print(\"x_i.shape =\", x_i.shape)\n",
|
| 282 |
" if guide_w != -1:\n",
|
|
|
|
| 945 |
},
|
| 946 |
{
|
| 947 |
"cell_type": "code",
|
| 948 |
+
"execution_count": 123,
|
| 949 |
"metadata": {},
|
| 950 |
"outputs": [],
|
| 951 |
"source": [
|
|
|
|
| 966 |
" num_image = 20 # 2400\n",
|
| 967 |
" HII_DIM = 64\n",
|
| 968 |
" num_redshift = 512#256#256#64#512#128\n",
|
| 969 |
+
" img_shape = (HII_DIM, num_redshift) if dim == 2 else (HII_DIM, HII_DIM, num_redshift)\n",
|
| 970 |
+
"\n",
|
| 971 |
" n_epoch = 10#2#5#25 # 120\n",
|
| 972 |
" num_timesteps = 1000#1000 # 1000, 500; DDPM time steps\n",
|
| 973 |
" train_batch_size = 10#10#20#2#100 # 10\n",
|
|
|
|
| 1010 |
},
|
| 1011 |
{
|
| 1012 |
"cell_type": "code",
|
| 1013 |
+
"execution_count": 124,
|
| 1014 |
"metadata": {},
|
| 1015 |
"outputs": [],
|
| 1016 |
"source": [
|
|
|
|
| 1020 |
},
|
| 1021 |
{
|
| 1022 |
"cell_type": "code",
|
| 1023 |
+
"execution_count": 125,
|
| 1024 |
"metadata": {},
|
| 1025 |
"outputs": [],
|
| 1026 |
"source": [
|
|
|
|
| 1029 |
},
|
| 1030 |
{
|
| 1031 |
"cell_type": "code",
|
| 1032 |
+
"execution_count": 126,
|
| 1033 |
"metadata": {},
|
| 1034 |
"outputs": [],
|
| 1035 |
"source": [
|
|
|
|
| 1053 |
},
|
| 1054 |
{
|
| 1055 |
"cell_type": "code",
|
| 1056 |
+
"execution_count": 127,
|
| 1057 |
"metadata": {},
|
| 1058 |
"outputs": [],
|
| 1059 |
"source": [
|
|
|
|
| 1251 |
},
|
| 1252 |
{
|
| 1253 |
"cell_type": "code",
|
| 1254 |
+
"execution_count": 128,
|
| 1255 |
"metadata": {},
|
| 1256 |
"outputs": [
|
| 1257 |
{
|
|
|
|
| 1265 |
"loading 20 images randomly\n",
|
| 1266 |
"images loaded: (20, 1, 64, 512)\n",
|
| 1267 |
"params loaded: (20, 2)\n",
|
| 1268 |
+
"images rescaled to [-1.0, 1.038496732711792]\n",
|
| 1269 |
+
"params rescaled to [0.0, 0.9816321951033768]\n",
|
| 1270 |
"resumed nn_model from model_state.pth\n",
|
| 1271 |
"Number of parameters for nn_model: 111048705\n"
|
| 1272 |
]
|
|
|
|
| 1426 |
"\n",
|
| 1427 |
" def sample(self, file, params=torch.tensor((0.2,0.8)), ema=False, entire=False):\n",
|
| 1428 |
" # n_sample = params.shape[0]\n",
|
|
|
|
| 1429 |
" model = self.ema_model if ema else self.nn_model\n",
|
| 1430 |
" # params = torch.tile(params, (n_sample,1)).to(device)\n",
|
| 1431 |
"\n",
|
|
|
|
| 1442 |
},
|
| 1443 |
{
|
| 1444 |
"cell_type": "code",
|
| 1445 |
+
"execution_count": 129,
|
| 1446 |
"metadata": {},
|
| 1447 |
"outputs": [
|
| 1448 |
{
|
| 1449 |
"data": {
|
| 1450 |
"application/vnd.jupyter.widget-view+json": {
|
| 1451 |
+
"model_id": "067df92056c8456aa796e3416bac122a",
|
| 1452 |
"version_major": 2,
|
| 1453 |
"version_minor": 0
|
| 1454 |
},
|
|
|
|
| 1462 |
{
|
| 1463 |
"data": {
|
| 1464 |
"application/vnd.jupyter.widget-view+json": {
|
| 1465 |
+
"model_id": "8e77d9787ee049e5896b1be75d34bf05",
|
| 1466 |
"version_major": 2,
|
| 1467 |
"version_minor": 0
|
| 1468 |
},
|
|
|
|
| 1476 |
{
|
| 1477 |
"data": {
|
| 1478 |
"application/vnd.jupyter.widget-view+json": {
|
| 1479 |
+
"model_id": "8211e85e22354d7da06f66786ff33d4a",
|
| 1480 |
"version_major": 2,
|
| 1481 |
"version_minor": 0
|
| 1482 |
},
|
|
|
|
| 1490 |
{
|
| 1491 |
"data": {
|
| 1492 |
"application/vnd.jupyter.widget-view+json": {
|
| 1493 |
+
"model_id": "31d068ad21c642468bb2a90c7af57c83",
|
| 1494 |
"version_major": 2,
|
| 1495 |
"version_minor": 0
|
| 1496 |
},
|
|
|
|
| 1504 |
{
|
| 1505 |
"data": {
|
| 1506 |
"application/vnd.jupyter.widget-view+json": {
|
| 1507 |
+
"model_id": "2ca6304e757f4c8696bacfc36692e791",
|
| 1508 |
"version_major": 2,
|
| 1509 |
"version_minor": 0
|
| 1510 |
},
|
|
|
|
| 1518 |
{
|
| 1519 |
"data": {
|
| 1520 |
"application/vnd.jupyter.widget-view+json": {
|
| 1521 |
+
"model_id": "7cc536030a784596995ec5130b7638c5",
|
| 1522 |
"version_major": 2,
|
| 1523 |
"version_minor": 0
|
| 1524 |
},
|
|
|
|
| 1532 |
{
|
| 1533 |
"data": {
|
| 1534 |
"application/vnd.jupyter.widget-view+json": {
|
| 1535 |
+
"model_id": "b415a15a942046f08e3e2c92404c14ad",
|
| 1536 |
"version_major": 2,
|
| 1537 |
"version_minor": 0
|
| 1538 |
},
|
|
|
|
| 1546 |
{
|
| 1547 |
"data": {
|
| 1548 |
"application/vnd.jupyter.widget-view+json": {
|
| 1549 |
+
"model_id": "2de1a814b7d34998b63eec43c1d43c12",
|
| 1550 |
"version_major": 2,
|
| 1551 |
"version_minor": 0
|
| 1552 |
},
|
|
|
|
| 1560 |
{
|
| 1561 |
"data": {
|
| 1562 |
"application/vnd.jupyter.widget-view+json": {
|
| 1563 |
+
"model_id": "2ae161b79b0d4e688b12432455a6c065",
|
| 1564 |
"version_major": 2,
|
| 1565 |
"version_minor": 0
|
| 1566 |
},
|
|
|
|
| 1574 |
{
|
| 1575 |
"data": {
|
| 1576 |
"application/vnd.jupyter.widget-view+json": {
|
| 1577 |
+
"model_id": "7497a93eb57a40e281141126947f78ae",
|
| 1578 |
"version_major": 2,
|
| 1579 |
"version_minor": 0
|
| 1580 |
},
|
|
|
|
| 1584 |
},
|
| 1585 |
"metadata": {},
|
| 1586 |
"output_type": "display_data"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1587 |
}
|
| 1588 |
],
|
| 1589 |
"source": [
|