0710-1257
Browse files- diffusion.py +10 -3
diffusion.py
CHANGED
|
@@ -240,7 +240,7 @@ class TrainConfig:
|
|
| 240 |
batch_size = 2#2#50#20#2#100 # 10
|
| 241 |
n_epoch = 2#10#50#20#20#2#5#25 # 120
|
| 242 |
HII_DIM = 28#64
|
| 243 |
-
num_redshift =
|
| 244 |
channel = 1
|
| 245 |
img_shape = (channel, HII_DIM, num_redshift) if dim == 2 else (channel, HII_DIM, HII_DIM, num_redshift)
|
| 246 |
|
|
@@ -359,7 +359,14 @@ class DDPM21CM:
|
|
| 359 |
dataset = Dataset4h5(self.config.dataset_name, num_image=self.config.num_image, HII_DIM=self.config.HII_DIM, num_redshift=self.config.num_redshift, drop_prob=self.config.drop_prob, dim=self.config.dim, ranges_dict=self.ranges_dict)
|
| 360 |
# self.shape_loaded = dataset.images.shape
|
| 361 |
# print("shape_loaded =", self.shape_loaded)
|
| 362 |
-
self.dataloader = DataLoader(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 363 |
# del dataset
|
| 364 |
# self.accelerate(self.config)
|
| 365 |
del dataset
|
|
@@ -547,7 +554,7 @@ def single_main(rank, world_size):
|
|
| 547 |
if __name__ == "__main__":
|
| 548 |
# torch.multiprocessing.set_start_method("spawn")
|
| 549 |
# args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)
|
| 550 |
-
world_size =
|
| 551 |
|
| 552 |
mp.spawn(single_main, args=(world_size,), nprocs=world_size)
|
| 553 |
# notebook_launcher(ddpm21cm.train, num_processes=1, mixed_precision='fp16')
|
|
|
|
| 240 |
batch_size = 2#2#50#20#2#100 # 10
|
| 241 |
n_epoch = 2#10#50#20#20#2#5#25 # 120
|
| 242 |
HII_DIM = 28#64
|
| 243 |
+
num_redshift = 2#128#64#512#256#256#64#512#128
|
| 244 |
channel = 1
|
| 245 |
img_shape = (channel, HII_DIM, num_redshift) if dim == 2 else (channel, HII_DIM, HII_DIM, num_redshift)
|
| 246 |
|
|
|
|
| 359 |
dataset = Dataset4h5(self.config.dataset_name, num_image=self.config.num_image, HII_DIM=self.config.HII_DIM, num_redshift=self.config.num_redshift, drop_prob=self.config.drop_prob, dim=self.config.dim, ranges_dict=self.ranges_dict)
|
| 360 |
# self.shape_loaded = dataset.images.shape
|
| 361 |
# print("shape_loaded =", self.shape_loaded)
|
| 362 |
+
self.dataloader = DataLoader(
|
| 363 |
+
dataset,
|
| 364 |
+
batch_size=self.config.batch_size,
|
| 365 |
+
shuffle=True,
|
| 366 |
+
num_workers=1,#len(os.sched_getaffinity(0)),
|
| 367 |
+
pin_memory=True,
|
| 368 |
+
persistent_workers=True,
|
| 369 |
+
)
|
| 370 |
# del dataset
|
| 371 |
# self.accelerate(self.config)
|
| 372 |
del dataset
|
|
|
|
| 554 |
if __name__ == "__main__":
|
| 555 |
# torch.multiprocessing.set_start_method("spawn")
|
| 556 |
# args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)
|
| 557 |
+
world_size = 2#torch.cuda.device_count()
|
| 558 |
|
| 559 |
mp.spawn(single_main, args=(world_size,), nprocs=world_size)
|
| 560 |
# notebook_launcher(ddpm21cm.train, num_processes=1, mixed_precision='fp16')
|