0521-2304
Browse files- diffusion.ipynb +78 -73
diffusion.ipynb
CHANGED
|
@@ -32,7 +32,7 @@
|
|
| 32 |
{
|
| 33 |
"data": {
|
| 34 |
"application/vnd.jupyter.widget-view+json": {
|
| 35 |
-
"model_id": "
|
| 36 |
"version_major": 2,
|
| 37 |
"version_minor": 0
|
| 38 |
},
|
|
@@ -962,7 +962,7 @@
|
|
| 962 |
},
|
| 963 |
{
|
| 964 |
"cell_type": "code",
|
| 965 |
-
"execution_count":
|
| 966 |
"metadata": {},
|
| 967 |
"outputs": [],
|
| 968 |
"source": [
|
|
@@ -987,7 +987,7 @@
|
|
| 987 |
"\n",
|
| 988 |
" n_epoch = 10#2#5#25 # 120\n",
|
| 989 |
" num_timesteps = 1000#1000 # 1000, 500; DDPM time steps\n",
|
| 990 |
-
"
|
| 991 |
" # n_sample = 24 # 64, the number of samples in sampling process\n",
|
| 992 |
" n_param = 2\n",
|
| 993 |
" guide_w = 0#-1#0#-1#0.1#[0,0.1] #[0,0.5,2] strength of generative guidance\n",
|
|
@@ -1032,7 +1032,7 @@
|
|
| 1032 |
"outputs": [],
|
| 1033 |
"source": [
|
| 1034 |
"# dataset = Dataset4h5(config.dataset_name, num_image=config.num_image, HII_DIM=config.HII_DIM, num_redshift=config.num_redshift, drop_prob=config.drop_prob, dim=config.dim)\n",
|
| 1035 |
-
"# dataloader = DataLoader(dataset, batch_size=config.
|
| 1036 |
]
|
| 1037 |
},
|
| 1038 |
{
|
|
@@ -1083,10 +1083,10 @@
|
|
| 1083 |
"# # HII_DIM = 64\n",
|
| 1084 |
"# # num_redshift = 64#512#128\n",
|
| 1085 |
"# # dataset = Dataset4h5(config.dataset_name, num_image=config.num_image, HII_DIM=config.HII_DIM, num_redshift=config.num_redshift, drop_prob=config.drop_prob)\n",
|
| 1086 |
-
"# # dataloader = DataLoader(dataset, batch_size=config.
|
| 1087 |
"# # tf = transforms.Compose([transforms.ToTensor()]) # mnist is already normalised 0 to 1\n",
|
| 1088 |
"# # dataset = MNIST(\"./data\", train=True, download=True, transform=tf)\n",
|
| 1089 |
-
"# # dataloader = DataLoader(dataset, batch_size=config.
|
| 1090 |
"# accelerator = Accelerator(\n",
|
| 1091 |
"# mixed_precision=config.mixed_precision,\n",
|
| 1092 |
"# gradient_accumulation_steps=config.gradient_accumulation_steps,\n",
|
|
@@ -1268,37 +1268,15 @@
|
|
| 1268 |
},
|
| 1269 |
{
|
| 1270 |
"cell_type": "code",
|
| 1271 |
-
"execution_count":
|
| 1272 |
"metadata": {},
|
| 1273 |
"outputs": [
|
| 1274 |
{
|
| 1275 |
"name": "stdout",
|
| 1276 |
"output_type": "stream",
|
| 1277 |
"text": [
|
| 1278 |
-
"dataset content: <KeysViewHDF5 ['brightness_temp', 'density', 'kwargs', 'params', 'redshifts_distances', 'seeds', 'xH_box']>\n",
|
| 1279 |
-
"51200 images can be loaded\n",
|
| 1280 |
-
"field.shape = (64, 64, 514)\n",
|
| 1281 |
-
"params keys = [b'ION_Tvir_MIN', b'HII_EFF_FACTOR']\n",
|
| 1282 |
-
"loading 200 images randomly\n",
|
| 1283 |
-
"images loaded: (200, 1, 64, 512)\n",
|
| 1284 |
-
"params loaded: (200, 2)\n",
|
| 1285 |
-
"images rescaled to [-1.0, 1.082756519317627]\n",
|
| 1286 |
-
"params rescaled to [0.0, 0.9938162632551855]\n",
|
| 1287 |
"resumed nn_model from model_state.pth\n",
|
| 1288 |
-
"Number of parameters for nn_model: 111048705\n"
|
| 1289 |
-
]
|
| 1290 |
-
},
|
| 1291 |
-
{
|
| 1292 |
-
"name": "stderr",
|
| 1293 |
-
"output_type": "stream",
|
| 1294 |
-
"text": [
|
| 1295 |
-
"Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n"
|
| 1296 |
-
]
|
| 1297 |
-
},
|
| 1298 |
-
{
|
| 1299 |
-
"name": "stdout",
|
| 1300 |
-
"output_type": "stream",
|
| 1301 |
-
"text": [
|
| 1302 |
"resumed ema_model from model_state.pth\n"
|
| 1303 |
]
|
| 1304 |
}
|
|
@@ -1306,15 +1284,14 @@
|
|
| 1306 |
"source": [
|
| 1307 |
"# @dataclass\n",
|
| 1308 |
"class DDPM21CM:\n",
|
| 1309 |
-
" def __init__(self
|
|
|
|
| 1310 |
" self.config = config\n",
|
| 1311 |
-
"\n",
|
| 1312 |
-
"
|
| 1313 |
-
" #
|
| 1314 |
-
" #
|
| 1315 |
-
"
|
| 1316 |
-
" del dataset\n",
|
| 1317 |
-
"\n",
|
| 1318 |
" self.ddpm = DDPMScheduler(betas=(1e-4, 0.02), num_timesteps=config.num_timesteps, img_shape=config.img_shape, device=config.device)\n",
|
| 1319 |
"\n",
|
| 1320 |
" # initialize the unet\n",
|
|
@@ -1345,10 +1322,17 @@
|
|
| 1345 |
" self.lr_scheduler = get_cosine_schedule_with_warmup(\n",
|
| 1346 |
" optimizer=self.optimizer,\n",
|
| 1347 |
" num_warmup_steps=config.lr_warmup_steps,\n",
|
| 1348 |
-
" num_training_steps=(
|
|
|
|
| 1349 |
" )\n",
|
| 1350 |
"\n",
|
| 1351 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1352 |
"\n",
|
| 1353 |
" def accelerate(self, config):\n",
|
| 1354 |
" self.accelerator = Accelerator(\n",
|
|
@@ -1376,6 +1360,7 @@
|
|
| 1376 |
" ## training loop ##\n",
|
| 1377 |
" ###################\n",
|
| 1378 |
" # plot_unet = True\n",
|
|
|
|
| 1379 |
" global_step = 0\n",
|
| 1380 |
" for ep in range(self.config.n_epoch):\n",
|
| 1381 |
" self.ddpm.train()\n",
|
|
@@ -1446,26 +1431,54 @@
|
|
| 1446 |
" model = self.ema_model if ema else self.nn_model\n",
|
| 1447 |
" # params = torch.tile(params, (n_sample,1)).to(device)\n",
|
| 1448 |
"\n",
|
| 1449 |
-
" x_last, x_entire = self.ddpm.sample(model, params=params,
|
| 1450 |
"\n",
|
| 1451 |
" np.save(os.path.join(self.config.output_dir, f\"{self.config.run_name}{'ema' if ema else None}.npy\"), x_last)\n",
|
| 1452 |
" if entire:\n",
|
| 1453 |
" np.save(os.path.join(self.config.output_dir, f\"{self.config.run_name}{'ema' if ema else None}_entire.npy\"), x_last)\n",
|
| 1454 |
"\n",
|
| 1455 |
"\n",
|
| 1456 |
-
"ddpm21cm = DDPM21CM(
|
| 1457 |
"# print(\"device =\", config.device)"
|
| 1458 |
]
|
| 1459 |
},
|
| 1460 |
{
|
| 1461 |
"cell_type": "code",
|
| 1462 |
-
"execution_count":
|
| 1463 |
"metadata": {},
|
| 1464 |
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1465 |
{
|
| 1466 |
"data": {
|
| 1467 |
"application/vnd.jupyter.widget-view+json": {
|
| 1468 |
-
"model_id": "
|
| 1469 |
"version_major": 2,
|
| 1470 |
"version_minor": 0
|
| 1471 |
},
|
|
@@ -1479,7 +1492,7 @@
|
|
| 1479 |
{
|
| 1480 |
"data": {
|
| 1481 |
"application/vnd.jupyter.widget-view+json": {
|
| 1482 |
-
"model_id": "
|
| 1483 |
"version_major": 2,
|
| 1484 |
"version_minor": 0
|
| 1485 |
},
|
|
@@ -1493,7 +1506,7 @@
|
|
| 1493 |
{
|
| 1494 |
"data": {
|
| 1495 |
"application/vnd.jupyter.widget-view+json": {
|
| 1496 |
-
"model_id": "
|
| 1497 |
"version_major": 2,
|
| 1498 |
"version_minor": 0
|
| 1499 |
},
|
|
@@ -1507,7 +1520,7 @@
|
|
| 1507 |
{
|
| 1508 |
"data": {
|
| 1509 |
"application/vnd.jupyter.widget-view+json": {
|
| 1510 |
-
"model_id": "
|
| 1511 |
"version_major": 2,
|
| 1512 |
"version_minor": 0
|
| 1513 |
},
|
|
@@ -1521,7 +1534,7 @@
|
|
| 1521 |
{
|
| 1522 |
"data": {
|
| 1523 |
"application/vnd.jupyter.widget-view+json": {
|
| 1524 |
-
"model_id": "
|
| 1525 |
"version_major": 2,
|
| 1526 |
"version_minor": 0
|
| 1527 |
},
|
|
@@ -1535,7 +1548,7 @@
|
|
| 1535 |
{
|
| 1536 |
"data": {
|
| 1537 |
"application/vnd.jupyter.widget-view+json": {
|
| 1538 |
-
"model_id": "
|
| 1539 |
"version_major": 2,
|
| 1540 |
"version_minor": 0
|
| 1541 |
},
|
|
@@ -1549,7 +1562,7 @@
|
|
| 1549 |
{
|
| 1550 |
"data": {
|
| 1551 |
"application/vnd.jupyter.widget-view+json": {
|
| 1552 |
-
"model_id": "
|
| 1553 |
"version_major": 2,
|
| 1554 |
"version_minor": 0
|
| 1555 |
},
|
|
@@ -1563,7 +1576,7 @@
|
|
| 1563 |
{
|
| 1564 |
"data": {
|
| 1565 |
"application/vnd.jupyter.widget-view+json": {
|
| 1566 |
-
"model_id": "
|
| 1567 |
"version_major": 2,
|
| 1568 |
"version_minor": 0
|
| 1569 |
},
|
|
@@ -1577,7 +1590,7 @@
|
|
| 1577 |
{
|
| 1578 |
"data": {
|
| 1579 |
"application/vnd.jupyter.widget-view+json": {
|
| 1580 |
-
"model_id": "
|
| 1581 |
"version_major": 2,
|
| 1582 |
"version_minor": 0
|
| 1583 |
},
|
|
@@ -1591,7 +1604,7 @@
|
|
| 1591 |
{
|
| 1592 |
"data": {
|
| 1593 |
"application/vnd.jupyter.widget-view+json": {
|
| 1594 |
-
"model_id": "
|
| 1595 |
"version_major": 2,
|
| 1596 |
"version_minor": 0
|
| 1597 |
},
|
|
@@ -1609,28 +1622,20 @@
|
|
| 1609 |
},
|
| 1610 |
{
|
| 1611 |
"cell_type": "code",
|
| 1612 |
-
"execution_count":
|
| 1613 |
"metadata": {},
|
| 1614 |
"outputs": [
|
| 1615 |
{
|
| 1616 |
-
"
|
| 1617 |
-
"
|
| 1618 |
-
"text": [
|
| 1619 |
-
"params.shape[0], len(params) 2 2\n"
|
| 1620 |
-
]
|
| 1621 |
-
},
|
| 1622 |
-
{
|
| 1623 |
-
"ename": "AttributeError",
|
| 1624 |
-
"evalue": "'DDPMScheduler' object has no attribute 'shape'",
|
| 1625 |
"output_type": "error",
|
| 1626 |
"traceback": [
|
| 1627 |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
| 1628 |
-
"\u001b[0;
|
| 1629 |
-
"Cell \u001b[0;32mIn[
|
| 1630 |
-
"Cell \u001b[0;32mIn[
|
| 1631 |
-
"Cell \u001b[0;32mIn[
|
| 1632 |
-
"
|
| 1633 |
-
"\u001b[0;31mAttributeError\u001b[0m: 'DDPMScheduler' object has no attribute 'shape'"
|
| 1634 |
]
|
| 1635 |
}
|
| 1636 |
],
|
|
@@ -2030,7 +2035,7 @@
|
|
| 2030 |
"@dataclass\n",
|
| 2031 |
"class TrainingConfig:\n",
|
| 2032 |
" image_size = 128\n",
|
| 2033 |
-
"
|
| 2034 |
" eval_batch_size = 16\n",
|
| 2035 |
" num_epochs = 50\n",
|
| 2036 |
" gradient_accumulation_steps = 1\n",
|
|
@@ -2067,7 +2072,7 @@
|
|
| 2067 |
" return {\"images\": images}\n",
|
| 2068 |
"\n",
|
| 2069 |
"dataset.set_transform(transform)\n",
|
| 2070 |
-
"dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.
|
| 2071 |
"\n",
|
| 2072 |
"model = UNet2DModel(\n",
|
| 2073 |
" sample_size = config.image_size,\n",
|
|
@@ -2268,7 +2273,7 @@
|
|
| 2268 |
"class TrainingConfig:\n",
|
| 2269 |
" num_images = 600\n",
|
| 2270 |
" image_size = [64,512]\n",
|
| 2271 |
-
"
|
| 2272 |
" eval_batch_size = 24\n",
|
| 2273 |
" num_epochs = 20\n",
|
| 2274 |
" gradient_accumulation_steps = 1\n",
|
|
@@ -2547,7 +2552,7 @@
|
|
| 2547 |
"outputs": [],
|
| 2548 |
"source": [
|
| 2549 |
"# dataset.set_transform(transform)\n",
|
| 2550 |
-
"dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.
|
| 2551 |
]
|
| 2552 |
},
|
| 2553 |
{
|
|
|
|
| 32 |
{
|
| 33 |
"data": {
|
| 34 |
"application/vnd.jupyter.widget-view+json": {
|
| 35 |
+
"model_id": "f925cb378800455fb1216e84a2900e58",
|
| 36 |
"version_major": 2,
|
| 37 |
"version_minor": 0
|
| 38 |
},
|
|
|
|
| 962 |
},
|
| 963 |
{
|
| 964 |
"cell_type": "code",
|
| 965 |
+
"execution_count": 29,
|
| 966 |
"metadata": {},
|
| 967 |
"outputs": [],
|
| 968 |
"source": [
|
|
|
|
| 987 |
"\n",
|
| 988 |
" n_epoch = 10#2#5#25 # 120\n",
|
| 989 |
" num_timesteps = 1000#1000 # 1000, 500; DDPM time steps\n",
|
| 990 |
+
" batch_size = 10#10#20#2#100 # 10\n",
|
| 991 |
" # n_sample = 24 # 64, the number of samples in sampling process\n",
|
| 992 |
" n_param = 2\n",
|
| 993 |
" guide_w = 0#-1#0#-1#0.1#[0,0.1] #[0,0.5,2] strength of generative guidance\n",
|
|
|
|
| 1032 |
"outputs": [],
|
| 1033 |
"source": [
|
| 1034 |
"# dataset = Dataset4h5(config.dataset_name, num_image=config.num_image, HII_DIM=config.HII_DIM, num_redshift=config.num_redshift, drop_prob=config.drop_prob, dim=config.dim)\n",
|
| 1035 |
+
"# dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True)"
|
| 1036 |
]
|
| 1037 |
},
|
| 1038 |
{
|
|
|
|
| 1083 |
"# # HII_DIM = 64\n",
|
| 1084 |
"# # num_redshift = 64#512#128\n",
|
| 1085 |
"# # dataset = Dataset4h5(config.dataset_name, num_image=config.num_image, HII_DIM=config.HII_DIM, num_redshift=config.num_redshift, drop_prob=config.drop_prob)\n",
|
| 1086 |
+
"# # dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True)\n",
|
| 1087 |
"# # tf = transforms.Compose([transforms.ToTensor()]) # mnist is already normalised 0 to 1\n",
|
| 1088 |
"# # dataset = MNIST(\"./data\", train=True, download=True, transform=tf)\n",
|
| 1089 |
+
"# # dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True, num_workers=5) # Initialize accelerator and tensorboard logging\n",
|
| 1090 |
"# accelerator = Accelerator(\n",
|
| 1091 |
"# mixed_precision=config.mixed_precision,\n",
|
| 1092 |
"# gradient_accumulation_steps=config.gradient_accumulation_steps,\n",
|
|
|
|
| 1268 |
},
|
| 1269 |
{
|
| 1270 |
"cell_type": "code",
|
| 1271 |
+
"execution_count": 31,
|
| 1272 |
"metadata": {},
|
| 1273 |
"outputs": [
|
| 1274 |
{
|
| 1275 |
"name": "stdout",
|
| 1276 |
"output_type": "stream",
|
| 1277 |
"text": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1278 |
"resumed nn_model from model_state.pth\n",
|
| 1279 |
+
"Number of parameters for nn_model: 111048705\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1280 |
"resumed ema_model from model_state.pth\n"
|
| 1281 |
]
|
| 1282 |
}
|
|
|
|
| 1284 |
"source": [
|
| 1285 |
"# @dataclass\n",
|
| 1286 |
"class DDPM21CM:\n",
|
| 1287 |
+
" def __init__(self):\n",
|
| 1288 |
+
" config = TrainConfig()\n",
|
| 1289 |
" self.config = config\n",
|
| 1290 |
+
" # dataset = Dataset4h5(config.dataset_name, num_image=config.num_image, HII_DIM=config.HII_DIM, num_redshift=config.num_redshift, drop_prob=config.drop_prob, dim=config.dim)\n",
|
| 1291 |
+
" # # self.shape_loaded = dataset.images.shape\n",
|
| 1292 |
+
" # # print(\"shape_loaded =\", self.shape_loaded)\n",
|
| 1293 |
+
" # self.dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True)\n",
|
| 1294 |
+
" # del dataset\n",
|
|
|
|
|
|
|
| 1295 |
" self.ddpm = DDPMScheduler(betas=(1e-4, 0.02), num_timesteps=config.num_timesteps, img_shape=config.img_shape, device=config.device)\n",
|
| 1296 |
"\n",
|
| 1297 |
" # initialize the unet\n",
|
|
|
|
| 1322 |
" self.lr_scheduler = get_cosine_schedule_with_warmup(\n",
|
| 1323 |
" optimizer=self.optimizer,\n",
|
| 1324 |
" num_warmup_steps=config.lr_warmup_steps,\n",
|
| 1325 |
+
" num_training_steps=(int(config.num_image/config.batch_size) * config.n_epoch),\n",
|
| 1326 |
+
" # num_training_steps=(len(self.dataloader) * config.n_epoch),\n",
|
| 1327 |
" )\n",
|
| 1328 |
"\n",
|
| 1329 |
+
" def load(self):\n",
|
| 1330 |
+
" dataset = Dataset4h5(self.config.dataset_name, num_image=self.config.num_image, HII_DIM=self.config.HII_DIM, num_redshift=self.config.num_redshift, drop_prob=self.config.drop_prob, dim=self.config.dim)\n",
|
| 1331 |
+
" # self.shape_loaded = dataset.images.shape\n",
|
| 1332 |
+
" # print(\"shape_loaded =\", self.shape_loaded)\n",
|
| 1333 |
+
" self.dataloader = DataLoader(dataset, batch_size=self.config.batch_size, shuffle=True)\n",
|
| 1334 |
+
" # del dataset\n",
|
| 1335 |
+
" self.accelerate(self.config)\n",
|
| 1336 |
"\n",
|
| 1337 |
" def accelerate(self, config):\n",
|
| 1338 |
" self.accelerator = Accelerator(\n",
|
|
|
|
| 1360 |
" ## training loop ##\n",
|
| 1361 |
" ###################\n",
|
| 1362 |
" # plot_unet = True\n",
|
| 1363 |
+
" self.load()\n",
|
| 1364 |
" global_step = 0\n",
|
| 1365 |
" for ep in range(self.config.n_epoch):\n",
|
| 1366 |
" self.ddpm.train()\n",
|
|
|
|
| 1431 |
" model = self.ema_model if ema else self.nn_model\n",
|
| 1432 |
" # params = torch.tile(params, (n_sample,1)).to(device)\n",
|
| 1433 |
"\n",
|
| 1434 |
+
" x_last, x_entire = self.ddpm.sample(model, params=params, device=self.config.device, guide_w=self.config.guide_w)\n",
|
| 1435 |
"\n",
|
| 1436 |
" np.save(os.path.join(self.config.output_dir, f\"{self.config.run_name}{'ema' if ema else None}.npy\"), x_last)\n",
|
| 1437 |
" if entire:\n",
|
| 1438 |
" np.save(os.path.join(self.config.output_dir, f\"{self.config.run_name}{'ema' if ema else None}_entire.npy\"), x_last)\n",
|
| 1439 |
"\n",
|
| 1440 |
"\n",
|
| 1441 |
+
"ddpm21cm = DDPM21CM()\n",
|
| 1442 |
"# print(\"device =\", config.device)"
|
| 1443 |
]
|
| 1444 |
},
|
| 1445 |
{
|
| 1446 |
"cell_type": "code",
|
| 1447 |
+
"execution_count": 32,
|
| 1448 |
"metadata": {},
|
| 1449 |
"outputs": [
|
| 1450 |
+
{
|
| 1451 |
+
"name": "stdout",
|
| 1452 |
+
"output_type": "stream",
|
| 1453 |
+
"text": [
|
| 1454 |
+
"dataset content: <KeysViewHDF5 ['brightness_temp', 'density', 'kwargs', 'params', 'redshifts_distances', 'seeds', 'xH_box']>\n",
|
| 1455 |
+
"51200 images can be loaded\n",
|
| 1456 |
+
"field.shape = (64, 64, 514)\n",
|
| 1457 |
+
"params keys = [b'ION_Tvir_MIN', b'HII_EFF_FACTOR']\n",
|
| 1458 |
+
"loading 200 images randomly\n",
|
| 1459 |
+
"images loaded: (200, 1, 64, 512)\n"
|
| 1460 |
+
]
|
| 1461 |
+
},
|
| 1462 |
+
{
|
| 1463 |
+
"name": "stderr",
|
| 1464 |
+
"output_type": "stream",
|
| 1465 |
+
"text": [
|
| 1466 |
+
"Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n"
|
| 1467 |
+
]
|
| 1468 |
+
},
|
| 1469 |
+
{
|
| 1470 |
+
"name": "stdout",
|
| 1471 |
+
"output_type": "stream",
|
| 1472 |
+
"text": [
|
| 1473 |
+
"params loaded: (200, 2)\n",
|
| 1474 |
+
"images rescaled to [-1.0, 1.095154047012329]\n",
|
| 1475 |
+
"params rescaled to [0.0, 0.997810682944812]\n"
|
| 1476 |
+
]
|
| 1477 |
+
},
|
| 1478 |
{
|
| 1479 |
"data": {
|
| 1480 |
"application/vnd.jupyter.widget-view+json": {
|
| 1481 |
+
"model_id": "e6947686dc1344b4b446910cca5326dc",
|
| 1482 |
"version_major": 2,
|
| 1483 |
"version_minor": 0
|
| 1484 |
},
|
|
|
|
| 1492 |
{
|
| 1493 |
"data": {
|
| 1494 |
"application/vnd.jupyter.widget-view+json": {
|
| 1495 |
+
"model_id": "90757bce414c44778e816eb75fbeb106",
|
| 1496 |
"version_major": 2,
|
| 1497 |
"version_minor": 0
|
| 1498 |
},
|
|
|
|
| 1506 |
{
|
| 1507 |
"data": {
|
| 1508 |
"application/vnd.jupyter.widget-view+json": {
|
| 1509 |
+
"model_id": "cc51568c1fb94c82a96726973b103b36",
|
| 1510 |
"version_major": 2,
|
| 1511 |
"version_minor": 0
|
| 1512 |
},
|
|
|
|
| 1520 |
{
|
| 1521 |
"data": {
|
| 1522 |
"application/vnd.jupyter.widget-view+json": {
|
| 1523 |
+
"model_id": "62c75108d8344ec69b34873181e3d308",
|
| 1524 |
"version_major": 2,
|
| 1525 |
"version_minor": 0
|
| 1526 |
},
|
|
|
|
| 1534 |
{
|
| 1535 |
"data": {
|
| 1536 |
"application/vnd.jupyter.widget-view+json": {
|
| 1537 |
+
"model_id": "6bece94eb3cd45c7933bae67ad774f06",
|
| 1538 |
"version_major": 2,
|
| 1539 |
"version_minor": 0
|
| 1540 |
},
|
|
|
|
| 1548 |
{
|
| 1549 |
"data": {
|
| 1550 |
"application/vnd.jupyter.widget-view+json": {
|
| 1551 |
+
"model_id": "d3c093174fab4863af65f477f8baee8f",
|
| 1552 |
"version_major": 2,
|
| 1553 |
"version_minor": 0
|
| 1554 |
},
|
|
|
|
| 1562 |
{
|
| 1563 |
"data": {
|
| 1564 |
"application/vnd.jupyter.widget-view+json": {
|
| 1565 |
+
"model_id": "f2565e70177a402ba90215f4d80f717e",
|
| 1566 |
"version_major": 2,
|
| 1567 |
"version_minor": 0
|
| 1568 |
},
|
|
|
|
| 1576 |
{
|
| 1577 |
"data": {
|
| 1578 |
"application/vnd.jupyter.widget-view+json": {
|
| 1579 |
+
"model_id": "f11d5a63fb1040c2a4366207e53476b7",
|
| 1580 |
"version_major": 2,
|
| 1581 |
"version_minor": 0
|
| 1582 |
},
|
|
|
|
| 1590 |
{
|
| 1591 |
"data": {
|
| 1592 |
"application/vnd.jupyter.widget-view+json": {
|
| 1593 |
+
"model_id": "4641851ffed541edb4bfd99d156bf000",
|
| 1594 |
"version_major": 2,
|
| 1595 |
"version_minor": 0
|
| 1596 |
},
|
|
|
|
| 1604 |
{
|
| 1605 |
"data": {
|
| 1606 |
"application/vnd.jupyter.widget-view+json": {
|
| 1607 |
+
"model_id": "9338bea79f5d49c4b79a2327b9833954",
|
| 1608 |
"version_major": 2,
|
| 1609 |
"version_minor": 0
|
| 1610 |
},
|
|
|
|
| 1622 |
},
|
| 1623 |
{
|
| 1624 |
"cell_type": "code",
|
| 1625 |
+
"execution_count": 29,
|
| 1626 |
"metadata": {},
|
| 1627 |
"outputs": [
|
| 1628 |
{
|
| 1629 |
+
"ename": "IndexError",
|
| 1630 |
+
"evalue": "tuple index out of range",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1631 |
"output_type": "error",
|
| 1632 |
"traceback": [
|
| 1633 |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
| 1634 |
+
"\u001b[0;31mIndexError\u001b[0m Traceback (most recent call last)",
|
| 1635 |
+
"Cell \u001b[0;32mIn[29], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m ddpm21cm\u001b[39m.\u001b[39;49msample(\u001b[39m\"\u001b[39;49m\u001b[39m./outputs/model_state_09.pth\u001b[39;49m\u001b[39m\"\u001b[39;49m)\n",
|
| 1636 |
+
"Cell \u001b[0;32mIn[28], line 143\u001b[0m, in \u001b[0;36mDDPM21CM.sample\u001b[0;34m(self, file, params, ema, entire)\u001b[0m\n\u001b[1;32m 140\u001b[0m model \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mema_model \u001b[39mif\u001b[39;00m ema \u001b[39melse\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mnn_model\n\u001b[1;32m 141\u001b[0m \u001b[39m# params = torch.tile(params, (n_sample,1)).to(device)\u001b[39;00m\n\u001b[0;32m--> 143\u001b[0m x_last, x_entire \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mddpm\u001b[39m.\u001b[39;49msample(model, params\u001b[39m=\u001b[39;49mparams, device\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mconfig\u001b[39m.\u001b[39;49mdevice, guide_w\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mconfig\u001b[39m.\u001b[39;49mguide_w)\n\u001b[1;32m 145\u001b[0m np\u001b[39m.\u001b[39msave(os\u001b[39m.\u001b[39mpath\u001b[39m.\u001b[39mjoin(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconfig\u001b[39m.\u001b[39moutput_dir, \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m{\u001b[39;00m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mconfig\u001b[39m.\u001b[39mrun_name\u001b[39m}\u001b[39;00m\u001b[39m{\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39mema\u001b[39m\u001b[39m'\u001b[39m\u001b[39m \u001b[39m\u001b[39mif\u001b[39;00m\u001b[39m \u001b[39mema\u001b[39m \u001b[39m\u001b[39melse\u001b[39;00m\u001b[39m \u001b[39m\u001b[39mNone\u001b[39;00m\u001b[39m}\u001b[39;00m\u001b[39m.npy\u001b[39m\u001b[39m\"\u001b[39m), x_last)\n\u001b[1;32m 146\u001b[0m \u001b[39mif\u001b[39;00m entire:\n",
|
| 1637 |
+
"Cell \u001b[0;32mIn[7], line 45\u001b[0m, in \u001b[0;36mDDPMScheduler.sample\u001b[0;34m(self, nn_model, params, device, guide_w)\u001b[0m\n\u001b[1;32m 43\u001b[0m \u001b[39mif\u001b[39;00m guide_w \u001b[39m!=\u001b[39m \u001b[39m-\u001b[39m\u001b[39m1\u001b[39m:\n\u001b[1;32m 44\u001b[0m c_i \u001b[39m=\u001b[39m params\n\u001b[0;32m---> 45\u001b[0m uncond_tokens \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mzeros(\u001b[39mint\u001b[39m(n_sample), params\u001b[39m.\u001b[39;49mshape[\u001b[39m1\u001b[39;49m])\u001b[39m.\u001b[39mto(device)\n\u001b[1;32m 46\u001b[0m \u001b[39m# uncond_tokens = torch.tensor(np.float32(np.array([0,0]))).to(device)\u001b[39;00m\n\u001b[1;32m 47\u001b[0m \u001b[39m# uncond_tokens = uncond_tokens.repeat(int(n_sample),1)\u001b[39;00m\n\u001b[1;32m 48\u001b[0m c_i \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mcat((c_i, uncond_tokens), \u001b[39m0\u001b[39m)\n",
|
| 1638 |
+
"\u001b[0;31mIndexError\u001b[0m: tuple index out of range"
|
|
|
|
| 1639 |
]
|
| 1640 |
}
|
| 1641 |
],
|
|
|
|
| 2035 |
"@dataclass\n",
|
| 2036 |
"class TrainingConfig:\n",
|
| 2037 |
" image_size = 128\n",
|
| 2038 |
+
" batch_size = 16\n",
|
| 2039 |
" eval_batch_size = 16\n",
|
| 2040 |
" num_epochs = 50\n",
|
| 2041 |
" gradient_accumulation_steps = 1\n",
|
|
|
|
| 2072 |
" return {\"images\": images}\n",
|
| 2073 |
"\n",
|
| 2074 |
"dataset.set_transform(transform)\n",
|
| 2075 |
+
"dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.batch_size, shuffle=True)\n",
|
| 2076 |
"\n",
|
| 2077 |
"model = UNet2DModel(\n",
|
| 2078 |
" sample_size = config.image_size,\n",
|
|
|
|
| 2273 |
"class TrainingConfig:\n",
|
| 2274 |
" num_images = 600\n",
|
| 2275 |
" image_size = [64,512]\n",
|
| 2276 |
+
" batch_size = 10\n",
|
| 2277 |
" eval_batch_size = 24\n",
|
| 2278 |
" num_epochs = 20\n",
|
| 2279 |
" gradient_accumulation_steps = 1\n",
|
|
|
|
| 2552 |
"outputs": [],
|
| 2553 |
"source": [
|
| 2554 |
"# dataset.set_transform(transform)\n",
|
| 2555 |
+
"dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.batch_size, shuffle=True)"
|
| 2556 |
]
|
| 2557 |
},
|
| 2558 |
{
|