0521-1738
Browse files- diffusion.ipynb +279 -228
diffusion.ipynb
CHANGED
|
@@ -234,7 +234,7 @@
|
|
| 234 |
},
|
| 235 |
{
|
| 236 |
"cell_type": "code",
|
| 237 |
-
"execution_count":
|
| 238 |
"metadata": {},
|
| 239 |
"outputs": [],
|
| 240 |
"source": [
|
|
@@ -256,30 +256,32 @@
|
|
| 256 |
" self.bar_alpha_t = torch.cumprod(self.alpha_t, dim=0)\n",
|
| 257 |
"\n",
|
| 258 |
" def add_noise(self, clean_images):\n",
|
| 259 |
-
" shape = clean_images.shape\n",
|
| 260 |
-
" expand = torch.ones(len(shape)-1, dtype=int)\n",
|
| 261 |
" # ts_expand = ts.view(ts.shape[0], *expand.tolist())\n",
|
| 262 |
" # expand = [1 for i in range(len(shape)-1)]\n",
|
| 263 |
"\n",
|
| 264 |
" noise = torch.randn_like(clean_images).to(self.device)\n",
|
| 265 |
-
" ts = torch.randint(0, self.num_timesteps, (shape[0],)).to(self.device)\n",
|
| 266 |
" \n",
|
| 267 |
" # test_expand = test.view(test.shape[0],*expand)\n",
|
| 268 |
" # extend_dim = [None for i in range(shape.dim()-1)]\n",
|
| 269 |
" noisy_images = (\n",
|
| 270 |
-
" clean_images * torch.sqrt(self.bar_alpha_t[ts]).view(shape[0], *expand.tolist())\n",
|
| 271 |
-
" + noise * torch.sqrt(1-self.bar_alpha_t[ts]).view(shape[0], *expand.tolist())\n",
|
| 272 |
" )\n",
|
| 273 |
" # print(x_t.shape)\n",
|
| 274 |
"\n",
|
| 275 |
" return noisy_images, noise, ts\n",
|
| 276 |
"\n",
|
| 277 |
-
" def sample(self, nn_model,
|
| 278 |
-
"
|
|
|
|
|
|
|
| 279 |
" # print(\"x_i.shape =\", x_i.shape)\n",
|
| 280 |
" if guide_w != -1:\n",
|
| 281 |
-
" c_i =
|
| 282 |
-
" uncond_tokens = torch.zeros(int(n_sample),
|
| 283 |
" # uncond_tokens = torch.tensor(np.float32(np.array([0,0]))).to(device)\n",
|
| 284 |
" # uncond_tokens = uncond_tokens.repeat(int(n_sample),1)\n",
|
| 285 |
" c_i = torch.cat((c_i, uncond_tokens), 0)\n",
|
|
@@ -295,7 +297,7 @@
|
|
| 295 |
" t_is = torch.tensor([i]).to(device)\n",
|
| 296 |
" t_is = t_is.repeat(n_sample)\n",
|
| 297 |
"\n",
|
| 298 |
-
" z = torch.randn(n_sample, *shape).to(device) if i > 0 else 0\n",
|
| 299 |
"\n",
|
| 300 |
" if guide_w == -1:\n",
|
| 301 |
" # eps = nn_model(x_i, t_is, return_dict=False)[0]\n",
|
|
@@ -303,7 +305,7 @@
|
|
| 303 |
" # x_i = 1/torch.sqrt(self.alpha_t[i])*(x_i-eps*self.beta_t[i]/torch.sqrt(1-self.bar_alpha_t[i])) + torch.sqrt(self.beta_t[i])*z\n",
|
| 304 |
" else:\n",
|
| 305 |
" # double batch\n",
|
| 306 |
-
" x_i = x_i.repeat(2, *torch.ones(len(shape), dtype=int).tolist())\n",
|
| 307 |
" t_is = t_is.repeat(2)\n",
|
| 308 |
"\n",
|
| 309 |
" # split predictions and compute weighting\n",
|
|
@@ -336,7 +338,7 @@
|
|
| 336 |
},
|
| 337 |
{
|
| 338 |
"cell_type": "code",
|
| 339 |
-
"execution_count":
|
| 340 |
"metadata": {},
|
| 341 |
"outputs": [],
|
| 342 |
"source": [
|
|
@@ -378,7 +380,7 @@
|
|
| 378 |
},
|
| 379 |
{
|
| 380 |
"cell_type": "code",
|
| 381 |
-
"execution_count":
|
| 382 |
"metadata": {},
|
| 383 |
"outputs": [],
|
| 384 |
"source": [
|
|
@@ -403,7 +405,7 @@
|
|
| 403 |
},
|
| 404 |
{
|
| 405 |
"cell_type": "code",
|
| 406 |
-
"execution_count":
|
| 407 |
"metadata": {},
|
| 408 |
"outputs": [],
|
| 409 |
"source": [
|
|
@@ -432,7 +434,7 @@
|
|
| 432 |
},
|
| 433 |
{
|
| 434 |
"cell_type": "code",
|
| 435 |
-
"execution_count":
|
| 436 |
"metadata": {},
|
| 437 |
"outputs": [],
|
| 438 |
"source": [
|
|
@@ -447,7 +449,7 @@
|
|
| 447 |
},
|
| 448 |
{
|
| 449 |
"cell_type": "code",
|
| 450 |
-
"execution_count":
|
| 451 |
"metadata": {},
|
| 452 |
"outputs": [],
|
| 453 |
"source": [
|
|
@@ -461,7 +463,7 @@
|
|
| 461 |
},
|
| 462 |
{
|
| 463 |
"cell_type": "code",
|
| 464 |
-
"execution_count":
|
| 465 |
"metadata": {},
|
| 466 |
"outputs": [],
|
| 467 |
"source": [
|
|
@@ -479,7 +481,7 @@
|
|
| 479 |
},
|
| 480 |
{
|
| 481 |
"cell_type": "code",
|
| 482 |
-
"execution_count":
|
| 483 |
"metadata": {},
|
| 484 |
"outputs": [],
|
| 485 |
"source": [
|
|
@@ -560,7 +562,7 @@
|
|
| 560 |
},
|
| 561 |
{
|
| 562 |
"cell_type": "code",
|
| 563 |
-
"execution_count":
|
| 564 |
"metadata": {},
|
| 565 |
"outputs": [],
|
| 566 |
"source": [
|
|
@@ -593,7 +595,7 @@
|
|
| 593 |
},
|
| 594 |
{
|
| 595 |
"cell_type": "code",
|
| 596 |
-
"execution_count":
|
| 597 |
"metadata": {},
|
| 598 |
"outputs": [],
|
| 599 |
"source": [
|
|
@@ -642,7 +644,7 @@
|
|
| 642 |
},
|
| 643 |
{
|
| 644 |
"cell_type": "code",
|
| 645 |
-
"execution_count":
|
| 646 |
"metadata": {},
|
| 647 |
"outputs": [],
|
| 648 |
"source": [
|
|
@@ -671,7 +673,7 @@
|
|
| 671 |
},
|
| 672 |
{
|
| 673 |
"cell_type": "code",
|
| 674 |
-
"execution_count":
|
| 675 |
"metadata": {},
|
| 676 |
"outputs": [],
|
| 677 |
"source": [
|
|
@@ -913,7 +915,7 @@
|
|
| 913 |
},
|
| 914 |
{
|
| 915 |
"cell_type": "code",
|
| 916 |
-
"execution_count":
|
| 917 |
"metadata": {},
|
| 918 |
"outputs": [],
|
| 919 |
"source": [
|
|
@@ -943,7 +945,7 @@
|
|
| 943 |
},
|
| 944 |
{
|
| 945 |
"cell_type": "code",
|
| 946 |
-
"execution_count":
|
| 947 |
"metadata": {},
|
| 948 |
"outputs": [],
|
| 949 |
"source": [
|
|
@@ -967,7 +969,7 @@
|
|
| 967 |
" n_epoch = 10#2#5#25 # 120\n",
|
| 968 |
" num_timesteps = 1000#1000 # 1000, 500; DDPM time steps\n",
|
| 969 |
" train_batch_size = 10#10#20#2#100 # 10\n",
|
| 970 |
-
" n_sample = 24 # 64, the number of samples in sampling process\n",
|
| 971 |
" n_param = 2\n",
|
| 972 |
" guide_w = 0#-1#0#-1#0.1#[0,0.1] #[0,0.5,2] strength of generative guidance\n",
|
| 973 |
" drop_prob = 0.28 # only takes effect when guide_w != -1\n",
|
|
@@ -987,9 +989,9 @@
|
|
| 987 |
" # cond = True # if training using the conditional information\n",
|
| 988 |
" # lr_decay = False #True# if using the learning rate decay\n",
|
| 989 |
" resume = 'model_state.pth' # if resume from the trained checkpoints\n",
|
| 990 |
-
"
|
| 991 |
-
"
|
| 992 |
-
" #
|
| 993 |
" # data_dir = './data' # data directory\n",
|
| 994 |
"\n",
|
| 995 |
" output_dir = \"./outputs/\"\n",
|
|
@@ -1006,7 +1008,7 @@
|
|
| 1006 |
},
|
| 1007 |
{
|
| 1008 |
"cell_type": "code",
|
| 1009 |
-
"execution_count":
|
| 1010 |
"metadata": {},
|
| 1011 |
"outputs": [],
|
| 1012 |
"source": [
|
|
@@ -1016,7 +1018,7 @@
|
|
| 1016 |
},
|
| 1017 |
{
|
| 1018 |
"cell_type": "code",
|
| 1019 |
-
"execution_count":
|
| 1020 |
"metadata": {},
|
| 1021 |
"outputs": [],
|
| 1022 |
"source": [
|
|
@@ -1025,7 +1027,7 @@
|
|
| 1025 |
},
|
| 1026 |
{
|
| 1027 |
"cell_type": "code",
|
| 1028 |
-
"execution_count":
|
| 1029 |
"metadata": {},
|
| 1030 |
"outputs": [],
|
| 1031 |
"source": [
|
|
@@ -1049,205 +1051,205 @@
|
|
| 1049 |
},
|
| 1050 |
{
|
| 1051 |
"cell_type": "code",
|
| 1052 |
-
"execution_count":
|
| 1053 |
"metadata": {},
|
| 1054 |
"outputs": [],
|
| 1055 |
"source": [
|
| 1056 |
-
"def train_loop(config, nn_model, ddpm, optimizer, dataloader, lr_scheduler): \n",
|
| 1057 |
-
"
|
| 1058 |
-
"
|
| 1059 |
-
"
|
| 1060 |
-
"
|
| 1061 |
-
"
|
| 1062 |
-
"
|
| 1063 |
-
"
|
| 1064 |
-
"
|
| 1065 |
-
"
|
| 1066 |
-
"
|
| 1067 |
-
"
|
| 1068 |
-
"
|
| 1069 |
-
"
|
| 1070 |
-
"
|
| 1071 |
-
"
|
| 1072 |
-
"
|
| 1073 |
-
"
|
| 1074 |
-
"
|
| 1075 |
-
"
|
| 1076 |
-
"
|
| 1077 |
-
"
|
| 1078 |
-
"
|
| 1079 |
-
"
|
| 1080 |
-
"
|
| 1081 |
-
"
|
| 1082 |
-
"
|
| 1083 |
-
"\n",
|
| 1084 |
-
"
|
| 1085 |
-
"
|
| 1086 |
" \n",
|
| 1087 |
-
"
|
| 1088 |
-
"
|
| 1089 |
-
"\n",
|
| 1090 |
-
"
|
| 1091 |
-
"\n",
|
| 1092 |
-
"
|
| 1093 |
-
"
|
| 1094 |
-
"
|
| 1095 |
-
"
|
| 1096 |
-
"
|
| 1097 |
-
"\n",
|
| 1098 |
-
"
|
| 1099 |
-
"
|
| 1100 |
-
"
|
| 1101 |
-
"
|
| 1102 |
-
"\n",
|
| 1103 |
-
"
|
| 1104 |
-
"
|
| 1105 |
-
"
|
| 1106 |
-
"\n",
|
| 1107 |
-
"
|
| 1108 |
-
"
|
| 1109 |
-
"
|
| 1110 |
-
"\n",
|
| 1111 |
-
"
|
| 1112 |
-
"
|
| 1113 |
-
"\n",
|
| 1114 |
-
"\n",
|
| 1115 |
-
"
|
| 1116 |
-
"
|
| 1117 |
-
"\n",
|
| 1118 |
-
"
|
| 1119 |
-
"
|
| 1120 |
-
"
|
| 1121 |
-
"
|
| 1122 |
-
"
|
| 1123 |
-
"
|
| 1124 |
-
"
|
| 1125 |
-
"
|
| 1126 |
-
"
|
| 1127 |
-
"
|
| 1128 |
-
"
|
| 1129 |
-
"
|
| 1130 |
-
"\n",
|
| 1131 |
-
"
|
| 1132 |
-
"
|
| 1133 |
-
"
|
| 1134 |
-
"
|
| 1135 |
-
"
|
| 1136 |
-
"
|
| 1137 |
-
"
|
| 1138 |
-
"
|
| 1139 |
-
"
|
| 1140 |
-
"
|
| 1141 |
-
"
|
| 1142 |
-
"
|
| 1143 |
-
"\n",
|
| 1144 |
-
"
|
| 1145 |
-
"
|
| 1146 |
-
"
|
| 1147 |
-
"
|
| 1148 |
-
"
|
| 1149 |
-
"
|
| 1150 |
-
"
|
| 1151 |
-
"
|
| 1152 |
-
"
|
| 1153 |
-
"\n",
|
| 1154 |
-
"
|
| 1155 |
-
"
|
| 1156 |
-
"
|
| 1157 |
" \n",
|
| 1158 |
-
"
|
| 1159 |
-
"
|
| 1160 |
-
"
|
| 1161 |
-
"
|
| 1162 |
-
"
|
| 1163 |
-
"
|
| 1164 |
" \n",
|
| 1165 |
-
"
|
| 1166 |
-
"
|
| 1167 |
-
"
|
| 1168 |
-
"
|
| 1169 |
-
"
|
| 1170 |
-
"
|
| 1171 |
-
"
|
| 1172 |
-
"
|
| 1173 |
-
"\n",
|
| 1174 |
-
"
|
| 1175 |
-
"
|
| 1176 |
-
"
|
| 1177 |
-
"\n",
|
| 1178 |
-
"
|
| 1179 |
-
"
|
| 1180 |
-
"
|
| 1181 |
-
"
|
| 1182 |
-
"
|
| 1183 |
-
"
|
| 1184 |
-
"
|
| 1185 |
-
"
|
| 1186 |
-
"\n",
|
| 1187 |
-
"
|
| 1188 |
-
"
|
| 1189 |
-
"
|
| 1190 |
-
"
|
| 1191 |
-
"\n",
|
| 1192 |
-
"\n",
|
| 1193 |
-
"
|
| 1194 |
-
"
|
| 1195 |
-
"
|
| 1196 |
-
"
|
| 1197 |
-
"
|
| 1198 |
-
"
|
| 1199 |
-
"
|
| 1200 |
-
"
|
| 1201 |
-
"
|
| 1202 |
-
"
|
| 1203 |
-
"
|
| 1204 |
-
"
|
| 1205 |
-
"
|
| 1206 |
-
"
|
| 1207 |
-
"
|
| 1208 |
-
"
|
| 1209 |
-
"
|
| 1210 |
-
"
|
| 1211 |
-
"
|
| 1212 |
-
"
|
| 1213 |
-
"
|
| 1214 |
-
"
|
| 1215 |
-
"\n",
|
| 1216 |
-
"
|
| 1217 |
-
"
|
| 1218 |
" \n",
|
| 1219 |
-
"
|
| 1220 |
-
"
|
| 1221 |
"\n",
|
| 1222 |
-
"
|
| 1223 |
-
"
|
| 1224 |
"\n",
|
| 1225 |
-
"
|
| 1226 |
-
"
|
| 1227 |
"\n",
|
| 1228 |
-
"
|
| 1229 |
-
"
|
| 1230 |
-
"
|
| 1231 |
-
"
|
| 1232 |
"\n",
|
| 1233 |
-
"
|
| 1234 |
-
"
|
| 1235 |
-
"
|
| 1236 |
"\n",
|
| 1237 |
-
"
|
| 1238 |
-
"
|
| 1239 |
-
"
|
| 1240 |
"\n",
|
| 1241 |
-
"
|
| 1242 |
-
"
|
| 1243 |
-
"
|
| 1244 |
-
"
|
| 1245 |
"\n"
|
| 1246 |
]
|
| 1247 |
},
|
| 1248 |
{
|
| 1249 |
"cell_type": "code",
|
| 1250 |
-
"execution_count":
|
| 1251 |
"metadata": {},
|
| 1252 |
"outputs": [
|
| 1253 |
{
|
|
@@ -1261,8 +1263,8 @@
|
|
| 1261 |
"loading 20 images randomly\n",
|
| 1262 |
"images loaded: (20, 1, 64, 512)\n",
|
| 1263 |
"params loaded: (20, 2)\n",
|
| 1264 |
-
"images rescaled to [-1.0, 0.
|
| 1265 |
-
"params rescaled to [0.0, 0.
|
| 1266 |
"resumed nn_model from model_state.pth\n",
|
| 1267 |
"Number of parameters for nn_model: 111048705\n"
|
| 1268 |
]
|
|
@@ -1289,7 +1291,7 @@
|
|
| 1289 |
" self.config = config\n",
|
| 1290 |
"\n",
|
| 1291 |
" dataset = Dataset4h5(config.dataset_name, num_image=config.num_image, HII_DIM=config.HII_DIM, num_redshift=config.num_redshift, drop_prob=config.drop_prob, dim=config.dim)\n",
|
| 1292 |
-
" self.shape_loaded = dataset.images.shape\n",
|
| 1293 |
" # print(\"shape_loaded =\", self.shape_loaded)\n",
|
| 1294 |
" self.dataloader = DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True)\n",
|
| 1295 |
" del dataset\n",
|
|
@@ -1420,10 +1422,12 @@
|
|
| 1420 |
" print('saved model at ' + self.config.output_dir + f\"model_state_{ep:02d}.pth\")\n",
|
| 1421 |
" # print('saved model at ' + config.save_dir + f\"model_epoch_{ep}_test_{config.run_name}.pth\")\n",
|
| 1422 |
"\n",
|
| 1423 |
-
" def sample(self, file,
|
|
|
|
| 1424 |
" model = self.ema_model if ema else self.nn_model\n",
|
|
|
|
| 1425 |
"\n",
|
| 1426 |
-
" x_last, x_entire = self.ddpm.sample(model, n_sample, self.
|
| 1427 |
"\n",
|
| 1428 |
" np.save(os.path.join(self.config.output_dir, f\"{self.config.run_name}{'ema' if ema else None}.npy\"), x_last)\n",
|
| 1429 |
" if entire:\n",
|
|
@@ -1436,13 +1440,27 @@
|
|
| 1436 |
},
|
| 1437 |
{
|
| 1438 |
"cell_type": "code",
|
| 1439 |
-
"execution_count":
|
| 1440 |
"metadata": {},
|
| 1441 |
"outputs": [
|
| 1442 |
{
|
| 1443 |
"data": {
|
| 1444 |
"application/vnd.jupyter.widget-view+json": {
|
| 1445 |
-
"model_id": "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1446 |
"version_major": 2,
|
| 1447 |
"version_minor": 0
|
| 1448 |
},
|
|
@@ -1456,7 +1474,7 @@
|
|
| 1456 |
{
|
| 1457 |
"data": {
|
| 1458 |
"application/vnd.jupyter.widget-view+json": {
|
| 1459 |
-
"model_id": "
|
| 1460 |
"version_major": 2,
|
| 1461 |
"version_minor": 0
|
| 1462 |
},
|
|
@@ -1470,7 +1488,7 @@
|
|
| 1470 |
{
|
| 1471 |
"data": {
|
| 1472 |
"application/vnd.jupyter.widget-view+json": {
|
| 1473 |
-
"model_id": "
|
| 1474 |
"version_major": 2,
|
| 1475 |
"version_minor": 0
|
| 1476 |
},
|
|
@@ -1484,7 +1502,7 @@
|
|
| 1484 |
{
|
| 1485 |
"data": {
|
| 1486 |
"application/vnd.jupyter.widget-view+json": {
|
| 1487 |
-
"model_id": "
|
| 1488 |
"version_major": 2,
|
| 1489 |
"version_minor": 0
|
| 1490 |
},
|
|
@@ -1498,7 +1516,7 @@
|
|
| 1498 |
{
|
| 1499 |
"data": {
|
| 1500 |
"application/vnd.jupyter.widget-view+json": {
|
| 1501 |
-
"model_id": "
|
| 1502 |
"version_major": 2,
|
| 1503 |
"version_minor": 0
|
| 1504 |
},
|
|
@@ -1516,11 +1534,44 @@
|
|
| 1516 |
},
|
| 1517 |
{
|
| 1518 |
"cell_type": "code",
|
| 1519 |
-
"execution_count":
|
| 1520 |
"metadata": {},
|
| 1521 |
-
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1522 |
"source": [
|
| 1523 |
-
"ddpm21cm.sample()"
|
| 1524 |
]
|
| 1525 |
},
|
| 1526 |
{
|
|
@@ -1718,7 +1769,7 @@
|
|
| 1718 |
"\n",
|
| 1719 |
"n_sample = 20\n",
|
| 1720 |
"with torch.no_grad():\n",
|
| 1721 |
-
" x_last_ema, x_ema_entire = ddpm.sample(nn_model, n_sample, (1,config.HII_DIM, config.num_redshift), config.device,
|
| 1722 |
"\n",
|
| 1723 |
"np.save(os.path.join(config.output_dir, f\"{config.run_name}_ema.npy\"), x_last_ema)"
|
| 1724 |
]
|
|
|
|
| 234 |
},
|
| 235 |
{
|
| 236 |
"cell_type": "code",
|
| 237 |
+
"execution_count": 90,
|
| 238 |
"metadata": {},
|
| 239 |
"outputs": [],
|
| 240 |
"source": [
|
|
|
|
| 256 |
" self.bar_alpha_t = torch.cumprod(self.alpha_t, dim=0)\n",
|
| 257 |
"\n",
|
| 258 |
" def add_noise(self, clean_images):\n",
|
| 259 |
+
" self.shape = clean_images.shape\n",
|
| 260 |
+
" expand = torch.ones(len(self.shape)-1, dtype=int)\n",
|
| 261 |
" # ts_expand = ts.view(ts.shape[0], *expand.tolist())\n",
|
| 262 |
" # expand = [1 for i in range(len(shape)-1)]\n",
|
| 263 |
"\n",
|
| 264 |
" noise = torch.randn_like(clean_images).to(self.device)\n",
|
| 265 |
+
" ts = torch.randint(0, self.num_timesteps, (self.shape[0],)).to(self.device)\n",
|
| 266 |
" \n",
|
| 267 |
" # test_expand = test.view(test.shape[0],*expand)\n",
|
| 268 |
" # extend_dim = [None for i in range(shape.dim()-1)]\n",
|
| 269 |
" noisy_images = (\n",
|
| 270 |
+
" clean_images * torch.sqrt(self.bar_alpha_t[ts]).view(self.shape[0], *expand.tolist())\n",
|
| 271 |
+
" + noise * torch.sqrt(1-self.bar_alpha_t[ts]).view(self.shape[0], *expand.tolist())\n",
|
| 272 |
" )\n",
|
| 273 |
" # print(x_t.shape)\n",
|
| 274 |
"\n",
|
| 275 |
" return noisy_images, noise, ts\n",
|
| 276 |
"\n",
|
| 277 |
+
" def sample(self, nn_model, params, device, guide_w = 0):\n",
|
| 278 |
+
" n_sample = params.shape[0]\n",
|
| 279 |
+
" print(\"params.shape[0], len(params)\", params.shape[0], len(params))\n",
|
| 280 |
+
" x_i = torch.randn(n_sample, *self.shape[1:]).to(device)\n",
|
| 281 |
" # print(\"x_i.shape =\", x_i.shape)\n",
|
| 282 |
" if guide_w != -1:\n",
|
| 283 |
+
" c_i = params\n",
|
| 284 |
+
" uncond_tokens = torch.zeros(int(n_sample), params.shape[1]).to(device)\n",
|
| 285 |
" # uncond_tokens = torch.tensor(np.float32(np.array([0,0]))).to(device)\n",
|
| 286 |
" # uncond_tokens = uncond_tokens.repeat(int(n_sample),1)\n",
|
| 287 |
" c_i = torch.cat((c_i, uncond_tokens), 0)\n",
|
|
|
|
| 297 |
" t_is = torch.tensor([i]).to(device)\n",
|
| 298 |
" t_is = t_is.repeat(n_sample)\n",
|
| 299 |
"\n",
|
| 300 |
+
" z = torch.randn(n_sample, *self.shape[1:]).to(device) if i > 0 else 0\n",
|
| 301 |
"\n",
|
| 302 |
" if guide_w == -1:\n",
|
| 303 |
" # eps = nn_model(x_i, t_is, return_dict=False)[0]\n",
|
|
|
|
| 305 |
" # x_i = 1/torch.sqrt(self.alpha_t[i])*(x_i-eps*self.beta_t[i]/torch.sqrt(1-self.bar_alpha_t[i])) + torch.sqrt(self.beta_t[i])*z\n",
|
| 306 |
" else:\n",
|
| 307 |
" # double batch\n",
|
| 308 |
+
" x_i = x_i.repeat(2, *torch.ones(len(self.shape[1:]), dtype=int).tolist())\n",
|
| 309 |
" t_is = t_is.repeat(2)\n",
|
| 310 |
"\n",
|
| 311 |
" # split predictions and compute weighting\n",
|
|
|
|
| 338 |
},
|
| 339 |
{
|
| 340 |
"cell_type": "code",
|
| 341 |
+
"execution_count": 91,
|
| 342 |
"metadata": {},
|
| 343 |
"outputs": [],
|
| 344 |
"source": [
|
|
|
|
| 380 |
},
|
| 381 |
{
|
| 382 |
"cell_type": "code",
|
| 383 |
+
"execution_count": 92,
|
| 384 |
"metadata": {},
|
| 385 |
"outputs": [],
|
| 386 |
"source": [
|
|
|
|
| 405 |
},
|
| 406 |
{
|
| 407 |
"cell_type": "code",
|
| 408 |
+
"execution_count": 93,
|
| 409 |
"metadata": {},
|
| 410 |
"outputs": [],
|
| 411 |
"source": [
|
|
|
|
| 434 |
},
|
| 435 |
{
|
| 436 |
"cell_type": "code",
|
| 437 |
+
"execution_count": 94,
|
| 438 |
"metadata": {},
|
| 439 |
"outputs": [],
|
| 440 |
"source": [
|
|
|
|
| 449 |
},
|
| 450 |
{
|
| 451 |
"cell_type": "code",
|
| 452 |
+
"execution_count": 95,
|
| 453 |
"metadata": {},
|
| 454 |
"outputs": [],
|
| 455 |
"source": [
|
|
|
|
| 463 |
},
|
| 464 |
{
|
| 465 |
"cell_type": "code",
|
| 466 |
+
"execution_count": 96,
|
| 467 |
"metadata": {},
|
| 468 |
"outputs": [],
|
| 469 |
"source": [
|
|
|
|
| 481 |
},
|
| 482 |
{
|
| 483 |
"cell_type": "code",
|
| 484 |
+
"execution_count": 97,
|
| 485 |
"metadata": {},
|
| 486 |
"outputs": [],
|
| 487 |
"source": [
|
|
|
|
| 562 |
},
|
| 563 |
{
|
| 564 |
"cell_type": "code",
|
| 565 |
+
"execution_count": 98,
|
| 566 |
"metadata": {},
|
| 567 |
"outputs": [],
|
| 568 |
"source": [
|
|
|
|
| 595 |
},
|
| 596 |
{
|
| 597 |
"cell_type": "code",
|
| 598 |
+
"execution_count": 99,
|
| 599 |
"metadata": {},
|
| 600 |
"outputs": [],
|
| 601 |
"source": [
|
|
|
|
| 644 |
},
|
| 645 |
{
|
| 646 |
"cell_type": "code",
|
| 647 |
+
"execution_count": 100,
|
| 648 |
"metadata": {},
|
| 649 |
"outputs": [],
|
| 650 |
"source": [
|
|
|
|
| 673 |
},
|
| 674 |
{
|
| 675 |
"cell_type": "code",
|
| 676 |
+
"execution_count": 101,
|
| 677 |
"metadata": {},
|
| 678 |
"outputs": [],
|
| 679 |
"source": [
|
|
|
|
| 915 |
},
|
| 916 |
{
|
| 917 |
"cell_type": "code",
|
| 918 |
+
"execution_count": 102,
|
| 919 |
"metadata": {},
|
| 920 |
"outputs": [],
|
| 921 |
"source": [
|
|
|
|
| 945 |
},
|
| 946 |
{
|
| 947 |
"cell_type": "code",
|
| 948 |
+
"execution_count": 103,
|
| 949 |
"metadata": {},
|
| 950 |
"outputs": [],
|
| 951 |
"source": [
|
|
|
|
| 969 |
" n_epoch = 10#2#5#25 # 120\n",
|
| 970 |
" num_timesteps = 1000#1000 # 1000, 500; DDPM time steps\n",
|
| 971 |
" train_batch_size = 10#10#20#2#100 # 10\n",
|
| 972 |
+
" # n_sample = 24 # 64, the number of samples in sampling process\n",
|
| 973 |
" n_param = 2\n",
|
| 974 |
" guide_w = 0#-1#0#-1#0.1#[0,0.1] #[0,0.5,2] strength of generative guidance\n",
|
| 975 |
" drop_prob = 0.28 # only takes effect when guide_w != -1\n",
|
|
|
|
| 989 |
" # cond = True # if training using the conditional information\n",
|
| 990 |
" # lr_decay = False #True# if using the learning rate decay\n",
|
| 991 |
" resume = 'model_state.pth' # if resume from the trained checkpoints\n",
|
| 992 |
+
" # params_single = torch.tensor([0.2,0.80000023])\n",
|
| 993 |
+
" # params = torch.tile(params_single,(n_sample,1)).to(device)\n",
|
| 994 |
+
" # params = params\n",
|
| 995 |
" # data_dir = './data' # data directory\n",
|
| 996 |
"\n",
|
| 997 |
" output_dir = \"./outputs/\"\n",
|
|
|
|
| 1008 |
},
|
| 1009 |
{
|
| 1010 |
"cell_type": "code",
|
| 1011 |
+
"execution_count": 104,
|
| 1012 |
"metadata": {},
|
| 1013 |
"outputs": [],
|
| 1014 |
"source": [
|
|
|
|
| 1018 |
},
|
| 1019 |
{
|
| 1020 |
"cell_type": "code",
|
| 1021 |
+
"execution_count": 105,
|
| 1022 |
"metadata": {},
|
| 1023 |
"outputs": [],
|
| 1024 |
"source": [
|
|
|
|
| 1027 |
},
|
| 1028 |
{
|
| 1029 |
"cell_type": "code",
|
| 1030 |
+
"execution_count": 106,
|
| 1031 |
"metadata": {},
|
| 1032 |
"outputs": [],
|
| 1033 |
"source": [
|
|
|
|
| 1051 |
},
|
| 1052 |
{
|
| 1053 |
"cell_type": "code",
|
| 1054 |
+
"execution_count": 107,
|
| 1055 |
"metadata": {},
|
| 1056 |
"outputs": [],
|
| 1057 |
"source": [
|
| 1058 |
+
"# def train_loop(config, nn_model, ddpm, optimizer, dataloader, lr_scheduler): \n",
|
| 1059 |
+
"# ########################\n",
|
| 1060 |
+
"# ## ready for training ##\n",
|
| 1061 |
+
"# ########################\n",
|
| 1062 |
+
"# # initialize the dataset\n",
|
| 1063 |
+
"# # num_image = 600\n",
|
| 1064 |
+
"# # HII_DIM = 64\n",
|
| 1065 |
+
"# # num_redshift = 64#512#128\n",
|
| 1066 |
+
"# # dataset = Dataset4h5(config.dataset_name, num_image=config.num_image, HII_DIM=config.HII_DIM, num_redshift=config.num_redshift, drop_prob=config.drop_prob)\n",
|
| 1067 |
+
"# # dataloader = DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True)\n",
|
| 1068 |
+
"# # tf = transforms.Compose([transforms.ToTensor()]) # mnist is already normalised 0 to 1\n",
|
| 1069 |
+
"# # dataset = MNIST(\"./data\", train=True, download=True, transform=tf)\n",
|
| 1070 |
+
"# # dataloader = DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True, num_workers=5) # Initialize accelerator and tensorboard logging\n",
|
| 1071 |
+
"# accelerator = Accelerator(\n",
|
| 1072 |
+
"# mixed_precision=config.mixed_precision,\n",
|
| 1073 |
+
"# gradient_accumulation_steps=config.gradient_accumulation_steps,\n",
|
| 1074 |
+
"# log_with=\"tensorboard\",\n",
|
| 1075 |
+
"# project_dir=os.path.join(config.output_dir, \"logs\"),\n",
|
| 1076 |
+
"# )\n",
|
| 1077 |
+
"# if accelerator.is_main_process:\n",
|
| 1078 |
+
"# if config.output_dir is not None:\n",
|
| 1079 |
+
"# os.makedirs(config.output_dir, exist_ok=True)\n",
|
| 1080 |
+
"# if config.push_to_hub:\n",
|
| 1081 |
+
"# repo_id = create_repo(\n",
|
| 1082 |
+
"# repo_id=config.hub_model_id or Path(config.output_dir).name, exist_ok=True\n",
|
| 1083 |
+
"# ).repo_id\n",
|
| 1084 |
+
"# accelerator.init_trackers(f\"{config.date}\")\n",
|
| 1085 |
+
"\n",
|
| 1086 |
+
"# nn_model, optimizer, dataloader, lr_scheduler = accelerator.prepare(\n",
|
| 1087 |
+
"# nn_model, optimizer, dataloader, lr_scheduler)\n",
|
| 1088 |
" \n",
|
| 1089 |
+
"# # initialize the DDPM\n",
|
| 1090 |
+
"# # logger = SummaryWriter(os.path.join(\"runs\", config.run_name)) # To log\n",
|
| 1091 |
+
"\n",
|
| 1092 |
+
"# # ddpm = DDPMScheduler(betas=(1e-4, 0.02), num_timesteps=config.num_timesteps, device=config.device)\n",
|
| 1093 |
+
"\n",
|
| 1094 |
+
"# # # initialize the unet\n",
|
| 1095 |
+
"# # nn_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM)\n",
|
| 1096 |
+
"# # # nn_model = ContextUnet(n_param=1, image_size=28)\n",
|
| 1097 |
+
"# # nn_model.train()\n",
|
| 1098 |
+
"# # nn_model.to(ddpm.device)\n",
|
| 1099 |
+
"\n",
|
| 1100 |
+
"# # parameters to be optimized\n",
|
| 1101 |
+
"# # params_to_optimize = [\n",
|
| 1102 |
+
"# # {'params': nn_model.parameters()}\n",
|
| 1103 |
+
"# # ]\n",
|
| 1104 |
+
"\n",
|
| 1105 |
+
"# # number of parameters to be trained\n",
|
| 1106 |
+
"# number_of_params = sum(x.numel() for x in nn_model.parameters())\n",
|
| 1107 |
+
"# print(f\"Number of parameters for unet: {number_of_params}\")\n",
|
| 1108 |
+
"\n",
|
| 1109 |
+
"# # # optionally load a model\n",
|
| 1110 |
+
"# # if config.resume:\n",
|
| 1111 |
+
"# # ddpm.load_state_dict(torch.load(os.path.join(config.save_dir, f\"train-{ep}xscale_test_{run_name}.npy\")))\n",
|
| 1112 |
+
"\n",
|
| 1113 |
+
"# # define the loss function\n",
|
| 1114 |
+
"# loss_mse = nn.MSELoss()\n",
|
| 1115 |
+
"\n",
|
| 1116 |
+
"\n",
|
| 1117 |
+
"# # initialize optimizer\n",
|
| 1118 |
+
"# # optim = torch.optim.Adam(params_to_optimize, lr=config.lrate)\n",
|
| 1119 |
+
"\n",
|
| 1120 |
+
"# # whether to use ema\n",
|
| 1121 |
+
"# if config.ema:\n",
|
| 1122 |
+
"# ema = EMA(config.ema_rate)\n",
|
| 1123 |
+
"# if config.resume:\n",
|
| 1124 |
+
"# print(\"resuming ema_model\")\n",
|
| 1125 |
+
"# # ema_model = DDPMScheduler(betas=(1e-4, 0.02), num_timesteps=config.num_timesteps, device=config.device)\n",
|
| 1126 |
+
"# ema_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM).to(config.device)\n",
|
| 1127 |
+
"# # print(\"ema_model.device =\", ema_model.device)\n",
|
| 1128 |
+
"# ema_model.load_state_dict(torch.load(os.path.join(config.output_dir, f\"model_state.pth\"))['ema_unet_state_dict'])\n",
|
| 1129 |
+
"# # ema_model.load_state_dict(torch.load(os.path.join(config.output_dir, f\"train-{ep}xscale_test_{config.run_name}_ema.npy\")))\n",
|
| 1130 |
+
"# else:\n",
|
| 1131 |
+
"# ema_model = copy.deepcopy(nn_model).eval().requires_grad_(False)\n",
|
| 1132 |
+
"\n",
|
| 1133 |
+
"# ################### \n",
|
| 1134 |
+
"# ## training loop ##\n",
|
| 1135 |
+
"# ###################\n",
|
| 1136 |
+
"# # plot_unet = True\n",
|
| 1137 |
+
"# global_step = 0\n",
|
| 1138 |
+
"# for ep in range(config.n_epoch):\n",
|
| 1139 |
+
"# # print(f'epoch {ep}')\n",
|
| 1140 |
+
"# # print(\"ddpm.train()\")\n",
|
| 1141 |
+
"# ddpm.train()\n",
|
| 1142 |
+
"# # linear lrate decay\n",
|
| 1143 |
+
"# # if config.lr_decay:\n",
|
| 1144 |
+
"# # optim.param_groups[0]['lr'] = config.lrate*(1-ep/config.n_epoch)\n",
|
| 1145 |
+
"\n",
|
| 1146 |
+
"# # data loader with progress bar\n",
|
| 1147 |
+
"# pbar_train = tqdm(total=len(dataloader), disable=not accelerator.is_local_main_process)\n",
|
| 1148 |
+
"# pbar_train.set_description(f\"Epoch {ep}\")\n",
|
| 1149 |
+
"# for i, (x, c) in enumerate(dataloader):\n",
|
| 1150 |
+
"# # global_step = ep * len(dataloader) + i\n",
|
| 1151 |
+
"# with accelerator.accumulate(nn_model):\n",
|
| 1152 |
+
"# # optim.zero_grad()\n",
|
| 1153 |
+
"# x = x.to(config.device)\n",
|
| 1154 |
+
"# xt, noise, ts = ddpm.add_noise(x)\n",
|
| 1155 |
+
"\n",
|
| 1156 |
+
"# # noise = torch.randn(x.shape, device=x.device)\n",
|
| 1157 |
+
"# # ts = torch.randint(0, num_timesteps, (x.shape[0],), device=x.device, dtype=torch.int64)\n",
|
| 1158 |
+
"# # xt = ddpm.add_noise(x, noise, ts)\n",
|
| 1159 |
" \n",
|
| 1160 |
+
"# if config.guide_w == -1:\n",
|
| 1161 |
+
"# # noise_pred = nn_model(xt, ts, return_dict=False)[0]\n",
|
| 1162 |
+
"# noise_pred = nn_model(xt, ts)\n",
|
| 1163 |
+
"# else:\n",
|
| 1164 |
+
"# c = c.to(config.device)\n",
|
| 1165 |
+
"# noise_pred = nn_model(xt, ts, c)\n",
|
| 1166 |
" \n",
|
| 1167 |
+
"# loss = loss_mse(noise, noise_pred)\n",
|
| 1168 |
+
"# accelerator.backward(loss)\n",
|
| 1169 |
+
"# # loss.backward()\n",
|
| 1170 |
+
"# # optim.step()\n",
|
| 1171 |
+
"# accelerator.clip_grad_norm_(nn_model.parameters(), 1)\n",
|
| 1172 |
+
"# optimizer.step()\n",
|
| 1173 |
+
"# lr_scheduler.step()\n",
|
| 1174 |
+
"# optimizer.zero_grad()\n",
|
| 1175 |
+
"\n",
|
| 1176 |
+
"# # ema update\n",
|
| 1177 |
+
"# if config.ema:\n",
|
| 1178 |
+
"# ema.step_ema(ema_model, nn_model)\n",
|
| 1179 |
+
"\n",
|
| 1180 |
+
"# # pbar.set_description(f\"epoch {ep} loss {loss.item():.4f}\")\n",
|
| 1181 |
+
"# pbar_train.update(1)\n",
|
| 1182 |
+
"# logs = dict(\n",
|
| 1183 |
+
"# loss=loss.detach().item(),\n",
|
| 1184 |
+
"# lr=optimizer.param_groups[0]['lr'],\n",
|
| 1185 |
+
"# step=global_step\n",
|
| 1186 |
+
"# )\n",
|
| 1187 |
+
"# pbar_train.set_postfix(**logs)\n",
|
| 1188 |
+
"\n",
|
| 1189 |
+
"# # logging loss\n",
|
| 1190 |
+
"# # logger.add_scalar(\"MSE\", loss.item(), global_step=global_step)\n",
|
| 1191 |
+
"# accelerator.log(logs, step=global_step)\n",
|
| 1192 |
+
"# global_step += 1\n",
|
| 1193 |
+
"\n",
|
| 1194 |
+
"\n",
|
| 1195 |
+
"# if accelerator.is_main_process:\n",
|
| 1196 |
+
"# # sample the image\n",
|
| 1197 |
+
"# if ep == config.n_epoch-1 or (ep+1)*config.save_freq==1:\n",
|
| 1198 |
+
"# nn_model.eval()\n",
|
| 1199 |
+
"# with torch.no_grad():\n",
|
| 1200 |
+
"# # save model\n",
|
| 1201 |
+
"# if config.push_to_hub:\n",
|
| 1202 |
+
"# upload_folder(\n",
|
| 1203 |
+
"# repo_id = repo_id,\n",
|
| 1204 |
+
"# folder_path = \".\",#config.output_dir,\n",
|
| 1205 |
+
"# commit_message = f\"{config.date}\",\n",
|
| 1206 |
+
"# ignore_patterns = [\"step_*\", \"epoch_*\", \"*.npy\"],\n",
|
| 1207 |
+
"# )\n",
|
| 1208 |
+
"# if config.save_model:\n",
|
| 1209 |
+
"# model_state = {\n",
|
| 1210 |
+
"# 'epoch': ep,\n",
|
| 1211 |
+
"# 'unet_state_dict': nn_model.state_dict(),\n",
|
| 1212 |
+
"# 'ema_unet_state_dict': ema_model.state_dict(),\n",
|
| 1213 |
+
"# }\n",
|
| 1214 |
+
"# torch.save(model_state, config.output_dir + f\"model_state.pth\")\n",
|
| 1215 |
+
"# print('saved model at ' + config.output_dir + f\"model_state.pth\")\n",
|
| 1216 |
+
"# # print('saved model at ' + config.save_dir + f\"model_epoch_{ep}_test_{config.run_name}.pth\")\n",
|
| 1217 |
+
"\n",
|
| 1218 |
+
"# # loop over the guidance scale\n",
|
| 1219 |
+
"# # for w in config.ws_test: \n",
|
| 1220 |
" \n",
|
| 1221 |
+
"# # pipeline = DDPMPipeline(unet=nn_model, scheduler=ddpm)\n",
|
| 1222 |
+
"# # evaluate(config, ep, pipeline)\n",
|
| 1223 |
"\n",
|
| 1224 |
+
"# # only output the image x0, omit the stored intermediate steps, OTHERWISE, uncomment \n",
|
| 1225 |
+
"# # line 142, 143 and output 'x_last, x_store = ' here.\n",
|
| 1226 |
"\n",
|
| 1227 |
+
"# # x_last_tot = []\n",
|
| 1228 |
+
"# x_last, x_entire = ddpm.sample(nn_model,config.n_sample, x.shape[1:], config.device, params=config.params, guide_w=config.guide_w)\n",
|
| 1229 |
"\n",
|
| 1230 |
+
"# # sample_save_dir = os.path.join(config.save_dir, f\"{config.run_name}.npy\")\n",
|
| 1231 |
+
"# np.save(os.path.join(config.output_dir, f\"{config.run_name}.npy\"), x_last)\n",
|
| 1232 |
+
"# # np.save(os.path.join(config.save_dir, f\"{config.run_name}_entire.npy\"), x_entire)\n",
|
| 1233 |
+
"# # print(f\"saved to {config.save_dir}\")\n",
|
| 1234 |
"\n",
|
| 1235 |
+
"# if config.ema:\n",
|
| 1236 |
+
"# # x_last_tot_ema = []\n",
|
| 1237 |
+
"# x_last_ema, x_entire_ema = ddpm.sample(ema_model,config.n_sample, x.shape[1:], config.device, params=config.params, guide_w=config.guide_w)\n",
|
| 1238 |
"\n",
|
| 1239 |
+
"# np.save(os.path.join(config.output_dir, f\"{config.run_name}_ema.npy\"), x_last_ema)\n",
|
| 1240 |
+
"# # np.save(os.path.join(config.save_dir, f\"{config.run_name}_ema_entire.npy\"), x_entire_ema)\n",
|
| 1241 |
+
"# # print(f\"saved to {config.save_dir}\")\n",
|
| 1242 |
"\n",
|
| 1243 |
+
"# # x_last_tot.append(np.array(x_last.cpu()))\n",
|
| 1244 |
+
"# # x_last_tot=np.array(x_last_tot)\n",
|
| 1245 |
+
"# # x_last_tot_ema.append(np.array(x_last_ema.cpu()))\n",
|
| 1246 |
+
"# # x_last_tot_ema=np.array(x_last_tot_ema)\n",
|
| 1247 |
"\n"
|
| 1248 |
]
|
| 1249 |
},
|
| 1250 |
{
|
| 1251 |
"cell_type": "code",
|
| 1252 |
+
"execution_count": 109,
|
| 1253 |
"metadata": {},
|
| 1254 |
"outputs": [
|
| 1255 |
{
|
|
|
|
| 1263 |
"loading 20 images randomly\n",
|
| 1264 |
"images loaded: (20, 1, 64, 512)\n",
|
| 1265 |
"params loaded: (20, 2)\n",
|
| 1266 |
+
"images rescaled to [-1.0, 0.946401834487915]\n",
|
| 1267 |
+
"params rescaled to [0.0, 0.9683106587269014]\n",
|
| 1268 |
"resumed nn_model from model_state.pth\n",
|
| 1269 |
"Number of parameters for nn_model: 111048705\n"
|
| 1270 |
]
|
|
|
|
| 1291 |
" self.config = config\n",
|
| 1292 |
"\n",
|
| 1293 |
" dataset = Dataset4h5(config.dataset_name, num_image=config.num_image, HII_DIM=config.HII_DIM, num_redshift=config.num_redshift, drop_prob=config.drop_prob, dim=config.dim)\n",
|
| 1294 |
+
" # self.shape_loaded = dataset.images.shape\n",
|
| 1295 |
" # print(\"shape_loaded =\", self.shape_loaded)\n",
|
| 1296 |
" self.dataloader = DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True)\n",
|
| 1297 |
" del dataset\n",
|
|
|
|
| 1422 |
" print('saved model at ' + self.config.output_dir + f\"model_state_{ep:02d}.pth\")\n",
|
| 1423 |
" # print('saved model at ' + config.save_dir + f\"model_epoch_{ep}_test_{config.run_name}.pth\")\n",
|
| 1424 |
"\n",
|
| 1425 |
+
" def sample(self, file, params=[0.2,0.8], ema=False, entire=False):\n",
|
| 1426 |
+
" n_sample = params.shape[0]\n",
|
| 1427 |
" model = self.ema_model if ema else self.nn_model\n",
|
| 1428 |
+
" # params = torch.tile(params, (n_sample,1)).to(device)\n",
|
| 1429 |
"\n",
|
| 1430 |
+
" x_last, x_entire = self.ddpm.sample(model, n_sample, self.config.device, params=params, guide_w=self.config.guide_w)\n",
|
| 1431 |
"\n",
|
| 1432 |
" np.save(os.path.join(self.config.output_dir, f\"{self.config.run_name}{'ema' if ema else None}.npy\"), x_last)\n",
|
| 1433 |
" if entire:\n",
|
|
|
|
| 1440 |
},
|
| 1441 |
{
|
| 1442 |
"cell_type": "code",
|
| 1443 |
+
"execution_count": 110,
|
| 1444 |
"metadata": {},
|
| 1445 |
"outputs": [
|
| 1446 |
{
|
| 1447 |
"data": {
|
| 1448 |
"application/vnd.jupyter.widget-view+json": {
|
| 1449 |
+
"model_id": "7b26c7444a1144728f0299db5d5683b1",
|
| 1450 |
+
"version_major": 2,
|
| 1451 |
+
"version_minor": 0
|
| 1452 |
+
},
|
| 1453 |
+
"text/plain": [
|
| 1454 |
+
" 0%| | 0/2 [00:00<?, ?it/s]"
|
| 1455 |
+
]
|
| 1456 |
+
},
|
| 1457 |
+
"metadata": {},
|
| 1458 |
+
"output_type": "display_data"
|
| 1459 |
+
},
|
| 1460 |
+
{
|
| 1461 |
+
"data": {
|
| 1462 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 1463 |
+
"model_id": "e3b8c1a18460443d986282f284bf0b42",
|
| 1464 |
"version_major": 2,
|
| 1465 |
"version_minor": 0
|
| 1466 |
},
|
|
|
|
| 1474 |
{
|
| 1475 |
"data": {
|
| 1476 |
"application/vnd.jupyter.widget-view+json": {
|
| 1477 |
+
"model_id": "84d558fcea2f4b998f311da7452a2c93",
|
| 1478 |
"version_major": 2,
|
| 1479 |
"version_minor": 0
|
| 1480 |
},
|
|
|
|
| 1488 |
{
|
| 1489 |
"data": {
|
| 1490 |
"application/vnd.jupyter.widget-view+json": {
|
| 1491 |
+
"model_id": "364208bd1fa04859baf233ed30451638",
|
| 1492 |
"version_major": 2,
|
| 1493 |
"version_minor": 0
|
| 1494 |
},
|
|
|
|
| 1502 |
{
|
| 1503 |
"data": {
|
| 1504 |
"application/vnd.jupyter.widget-view+json": {
|
| 1505 |
+
"model_id": "e8aa693ca9b444d7b6adb77531f8b718",
|
| 1506 |
"version_major": 2,
|
| 1507 |
"version_minor": 0
|
| 1508 |
},
|
|
|
|
| 1516 |
{
|
| 1517 |
"data": {
|
| 1518 |
"application/vnd.jupyter.widget-view+json": {
|
| 1519 |
+
"model_id": "aab44a74df8c4a7a8baf975b8ec50b8b",
|
| 1520 |
"version_major": 2,
|
| 1521 |
"version_minor": 0
|
| 1522 |
},
|
|
|
|
| 1534 |
},
|
| 1535 |
{
|
| 1536 |
"cell_type": "code",
|
| 1537 |
+
"execution_count": 69,
|
| 1538 |
"metadata": {},
|
| 1539 |
+
"outputs": [
|
| 1540 |
+
{
|
| 1541 |
+
"data": {
|
| 1542 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 1543 |
+
"model_id": "89a5be983ade43d89a1be3d977750a40",
|
| 1544 |
+
"version_major": 2,
|
| 1545 |
+
"version_minor": 0
|
| 1546 |
+
},
|
| 1547 |
+
"text/plain": [
|
| 1548 |
+
" 0%| | 0/1000 [00:00<?, ?it/s]"
|
| 1549 |
+
]
|
| 1550 |
+
},
|
| 1551 |
+
"metadata": {},
|
| 1552 |
+
"output_type": "display_data"
|
| 1553 |
+
},
|
| 1554 |
+
{
|
| 1555 |
+
"ename": "RuntimeError",
|
| 1556 |
+
"evalue": "The size of tensor a (24) must match the size of tensor b (36) at non-singleton dimension 0",
|
| 1557 |
+
"output_type": "error",
|
| 1558 |
+
"traceback": [
|
| 1559 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
| 1560 |
+
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
|
| 1561 |
+
"Cell \u001b[0;32mIn[69], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m ddpm21cm\u001b[39m.\u001b[39;49msample(\u001b[39m\"\u001b[39;49m\u001b[39m./outputs/model_state_09.pth\u001b[39;49m\u001b[39m\"\u001b[39;49m)\n",
|
| 1562 |
+
"Cell \u001b[0;32mIn[67], line 141\u001b[0m, in \u001b[0;36mDDPM21CM.sample\u001b[0;34m(self, file, n_sample, ema, entire)\u001b[0m\n\u001b[1;32m 138\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39msample\u001b[39m(\u001b[39mself\u001b[39m, file, n_sample\u001b[39m=\u001b[39m\u001b[39m12\u001b[39m, ema\u001b[39m=\u001b[39m\u001b[39mFalse\u001b[39;00m, entire\u001b[39m=\u001b[39m\u001b[39mFalse\u001b[39;00m):\n\u001b[1;32m 139\u001b[0m model \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mema_model \u001b[39mif\u001b[39;00m ema \u001b[39melse\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mnn_model\n\u001b[0;32m--> 141\u001b[0m x_last, x_entire \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mddpm\u001b[39m.\u001b[39;49msample(model, n_sample, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mshape_loaded[\u001b[39m1\u001b[39;49m:], \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mconfig\u001b[39m.\u001b[39;49mdevice, test_param\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mconfig\u001b[39m.\u001b[39;49mtest_param, guide_w\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mconfig\u001b[39m.\u001b[39;49mguide_w)\n\u001b[1;32m 143\u001b[0m np\u001b[39m.\u001b[39msave(os\u001b[39m.\u001b[39mpath\u001b[39m.\u001b[39mjoin(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconfig\u001b[39m.\u001b[39moutput_dir, \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m{\u001b[39;00m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconfig\u001b[39m.\u001b[39mrun_name\u001b[39m}\u001b[39;00m\u001b[39m{\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39mema\u001b[39m\u001b[39m'\u001b[39m\u001b[39m \u001b[39m\u001b[39mif\u001b[39;00m\u001b[39m \u001b[39mema\u001b[39m \u001b[39m\u001b[39melse\u001b[39;00m\u001b[39m \u001b[39m\u001b[39mNone\u001b[39;00m\u001b[39m}\u001b[39;00m\u001b[39m.npy\u001b[39m\u001b[39m\"\u001b[39m), x_last)\n\u001b[1;32m 144\u001b[0m \u001b[39mif\u001b[39;00m entire:\n",
|
| 1563 |
+
"Cell \u001b[0;32mIn[7], line 70\u001b[0m, in \u001b[0;36mDDPMScheduler.sample\u001b[0;34m(self, nn_model, n_sample, shape, device, test_param, guide_w)\u001b[0m\n\u001b[1;32m 67\u001b[0m t_is \u001b[39m=\u001b[39m t_is\u001b[39m.\u001b[39mrepeat(\u001b[39m2\u001b[39m)\n\u001b[1;32m 69\u001b[0m \u001b[39m# split predictions and compute weighting\u001b[39;00m\n\u001b[0;32m---> 70\u001b[0m eps \u001b[39m=\u001b[39m nn_model(x_i, t_is, c_i)\n\u001b[1;32m 71\u001b[0m eps1 \u001b[39m=\u001b[39m eps[:n_sample]\n\u001b[1;32m 72\u001b[0m eps2 \u001b[39m=\u001b[39m eps[n_sample:]\n",
|
| 1564 |
+
"File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
|
| 1565 |
+
"File \u001b[0;32m~/.conda/envs/diffusers/lib/python3.9/site-packages/accelerate/utils/operations.py:822\u001b[0m, in \u001b[0;36mconvert_outputs_to_fp32.<locals>.forward\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 821\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mforward\u001b[39m(\u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs):\n\u001b[0;32m--> 822\u001b[0m \u001b[39mreturn\u001b[39;00m model_forward(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n",
|
| 1566 |
+
"File \u001b[0;32m~/.conda/envs/diffusers/lib/python3.9/site-packages/accelerate/utils/operations.py:810\u001b[0m, in \u001b[0;36mConvertOutputsToFp32.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 809\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__call__\u001b[39m(\u001b[39mself\u001b[39m, \u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs):\n\u001b[0;32m--> 810\u001b[0m \u001b[39mreturn\u001b[39;00m convert_to_fp32(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mmodel_forward(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs))\n",
|
| 1567 |
+
"File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/amp/autocast_mode.py:12\u001b[0m, in \u001b[0;36mautocast_decorator.<locals>.decorate_autocast\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[39m@functools\u001b[39m\u001b[39m.\u001b[39mwraps(func)\n\u001b[1;32m 10\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mdecorate_autocast\u001b[39m(\u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs):\n\u001b[1;32m 11\u001b[0m \u001b[39mwith\u001b[39;00m autocast_instance:\n\u001b[0;32m---> 12\u001b[0m \u001b[39mreturn\u001b[39;00m func(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n",
|
| 1568 |
+
"Cell \u001b[0;32mIn[18], line 211\u001b[0m, in \u001b[0;36mContextUnet.forward\u001b[0;34m(self, x, timesteps, y)\u001b[0m\n\u001b[1;32m 209\u001b[0m \u001b[39mif\u001b[39;00m y \u001b[39m!=\u001b[39m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m 210\u001b[0m text_outputs \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtoken_embedding(y\u001b[39m.\u001b[39mfloat())\n\u001b[0;32m--> 211\u001b[0m emb \u001b[39m=\u001b[39m emb \u001b[39m+\u001b[39;49m text_outputs\u001b[39m.\u001b[39;49mto(emb)\n\u001b[1;32m 213\u001b[0m h \u001b[39m=\u001b[39m x\u001b[39m.\u001b[39mtype(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdtype)\n\u001b[1;32m 214\u001b[0m \u001b[39m# print(\"0,h.shape =\", h.shape)\u001b[39;00m\n",
|
| 1569 |
+
"\u001b[0;31mRuntimeError\u001b[0m: The size of tensor a (24) must match the size of tensor b (36) at non-singleton dimension 0"
|
| 1570 |
+
]
|
| 1571 |
+
}
|
| 1572 |
+
],
|
| 1573 |
"source": [
|
| 1574 |
+
"ddpm21cm.sample(\"./outputs/model_state_09.pth\")"
|
| 1575 |
]
|
| 1576 |
},
|
| 1577 |
{
|
|
|
|
| 1769 |
"\n",
|
| 1770 |
"n_sample = 20\n",
|
| 1771 |
"with torch.no_grad():\n",
|
| 1772 |
+
" x_last_ema, x_ema_entire = ddpm.sample(nn_model, n_sample, (1,config.HII_DIM, config.num_redshift), config.device, params = torch.tile(config.params_single,(n_sample,1)).to(config.device), guide_w=config.guide_w)\n",
|
| 1773 |
"\n",
|
| 1774 |
"np.save(os.path.join(config.output_dir, f\"{config.run_name}_ema.npy\"), x_last_ema)"
|
| 1775 |
]
|