0521-1651
Browse files- diffusion.ipynb +289 -196
diffusion.ipynb
CHANGED
|
@@ -1053,203 +1053,235 @@
|
|
| 1053 |
"metadata": {},
|
| 1054 |
"outputs": [],
|
| 1055 |
"source": [
|
| 1056 |
-
"
|
| 1057 |
-
"#
|
| 1058 |
-
"#
|
| 1059 |
-
"#
|
| 1060 |
-
"#
|
| 1061 |
-
"#
|
| 1062 |
-
"#
|
| 1063 |
-
"#
|
| 1064 |
-
"#
|
| 1065 |
-
"#
|
| 1066 |
-
"#
|
| 1067 |
-
"#
|
| 1068 |
-
"#
|
| 1069 |
-
"
|
| 1070 |
-
"
|
| 1071 |
-
"
|
| 1072 |
-
"
|
| 1073 |
-
"
|
| 1074 |
-
"
|
| 1075 |
-
"
|
| 1076 |
-
"
|
| 1077 |
-
"
|
| 1078 |
-
"
|
| 1079 |
-
"
|
| 1080 |
-
"
|
| 1081 |
-
"
|
| 1082 |
-
"
|
| 1083 |
-
"\n",
|
| 1084 |
-
"
|
| 1085 |
-
"
|
| 1086 |
" \n",
|
| 1087 |
-
"#
|
| 1088 |
-
"#
|
| 1089 |
-
"\n",
|
| 1090 |
-
"#
|
| 1091 |
-
"\n",
|
| 1092 |
-
"#
|
| 1093 |
-
"#
|
| 1094 |
-
"#
|
| 1095 |
-
"#
|
| 1096 |
-
"#
|
| 1097 |
-
"\n",
|
| 1098 |
-
"#
|
| 1099 |
-
"#
|
| 1100 |
-
"#
|
| 1101 |
-
"#
|
| 1102 |
-
"\n",
|
| 1103 |
-
"#
|
| 1104 |
-
"
|
| 1105 |
-
"
|
| 1106 |
-
"\n",
|
| 1107 |
-
"#
|
| 1108 |
-
"#
|
| 1109 |
-
"#
|
| 1110 |
-
"\n",
|
| 1111 |
-
"#
|
| 1112 |
-
"
|
| 1113 |
-
"\n",
|
| 1114 |
-
"\n",
|
| 1115 |
-
"#
|
| 1116 |
-
"#
|
| 1117 |
-
"\n",
|
| 1118 |
-
"#
|
| 1119 |
-
"
|
| 1120 |
-
"
|
| 1121 |
-
"
|
| 1122 |
-
"
|
| 1123 |
-
"#
|
| 1124 |
-
"
|
| 1125 |
-
"#
|
| 1126 |
-
"
|
| 1127 |
-
"#
|
| 1128 |
-
"
|
| 1129 |
-
"
|
| 1130 |
-
"\n",
|
| 1131 |
-
"#
|
| 1132 |
-
"#
|
| 1133 |
-
"#
|
| 1134 |
-
"#
|
| 1135 |
-
"
|
| 1136 |
-
"
|
| 1137 |
-
"#
|
| 1138 |
-
"#
|
| 1139 |
-
"
|
| 1140 |
-
"#
|
| 1141 |
-
"#
|
| 1142 |
-
"#
|
| 1143 |
-
"\n",
|
| 1144 |
-
"#
|
| 1145 |
-
"
|
| 1146 |
-
"
|
| 1147 |
-
"
|
| 1148 |
-
"#
|
| 1149 |
-
"
|
| 1150 |
-
"#
|
| 1151 |
-
"
|
| 1152 |
-
"
|
| 1153 |
-
"\n",
|
| 1154 |
-
"#
|
| 1155 |
-
"#
|
| 1156 |
-
"#
|
| 1157 |
" \n",
|
| 1158 |
-
"
|
| 1159 |
-
"#
|
| 1160 |
-
"
|
| 1161 |
-
"
|
| 1162 |
-
"
|
| 1163 |
-
"
|
| 1164 |
" \n",
|
| 1165 |
-
"
|
| 1166 |
-
"
|
| 1167 |
-
"#
|
| 1168 |
-
"#
|
| 1169 |
-
"
|
| 1170 |
-
"
|
| 1171 |
-
"
|
| 1172 |
-
"
|
| 1173 |
-
"\n",
|
| 1174 |
-
"#
|
| 1175 |
-
"
|
| 1176 |
-
"
|
| 1177 |
-
"\n",
|
| 1178 |
-
"#
|
| 1179 |
-
"
|
| 1180 |
-
"
|
| 1181 |
-
"
|
| 1182 |
-
"
|
| 1183 |
-
"
|
| 1184 |
-
"
|
| 1185 |
-
"
|
| 1186 |
-
"\n",
|
| 1187 |
-
"#
|
| 1188 |
-
"#
|
| 1189 |
-
"
|
| 1190 |
-
"
|
| 1191 |
-
"\n",
|
| 1192 |
-
"\n",
|
| 1193 |
-
"
|
| 1194 |
-
"#
|
| 1195 |
-
"
|
| 1196 |
-
"
|
| 1197 |
-
"
|
| 1198 |
-
"#
|
| 1199 |
-
"
|
| 1200 |
-
"
|
| 1201 |
-
"
|
| 1202 |
-
"
|
| 1203 |
-
"
|
| 1204 |
-
"
|
| 1205 |
-
"
|
| 1206 |
-
"
|
| 1207 |
-
"
|
| 1208 |
-
"
|
| 1209 |
-
"
|
| 1210 |
-
"
|
| 1211 |
-
"
|
| 1212 |
-
"
|
| 1213 |
-
"
|
| 1214 |
-
"#
|
| 1215 |
-
"\n",
|
| 1216 |
-
"#
|
| 1217 |
-
"#
|
| 1218 |
" \n",
|
| 1219 |
-
"#
|
| 1220 |
-
"#
|
| 1221 |
"\n",
|
| 1222 |
-
"#
|
| 1223 |
-
"#
|
| 1224 |
"\n",
|
| 1225 |
-
"#
|
| 1226 |
-
"
|
| 1227 |
"\n",
|
| 1228 |
-
"#
|
| 1229 |
-
"
|
| 1230 |
-
"#
|
| 1231 |
-
"#
|
| 1232 |
"\n",
|
| 1233 |
-
"
|
| 1234 |
-
"#
|
| 1235 |
-
"
|
| 1236 |
"\n",
|
| 1237 |
-
"
|
| 1238 |
-
"#
|
| 1239 |
-
"#
|
| 1240 |
"\n",
|
| 1241 |
-
"#
|
| 1242 |
-
"#
|
| 1243 |
-
"#
|
| 1244 |
-
"#
|
| 1245 |
"\n"
|
| 1246 |
]
|
| 1247 |
},
|
| 1248 |
{
|
| 1249 |
"cell_type": "code",
|
| 1250 |
-
"execution_count":
|
| 1251 |
"metadata": {},
|
| 1252 |
-
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1253 |
"source": [
|
| 1254 |
"# @dataclass\n",
|
| 1255 |
"class DDPM21CM:\n",
|
|
@@ -1257,6 +1289,8 @@
|
|
| 1257 |
" self.config = config\n",
|
| 1258 |
"\n",
|
| 1259 |
" dataset = Dataset4h5(config.dataset_name, num_image=config.num_image, HII_DIM=config.HII_DIM, num_redshift=config.num_redshift, drop_prob=config.drop_prob, dim=config.dim)\n",
|
|
|
|
|
|
|
| 1260 |
" self.dataloader = DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True)\n",
|
| 1261 |
" del dataset\n",
|
| 1262 |
"\n",
|
|
@@ -1386,8 +1420,15 @@
|
|
| 1386 |
" print('saved model at ' + self.config.output_dir + f\"model_state_{ep:02d}.pth\")\n",
|
| 1387 |
" # print('saved model at ' + config.save_dir + f\"model_epoch_{ep}_test_{config.run_name}.pth\")\n",
|
| 1388 |
"\n",
|
| 1389 |
-
" def sample(self):\n",
|
| 1390 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1391 |
"\n",
|
| 1392 |
"ddpm21cm = DDPM21CM(TrainConfig())\n",
|
| 1393 |
"# print(\"device =\", config.device)"
|
|
@@ -1395,13 +1436,13 @@
|
|
| 1395 |
},
|
| 1396 |
{
|
| 1397 |
"cell_type": "code",
|
| 1398 |
-
"execution_count":
|
| 1399 |
"metadata": {},
|
| 1400 |
"outputs": [
|
| 1401 |
{
|
| 1402 |
"data": {
|
| 1403 |
"application/vnd.jupyter.widget-view+json": {
|
| 1404 |
-
"model_id": "
|
| 1405 |
"version_major": 2,
|
| 1406 |
"version_minor": 0
|
| 1407 |
},
|
|
@@ -1413,23 +1454,75 @@
|
|
| 1413 |
"output_type": "display_data"
|
| 1414 |
},
|
| 1415 |
{
|
| 1416 |
-
"
|
| 1417 |
-
|
| 1418 |
-
|
| 1419 |
-
|
| 1420 |
-
|
| 1421 |
-
|
| 1422 |
-
"
|
| 1423 |
-
|
| 1424 |
-
|
| 1425 |
-
|
| 1426 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1427 |
}
|
| 1428 |
],
|
| 1429 |
"source": [
|
| 1430 |
"ddpm21cm.train()"
|
| 1431 |
]
|
| 1432 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1433 |
{
|
| 1434 |
"cell_type": "code",
|
| 1435 |
"execution_count": 45,
|
|
|
|
| 1053 |
"metadata": {},
|
| 1054 |
"outputs": [],
|
| 1055 |
"source": [
|
| 1056 |
+
"def train_loop(config, nn_model, ddpm, optimizer, dataloader, lr_scheduler): \n",
|
| 1057 |
+
" ########################\n",
|
| 1058 |
+
" ## ready for training ##\n",
|
| 1059 |
+
" ########################\n",
|
| 1060 |
+
" # initialize the dataset\n",
|
| 1061 |
+
" # num_image = 600\n",
|
| 1062 |
+
" # HII_DIM = 64\n",
|
| 1063 |
+
" # num_redshift = 64#512#128\n",
|
| 1064 |
+
" # dataset = Dataset4h5(config.dataset_name, num_image=config.num_image, HII_DIM=config.HII_DIM, num_redshift=config.num_redshift, drop_prob=config.drop_prob)\n",
|
| 1065 |
+
" # dataloader = DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True)\n",
|
| 1066 |
+
" # tf = transforms.Compose([transforms.ToTensor()]) # mnist is already normalised 0 to 1\n",
|
| 1067 |
+
" # dataset = MNIST(\"./data\", train=True, download=True, transform=tf)\n",
|
| 1068 |
+
" # dataloader = DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True, num_workers=5) # Initialize accelerator and tensorboard logging\n",
|
| 1069 |
+
" accelerator = Accelerator(\n",
|
| 1070 |
+
" mixed_precision=config.mixed_precision,\n",
|
| 1071 |
+
" gradient_accumulation_steps=config.gradient_accumulation_steps,\n",
|
| 1072 |
+
" log_with=\"tensorboard\",\n",
|
| 1073 |
+
" project_dir=os.path.join(config.output_dir, \"logs\"),\n",
|
| 1074 |
+
" )\n",
|
| 1075 |
+
" if accelerator.is_main_process:\n",
|
| 1076 |
+
" if config.output_dir is not None:\n",
|
| 1077 |
+
" os.makedirs(config.output_dir, exist_ok=True)\n",
|
| 1078 |
+
" if config.push_to_hub:\n",
|
| 1079 |
+
" repo_id = create_repo(\n",
|
| 1080 |
+
" repo_id=config.hub_model_id or Path(config.output_dir).name, exist_ok=True\n",
|
| 1081 |
+
" ).repo_id\n",
|
| 1082 |
+
" accelerator.init_trackers(f\"{config.date}\")\n",
|
| 1083 |
+
"\n",
|
| 1084 |
+
" nn_model, optimizer, dataloader, lr_scheduler = accelerator.prepare(\n",
|
| 1085 |
+
" nn_model, optimizer, dataloader, lr_scheduler)\n",
|
| 1086 |
" \n",
|
| 1087 |
+
" # initialize the DDPM\n",
|
| 1088 |
+
" # logger = SummaryWriter(os.path.join(\"runs\", config.run_name)) # To log\n",
|
| 1089 |
+
"\n",
|
| 1090 |
+
" # ddpm = DDPMScheduler(betas=(1e-4, 0.02), num_timesteps=config.num_timesteps, device=config.device)\n",
|
| 1091 |
+
"\n",
|
| 1092 |
+
" # # initialize the unet\n",
|
| 1093 |
+
" # nn_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM)\n",
|
| 1094 |
+
" # # nn_model = ContextUnet(n_param=1, image_size=28)\n",
|
| 1095 |
+
" # nn_model.train()\n",
|
| 1096 |
+
" # nn_model.to(ddpm.device)\n",
|
| 1097 |
+
"\n",
|
| 1098 |
+
" # parameters to be optimized\n",
|
| 1099 |
+
" # params_to_optimize = [\n",
|
| 1100 |
+
" # {'params': nn_model.parameters()}\n",
|
| 1101 |
+
" # ]\n",
|
| 1102 |
+
"\n",
|
| 1103 |
+
" # number of parameters to be trained\n",
|
| 1104 |
+
" number_of_params = sum(x.numel() for x in nn_model.parameters())\n",
|
| 1105 |
+
" print(f\"Number of parameters for unet: {number_of_params}\")\n",
|
| 1106 |
+
"\n",
|
| 1107 |
+
" # # optionally load a model\n",
|
| 1108 |
+
" # if config.resume:\n",
|
| 1109 |
+
" # ddpm.load_state_dict(torch.load(os.path.join(config.save_dir, f\"train-{ep}xscale_test_{run_name}.npy\")))\n",
|
| 1110 |
+
"\n",
|
| 1111 |
+
" # define the loss function\n",
|
| 1112 |
+
" loss_mse = nn.MSELoss()\n",
|
| 1113 |
+
"\n",
|
| 1114 |
+
"\n",
|
| 1115 |
+
" # initialize optimizer\n",
|
| 1116 |
+
" # optim = torch.optim.Adam(params_to_optimize, lr=config.lrate)\n",
|
| 1117 |
+
"\n",
|
| 1118 |
+
" # whether to use ema\n",
|
| 1119 |
+
" if config.ema:\n",
|
| 1120 |
+
" ema = EMA(config.ema_rate)\n",
|
| 1121 |
+
" if config.resume:\n",
|
| 1122 |
+
" print(\"resuming ema_model\")\n",
|
| 1123 |
+
" # ema_model = DDPMScheduler(betas=(1e-4, 0.02), num_timesteps=config.num_timesteps, device=config.device)\n",
|
| 1124 |
+
" ema_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM).to(config.device)\n",
|
| 1125 |
+
" # print(\"ema_model.device =\", ema_model.device)\n",
|
| 1126 |
+
" ema_model.load_state_dict(torch.load(os.path.join(config.output_dir, f\"model_state.pth\"))['ema_unet_state_dict'])\n",
|
| 1127 |
+
" # ema_model.load_state_dict(torch.load(os.path.join(config.output_dir, f\"train-{ep}xscale_test_{config.run_name}_ema.npy\")))\n",
|
| 1128 |
+
" else:\n",
|
| 1129 |
+
" ema_model = copy.deepcopy(nn_model).eval().requires_grad_(False)\n",
|
| 1130 |
+
"\n",
|
| 1131 |
+
" ################### \n",
|
| 1132 |
+
" ## training loop ##\n",
|
| 1133 |
+
" ###################\n",
|
| 1134 |
+
" # plot_unet = True\n",
|
| 1135 |
+
" global_step = 0\n",
|
| 1136 |
+
" for ep in range(config.n_epoch):\n",
|
| 1137 |
+
" # print(f'epoch {ep}')\n",
|
| 1138 |
+
" # print(\"ddpm.train()\")\n",
|
| 1139 |
+
" ddpm.train()\n",
|
| 1140 |
+
" # linear lrate decay\n",
|
| 1141 |
+
" # if config.lr_decay:\n",
|
| 1142 |
+
" # optim.param_groups[0]['lr'] = config.lrate*(1-ep/config.n_epoch)\n",
|
| 1143 |
+
"\n",
|
| 1144 |
+
" # data loader with progress bar\n",
|
| 1145 |
+
" pbar_train = tqdm(total=len(dataloader), disable=not accelerator.is_local_main_process)\n",
|
| 1146 |
+
" pbar_train.set_description(f\"Epoch {ep}\")\n",
|
| 1147 |
+
" for i, (x, c) in enumerate(dataloader):\n",
|
| 1148 |
+
" # global_step = ep * len(dataloader) + i\n",
|
| 1149 |
+
" with accelerator.accumulate(nn_model):\n",
|
| 1150 |
+
" # optim.zero_grad()\n",
|
| 1151 |
+
" x = x.to(config.device)\n",
|
| 1152 |
+
" xt, noise, ts = ddpm.add_noise(x)\n",
|
| 1153 |
+
"\n",
|
| 1154 |
+
" # noise = torch.randn(x.shape, device=x.device)\n",
|
| 1155 |
+
" # ts = torch.randint(0, num_timesteps, (x.shape[0],), device=x.device, dtype=torch.int64)\n",
|
| 1156 |
+
" # xt = ddpm.add_noise(x, noise, ts)\n",
|
| 1157 |
" \n",
|
| 1158 |
+
" if config.guide_w == -1:\n",
|
| 1159 |
+
" # noise_pred = nn_model(xt, ts, return_dict=False)[0]\n",
|
| 1160 |
+
" noise_pred = nn_model(xt, ts)\n",
|
| 1161 |
+
" else:\n",
|
| 1162 |
+
" c = c.to(config.device)\n",
|
| 1163 |
+
" noise_pred = nn_model(xt, ts, c)\n",
|
| 1164 |
" \n",
|
| 1165 |
+
" loss = loss_mse(noise, noise_pred)\n",
|
| 1166 |
+
" accelerator.backward(loss)\n",
|
| 1167 |
+
" # loss.backward()\n",
|
| 1168 |
+
" # optim.step()\n",
|
| 1169 |
+
" accelerator.clip_grad_norm_(nn_model.parameters(), 1)\n",
|
| 1170 |
+
" optimizer.step()\n",
|
| 1171 |
+
" lr_scheduler.step()\n",
|
| 1172 |
+
" optimizer.zero_grad()\n",
|
| 1173 |
+
"\n",
|
| 1174 |
+
" # ema update\n",
|
| 1175 |
+
" if config.ema:\n",
|
| 1176 |
+
" ema.step_ema(ema_model, nn_model)\n",
|
| 1177 |
+
"\n",
|
| 1178 |
+
" # pbar.set_description(f\"epoch {ep} loss {loss.item():.4f}\")\n",
|
| 1179 |
+
" pbar_train.update(1)\n",
|
| 1180 |
+
" logs = dict(\n",
|
| 1181 |
+
" loss=loss.detach().item(),\n",
|
| 1182 |
+
" lr=optimizer.param_groups[0]['lr'],\n",
|
| 1183 |
+
" step=global_step\n",
|
| 1184 |
+
" )\n",
|
| 1185 |
+
" pbar_train.set_postfix(**logs)\n",
|
| 1186 |
+
"\n",
|
| 1187 |
+
" # logging loss\n",
|
| 1188 |
+
" # logger.add_scalar(\"MSE\", loss.item(), global_step=global_step)\n",
|
| 1189 |
+
" accelerator.log(logs, step=global_step)\n",
|
| 1190 |
+
" global_step += 1\n",
|
| 1191 |
+
"\n",
|
| 1192 |
+
"\n",
|
| 1193 |
+
" if accelerator.is_main_process:\n",
|
| 1194 |
+
" # sample the image\n",
|
| 1195 |
+
" if ep == config.n_epoch-1 or (ep+1)*config.save_freq==1:\n",
|
| 1196 |
+
" nn_model.eval()\n",
|
| 1197 |
+
" with torch.no_grad():\n",
|
| 1198 |
+
" # save model\n",
|
| 1199 |
+
" if config.push_to_hub:\n",
|
| 1200 |
+
" upload_folder(\n",
|
| 1201 |
+
" repo_id = repo_id,\n",
|
| 1202 |
+
" folder_path = \".\",#config.output_dir,\n",
|
| 1203 |
+
" commit_message = f\"{config.date}\",\n",
|
| 1204 |
+
" ignore_patterns = [\"step_*\", \"epoch_*\", \"*.npy\"],\n",
|
| 1205 |
+
" )\n",
|
| 1206 |
+
" if config.save_model:\n",
|
| 1207 |
+
" model_state = {\n",
|
| 1208 |
+
" 'epoch': ep,\n",
|
| 1209 |
+
" 'unet_state_dict': nn_model.state_dict(),\n",
|
| 1210 |
+
" 'ema_unet_state_dict': ema_model.state_dict(),\n",
|
| 1211 |
+
" }\n",
|
| 1212 |
+
" torch.save(model_state, config.output_dir + f\"model_state.pth\")\n",
|
| 1213 |
+
" print('saved model at ' + config.output_dir + f\"model_state.pth\")\n",
|
| 1214 |
+
" # print('saved model at ' + config.save_dir + f\"model_epoch_{ep}_test_{config.run_name}.pth\")\n",
|
| 1215 |
+
"\n",
|
| 1216 |
+
" # loop over the guidance scale\n",
|
| 1217 |
+
" # for w in config.ws_test: \n",
|
| 1218 |
" \n",
|
| 1219 |
+
" # pipeline = DDPMPipeline(unet=nn_model, scheduler=ddpm)\n",
|
| 1220 |
+
" # evaluate(config, ep, pipeline)\n",
|
| 1221 |
"\n",
|
| 1222 |
+
" # only output the image x0, omit the stored intermediate steps, OTHERWISE, uncomment \n",
|
| 1223 |
+
" # line 142, 143 and output 'x_last, x_store = ' here.\n",
|
| 1224 |
"\n",
|
| 1225 |
+
" # x_last_tot = []\n",
|
| 1226 |
+
" x_last, x_entire = ddpm.sample(nn_model,config.n_sample, x.shape[1:], config.device, test_param=config.test_param, guide_w=config.guide_w)\n",
|
| 1227 |
"\n",
|
| 1228 |
+
" # sample_save_dir = os.path.join(config.save_dir, f\"{config.run_name}.npy\")\n",
|
| 1229 |
+
" np.save(os.path.join(config.output_dir, f\"{config.run_name}.npy\"), x_last)\n",
|
| 1230 |
+
" # np.save(os.path.join(config.save_dir, f\"{config.run_name}_entire.npy\"), x_entire)\n",
|
| 1231 |
+
" # print(f\"saved to {config.save_dir}\")\n",
|
| 1232 |
"\n",
|
| 1233 |
+
" if config.ema:\n",
|
| 1234 |
+
" # x_last_tot_ema = []\n",
|
| 1235 |
+
" x_last_ema, x_entire_ema = ddpm.sample(ema_model,config.n_sample, x.shape[1:], config.device, test_param=config.test_param, guide_w=config.guide_w)\n",
|
| 1236 |
"\n",
|
| 1237 |
+
" np.save(os.path.join(config.output_dir, f\"{config.run_name}_ema.npy\"), x_last_ema)\n",
|
| 1238 |
+
" # np.save(os.path.join(config.save_dir, f\"{config.run_name}_ema_entire.npy\"), x_entire_ema)\n",
|
| 1239 |
+
" # print(f\"saved to {config.save_dir}\")\n",
|
| 1240 |
"\n",
|
| 1241 |
+
" # x_last_tot.append(np.array(x_last.cpu()))\n",
|
| 1242 |
+
" # x_last_tot=np.array(x_last_tot)\n",
|
| 1243 |
+
" # x_last_tot_ema.append(np.array(x_last_ema.cpu()))\n",
|
| 1244 |
+
" # x_last_tot_ema=np.array(x_last_tot_ema)\n",
|
| 1245 |
"\n"
|
| 1246 |
]
|
| 1247 |
},
|
| 1248 |
{
|
| 1249 |
"cell_type": "code",
|
| 1250 |
+
"execution_count": 67,
|
| 1251 |
"metadata": {},
|
| 1252 |
+
"outputs": [
|
| 1253 |
+
{
|
| 1254 |
+
"name": "stdout",
|
| 1255 |
+
"output_type": "stream",
|
| 1256 |
+
"text": [
|
| 1257 |
+
"dataset content: <KeysViewHDF5 ['brightness_temp', 'density', 'kwargs', 'params', 'redshifts_distances', 'seeds', 'xH_box']>\n",
|
| 1258 |
+
"51200 images can be loaded\n",
|
| 1259 |
+
"field.shape = (64, 64, 514)\n",
|
| 1260 |
+
"params keys = [b'ION_Tvir_MIN', b'HII_EFF_FACTOR']\n",
|
| 1261 |
+
"loading 20 images randomly\n",
|
| 1262 |
+
"images loaded: (20, 1, 64, 512)\n",
|
| 1263 |
+
"params loaded: (20, 2)\n",
|
| 1264 |
+
"images rescaled to [-1.0, 0.8148523569107056]\n",
|
| 1265 |
+
"params rescaled to [0.0, 0.9062639012309924]\n",
|
| 1266 |
+
"resumed nn_model from model_state.pth\n",
|
| 1267 |
+
"Number of parameters for nn_model: 111048705\n"
|
| 1268 |
+
]
|
| 1269 |
+
},
|
| 1270 |
+
{
|
| 1271 |
+
"name": "stderr",
|
| 1272 |
+
"output_type": "stream",
|
| 1273 |
+
"text": [
|
| 1274 |
+
"Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n"
|
| 1275 |
+
]
|
| 1276 |
+
},
|
| 1277 |
+
{
|
| 1278 |
+
"name": "stdout",
|
| 1279 |
+
"output_type": "stream",
|
| 1280 |
+
"text": [
|
| 1281 |
+
"resumed ema_model from model_state.pth\n"
|
| 1282 |
+
]
|
| 1283 |
+
}
|
| 1284 |
+
],
|
| 1285 |
"source": [
|
| 1286 |
"# @dataclass\n",
|
| 1287 |
"class DDPM21CM:\n",
|
|
|
|
| 1289 |
" self.config = config\n",
|
| 1290 |
"\n",
|
| 1291 |
" dataset = Dataset4h5(config.dataset_name, num_image=config.num_image, HII_DIM=config.HII_DIM, num_redshift=config.num_redshift, drop_prob=config.drop_prob, dim=config.dim)\n",
|
| 1292 |
+
" self.shape_loaded = dataset.images.shape\n",
|
| 1293 |
+
" # print(\"shape_loaded =\", self.shape_loaded)\n",
|
| 1294 |
" self.dataloader = DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True)\n",
|
| 1295 |
" del dataset\n",
|
| 1296 |
"\n",
|
|
|
|
| 1420 |
" print('saved model at ' + self.config.output_dir + f\"model_state_{ep:02d}.pth\")\n",
|
| 1421 |
" # print('saved model at ' + config.save_dir + f\"model_epoch_{ep}_test_{config.run_name}.pth\")\n",
|
| 1422 |
"\n",
|
| 1423 |
+
" def sample(self, file, n_sample=12, ema=False, entire=False):\n",
|
| 1424 |
+
" model = self.ema_model if ema else self.nn_model\n",
|
| 1425 |
+
"\n",
|
| 1426 |
+
" x_last, x_entire = self.ddpm.sample(model, n_sample, self.shape_loaded[1:], self.config.device, test_param=self.config.test_param, guide_w=self.config.guide_w)\n",
|
| 1427 |
+
"\n",
|
| 1428 |
+
" np.save(os.path.join(self.config.output_dir, f\"{self.config.run_name}{'ema' if ema else None}.npy\"), x_last)\n",
|
| 1429 |
+
" if entire:\n",
|
| 1430 |
+
" np.save(os.path.join(self.config.output_dir, f\"{self.config.run_name}{'ema' if ema else None}_entire.npy\"), x_last)\n",
|
| 1431 |
+
"\n",
|
| 1432 |
"\n",
|
| 1433 |
"ddpm21cm = DDPM21CM(TrainConfig())\n",
|
| 1434 |
"# print(\"device =\", config.device)"
|
|
|
|
| 1436 |
},
|
| 1437 |
{
|
| 1438 |
"cell_type": "code",
|
| 1439 |
+
"execution_count": 68,
|
| 1440 |
"metadata": {},
|
| 1441 |
"outputs": [
|
| 1442 |
{
|
| 1443 |
"data": {
|
| 1444 |
"application/vnd.jupyter.widget-view+json": {
|
| 1445 |
+
"model_id": "4874b29272224f1aa6bcbead8dc5d11f",
|
| 1446 |
"version_major": 2,
|
| 1447 |
"version_minor": 0
|
| 1448 |
},
|
|
|
|
| 1454 |
"output_type": "display_data"
|
| 1455 |
},
|
| 1456 |
{
|
| 1457 |
+
"data": {
|
| 1458 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 1459 |
+
"model_id": "1caa087a9a7b4252a633054990ff76d8",
|
| 1460 |
+
"version_major": 2,
|
| 1461 |
+
"version_minor": 0
|
| 1462 |
+
},
|
| 1463 |
+
"text/plain": [
|
| 1464 |
+
" 0%| | 0/2 [00:00<?, ?it/s]"
|
| 1465 |
+
]
|
| 1466 |
+
},
|
| 1467 |
+
"metadata": {},
|
| 1468 |
+
"output_type": "display_data"
|
| 1469 |
+
},
|
| 1470 |
+
{
|
| 1471 |
+
"data": {
|
| 1472 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 1473 |
+
"model_id": "81644fce614d49f5aa63e291ec458ccf",
|
| 1474 |
+
"version_major": 2,
|
| 1475 |
+
"version_minor": 0
|
| 1476 |
+
},
|
| 1477 |
+
"text/plain": [
|
| 1478 |
+
" 0%| | 0/2 [00:00<?, ?it/s]"
|
| 1479 |
+
]
|
| 1480 |
+
},
|
| 1481 |
+
"metadata": {},
|
| 1482 |
+
"output_type": "display_data"
|
| 1483 |
+
},
|
| 1484 |
+
{
|
| 1485 |
+
"data": {
|
| 1486 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 1487 |
+
"model_id": "a011d7ece27e42128a6ec51227313e60",
|
| 1488 |
+
"version_major": 2,
|
| 1489 |
+
"version_minor": 0
|
| 1490 |
+
},
|
| 1491 |
+
"text/plain": [
|
| 1492 |
+
" 0%| | 0/2 [00:00<?, ?it/s]"
|
| 1493 |
+
]
|
| 1494 |
+
},
|
| 1495 |
+
"metadata": {},
|
| 1496 |
+
"output_type": "display_data"
|
| 1497 |
+
},
|
| 1498 |
+
{
|
| 1499 |
+
"data": {
|
| 1500 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 1501 |
+
"model_id": "6d2f56d5d27443e589a8cca5d45892e3",
|
| 1502 |
+
"version_major": 2,
|
| 1503 |
+
"version_minor": 0
|
| 1504 |
+
},
|
| 1505 |
+
"text/plain": [
|
| 1506 |
+
" 0%| | 0/2 [00:00<?, ?it/s]"
|
| 1507 |
+
]
|
| 1508 |
+
},
|
| 1509 |
+
"metadata": {},
|
| 1510 |
+
"output_type": "display_data"
|
| 1511 |
}
|
| 1512 |
],
|
| 1513 |
"source": [
|
| 1514 |
"ddpm21cm.train()"
|
| 1515 |
]
|
| 1516 |
},
|
| 1517 |
+
{
|
| 1518 |
+
"cell_type": "code",
|
| 1519 |
+
"execution_count": null,
|
| 1520 |
+
"metadata": {},
|
| 1521 |
+
"outputs": [],
|
| 1522 |
+
"source": [
|
| 1523 |
+
"ddpm21cm.sample()"
|
| 1524 |
+
]
|
| 1525 |
+
},
|
| 1526 |
{
|
| 1527 |
"cell_type": "code",
|
| 1528 |
"execution_count": 45,
|