Xsmos commited on
Commit
99273da
·
verified ·
1 Parent(s): 6dfab72

0521-1651

Browse files
Files changed (1) hide show
  1. diffusion.ipynb +289 -196
diffusion.ipynb CHANGED
@@ -1053,203 +1053,235 @@
1053
  "metadata": {},
1054
  "outputs": [],
1055
  "source": [
1056
- "# def train_loop(config, nn_model, ddpm, optimizer, dataloader, lr_scheduler): \n",
1057
- "# ########################\n",
1058
- "# ## ready for training ##\n",
1059
- "# ########################\n",
1060
- "# # initialize the dataset\n",
1061
- "# # num_image = 600\n",
1062
- "# # HII_DIM = 64\n",
1063
- "# # num_redshift = 64#512#128\n",
1064
- "# # dataset = Dataset4h5(config.dataset_name, num_image=config.num_image, HII_DIM=config.HII_DIM, num_redshift=config.num_redshift, drop_prob=config.drop_prob)\n",
1065
- "# # dataloader = DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True)\n",
1066
- "# # tf = transforms.Compose([transforms.ToTensor()]) # mnist is already normalised 0 to 1\n",
1067
- "# # dataset = MNIST(\"./data\", train=True, download=True, transform=tf)\n",
1068
- "# # dataloader = DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True, num_workers=5) # Initialize accelerator and tensorboard logging\n",
1069
- "# accelerator = Accelerator(\n",
1070
- "# mixed_precision=config.mixed_precision,\n",
1071
- "# gradient_accumulation_steps=config.gradient_accumulation_steps,\n",
1072
- "# log_with=\"tensorboard\",\n",
1073
- "# project_dir=os.path.join(config.output_dir, \"logs\"),\n",
1074
- "# )\n",
1075
- "# if accelerator.is_main_process:\n",
1076
- "# if config.output_dir is not None:\n",
1077
- "# os.makedirs(config.output_dir, exist_ok=True)\n",
1078
- "# if config.push_to_hub:\n",
1079
- "# repo_id = create_repo(\n",
1080
- "# repo_id=config.hub_model_id or Path(config.output_dir).name, exist_ok=True\n",
1081
- "# ).repo_id\n",
1082
- "# accelerator.init_trackers(f\"{config.date}\")\n",
1083
- "\n",
1084
- "# nn_model, optimizer, dataloader, lr_scheduler = accelerator.prepare(\n",
1085
- "# nn_model, optimizer, dataloader, lr_scheduler)\n",
1086
  " \n",
1087
- "# # initialize the DDPM\n",
1088
- "# # logger = SummaryWriter(os.path.join(\"runs\", config.run_name)) # To log\n",
1089
- "\n",
1090
- "# # ddpm = DDPMScheduler(betas=(1e-4, 0.02), num_timesteps=config.num_timesteps, device=config.device)\n",
1091
- "\n",
1092
- "# # # initialize the unet\n",
1093
- "# # nn_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM)\n",
1094
- "# # # nn_model = ContextUnet(n_param=1, image_size=28)\n",
1095
- "# # nn_model.train()\n",
1096
- "# # nn_model.to(ddpm.device)\n",
1097
- "\n",
1098
- "# # parameters to be optimized\n",
1099
- "# # params_to_optimize = [\n",
1100
- "# # {'params': nn_model.parameters()}\n",
1101
- "# # ]\n",
1102
- "\n",
1103
- "# # number of parameters to be trained\n",
1104
- "# number_of_params = sum(x.numel() for x in nn_model.parameters())\n",
1105
- "# print(f\"Number of parameters for unet: {number_of_params}\")\n",
1106
- "\n",
1107
- "# # # optionally load a model\n",
1108
- "# # if config.resume:\n",
1109
- "# # ddpm.load_state_dict(torch.load(os.path.join(config.save_dir, f\"train-{ep}xscale_test_{run_name}.npy\")))\n",
1110
- "\n",
1111
- "# # define the loss function\n",
1112
- "# loss_mse = nn.MSELoss()\n",
1113
- "\n",
1114
- "\n",
1115
- "# # initialize optimizer\n",
1116
- "# # optim = torch.optim.Adam(params_to_optimize, lr=config.lrate)\n",
1117
- "\n",
1118
- "# # whether to use ema\n",
1119
- "# if config.ema:\n",
1120
- "# ema = EMA(config.ema_rate)\n",
1121
- "# if config.resume:\n",
1122
- "# print(\"resuming ema_model\")\n",
1123
- "# # ema_model = DDPMScheduler(betas=(1e-4, 0.02), num_timesteps=config.num_timesteps, device=config.device)\n",
1124
- "# ema_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM).to(config.device)\n",
1125
- "# # print(\"ema_model.device =\", ema_model.device)\n",
1126
- "# ema_model.load_state_dict(torch.load(os.path.join(config.output_dir, f\"model_state.pth\"))['ema_unet_state_dict'])\n",
1127
- "# # ema_model.load_state_dict(torch.load(os.path.join(config.output_dir, f\"train-{ep}xscale_test_{config.run_name}_ema.npy\")))\n",
1128
- "# else:\n",
1129
- "# ema_model = copy.deepcopy(nn_model).eval().requires_grad_(False)\n",
1130
- "\n",
1131
- "# ################### \n",
1132
- "# ## training loop ##\n",
1133
- "# ###################\n",
1134
- "# # plot_unet = True\n",
1135
- "# global_step = 0\n",
1136
- "# for ep in range(config.n_epoch):\n",
1137
- "# # print(f'epoch {ep}')\n",
1138
- "# # print(\"ddpm.train()\")\n",
1139
- "# ddpm.train()\n",
1140
- "# # linear lrate decay\n",
1141
- "# # if config.lr_decay:\n",
1142
- "# # optim.param_groups[0]['lr'] = config.lrate*(1-ep/config.n_epoch)\n",
1143
- "\n",
1144
- "# # data loader with progress bar\n",
1145
- "# pbar_train = tqdm(total=len(dataloader), disable=not accelerator.is_local_main_process)\n",
1146
- "# pbar_train.set_description(f\"Epoch {ep}\")\n",
1147
- "# for i, (x, c) in enumerate(dataloader):\n",
1148
- "# # global_step = ep * len(dataloader) + i\n",
1149
- "# with accelerator.accumulate(nn_model):\n",
1150
- "# # optim.zero_grad()\n",
1151
- "# x = x.to(config.device)\n",
1152
- "# xt, noise, ts = ddpm.add_noise(x)\n",
1153
- "\n",
1154
- "# # noise = torch.randn(x.shape, device=x.device)\n",
1155
- "# # ts = torch.randint(0, num_timesteps, (x.shape[0],), device=x.device, dtype=torch.int64)\n",
1156
- "# # xt = ddpm.add_noise(x, noise, ts)\n",
1157
  " \n",
1158
- "# if config.guide_w == -1:\n",
1159
- "# # noise_pred = nn_model(xt, ts, return_dict=False)[0]\n",
1160
- "# noise_pred = nn_model(xt, ts)\n",
1161
- "# else:\n",
1162
- "# c = c.to(config.device)\n",
1163
- "# noise_pred = nn_model(xt, ts, c)\n",
1164
  " \n",
1165
- "# loss = loss_mse(noise, noise_pred)\n",
1166
- "# accelerator.backward(loss)\n",
1167
- "# # loss.backward()\n",
1168
- "# # optim.step()\n",
1169
- "# accelerator.clip_grad_norm_(nn_model.parameters(), 1)\n",
1170
- "# optimizer.step()\n",
1171
- "# lr_scheduler.step()\n",
1172
- "# optimizer.zero_grad()\n",
1173
- "\n",
1174
- "# # ema update\n",
1175
- "# if config.ema:\n",
1176
- "# ema.step_ema(ema_model, nn_model)\n",
1177
- "\n",
1178
- "# # pbar.set_description(f\"epoch {ep} loss {loss.item():.4f}\")\n",
1179
- "# pbar_train.update(1)\n",
1180
- "# logs = dict(\n",
1181
- "# loss=loss.detach().item(),\n",
1182
- "# lr=optimizer.param_groups[0]['lr'],\n",
1183
- "# step=global_step\n",
1184
- "# )\n",
1185
- "# pbar_train.set_postfix(**logs)\n",
1186
- "\n",
1187
- "# # logging loss\n",
1188
- "# # logger.add_scalar(\"MSE\", loss.item(), global_step=global_step)\n",
1189
- "# accelerator.log(logs, step=global_step)\n",
1190
- "# global_step += 1\n",
1191
- "\n",
1192
- "\n",
1193
- "# if accelerator.is_main_process:\n",
1194
- "# # sample the image\n",
1195
- "# if ep == config.n_epoch-1 or (ep+1)*config.save_freq==1:\n",
1196
- "# nn_model.eval()\n",
1197
- "# with torch.no_grad():\n",
1198
- "# # save model\n",
1199
- "# if config.push_to_hub:\n",
1200
- "# upload_folder(\n",
1201
- "# repo_id = repo_id,\n",
1202
- "# folder_path = \".\",#config.output_dir,\n",
1203
- "# commit_message = f\"{config.date}\",\n",
1204
- "# ignore_patterns = [\"step_*\", \"epoch_*\", \"*.npy\"],\n",
1205
- "# )\n",
1206
- "# if config.save_model:\n",
1207
- "# model_state = {\n",
1208
- "# 'epoch': ep,\n",
1209
- "# 'unet_state_dict': nn_model.state_dict(),\n",
1210
- "# 'ema_unet_state_dict': ema_model.state_dict(),\n",
1211
- "# }\n",
1212
- "# torch.save(model_state, config.output_dir + f\"model_state.pth\")\n",
1213
- "# print('saved model at ' + config.output_dir + f\"model_state.pth\")\n",
1214
- "# # print('saved model at ' + config.save_dir + f\"model_epoch_{ep}_test_{config.run_name}.pth\")\n",
1215
- "\n",
1216
- "# # loop over the guidance scale\n",
1217
- "# # for w in config.ws_test: \n",
1218
  " \n",
1219
- "# # pipeline = DDPMPipeline(unet=nn_model, scheduler=ddpm)\n",
1220
- "# # evaluate(config, ep, pipeline)\n",
1221
  "\n",
1222
- "# # only output the image x0, omit the stored intermediate steps, OTHERWISE, uncomment \n",
1223
- "# # line 142, 143 and output 'x_last, x_store = ' here.\n",
1224
  "\n",
1225
- "# # x_last_tot = []\n",
1226
- "# x_last, x_entire = ddpm.sample(nn_model,config.n_sample, x.shape[1:], config.device, test_param=config.test_param, guide_w=config.guide_w)\n",
1227
  "\n",
1228
- "# # sample_save_dir = os.path.join(config.save_dir, f\"{config.run_name}.npy\")\n",
1229
- "# np.save(os.path.join(config.output_dir, f\"{config.run_name}.npy\"), x_last)\n",
1230
- "# # np.save(os.path.join(config.save_dir, f\"{config.run_name}_entire.npy\"), x_entire)\n",
1231
- "# # print(f\"saved to {config.save_dir}\")\n",
1232
  "\n",
1233
- "# if config.ema:\n",
1234
- "# # x_last_tot_ema = []\n",
1235
- "# x_last_ema, x_entire_ema = ddpm.sample(ema_model,config.n_sample, x.shape[1:], config.device, test_param=config.test_param, guide_w=config.guide_w)\n",
1236
  "\n",
1237
- "# np.save(os.path.join(config.output_dir, f\"{config.run_name}_ema.npy\"), x_last_ema)\n",
1238
- "# # np.save(os.path.join(config.save_dir, f\"{config.run_name}_ema_entire.npy\"), x_entire_ema)\n",
1239
- "# # print(f\"saved to {config.save_dir}\")\n",
1240
  "\n",
1241
- "# # x_last_tot.append(np.array(x_last.cpu()))\n",
1242
- "# # x_last_tot=np.array(x_last_tot)\n",
1243
- "# # x_last_tot_ema.append(np.array(x_last_ema.cpu()))\n",
1244
- "# # x_last_tot_ema=np.array(x_last_tot_ema)\n",
1245
  "\n"
1246
  ]
1247
  },
1248
  {
1249
  "cell_type": "code",
1250
- "execution_count": null,
1251
  "metadata": {},
1252
- "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1253
  "source": [
1254
  "# @dataclass\n",
1255
  "class DDPM21CM:\n",
@@ -1257,6 +1289,8 @@
1257
  " self.config = config\n",
1258
  "\n",
1259
  " dataset = Dataset4h5(config.dataset_name, num_image=config.num_image, HII_DIM=config.HII_DIM, num_redshift=config.num_redshift, drop_prob=config.drop_prob, dim=config.dim)\n",
 
 
1260
  " self.dataloader = DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True)\n",
1261
  " del dataset\n",
1262
  "\n",
@@ -1386,8 +1420,15 @@
1386
  " print('saved model at ' + self.config.output_dir + f\"model_state_{ep:02d}.pth\")\n",
1387
  " # print('saved model at ' + config.save_dir + f\"model_epoch_{ep}_test_{config.run_name}.pth\")\n",
1388
  "\n",
1389
- " def sample(self):\n",
1390
- " pass\n",
 
 
 
 
 
 
 
1391
  "\n",
1392
  "ddpm21cm = DDPM21CM(TrainConfig())\n",
1393
  "# print(\"device =\", config.device)"
@@ -1395,13 +1436,13 @@
1395
  },
1396
  {
1397
  "cell_type": "code",
1398
- "execution_count": null,
1399
  "metadata": {},
1400
  "outputs": [
1401
  {
1402
  "data": {
1403
  "application/vnd.jupyter.widget-view+json": {
1404
- "model_id": "b900cc8b147649d89832371147da66f1",
1405
  "version_major": 2,
1406
  "version_minor": 0
1407
  },
@@ -1413,23 +1454,75 @@
1413
  "output_type": "display_data"
1414
  },
1415
  {
1416
- "ename": "AttributeError",
1417
- "evalue": "'TrainConfig' object has no attribute 'save_freq'",
1418
- "output_type": "error",
1419
- "traceback": [
1420
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
1421
- "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
1422
- "Cell \u001b[0;32mIn[56], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m ddpm21cm\u001b[39m.\u001b[39;49mtrain()\n",
1423
- "Cell \u001b[0;32mIn[55], line 111\u001b[0m, in \u001b[0;36mDDPM21CM.train\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 108\u001b[0m global_step \u001b[39m+\u001b[39m\u001b[39m=\u001b[39m \u001b[39m1\u001b[39m\n\u001b[1;32m 110\u001b[0m \u001b[39m# if ep == config.n_epoch-1 or (ep+1)*config.save_freq==1:\u001b[39;00m\n\u001b[0;32m--> 111\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49msave(ep)\n",
1424
- "Cell \u001b[0;32mIn[55], line 116\u001b[0m, in \u001b[0;36mDDPM21CM.save\u001b[0;34m(self, ep)\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39msave\u001b[39m(\u001b[39mself\u001b[39m, ep):\n\u001b[1;32m 114\u001b[0m \u001b[39m# save model\u001b[39;00m\n\u001b[1;32m 115\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39maccelerator\u001b[39m.\u001b[39mis_main_process:\n\u001b[0;32m--> 116\u001b[0m \u001b[39mif\u001b[39;00m ep \u001b[39m==\u001b[39m config\u001b[39m.\u001b[39mn_epoch\u001b[39m-\u001b[39m\u001b[39m1\u001b[39m \u001b[39mor\u001b[39;00m (ep\u001b[39m+\u001b[39m\u001b[39m1\u001b[39m)\u001b[39m*\u001b[39mconfig\u001b[39m.\u001b[39;49msave_freq\u001b[39m==\u001b[39m\u001b[39m1\u001b[39m:\n\u001b[1;32m 117\u001b[0m nn_model\u001b[39m.\u001b[39meval()\n\u001b[1;32m 118\u001b[0m \u001b[39mwith\u001b[39;00m torch\u001b[39m.\u001b[39mno_grad():\n",
1425
- "\u001b[0;31mAttributeError\u001b[0m: 'TrainConfig' object has no attribute 'save_freq'"
1426
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1427
  }
1428
  ],
1429
  "source": [
1430
  "ddpm21cm.train()"
1431
  ]
1432
  },
 
 
 
 
 
 
 
 
 
1433
  {
1434
  "cell_type": "code",
1435
  "execution_count": 45,
 
1053
  "metadata": {},
1054
  "outputs": [],
1055
  "source": [
1056
+ "def train_loop(config, nn_model, ddpm, optimizer, dataloader, lr_scheduler): \n",
1057
+ " ########################\n",
1058
+ " ## ready for training ##\n",
1059
+ " ########################\n",
1060
+ " # initialize the dataset\n",
1061
+ " # num_image = 600\n",
1062
+ " # HII_DIM = 64\n",
1063
+ " # num_redshift = 64#512#128\n",
1064
+ " # dataset = Dataset4h5(config.dataset_name, num_image=config.num_image, HII_DIM=config.HII_DIM, num_redshift=config.num_redshift, drop_prob=config.drop_prob)\n",
1065
+ " # dataloader = DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True)\n",
1066
+ " # tf = transforms.Compose([transforms.ToTensor()]) # mnist is already normalised 0 to 1\n",
1067
+ " # dataset = MNIST(\"./data\", train=True, download=True, transform=tf)\n",
1068
+ " # dataloader = DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True, num_workers=5) # Initialize accelerator and tensorboard logging\n",
1069
+ " accelerator = Accelerator(\n",
1070
+ " mixed_precision=config.mixed_precision,\n",
1071
+ " gradient_accumulation_steps=config.gradient_accumulation_steps,\n",
1072
+ " log_with=\"tensorboard\",\n",
1073
+ " project_dir=os.path.join(config.output_dir, \"logs\"),\n",
1074
+ " )\n",
1075
+ " if accelerator.is_main_process:\n",
1076
+ " if config.output_dir is not None:\n",
1077
+ " os.makedirs(config.output_dir, exist_ok=True)\n",
1078
+ " if config.push_to_hub:\n",
1079
+ " repo_id = create_repo(\n",
1080
+ " repo_id=config.hub_model_id or Path(config.output_dir).name, exist_ok=True\n",
1081
+ " ).repo_id\n",
1082
+ " accelerator.init_trackers(f\"{config.date}\")\n",
1083
+ "\n",
1084
+ " nn_model, optimizer, dataloader, lr_scheduler = accelerator.prepare(\n",
1085
+ " nn_model, optimizer, dataloader, lr_scheduler)\n",
1086
  " \n",
1087
+ " # initialize the DDPM\n",
1088
+ " # logger = SummaryWriter(os.path.join(\"runs\", config.run_name)) # To log\n",
1089
+ "\n",
1090
+ " # ddpm = DDPMScheduler(betas=(1e-4, 0.02), num_timesteps=config.num_timesteps, device=config.device)\n",
1091
+ "\n",
1092
+ " # # initialize the unet\n",
1093
+ " # nn_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM)\n",
1094
+ " # # nn_model = ContextUnet(n_param=1, image_size=28)\n",
1095
+ " # nn_model.train()\n",
1096
+ " # nn_model.to(ddpm.device)\n",
1097
+ "\n",
1098
+ " # parameters to be optimized\n",
1099
+ " # params_to_optimize = [\n",
1100
+ " # {'params': nn_model.parameters()}\n",
1101
+ " # ]\n",
1102
+ "\n",
1103
+ " # number of parameters to be trained\n",
1104
+ " number_of_params = sum(x.numel() for x in nn_model.parameters())\n",
1105
+ " print(f\"Number of parameters for unet: {number_of_params}\")\n",
1106
+ "\n",
1107
+ " # # optionally load a model\n",
1108
+ " # if config.resume:\n",
1109
+ " # ddpm.load_state_dict(torch.load(os.path.join(config.save_dir, f\"train-{ep}xscale_test_{run_name}.npy\")))\n",
1110
+ "\n",
1111
+ " # define the loss function\n",
1112
+ " loss_mse = nn.MSELoss()\n",
1113
+ "\n",
1114
+ "\n",
1115
+ " # initialize optimizer\n",
1116
+ " # optim = torch.optim.Adam(params_to_optimize, lr=config.lrate)\n",
1117
+ "\n",
1118
+ " # whether to use ema\n",
1119
+ " if config.ema:\n",
1120
+ " ema = EMA(config.ema_rate)\n",
1121
+ " if config.resume:\n",
1122
+ " print(\"resuming ema_model\")\n",
1123
+ " # ema_model = DDPMScheduler(betas=(1e-4, 0.02), num_timesteps=config.num_timesteps, device=config.device)\n",
1124
+ " ema_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM).to(config.device)\n",
1125
+ " # print(\"ema_model.device =\", ema_model.device)\n",
1126
+ " ema_model.load_state_dict(torch.load(os.path.join(config.output_dir, f\"model_state.pth\"))['ema_unet_state_dict'])\n",
1127
+ " # ema_model.load_state_dict(torch.load(os.path.join(config.output_dir, f\"train-{ep}xscale_test_{config.run_name}_ema.npy\")))\n",
1128
+ " else:\n",
1129
+ " ema_model = copy.deepcopy(nn_model).eval().requires_grad_(False)\n",
1130
+ "\n",
1131
+ " ################### \n",
1132
+ " ## training loop ##\n",
1133
+ " ###################\n",
1134
+ " # plot_unet = True\n",
1135
+ " global_step = 0\n",
1136
+ " for ep in range(config.n_epoch):\n",
1137
+ " # print(f'epoch {ep}')\n",
1138
+ " # print(\"ddpm.train()\")\n",
1139
+ " ddpm.train()\n",
1140
+ " # linear lrate decay\n",
1141
+ " # if config.lr_decay:\n",
1142
+ " # optim.param_groups[0]['lr'] = config.lrate*(1-ep/config.n_epoch)\n",
1143
+ "\n",
1144
+ " # data loader with progress bar\n",
1145
+ " pbar_train = tqdm(total=len(dataloader), disable=not accelerator.is_local_main_process)\n",
1146
+ " pbar_train.set_description(f\"Epoch {ep}\")\n",
1147
+ " for i, (x, c) in enumerate(dataloader):\n",
1148
+ " # global_step = ep * len(dataloader) + i\n",
1149
+ " with accelerator.accumulate(nn_model):\n",
1150
+ " # optim.zero_grad()\n",
1151
+ " x = x.to(config.device)\n",
1152
+ " xt, noise, ts = ddpm.add_noise(x)\n",
1153
+ "\n",
1154
+ " # noise = torch.randn(x.shape, device=x.device)\n",
1155
+ " # ts = torch.randint(0, num_timesteps, (x.shape[0],), device=x.device, dtype=torch.int64)\n",
1156
+ " # xt = ddpm.add_noise(x, noise, ts)\n",
1157
  " \n",
1158
+ " if config.guide_w == -1:\n",
1159
+ " # noise_pred = nn_model(xt, ts, return_dict=False)[0]\n",
1160
+ " noise_pred = nn_model(xt, ts)\n",
1161
+ " else:\n",
1162
+ " c = c.to(config.device)\n",
1163
+ " noise_pred = nn_model(xt, ts, c)\n",
1164
  " \n",
1165
+ " loss = loss_mse(noise, noise_pred)\n",
1166
+ " accelerator.backward(loss)\n",
1167
+ " # loss.backward()\n",
1168
+ " # optim.step()\n",
1169
+ " accelerator.clip_grad_norm_(nn_model.parameters(), 1)\n",
1170
+ " optimizer.step()\n",
1171
+ " lr_scheduler.step()\n",
1172
+ " optimizer.zero_grad()\n",
1173
+ "\n",
1174
+ " # ema update\n",
1175
+ " if config.ema:\n",
1176
+ " ema.step_ema(ema_model, nn_model)\n",
1177
+ "\n",
1178
+ " # pbar.set_description(f\"epoch {ep} loss {loss.item():.4f}\")\n",
1179
+ " pbar_train.update(1)\n",
1180
+ " logs = dict(\n",
1181
+ " loss=loss.detach().item(),\n",
1182
+ " lr=optimizer.param_groups[0]['lr'],\n",
1183
+ " step=global_step\n",
1184
+ " )\n",
1185
+ " pbar_train.set_postfix(**logs)\n",
1186
+ "\n",
1187
+ " # logging loss\n",
1188
+ " # logger.add_scalar(\"MSE\", loss.item(), global_step=global_step)\n",
1189
+ " accelerator.log(logs, step=global_step)\n",
1190
+ " global_step += 1\n",
1191
+ "\n",
1192
+ "\n",
1193
+ " if accelerator.is_main_process:\n",
1194
+ " # sample the image\n",
1195
+ " if ep == config.n_epoch-1 or (ep+1)*config.save_freq==1:\n",
1196
+ " nn_model.eval()\n",
1197
+ " with torch.no_grad():\n",
1198
+ " # save model\n",
1199
+ " if config.push_to_hub:\n",
1200
+ " upload_folder(\n",
1201
+ " repo_id = repo_id,\n",
1202
+ " folder_path = \".\",#config.output_dir,\n",
1203
+ " commit_message = f\"{config.date}\",\n",
1204
+ " ignore_patterns = [\"step_*\", \"epoch_*\", \"*.npy\"],\n",
1205
+ " )\n",
1206
+ " if config.save_model:\n",
1207
+ " model_state = {\n",
1208
+ " 'epoch': ep,\n",
1209
+ " 'unet_state_dict': nn_model.state_dict(),\n",
1210
+ " 'ema_unet_state_dict': ema_model.state_dict(),\n",
1211
+ " }\n",
1212
+ " torch.save(model_state, config.output_dir + f\"model_state.pth\")\n",
1213
+ " print('saved model at ' + config.output_dir + f\"model_state.pth\")\n",
1214
+ " # print('saved model at ' + config.save_dir + f\"model_epoch_{ep}_test_{config.run_name}.pth\")\n",
1215
+ "\n",
1216
+ " # loop over the guidance scale\n",
1217
+ " # for w in config.ws_test: \n",
1218
  " \n",
1219
+ " # pipeline = DDPMPipeline(unet=nn_model, scheduler=ddpm)\n",
1220
+ " # evaluate(config, ep, pipeline)\n",
1221
  "\n",
1222
+ " # only output the image x0, omit the stored intermediate steps, OTHERWISE, uncomment \n",
1223
+ " # line 142, 143 and output 'x_last, x_store = ' here.\n",
1224
  "\n",
1225
+ " # x_last_tot = []\n",
1226
+ " x_last, x_entire = ddpm.sample(nn_model,config.n_sample, x.shape[1:], config.device, test_param=config.test_param, guide_w=config.guide_w)\n",
1227
  "\n",
1228
+ " # sample_save_dir = os.path.join(config.save_dir, f\"{config.run_name}.npy\")\n",
1229
+ " np.save(os.path.join(config.output_dir, f\"{config.run_name}.npy\"), x_last)\n",
1230
+ " # np.save(os.path.join(config.save_dir, f\"{config.run_name}_entire.npy\"), x_entire)\n",
1231
+ " # print(f\"saved to {config.save_dir}\")\n",
1232
  "\n",
1233
+ " if config.ema:\n",
1234
+ " # x_last_tot_ema = []\n",
1235
+ " x_last_ema, x_entire_ema = ddpm.sample(ema_model,config.n_sample, x.shape[1:], config.device, test_param=config.test_param, guide_w=config.guide_w)\n",
1236
  "\n",
1237
+ " np.save(os.path.join(config.output_dir, f\"{config.run_name}_ema.npy\"), x_last_ema)\n",
1238
+ " # np.save(os.path.join(config.save_dir, f\"{config.run_name}_ema_entire.npy\"), x_entire_ema)\n",
1239
+ " # print(f\"saved to {config.save_dir}\")\n",
1240
  "\n",
1241
+ " # x_last_tot.append(np.array(x_last.cpu()))\n",
1242
+ " # x_last_tot=np.array(x_last_tot)\n",
1243
+ " # x_last_tot_ema.append(np.array(x_last_ema.cpu()))\n",
1244
+ " # x_last_tot_ema=np.array(x_last_tot_ema)\n",
1245
  "\n"
1246
  ]
1247
  },
1248
  {
1249
  "cell_type": "code",
1250
+ "execution_count": 67,
1251
  "metadata": {},
1252
+ "outputs": [
1253
+ {
1254
+ "name": "stdout",
1255
+ "output_type": "stream",
1256
+ "text": [
1257
+ "dataset content: <KeysViewHDF5 ['brightness_temp', 'density', 'kwargs', 'params', 'redshifts_distances', 'seeds', 'xH_box']>\n",
1258
+ "51200 images can be loaded\n",
1259
+ "field.shape = (64, 64, 514)\n",
1260
+ "params keys = [b'ION_Tvir_MIN', b'HII_EFF_FACTOR']\n",
1261
+ "loading 20 images randomly\n",
1262
+ "images loaded: (20, 1, 64, 512)\n",
1263
+ "params loaded: (20, 2)\n",
1264
+ "images rescaled to [-1.0, 0.8148523569107056]\n",
1265
+ "params rescaled to [0.0, 0.9062639012309924]\n",
1266
+ "resumed nn_model from model_state.pth\n",
1267
+ "Number of parameters for nn_model: 111048705\n"
1268
+ ]
1269
+ },
1270
+ {
1271
+ "name": "stderr",
1272
+ "output_type": "stream",
1273
+ "text": [
1274
+ "Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n"
1275
+ ]
1276
+ },
1277
+ {
1278
+ "name": "stdout",
1279
+ "output_type": "stream",
1280
+ "text": [
1281
+ "resumed ema_model from model_state.pth\n"
1282
+ ]
1283
+ }
1284
+ ],
1285
  "source": [
1286
  "# @dataclass\n",
1287
  "class DDPM21CM:\n",
 
1289
  " self.config = config\n",
1290
  "\n",
1291
  " dataset = Dataset4h5(config.dataset_name, num_image=config.num_image, HII_DIM=config.HII_DIM, num_redshift=config.num_redshift, drop_prob=config.drop_prob, dim=config.dim)\n",
1292
+ " self.shape_loaded = dataset.images.shape\n",
1293
+ " # print(\"shape_loaded =\", self.shape_loaded)\n",
1294
  " self.dataloader = DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True)\n",
1295
  " del dataset\n",
1296
  "\n",
 
1420
  " print('saved model at ' + self.config.output_dir + f\"model_state_{ep:02d}.pth\")\n",
1421
  " # print('saved model at ' + config.save_dir + f\"model_epoch_{ep}_test_{config.run_name}.pth\")\n",
1422
  "\n",
1423
+ " def sample(self, file, n_sample=12, ema=False, entire=False):\n",
1424
+ " model = self.ema_model if ema else self.nn_model\n",
1425
+ "\n",
1426
+ " x_last, x_entire = self.ddpm.sample(model, n_sample, self.shape_loaded[1:], self.config.device, test_param=self.config.test_param, guide_w=self.config.guide_w)\n",
1427
+ "\n",
1428
+ " np.save(os.path.join(self.config.output_dir, f\"{self.config.run_name}{'ema' if ema else None}.npy\"), x_last)\n",
1429
+ " if entire:\n",
1430
+ " np.save(os.path.join(self.config.output_dir, f\"{self.config.run_name}{'ema' if ema else None}_entire.npy\"), x_last)\n",
1431
+ "\n",
1432
  "\n",
1433
  "ddpm21cm = DDPM21CM(TrainConfig())\n",
1434
  "# print(\"device =\", config.device)"
 
1436
  },
1437
  {
1438
  "cell_type": "code",
1439
+ "execution_count": 68,
1440
  "metadata": {},
1441
  "outputs": [
1442
  {
1443
  "data": {
1444
  "application/vnd.jupyter.widget-view+json": {
1445
+ "model_id": "4874b29272224f1aa6bcbead8dc5d11f",
1446
  "version_major": 2,
1447
  "version_minor": 0
1448
  },
 
1454
  "output_type": "display_data"
1455
  },
1456
  {
1457
+ "data": {
1458
+ "application/vnd.jupyter.widget-view+json": {
1459
+ "model_id": "1caa087a9a7b4252a633054990ff76d8",
1460
+ "version_major": 2,
1461
+ "version_minor": 0
1462
+ },
1463
+ "text/plain": [
1464
+ " 0%| | 0/2 [00:00<?, ?it/s]"
1465
+ ]
1466
+ },
1467
+ "metadata": {},
1468
+ "output_type": "display_data"
1469
+ },
1470
+ {
1471
+ "data": {
1472
+ "application/vnd.jupyter.widget-view+json": {
1473
+ "model_id": "81644fce614d49f5aa63e291ec458ccf",
1474
+ "version_major": 2,
1475
+ "version_minor": 0
1476
+ },
1477
+ "text/plain": [
1478
+ " 0%| | 0/2 [00:00<?, ?it/s]"
1479
+ ]
1480
+ },
1481
+ "metadata": {},
1482
+ "output_type": "display_data"
1483
+ },
1484
+ {
1485
+ "data": {
1486
+ "application/vnd.jupyter.widget-view+json": {
1487
+ "model_id": "a011d7ece27e42128a6ec51227313e60",
1488
+ "version_major": 2,
1489
+ "version_minor": 0
1490
+ },
1491
+ "text/plain": [
1492
+ " 0%| | 0/2 [00:00<?, ?it/s]"
1493
+ ]
1494
+ },
1495
+ "metadata": {},
1496
+ "output_type": "display_data"
1497
+ },
1498
+ {
1499
+ "data": {
1500
+ "application/vnd.jupyter.widget-view+json": {
1501
+ "model_id": "6d2f56d5d27443e589a8cca5d45892e3",
1502
+ "version_major": 2,
1503
+ "version_minor": 0
1504
+ },
1505
+ "text/plain": [
1506
+ " 0%| | 0/2 [00:00<?, ?it/s]"
1507
+ ]
1508
+ },
1509
+ "metadata": {},
1510
+ "output_type": "display_data"
1511
  }
1512
  ],
1513
  "source": [
1514
  "ddpm21cm.train()"
1515
  ]
1516
  },
1517
+ {
1518
+ "cell_type": "code",
1519
+ "execution_count": null,
1520
+ "metadata": {},
1521
+ "outputs": [],
1522
+ "source": [
1523
+ "ddpm21cm.sample()"
1524
+ ]
1525
+ },
1526
  {
1527
  "cell_type": "code",
1528
  "execution_count": 45,