Xsmos commited on
Commit
d1e5b1a
·
verified ·
1 Parent(s): 99273da
Files changed (1) hide show
  1. diffusion.ipynb +279 -228
diffusion.ipynb CHANGED
@@ -234,7 +234,7 @@
234
  },
235
  {
236
  "cell_type": "code",
237
- "execution_count": 7,
238
  "metadata": {},
239
  "outputs": [],
240
  "source": [
@@ -256,30 +256,32 @@
256
  " self.bar_alpha_t = torch.cumprod(self.alpha_t, dim=0)\n",
257
  "\n",
258
  " def add_noise(self, clean_images):\n",
259
- " shape = clean_images.shape\n",
260
- " expand = torch.ones(len(shape)-1, dtype=int)\n",
261
  " # ts_expand = ts.view(ts.shape[0], *expand.tolist())\n",
262
  " # expand = [1 for i in range(len(shape)-1)]\n",
263
  "\n",
264
  " noise = torch.randn_like(clean_images).to(self.device)\n",
265
- " ts = torch.randint(0, self.num_timesteps, (shape[0],)).to(self.device)\n",
266
  " \n",
267
  " # test_expand = test.view(test.shape[0],*expand)\n",
268
  " # extend_dim = [None for i in range(shape.dim()-1)]\n",
269
  " noisy_images = (\n",
270
- " clean_images * torch.sqrt(self.bar_alpha_t[ts]).view(shape[0], *expand.tolist())\n",
271
- " + noise * torch.sqrt(1-self.bar_alpha_t[ts]).view(shape[0], *expand.tolist())\n",
272
  " )\n",
273
  " # print(x_t.shape)\n",
274
  "\n",
275
  " return noisy_images, noise, ts\n",
276
  "\n",
277
- " def sample(self, nn_model, n_sample, shape, device, test_param, guide_w = 0):\n",
278
- " x_i = torch.randn(n_sample, *shape).to(device)\n",
 
 
279
  " # print(\"x_i.shape =\", x_i.shape)\n",
280
  " if guide_w != -1:\n",
281
- " c_i = test_param\n",
282
- " uncond_tokens = torch.zeros(int(n_sample), test_param.shape[1]).to(device)\n",
283
  " # uncond_tokens = torch.tensor(np.float32(np.array([0,0]))).to(device)\n",
284
  " # uncond_tokens = uncond_tokens.repeat(int(n_sample),1)\n",
285
  " c_i = torch.cat((c_i, uncond_tokens), 0)\n",
@@ -295,7 +297,7 @@
295
  " t_is = torch.tensor([i]).to(device)\n",
296
  " t_is = t_is.repeat(n_sample)\n",
297
  "\n",
298
- " z = torch.randn(n_sample, *shape).to(device) if i > 0 else 0\n",
299
  "\n",
300
  " if guide_w == -1:\n",
301
  " # eps = nn_model(x_i, t_is, return_dict=False)[0]\n",
@@ -303,7 +305,7 @@
303
  " # x_i = 1/torch.sqrt(self.alpha_t[i])*(x_i-eps*self.beta_t[i]/torch.sqrt(1-self.bar_alpha_t[i])) + torch.sqrt(self.beta_t[i])*z\n",
304
  " else:\n",
305
  " # double batch\n",
306
- " x_i = x_i.repeat(2, *torch.ones(len(shape), dtype=int).tolist())\n",
307
  " t_is = t_is.repeat(2)\n",
308
  "\n",
309
  " # split predictions and compute weighting\n",
@@ -336,7 +338,7 @@
336
  },
337
  {
338
  "cell_type": "code",
339
- "execution_count": 8,
340
  "metadata": {},
341
  "outputs": [],
342
  "source": [
@@ -378,7 +380,7 @@
378
  },
379
  {
380
  "cell_type": "code",
381
- "execution_count": 9,
382
  "metadata": {},
383
  "outputs": [],
384
  "source": [
@@ -403,7 +405,7 @@
403
  },
404
  {
405
  "cell_type": "code",
406
- "execution_count": 10,
407
  "metadata": {},
408
  "outputs": [],
409
  "source": [
@@ -432,7 +434,7 @@
432
  },
433
  {
434
  "cell_type": "code",
435
- "execution_count": 11,
436
  "metadata": {},
437
  "outputs": [],
438
  "source": [
@@ -447,7 +449,7 @@
447
  },
448
  {
449
  "cell_type": "code",
450
- "execution_count": 12,
451
  "metadata": {},
452
  "outputs": [],
453
  "source": [
@@ -461,7 +463,7 @@
461
  },
462
  {
463
  "cell_type": "code",
464
- "execution_count": 13,
465
  "metadata": {},
466
  "outputs": [],
467
  "source": [
@@ -479,7 +481,7 @@
479
  },
480
  {
481
  "cell_type": "code",
482
- "execution_count": 14,
483
  "metadata": {},
484
  "outputs": [],
485
  "source": [
@@ -560,7 +562,7 @@
560
  },
561
  {
562
  "cell_type": "code",
563
- "execution_count": 15,
564
  "metadata": {},
565
  "outputs": [],
566
  "source": [
@@ -593,7 +595,7 @@
593
  },
594
  {
595
  "cell_type": "code",
596
- "execution_count": 16,
597
  "metadata": {},
598
  "outputs": [],
599
  "source": [
@@ -642,7 +644,7 @@
642
  },
643
  {
644
  "cell_type": "code",
645
- "execution_count": 17,
646
  "metadata": {},
647
  "outputs": [],
648
  "source": [
@@ -671,7 +673,7 @@
671
  },
672
  {
673
  "cell_type": "code",
674
- "execution_count": 18,
675
  "metadata": {},
676
  "outputs": [],
677
  "source": [
@@ -913,7 +915,7 @@
913
  },
914
  {
915
  "cell_type": "code",
916
- "execution_count": 19,
917
  "metadata": {},
918
  "outputs": [],
919
  "source": [
@@ -943,7 +945,7 @@
943
  },
944
  {
945
  "cell_type": "code",
946
- "execution_count": 57,
947
  "metadata": {},
948
  "outputs": [],
949
  "source": [
@@ -967,7 +969,7 @@
967
  " n_epoch = 10#2#5#25 # 120\n",
968
  " num_timesteps = 1000#1000 # 1000, 500; DDPM time steps\n",
969
  " train_batch_size = 10#10#20#2#100 # 10\n",
970
- " n_sample = 24 # 64, the number of samples in sampling process\n",
971
  " n_param = 2\n",
972
  " guide_w = 0#-1#0#-1#0.1#[0,0.1] #[0,0.5,2] strength of generative guidance\n",
973
  " drop_prob = 0.28 # only takes effect when guide_w != -1\n",
@@ -987,9 +989,9 @@
987
  " # cond = True # if training using the conditional information\n",
988
  " # lr_decay = False #True# if using the learning rate decay\n",
989
  " resume = 'model_state.pth' # if resume from the trained checkpoints\n",
990
- " test_param_single = torch.tensor([0.2,0.80000023])\n",
991
- " test_param = torch.tile(test_param_single,(n_sample,1)).to(device)\n",
992
- " # test_param = test_param\n",
993
  " # data_dir = './data' # data directory\n",
994
  "\n",
995
  " output_dir = \"./outputs/\"\n",
@@ -1006,7 +1008,7 @@
1006
  },
1007
  {
1008
  "cell_type": "code",
1009
- "execution_count": 58,
1010
  "metadata": {},
1011
  "outputs": [],
1012
  "source": [
@@ -1016,7 +1018,7 @@
1016
  },
1017
  {
1018
  "cell_type": "code",
1019
- "execution_count": 59,
1020
  "metadata": {},
1021
  "outputs": [],
1022
  "source": [
@@ -1025,7 +1027,7 @@
1025
  },
1026
  {
1027
  "cell_type": "code",
1028
- "execution_count": 60,
1029
  "metadata": {},
1030
  "outputs": [],
1031
  "source": [
@@ -1049,205 +1051,205 @@
1049
  },
1050
  {
1051
  "cell_type": "code",
1052
- "execution_count": 61,
1053
  "metadata": {},
1054
  "outputs": [],
1055
  "source": [
1056
- "def train_loop(config, nn_model, ddpm, optimizer, dataloader, lr_scheduler): \n",
1057
- " ########################\n",
1058
- " ## ready for training ##\n",
1059
- " ########################\n",
1060
- " # initialize the dataset\n",
1061
- " # num_image = 600\n",
1062
- " # HII_DIM = 64\n",
1063
- " # num_redshift = 64#512#128\n",
1064
- " # dataset = Dataset4h5(config.dataset_name, num_image=config.num_image, HII_DIM=config.HII_DIM, num_redshift=config.num_redshift, drop_prob=config.drop_prob)\n",
1065
- " # dataloader = DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True)\n",
1066
- " # tf = transforms.Compose([transforms.ToTensor()]) # mnist is already normalised 0 to 1\n",
1067
- " # dataset = MNIST(\"./data\", train=True, download=True, transform=tf)\n",
1068
- " # dataloader = DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True, num_workers=5) # Initialize accelerator and tensorboard logging\n",
1069
- " accelerator = Accelerator(\n",
1070
- " mixed_precision=config.mixed_precision,\n",
1071
- " gradient_accumulation_steps=config.gradient_accumulation_steps,\n",
1072
- " log_with=\"tensorboard\",\n",
1073
- " project_dir=os.path.join(config.output_dir, \"logs\"),\n",
1074
- " )\n",
1075
- " if accelerator.is_main_process:\n",
1076
- " if config.output_dir is not None:\n",
1077
- " os.makedirs(config.output_dir, exist_ok=True)\n",
1078
- " if config.push_to_hub:\n",
1079
- " repo_id = create_repo(\n",
1080
- " repo_id=config.hub_model_id or Path(config.output_dir).name, exist_ok=True\n",
1081
- " ).repo_id\n",
1082
- " accelerator.init_trackers(f\"{config.date}\")\n",
1083
- "\n",
1084
- " nn_model, optimizer, dataloader, lr_scheduler = accelerator.prepare(\n",
1085
- " nn_model, optimizer, dataloader, lr_scheduler)\n",
1086
  " \n",
1087
- " # initialize the DDPM\n",
1088
- " # logger = SummaryWriter(os.path.join(\"runs\", config.run_name)) # To log\n",
1089
- "\n",
1090
- " # ddpm = DDPMScheduler(betas=(1e-4, 0.02), num_timesteps=config.num_timesteps, device=config.device)\n",
1091
- "\n",
1092
- " # # initialize the unet\n",
1093
- " # nn_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM)\n",
1094
- " # # nn_model = ContextUnet(n_param=1, image_size=28)\n",
1095
- " # nn_model.train()\n",
1096
- " # nn_model.to(ddpm.device)\n",
1097
- "\n",
1098
- " # parameters to be optimized\n",
1099
- " # params_to_optimize = [\n",
1100
- " # {'params': nn_model.parameters()}\n",
1101
- " # ]\n",
1102
- "\n",
1103
- " # number of parameters to be trained\n",
1104
- " number_of_params = sum(x.numel() for x in nn_model.parameters())\n",
1105
- " print(f\"Number of parameters for unet: {number_of_params}\")\n",
1106
- "\n",
1107
- " # # optionally load a model\n",
1108
- " # if config.resume:\n",
1109
- " # ddpm.load_state_dict(torch.load(os.path.join(config.save_dir, f\"train-{ep}xscale_test_{run_name}.npy\")))\n",
1110
- "\n",
1111
- " # define the loss function\n",
1112
- " loss_mse = nn.MSELoss()\n",
1113
- "\n",
1114
- "\n",
1115
- " # initialize optimizer\n",
1116
- " # optim = torch.optim.Adam(params_to_optimize, lr=config.lrate)\n",
1117
- "\n",
1118
- " # whether to use ema\n",
1119
- " if config.ema:\n",
1120
- " ema = EMA(config.ema_rate)\n",
1121
- " if config.resume:\n",
1122
- " print(\"resuming ema_model\")\n",
1123
- " # ema_model = DDPMScheduler(betas=(1e-4, 0.02), num_timesteps=config.num_timesteps, device=config.device)\n",
1124
- " ema_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM).to(config.device)\n",
1125
- " # print(\"ema_model.device =\", ema_model.device)\n",
1126
- " ema_model.load_state_dict(torch.load(os.path.join(config.output_dir, f\"model_state.pth\"))['ema_unet_state_dict'])\n",
1127
- " # ema_model.load_state_dict(torch.load(os.path.join(config.output_dir, f\"train-{ep}xscale_test_{config.run_name}_ema.npy\")))\n",
1128
- " else:\n",
1129
- " ema_model = copy.deepcopy(nn_model).eval().requires_grad_(False)\n",
1130
- "\n",
1131
- " ################### \n",
1132
- " ## training loop ##\n",
1133
- " ###################\n",
1134
- " # plot_unet = True\n",
1135
- " global_step = 0\n",
1136
- " for ep in range(config.n_epoch):\n",
1137
- " # print(f'epoch {ep}')\n",
1138
- " # print(\"ddpm.train()\")\n",
1139
- " ddpm.train()\n",
1140
- " # linear lrate decay\n",
1141
- " # if config.lr_decay:\n",
1142
- " # optim.param_groups[0]['lr'] = config.lrate*(1-ep/config.n_epoch)\n",
1143
- "\n",
1144
- " # data loader with progress bar\n",
1145
- " pbar_train = tqdm(total=len(dataloader), disable=not accelerator.is_local_main_process)\n",
1146
- " pbar_train.set_description(f\"Epoch {ep}\")\n",
1147
- " for i, (x, c) in enumerate(dataloader):\n",
1148
- " # global_step = ep * len(dataloader) + i\n",
1149
- " with accelerator.accumulate(nn_model):\n",
1150
- " # optim.zero_grad()\n",
1151
- " x = x.to(config.device)\n",
1152
- " xt, noise, ts = ddpm.add_noise(x)\n",
1153
- "\n",
1154
- " # noise = torch.randn(x.shape, device=x.device)\n",
1155
- " # ts = torch.randint(0, num_timesteps, (x.shape[0],), device=x.device, dtype=torch.int64)\n",
1156
- " # xt = ddpm.add_noise(x, noise, ts)\n",
1157
  " \n",
1158
- " if config.guide_w == -1:\n",
1159
- " # noise_pred = nn_model(xt, ts, return_dict=False)[0]\n",
1160
- " noise_pred = nn_model(xt, ts)\n",
1161
- " else:\n",
1162
- " c = c.to(config.device)\n",
1163
- " noise_pred = nn_model(xt, ts, c)\n",
1164
  " \n",
1165
- " loss = loss_mse(noise, noise_pred)\n",
1166
- " accelerator.backward(loss)\n",
1167
- " # loss.backward()\n",
1168
- " # optim.step()\n",
1169
- " accelerator.clip_grad_norm_(nn_model.parameters(), 1)\n",
1170
- " optimizer.step()\n",
1171
- " lr_scheduler.step()\n",
1172
- " optimizer.zero_grad()\n",
1173
- "\n",
1174
- " # ema update\n",
1175
- " if config.ema:\n",
1176
- " ema.step_ema(ema_model, nn_model)\n",
1177
- "\n",
1178
- " # pbar.set_description(f\"epoch {ep} loss {loss.item():.4f}\")\n",
1179
- " pbar_train.update(1)\n",
1180
- " logs = dict(\n",
1181
- " loss=loss.detach().item(),\n",
1182
- " lr=optimizer.param_groups[0]['lr'],\n",
1183
- " step=global_step\n",
1184
- " )\n",
1185
- " pbar_train.set_postfix(**logs)\n",
1186
- "\n",
1187
- " # logging loss\n",
1188
- " # logger.add_scalar(\"MSE\", loss.item(), global_step=global_step)\n",
1189
- " accelerator.log(logs, step=global_step)\n",
1190
- " global_step += 1\n",
1191
- "\n",
1192
- "\n",
1193
- " if accelerator.is_main_process:\n",
1194
- " # sample the image\n",
1195
- " if ep == config.n_epoch-1 or (ep+1)*config.save_freq==1:\n",
1196
- " nn_model.eval()\n",
1197
- " with torch.no_grad():\n",
1198
- " # save model\n",
1199
- " if config.push_to_hub:\n",
1200
- " upload_folder(\n",
1201
- " repo_id = repo_id,\n",
1202
- " folder_path = \".\",#config.output_dir,\n",
1203
- " commit_message = f\"{config.date}\",\n",
1204
- " ignore_patterns = [\"step_*\", \"epoch_*\", \"*.npy\"],\n",
1205
- " )\n",
1206
- " if config.save_model:\n",
1207
- " model_state = {\n",
1208
- " 'epoch': ep,\n",
1209
- " 'unet_state_dict': nn_model.state_dict(),\n",
1210
- " 'ema_unet_state_dict': ema_model.state_dict(),\n",
1211
- " }\n",
1212
- " torch.save(model_state, config.output_dir + f\"model_state.pth\")\n",
1213
- " print('saved model at ' + config.output_dir + f\"model_state.pth\")\n",
1214
- " # print('saved model at ' + config.save_dir + f\"model_epoch_{ep}_test_{config.run_name}.pth\")\n",
1215
- "\n",
1216
- " # loop over the guidance scale\n",
1217
- " # for w in config.ws_test: \n",
1218
  " \n",
1219
- " # pipeline = DDPMPipeline(unet=nn_model, scheduler=ddpm)\n",
1220
- " # evaluate(config, ep, pipeline)\n",
1221
  "\n",
1222
- " # only output the image x0, omit the stored intermediate steps, OTHERWISE, uncomment \n",
1223
- " # line 142, 143 and output 'x_last, x_store = ' here.\n",
1224
  "\n",
1225
- " # x_last_tot = []\n",
1226
- " x_last, x_entire = ddpm.sample(nn_model,config.n_sample, x.shape[1:], config.device, test_param=config.test_param, guide_w=config.guide_w)\n",
1227
  "\n",
1228
- " # sample_save_dir = os.path.join(config.save_dir, f\"{config.run_name}.npy\")\n",
1229
- " np.save(os.path.join(config.output_dir, f\"{config.run_name}.npy\"), x_last)\n",
1230
- " # np.save(os.path.join(config.save_dir, f\"{config.run_name}_entire.npy\"), x_entire)\n",
1231
- " # print(f\"saved to {config.save_dir}\")\n",
1232
  "\n",
1233
- " if config.ema:\n",
1234
- " # x_last_tot_ema = []\n",
1235
- " x_last_ema, x_entire_ema = ddpm.sample(ema_model,config.n_sample, x.shape[1:], config.device, test_param=config.test_param, guide_w=config.guide_w)\n",
1236
  "\n",
1237
- " np.save(os.path.join(config.output_dir, f\"{config.run_name}_ema.npy\"), x_last_ema)\n",
1238
- " # np.save(os.path.join(config.save_dir, f\"{config.run_name}_ema_entire.npy\"), x_entire_ema)\n",
1239
- " # print(f\"saved to {config.save_dir}\")\n",
1240
  "\n",
1241
- " # x_last_tot.append(np.array(x_last.cpu()))\n",
1242
- " # x_last_tot=np.array(x_last_tot)\n",
1243
- " # x_last_tot_ema.append(np.array(x_last_ema.cpu()))\n",
1244
- " # x_last_tot_ema=np.array(x_last_tot_ema)\n",
1245
  "\n"
1246
  ]
1247
  },
1248
  {
1249
  "cell_type": "code",
1250
- "execution_count": 67,
1251
  "metadata": {},
1252
  "outputs": [
1253
  {
@@ -1261,8 +1263,8 @@
1261
  "loading 20 images randomly\n",
1262
  "images loaded: (20, 1, 64, 512)\n",
1263
  "params loaded: (20, 2)\n",
1264
- "images rescaled to [-1.0, 0.8148523569107056]\n",
1265
- "params rescaled to [0.0, 0.9062639012309924]\n",
1266
  "resumed nn_model from model_state.pth\n",
1267
  "Number of parameters for nn_model: 111048705\n"
1268
  ]
@@ -1289,7 +1291,7 @@
1289
  " self.config = config\n",
1290
  "\n",
1291
  " dataset = Dataset4h5(config.dataset_name, num_image=config.num_image, HII_DIM=config.HII_DIM, num_redshift=config.num_redshift, drop_prob=config.drop_prob, dim=config.dim)\n",
1292
- " self.shape_loaded = dataset.images.shape\n",
1293
  " # print(\"shape_loaded =\", self.shape_loaded)\n",
1294
  " self.dataloader = DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True)\n",
1295
  " del dataset\n",
@@ -1420,10 +1422,12 @@
1420
  " print('saved model at ' + self.config.output_dir + f\"model_state_{ep:02d}.pth\")\n",
1421
  " # print('saved model at ' + config.save_dir + f\"model_epoch_{ep}_test_{config.run_name}.pth\")\n",
1422
  "\n",
1423
- " def sample(self, file, n_sample=12, ema=False, entire=False):\n",
 
1424
  " model = self.ema_model if ema else self.nn_model\n",
 
1425
  "\n",
1426
- " x_last, x_entire = self.ddpm.sample(model, n_sample, self.shape_loaded[1:], self.config.device, test_param=self.config.test_param, guide_w=self.config.guide_w)\n",
1427
  "\n",
1428
  " np.save(os.path.join(self.config.output_dir, f\"{self.config.run_name}{'ema' if ema else None}.npy\"), x_last)\n",
1429
  " if entire:\n",
@@ -1436,13 +1440,27 @@
1436
  },
1437
  {
1438
  "cell_type": "code",
1439
- "execution_count": 68,
1440
  "metadata": {},
1441
  "outputs": [
1442
  {
1443
  "data": {
1444
  "application/vnd.jupyter.widget-view+json": {
1445
- "model_id": "4874b29272224f1aa6bcbead8dc5d11f",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1446
  "version_major": 2,
1447
  "version_minor": 0
1448
  },
@@ -1456,7 +1474,7 @@
1456
  {
1457
  "data": {
1458
  "application/vnd.jupyter.widget-view+json": {
1459
- "model_id": "1caa087a9a7b4252a633054990ff76d8",
1460
  "version_major": 2,
1461
  "version_minor": 0
1462
  },
@@ -1470,7 +1488,7 @@
1470
  {
1471
  "data": {
1472
  "application/vnd.jupyter.widget-view+json": {
1473
- "model_id": "81644fce614d49f5aa63e291ec458ccf",
1474
  "version_major": 2,
1475
  "version_minor": 0
1476
  },
@@ -1484,7 +1502,7 @@
1484
  {
1485
  "data": {
1486
  "application/vnd.jupyter.widget-view+json": {
1487
- "model_id": "a011d7ece27e42128a6ec51227313e60",
1488
  "version_major": 2,
1489
  "version_minor": 0
1490
  },
@@ -1498,7 +1516,7 @@
1498
  {
1499
  "data": {
1500
  "application/vnd.jupyter.widget-view+json": {
1501
- "model_id": "6d2f56d5d27443e589a8cca5d45892e3",
1502
  "version_major": 2,
1503
  "version_minor": 0
1504
  },
@@ -1516,11 +1534,44 @@
1516
  },
1517
  {
1518
  "cell_type": "code",
1519
- "execution_count": null,
1520
  "metadata": {},
1521
- "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1522
  "source": [
1523
- "ddpm21cm.sample()"
1524
  ]
1525
  },
1526
  {
@@ -1718,7 +1769,7 @@
1718
  "\n",
1719
  "n_sample = 20\n",
1720
  "with torch.no_grad():\n",
1721
- " x_last_ema, x_ema_entire = ddpm.sample(nn_model, n_sample, (1,config.HII_DIM, config.num_redshift), config.device, test_param = torch.tile(config.test_param_single,(n_sample,1)).to(config.device), guide_w=config.guide_w)\n",
1722
  "\n",
1723
  "np.save(os.path.join(config.output_dir, f\"{config.run_name}_ema.npy\"), x_last_ema)"
1724
  ]
 
234
  },
235
  {
236
  "cell_type": "code",
237
+ "execution_count": 90,
238
  "metadata": {},
239
  "outputs": [],
240
  "source": [
 
256
  " self.bar_alpha_t = torch.cumprod(self.alpha_t, dim=0)\n",
257
  "\n",
258
  " def add_noise(self, clean_images):\n",
259
+ " self.shape = clean_images.shape\n",
260
+ " expand = torch.ones(len(self.shape)-1, dtype=int)\n",
261
  " # ts_expand = ts.view(ts.shape[0], *expand.tolist())\n",
262
  " # expand = [1 for i in range(len(shape)-1)]\n",
263
  "\n",
264
  " noise = torch.randn_like(clean_images).to(self.device)\n",
265
+ " ts = torch.randint(0, self.num_timesteps, (self.shape[0],)).to(self.device)\n",
266
  " \n",
267
  " # test_expand = test.view(test.shape[0],*expand)\n",
268
  " # extend_dim = [None for i in range(shape.dim()-1)]\n",
269
  " noisy_images = (\n",
270
+ " clean_images * torch.sqrt(self.bar_alpha_t[ts]).view(self.shape[0], *expand.tolist())\n",
271
+ " + noise * torch.sqrt(1-self.bar_alpha_t[ts]).view(self.shape[0], *expand.tolist())\n",
272
  " )\n",
273
  " # print(x_t.shape)\n",
274
  "\n",
275
  " return noisy_images, noise, ts\n",
276
  "\n",
277
+ " def sample(self, nn_model, params, device, guide_w = 0):\n",
278
+ " n_sample = params.shape[0]\n",
279
+ " print(\"params.shape[0], len(params)\", params.shape[0], len(params))\n",
280
+ " x_i = torch.randn(n_sample, *self.shape[1:]).to(device)\n",
281
  " # print(\"x_i.shape =\", x_i.shape)\n",
282
  " if guide_w != -1:\n",
283
+ " c_i = params\n",
284
+ " uncond_tokens = torch.zeros(int(n_sample), params.shape[1]).to(device)\n",
285
  " # uncond_tokens = torch.tensor(np.float32(np.array([0,0]))).to(device)\n",
286
  " # uncond_tokens = uncond_tokens.repeat(int(n_sample),1)\n",
287
  " c_i = torch.cat((c_i, uncond_tokens), 0)\n",
 
297
  " t_is = torch.tensor([i]).to(device)\n",
298
  " t_is = t_is.repeat(n_sample)\n",
299
  "\n",
300
+ " z = torch.randn(n_sample, *self.shape[1:]).to(device) if i > 0 else 0\n",
301
  "\n",
302
  " if guide_w == -1:\n",
303
  " # eps = nn_model(x_i, t_is, return_dict=False)[0]\n",
 
305
  " # x_i = 1/torch.sqrt(self.alpha_t[i])*(x_i-eps*self.beta_t[i]/torch.sqrt(1-self.bar_alpha_t[i])) + torch.sqrt(self.beta_t[i])*z\n",
306
  " else:\n",
307
  " # double batch\n",
308
+ " x_i = x_i.repeat(2, *torch.ones(len(self.shape[1:]), dtype=int).tolist())\n",
309
  " t_is = t_is.repeat(2)\n",
310
  "\n",
311
  " # split predictions and compute weighting\n",
 
338
  },
339
  {
340
  "cell_type": "code",
341
+ "execution_count": 91,
342
  "metadata": {},
343
  "outputs": [],
344
  "source": [
 
380
  },
381
  {
382
  "cell_type": "code",
383
+ "execution_count": 92,
384
  "metadata": {},
385
  "outputs": [],
386
  "source": [
 
405
  },
406
  {
407
  "cell_type": "code",
408
+ "execution_count": 93,
409
  "metadata": {},
410
  "outputs": [],
411
  "source": [
 
434
  },
435
  {
436
  "cell_type": "code",
437
+ "execution_count": 94,
438
  "metadata": {},
439
  "outputs": [],
440
  "source": [
 
449
  },
450
  {
451
  "cell_type": "code",
452
+ "execution_count": 95,
453
  "metadata": {},
454
  "outputs": [],
455
  "source": [
 
463
  },
464
  {
465
  "cell_type": "code",
466
+ "execution_count": 96,
467
  "metadata": {},
468
  "outputs": [],
469
  "source": [
 
481
  },
482
  {
483
  "cell_type": "code",
484
+ "execution_count": 97,
485
  "metadata": {},
486
  "outputs": [],
487
  "source": [
 
562
  },
563
  {
564
  "cell_type": "code",
565
+ "execution_count": 98,
566
  "metadata": {},
567
  "outputs": [],
568
  "source": [
 
595
  },
596
  {
597
  "cell_type": "code",
598
+ "execution_count": 99,
599
  "metadata": {},
600
  "outputs": [],
601
  "source": [
 
644
  },
645
  {
646
  "cell_type": "code",
647
+ "execution_count": 100,
648
  "metadata": {},
649
  "outputs": [],
650
  "source": [
 
673
  },
674
  {
675
  "cell_type": "code",
676
+ "execution_count": 101,
677
  "metadata": {},
678
  "outputs": [],
679
  "source": [
 
915
  },
916
  {
917
  "cell_type": "code",
918
+ "execution_count": 102,
919
  "metadata": {},
920
  "outputs": [],
921
  "source": [
 
945
  },
946
  {
947
  "cell_type": "code",
948
+ "execution_count": 103,
949
  "metadata": {},
950
  "outputs": [],
951
  "source": [
 
969
  " n_epoch = 10#2#5#25 # 120\n",
970
  " num_timesteps = 1000#1000 # 1000, 500; DDPM time steps\n",
971
  " train_batch_size = 10#10#20#2#100 # 10\n",
972
+ " # n_sample = 24 # 64, the number of samples in sampling process\n",
973
  " n_param = 2\n",
974
  " guide_w = 0#-1#0#-1#0.1#[0,0.1] #[0,0.5,2] strength of generative guidance\n",
975
  " drop_prob = 0.28 # only takes effect when guide_w != -1\n",
 
989
  " # cond = True # if training using the conditional information\n",
990
  " # lr_decay = False #True# if using the learning rate decay\n",
991
  " resume = 'model_state.pth' # if resume from the trained checkpoints\n",
992
+ " # params_single = torch.tensor([0.2,0.80000023])\n",
993
+ " # params = torch.tile(params_single,(n_sample,1)).to(device)\n",
994
+ " # params = params\n",
995
  " # data_dir = './data' # data directory\n",
996
  "\n",
997
  " output_dir = \"./outputs/\"\n",
 
1008
  },
1009
  {
1010
  "cell_type": "code",
1011
+ "execution_count": 104,
1012
  "metadata": {},
1013
  "outputs": [],
1014
  "source": [
 
1018
  },
1019
  {
1020
  "cell_type": "code",
1021
+ "execution_count": 105,
1022
  "metadata": {},
1023
  "outputs": [],
1024
  "source": [
 
1027
  },
1028
  {
1029
  "cell_type": "code",
1030
+ "execution_count": 106,
1031
  "metadata": {},
1032
  "outputs": [],
1033
  "source": [
 
1051
  },
1052
  {
1053
  "cell_type": "code",
1054
+ "execution_count": 107,
1055
  "metadata": {},
1056
  "outputs": [],
1057
  "source": [
1058
+ "# def train_loop(config, nn_model, ddpm, optimizer, dataloader, lr_scheduler): \n",
1059
+ "# ########################\n",
1060
+ "# ## ready for training ##\n",
1061
+ "# ########################\n",
1062
+ "# # initialize the dataset\n",
1063
+ "# # num_image = 600\n",
1064
+ "# # HII_DIM = 64\n",
1065
+ "# # num_redshift = 64#512#128\n",
1066
+ "# # dataset = Dataset4h5(config.dataset_name, num_image=config.num_image, HII_DIM=config.HII_DIM, num_redshift=config.num_redshift, drop_prob=config.drop_prob)\n",
1067
+ "# # dataloader = DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True)\n",
1068
+ "# # tf = transforms.Compose([transforms.ToTensor()]) # mnist is already normalised 0 to 1\n",
1069
+ "# # dataset = MNIST(\"./data\", train=True, download=True, transform=tf)\n",
1070
+ "# # dataloader = DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True, num_workers=5) # Initialize accelerator and tensorboard logging\n",
1071
+ "# accelerator = Accelerator(\n",
1072
+ "# mixed_precision=config.mixed_precision,\n",
1073
+ "# gradient_accumulation_steps=config.gradient_accumulation_steps,\n",
1074
+ "# log_with=\"tensorboard\",\n",
1075
+ "# project_dir=os.path.join(config.output_dir, \"logs\"),\n",
1076
+ "# )\n",
1077
+ "# if accelerator.is_main_process:\n",
1078
+ "# if config.output_dir is not None:\n",
1079
+ "# os.makedirs(config.output_dir, exist_ok=True)\n",
1080
+ "# if config.push_to_hub:\n",
1081
+ "# repo_id = create_repo(\n",
1082
+ "# repo_id=config.hub_model_id or Path(config.output_dir).name, exist_ok=True\n",
1083
+ "# ).repo_id\n",
1084
+ "# accelerator.init_trackers(f\"{config.date}\")\n",
1085
+ "\n",
1086
+ "# nn_model, optimizer, dataloader, lr_scheduler = accelerator.prepare(\n",
1087
+ "# nn_model, optimizer, dataloader, lr_scheduler)\n",
1088
  " \n",
1089
+ "# # initialize the DDPM\n",
1090
+ "# # logger = SummaryWriter(os.path.join(\"runs\", config.run_name)) # To log\n",
1091
+ "\n",
1092
+ "# # ddpm = DDPMScheduler(betas=(1e-4, 0.02), num_timesteps=config.num_timesteps, device=config.device)\n",
1093
+ "\n",
1094
+ "# # # initialize the unet\n",
1095
+ "# # nn_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM)\n",
1096
+ "# # # nn_model = ContextUnet(n_param=1, image_size=28)\n",
1097
+ "# # nn_model.train()\n",
1098
+ "# # nn_model.to(ddpm.device)\n",
1099
+ "\n",
1100
+ "# # parameters to be optimized\n",
1101
+ "# # params_to_optimize = [\n",
1102
+ "# # {'params': nn_model.parameters()}\n",
1103
+ "# # ]\n",
1104
+ "\n",
1105
+ "# # number of parameters to be trained\n",
1106
+ "# number_of_params = sum(x.numel() for x in nn_model.parameters())\n",
1107
+ "# print(f\"Number of parameters for unet: {number_of_params}\")\n",
1108
+ "\n",
1109
+ "# # # optionally load a model\n",
1110
+ "# # if config.resume:\n",
1111
+ "# # ddpm.load_state_dict(torch.load(os.path.join(config.save_dir, f\"train-{ep}xscale_test_{run_name}.npy\")))\n",
1112
+ "\n",
1113
+ "# # define the loss function\n",
1114
+ "# loss_mse = nn.MSELoss()\n",
1115
+ "\n",
1116
+ "\n",
1117
+ "# # initialize optimizer\n",
1118
+ "# # optim = torch.optim.Adam(params_to_optimize, lr=config.lrate)\n",
1119
+ "\n",
1120
+ "# # whether to use ema\n",
1121
+ "# if config.ema:\n",
1122
+ "# ema = EMA(config.ema_rate)\n",
1123
+ "# if config.resume:\n",
1124
+ "# print(\"resuming ema_model\")\n",
1125
+ "# # ema_model = DDPMScheduler(betas=(1e-4, 0.02), num_timesteps=config.num_timesteps, device=config.device)\n",
1126
+ "# ema_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM).to(config.device)\n",
1127
+ "# # print(\"ema_model.device =\", ema_model.device)\n",
1128
+ "# ema_model.load_state_dict(torch.load(os.path.join(config.output_dir, f\"model_state.pth\"))['ema_unet_state_dict'])\n",
1129
+ "# # ema_model.load_state_dict(torch.load(os.path.join(config.output_dir, f\"train-{ep}xscale_test_{config.run_name}_ema.npy\")))\n",
1130
+ "# else:\n",
1131
+ "# ema_model = copy.deepcopy(nn_model).eval().requires_grad_(False)\n",
1132
+ "\n",
1133
+ "# ################### \n",
1134
+ "# ## training loop ##\n",
1135
+ "# ###################\n",
1136
+ "# # plot_unet = True\n",
1137
+ "# global_step = 0\n",
1138
+ "# for ep in range(config.n_epoch):\n",
1139
+ "# # print(f'epoch {ep}')\n",
1140
+ "# # print(\"ddpm.train()\")\n",
1141
+ "# ddpm.train()\n",
1142
+ "# # linear lrate decay\n",
1143
+ "# # if config.lr_decay:\n",
1144
+ "# # optim.param_groups[0]['lr'] = config.lrate*(1-ep/config.n_epoch)\n",
1145
+ "\n",
1146
+ "# # data loader with progress bar\n",
1147
+ "# pbar_train = tqdm(total=len(dataloader), disable=not accelerator.is_local_main_process)\n",
1148
+ "# pbar_train.set_description(f\"Epoch {ep}\")\n",
1149
+ "# for i, (x, c) in enumerate(dataloader):\n",
1150
+ "# # global_step = ep * len(dataloader) + i\n",
1151
+ "# with accelerator.accumulate(nn_model):\n",
1152
+ "# # optim.zero_grad()\n",
1153
+ "# x = x.to(config.device)\n",
1154
+ "# xt, noise, ts = ddpm.add_noise(x)\n",
1155
+ "\n",
1156
+ "# # noise = torch.randn(x.shape, device=x.device)\n",
1157
+ "# # ts = torch.randint(0, num_timesteps, (x.shape[0],), device=x.device, dtype=torch.int64)\n",
1158
+ "# # xt = ddpm.add_noise(x, noise, ts)\n",
1159
  " \n",
1160
+ "# if config.guide_w == -1:\n",
1161
+ "# # noise_pred = nn_model(xt, ts, return_dict=False)[0]\n",
1162
+ "# noise_pred = nn_model(xt, ts)\n",
1163
+ "# else:\n",
1164
+ "# c = c.to(config.device)\n",
1165
+ "# noise_pred = nn_model(xt, ts, c)\n",
1166
  " \n",
1167
+ "# loss = loss_mse(noise, noise_pred)\n",
1168
+ "# accelerator.backward(loss)\n",
1169
+ "# # loss.backward()\n",
1170
+ "# # optim.step()\n",
1171
+ "# accelerator.clip_grad_norm_(nn_model.parameters(), 1)\n",
1172
+ "# optimizer.step()\n",
1173
+ "# lr_scheduler.step()\n",
1174
+ "# optimizer.zero_grad()\n",
1175
+ "\n",
1176
+ "# # ema update\n",
1177
+ "# if config.ema:\n",
1178
+ "# ema.step_ema(ema_model, nn_model)\n",
1179
+ "\n",
1180
+ "# # pbar.set_description(f\"epoch {ep} loss {loss.item():.4f}\")\n",
1181
+ "# pbar_train.update(1)\n",
1182
+ "# logs = dict(\n",
1183
+ "# loss=loss.detach().item(),\n",
1184
+ "# lr=optimizer.param_groups[0]['lr'],\n",
1185
+ "# step=global_step\n",
1186
+ "# )\n",
1187
+ "# pbar_train.set_postfix(**logs)\n",
1188
+ "\n",
1189
+ "# # logging loss\n",
1190
+ "# # logger.add_scalar(\"MSE\", loss.item(), global_step=global_step)\n",
1191
+ "# accelerator.log(logs, step=global_step)\n",
1192
+ "# global_step += 1\n",
1193
+ "\n",
1194
+ "\n",
1195
+ "# if accelerator.is_main_process:\n",
1196
+ "# # sample the image\n",
1197
+ "# if ep == config.n_epoch-1 or (ep+1)*config.save_freq==1:\n",
1198
+ "# nn_model.eval()\n",
1199
+ "# with torch.no_grad():\n",
1200
+ "# # save model\n",
1201
+ "# if config.push_to_hub:\n",
1202
+ "# upload_folder(\n",
1203
+ "# repo_id = repo_id,\n",
1204
+ "# folder_path = \".\",#config.output_dir,\n",
1205
+ "# commit_message = f\"{config.date}\",\n",
1206
+ "# ignore_patterns = [\"step_*\", \"epoch_*\", \"*.npy\"],\n",
1207
+ "# )\n",
1208
+ "# if config.save_model:\n",
1209
+ "# model_state = {\n",
1210
+ "# 'epoch': ep,\n",
1211
+ "# 'unet_state_dict': nn_model.state_dict(),\n",
1212
+ "# 'ema_unet_state_dict': ema_model.state_dict(),\n",
1213
+ "# }\n",
1214
+ "# torch.save(model_state, config.output_dir + f\"model_state.pth\")\n",
1215
+ "# print('saved model at ' + config.output_dir + f\"model_state.pth\")\n",
1216
+ "# # print('saved model at ' + config.save_dir + f\"model_epoch_{ep}_test_{config.run_name}.pth\")\n",
1217
+ "\n",
1218
+ "# # loop over the guidance scale\n",
1219
+ "# # for w in config.ws_test: \n",
1220
  " \n",
1221
+ "# # pipeline = DDPMPipeline(unet=nn_model, scheduler=ddpm)\n",
1222
+ "# # evaluate(config, ep, pipeline)\n",
1223
  "\n",
1224
+ "# # only output the image x0, omit the stored intermediate steps, OTHERWISE, uncomment \n",
1225
+ "# # line 142, 143 and output 'x_last, x_store = ' here.\n",
1226
  "\n",
1227
+ "# # x_last_tot = []\n",
1228
+ "# x_last, x_entire = ddpm.sample(nn_model,config.n_sample, x.shape[1:], config.device, params=config.params, guide_w=config.guide_w)\n",
1229
  "\n",
1230
+ "# # sample_save_dir = os.path.join(config.save_dir, f\"{config.run_name}.npy\")\n",
1231
+ "# np.save(os.path.join(config.output_dir, f\"{config.run_name}.npy\"), x_last)\n",
1232
+ "# # np.save(os.path.join(config.save_dir, f\"{config.run_name}_entire.npy\"), x_entire)\n",
1233
+ "# # print(f\"saved to {config.save_dir}\")\n",
1234
  "\n",
1235
+ "# if config.ema:\n",
1236
+ "# # x_last_tot_ema = []\n",
1237
+ "# x_last_ema, x_entire_ema = ddpm.sample(ema_model,config.n_sample, x.shape[1:], config.device, params=config.params, guide_w=config.guide_w)\n",
1238
  "\n",
1239
+ "# np.save(os.path.join(config.output_dir, f\"{config.run_name}_ema.npy\"), x_last_ema)\n",
1240
+ "# # np.save(os.path.join(config.save_dir, f\"{config.run_name}_ema_entire.npy\"), x_entire_ema)\n",
1241
+ "# # print(f\"saved to {config.save_dir}\")\n",
1242
  "\n",
1243
+ "# # x_last_tot.append(np.array(x_last.cpu()))\n",
1244
+ "# # x_last_tot=np.array(x_last_tot)\n",
1245
+ "# # x_last_tot_ema.append(np.array(x_last_ema.cpu()))\n",
1246
+ "# # x_last_tot_ema=np.array(x_last_tot_ema)\n",
1247
  "\n"
1248
  ]
1249
  },
1250
  {
1251
  "cell_type": "code",
1252
+ "execution_count": 109,
1253
  "metadata": {},
1254
  "outputs": [
1255
  {
 
1263
  "loading 20 images randomly\n",
1264
  "images loaded: (20, 1, 64, 512)\n",
1265
  "params loaded: (20, 2)\n",
1266
+ "images rescaled to [-1.0, 0.946401834487915]\n",
1267
+ "params rescaled to [0.0, 0.9683106587269014]\n",
1268
  "resumed nn_model from model_state.pth\n",
1269
  "Number of parameters for nn_model: 111048705\n"
1270
  ]
 
1291
  " self.config = config\n",
1292
  "\n",
1293
  " dataset = Dataset4h5(config.dataset_name, num_image=config.num_image, HII_DIM=config.HII_DIM, num_redshift=config.num_redshift, drop_prob=config.drop_prob, dim=config.dim)\n",
1294
+ " # self.shape_loaded = dataset.images.shape\n",
1295
  " # print(\"shape_loaded =\", self.shape_loaded)\n",
1296
  " self.dataloader = DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True)\n",
1297
  " del dataset\n",
 
1422
  " print('saved model at ' + self.config.output_dir + f\"model_state_{ep:02d}.pth\")\n",
1423
  " # print('saved model at ' + config.save_dir + f\"model_epoch_{ep}_test_{config.run_name}.pth\")\n",
1424
  "\n",
1425
+ " def sample(self, file, params=[0.2,0.8], ema=False, entire=False):\n",
1426
+ " n_sample = params.shape[0]\n",
1427
  " model = self.ema_model if ema else self.nn_model\n",
1428
+ " # params = torch.tile(params, (n_sample,1)).to(device)\n",
1429
  "\n",
1430
+ " x_last, x_entire = self.ddpm.sample(model, n_sample, self.config.device, params=params, guide_w=self.config.guide_w)\n",
1431
  "\n",
1432
  " np.save(os.path.join(self.config.output_dir, f\"{self.config.run_name}{'ema' if ema else None}.npy\"), x_last)\n",
1433
  " if entire:\n",
 
1440
  },
1441
  {
1442
  "cell_type": "code",
1443
+ "execution_count": 110,
1444
  "metadata": {},
1445
  "outputs": [
1446
  {
1447
  "data": {
1448
  "application/vnd.jupyter.widget-view+json": {
1449
+ "model_id": "7b26c7444a1144728f0299db5d5683b1",
1450
+ "version_major": 2,
1451
+ "version_minor": 0
1452
+ },
1453
+ "text/plain": [
1454
+ " 0%| | 0/2 [00:00<?, ?it/s]"
1455
+ ]
1456
+ },
1457
+ "metadata": {},
1458
+ "output_type": "display_data"
1459
+ },
1460
+ {
1461
+ "data": {
1462
+ "application/vnd.jupyter.widget-view+json": {
1463
+ "model_id": "e3b8c1a18460443d986282f284bf0b42",
1464
  "version_major": 2,
1465
  "version_minor": 0
1466
  },
 
1474
  {
1475
  "data": {
1476
  "application/vnd.jupyter.widget-view+json": {
1477
+ "model_id": "84d558fcea2f4b998f311da7452a2c93",
1478
  "version_major": 2,
1479
  "version_minor": 0
1480
  },
 
1488
  {
1489
  "data": {
1490
  "application/vnd.jupyter.widget-view+json": {
1491
+ "model_id": "364208bd1fa04859baf233ed30451638",
1492
  "version_major": 2,
1493
  "version_minor": 0
1494
  },
 
1502
  {
1503
  "data": {
1504
  "application/vnd.jupyter.widget-view+json": {
1505
+ "model_id": "e8aa693ca9b444d7b6adb77531f8b718",
1506
  "version_major": 2,
1507
  "version_minor": 0
1508
  },
 
1516
  {
1517
  "data": {
1518
  "application/vnd.jupyter.widget-view+json": {
1519
+ "model_id": "aab44a74df8c4a7a8baf975b8ec50b8b",
1520
  "version_major": 2,
1521
  "version_minor": 0
1522
  },
 
1534
  },
1535
  {
1536
  "cell_type": "code",
1537
+ "execution_count": 69,
1538
  "metadata": {},
1539
+ "outputs": [
1540
+ {
1541
+ "data": {
1542
+ "application/vnd.jupyter.widget-view+json": {
1543
+ "model_id": "89a5be983ade43d89a1be3d977750a40",
1544
+ "version_major": 2,
1545
+ "version_minor": 0
1546
+ },
1547
+ "text/plain": [
1548
+ " 0%| | 0/1000 [00:00<?, ?it/s]"
1549
+ ]
1550
+ },
1551
+ "metadata": {},
1552
+ "output_type": "display_data"
1553
+ },
1554
+ {
1555
+ "ename": "RuntimeError",
1556
+ "evalue": "The size of tensor a (24) must match the size of tensor b (36) at non-singleton dimension 0",
1557
+ "output_type": "error",
1558
+ "traceback": [
1559
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
1560
+ "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
1561
+ "Cell \u001b[0;32mIn[69], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m ddpm21cm\u001b[39m.\u001b[39;49msample(\u001b[39m\"\u001b[39;49m\u001b[39m./outputs/model_state_09.pth\u001b[39;49m\u001b[39m\"\u001b[39;49m)\n",
1562
+ "Cell \u001b[0;32mIn[67], line 141\u001b[0m, in \u001b[0;36mDDPM21CM.sample\u001b[0;34m(self, file, n_sample, ema, entire)\u001b[0m\n\u001b[1;32m 138\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39msample\u001b[39m(\u001b[39mself\u001b[39m, file, n_sample\u001b[39m=\u001b[39m\u001b[39m12\u001b[39m, ema\u001b[39m=\u001b[39m\u001b[39mFalse\u001b[39;00m, entire\u001b[39m=\u001b[39m\u001b[39mFalse\u001b[39;00m):\n\u001b[1;32m 139\u001b[0m model \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mema_model \u001b[39mif\u001b[39;00m ema \u001b[39melse\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mnn_model\n\u001b[0;32m--> 141\u001b[0m x_last, x_entire \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mddpm\u001b[39m.\u001b[39;49msample(model, n_sample, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mshape_loaded[\u001b[39m1\u001b[39;49m:], \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mconfig\u001b[39m.\u001b[39;49mdevice, test_param\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mconfig\u001b[39m.\u001b[39;49mtest_param, guide_w\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mconfig\u001b[39m.\u001b[39;49mguide_w)\n\u001b[1;32m 143\u001b[0m np\u001b[39m.\u001b[39msave(os\u001b[39m.\u001b[39mpath\u001b[39m.\u001b[39mjoin(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconfig\u001b[39m.\u001b[39moutput_dir, \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m{\u001b[39;00m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconfig\u001b[39m.\u001b[39mrun_name\u001b[39m}\u001b[39;00m\u001b[39m{\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39mema\u001b[39m\u001b[39m'\u001b[39m\u001b[39m \u001b[39m\u001b[39mif\u001b[39;00m\u001b[39m \u001b[39mema\u001b[39m \u001b[39m\u001b[39melse\u001b[39;00m\u001b[39m \u001b[39m\u001b[39mNone\u001b[39;00m\u001b[39m}\u001b[39;00m\u001b[39m.npy\u001b[39m\u001b[39m\"\u001b[39m), x_last)\n\u001b[1;32m 144\u001b[0m \u001b[39mif\u001b[39;00m entire:\n",
1563
+ "Cell \u001b[0;32mIn[7], line 70\u001b[0m, in \u001b[0;36mDDPMScheduler.sample\u001b[0;34m(self, nn_model, n_sample, shape, device, test_param, guide_w)\u001b[0m\n\u001b[1;32m 67\u001b[0m t_is \u001b[39m=\u001b[39m t_is\u001b[39m.\u001b[39mrepeat(\u001b[39m2\u001b[39m)\n\u001b[1;32m 69\u001b[0m \u001b[39m# split predictions and compute weighting\u001b[39;00m\n\u001b[0;32m---> 70\u001b[0m eps \u001b[39m=\u001b[39m nn_model(x_i, t_is, c_i)\n\u001b[1;32m 71\u001b[0m eps1 \u001b[39m=\u001b[39m eps[:n_sample]\n\u001b[1;32m 72\u001b[0m eps2 \u001b[39m=\u001b[39m eps[n_sample:]\n",
1564
+ "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49m\u001b[39minput\u001b[39;49m, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1131\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
1565
+ "File \u001b[0;32m~/.conda/envs/diffusers/lib/python3.9/site-packages/accelerate/utils/operations.py:822\u001b[0m, in \u001b[0;36mconvert_outputs_to_fp32.<locals>.forward\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 821\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mforward\u001b[39m(\u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs):\n\u001b[0;32m--> 822\u001b[0m \u001b[39mreturn\u001b[39;00m model_forward(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n",
1566
+ "File \u001b[0;32m~/.conda/envs/diffusers/lib/python3.9/site-packages/accelerate/utils/operations.py:810\u001b[0m, in \u001b[0;36mConvertOutputsToFp32.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 809\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__call__\u001b[39m(\u001b[39mself\u001b[39m, \u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs):\n\u001b[0;32m--> 810\u001b[0m \u001b[39mreturn\u001b[39;00m convert_to_fp32(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mmodel_forward(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs))\n",
1567
+ "File \u001b[0;32m/usr/local/pace-apps/manual/packages/pytorch/1.12.0/lib/python3.9/site-packages/torch/amp/autocast_mode.py:12\u001b[0m, in \u001b[0;36mautocast_decorator.<locals>.decorate_autocast\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[39m@functools\u001b[39m\u001b[39m.\u001b[39mwraps(func)\n\u001b[1;32m 10\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mdecorate_autocast\u001b[39m(\u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs):\n\u001b[1;32m 11\u001b[0m \u001b[39mwith\u001b[39;00m autocast_instance:\n\u001b[0;32m---> 12\u001b[0m \u001b[39mreturn\u001b[39;00m func(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n",
1568
+ "Cell \u001b[0;32mIn[18], line 211\u001b[0m, in \u001b[0;36mContextUnet.forward\u001b[0;34m(self, x, timesteps, y)\u001b[0m\n\u001b[1;32m 209\u001b[0m \u001b[39mif\u001b[39;00m y \u001b[39m!=\u001b[39m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m 210\u001b[0m text_outputs \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mtoken_embedding(y\u001b[39m.\u001b[39mfloat())\n\u001b[0;32m--> 211\u001b[0m emb \u001b[39m=\u001b[39m emb \u001b[39m+\u001b[39;49m text_outputs\u001b[39m.\u001b[39;49mto(emb)\n\u001b[1;32m 213\u001b[0m h \u001b[39m=\u001b[39m x\u001b[39m.\u001b[39mtype(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mdtype)\n\u001b[1;32m 214\u001b[0m \u001b[39m# print(\"0,h.shape =\", h.shape)\u001b[39;00m\n",
1569
+ "\u001b[0;31mRuntimeError\u001b[0m: The size of tensor a (24) must match the size of tensor b (36) at non-singleton dimension 0"
1570
+ ]
1571
+ }
1572
+ ],
1573
  "source": [
1574
+ "ddpm21cm.sample(\"./outputs/model_state_09.pth\")"
1575
  ]
1576
  },
1577
  {
 
1769
  "\n",
1770
  "n_sample = 20\n",
1771
  "with torch.no_grad():\n",
1772
+ " x_last_ema, x_ema_entire = ddpm.sample(nn_model, n_sample, (1,config.HII_DIM, config.num_redshift), config.device, params = torch.tile(config.params_single,(n_sample,1)).to(config.device), guide_w=config.guide_w)\n",
1773
  "\n",
1774
  "np.save(os.path.join(config.output_dir, f\"{config.run_name}_ema.npy\"), x_last_ema)"
1775
  ]