0710-1816
Browse files- diffusion.py +7 -5
diffusion.py
CHANGED
|
@@ -361,7 +361,7 @@ class DDPM21CM:
|
|
| 361 |
dataset = Dataset4h5(self.config.dataset_name, num_image=self.config.num_image, HII_DIM=self.config.HII_DIM, num_redshift=self.config.num_redshift, drop_prob=self.config.drop_prob, dim=self.config.dim, ranges_dict=self.ranges_dict)
|
| 362 |
# self.shape_loaded = dataset.images.shape
|
| 363 |
# print("shape_loaded =", self.shape_loaded)
|
| 364 |
-
# print(f"load, current_device = {torch.cuda.current_device()}")
|
| 365 |
self.dataloader = DataLoader(
|
| 366 |
dataset=dataset,
|
| 367 |
batch_size=self.config.batch_size,
|
|
@@ -394,7 +394,8 @@ class DDPM21CM:
|
|
| 394 |
# distributed_type="MULTI_GPU",
|
| 395 |
)
|
| 396 |
# print("!!!!!!!!!!!!!!!!!!!self.accelerator.device:", self.accelerator.device)
|
| 397 |
-
if self.accelerator.is_main_process:
|
|
|
|
| 398 |
if self.config.output_dir is not None:
|
| 399 |
os.makedirs(self.config.output_dir, exist_ok=True)
|
| 400 |
if self.config.push_to_hub:
|
|
@@ -414,7 +415,7 @@ class DDPM21CM:
|
|
| 414 |
self.nn_model, self.optimizer, self.lr_scheduler
|
| 415 |
)
|
| 416 |
|
| 417 |
-
print("!!!!!!!!!!!!!!!!, after prepare, self.dataloader.sampler =", self.dataloader.sampler)
|
| 418 |
# print("!!!!!!!!!!!!!!!!, after prepare, self.dataloader.batch_sampler =", self.dataloader.batch_sampler)
|
| 419 |
# print("!!!!!!!!!!!!!!!!, after prepare, self.dataloader.DistributedSampler =", self.dataloader.DistributedSampler)
|
| 420 |
|
|
@@ -468,7 +469,8 @@ class DDPM21CM:
|
|
| 468 |
|
| 469 |
def save(self, ep):
|
| 470 |
# save model
|
| 471 |
-
if self.accelerator.is_main_process:
|
|
|
|
| 472 |
if ep == self.config.n_epoch-1 or (ep+1)*self.config.save_freq==1:
|
| 473 |
self.nn_model.eval()
|
| 474 |
with torch.no_grad():
|
|
@@ -486,7 +488,7 @@ class DDPM21CM:
|
|
| 486 |
# 'ema_unet_state_dict': self.ema_model.state_dict(),
|
| 487 |
}
|
| 488 |
torch.save(model_state, self.config.save_name+f"-N{self.config.num_image}")
|
| 489 |
-
print('saved model at ' + self.config.save_name+f"-N{self.config.num_image}")
|
| 490 |
# print('saved model at ' + config.save_dir + f"model_epoch_{ep}_test_{config.run_name}.pth")
|
| 491 |
|
| 492 |
# def rescale(self, value, type='params', to_ranges=[0,1]):
|
|
|
|
| 361 |
dataset = Dataset4h5(self.config.dataset_name, num_image=self.config.num_image, HII_DIM=self.config.HII_DIM, num_redshift=self.config.num_redshift, drop_prob=self.config.drop_prob, dim=self.config.dim, ranges_dict=self.ranges_dict)
|
| 362 |
# self.shape_loaded = dataset.images.shape
|
| 363 |
# print("shape_loaded =", self.shape_loaded)
|
| 364 |
+
# print(f"load, current_device() = {torch.cuda.current_device()}")
|
| 365 |
self.dataloader = DataLoader(
|
| 366 |
dataset=dataset,
|
| 367 |
batch_size=self.config.batch_size,
|
|
|
|
| 394 |
# distributed_type="MULTI_GPU",
|
| 395 |
)
|
| 396 |
# print("!!!!!!!!!!!!!!!!!!!self.accelerator.device:", self.accelerator.device)
|
| 397 |
+
# if self.accelerator.is_main_process:
|
| 398 |
+
if torch.cuda.current_device() == 0:
|
| 399 |
if self.config.output_dir is not None:
|
| 400 |
os.makedirs(self.config.output_dir, exist_ok=True)
|
| 401 |
if self.config.push_to_hub:
|
|
|
|
| 415 |
self.nn_model, self.optimizer, self.lr_scheduler
|
| 416 |
)
|
| 417 |
|
| 418 |
+
# print("!!!!!!!!!!!!!!!!, after prepare, self.dataloader.sampler =", self.dataloader.sampler)
|
| 419 |
# print("!!!!!!!!!!!!!!!!, after prepare, self.dataloader.batch_sampler =", self.dataloader.batch_sampler)
|
| 420 |
# print("!!!!!!!!!!!!!!!!, after prepare, self.dataloader.DistributedSampler =", self.dataloader.DistributedSampler)
|
| 421 |
|
|
|
|
| 469 |
|
| 470 |
def save(self, ep):
|
| 471 |
# save model
|
| 472 |
+
# if self.accelerator.is_main_process:
|
| 473 |
+
if torch.cuda.current_device() == 0:
|
| 474 |
if ep == self.config.n_epoch-1 or (ep+1)*self.config.save_freq==1:
|
| 475 |
self.nn_model.eval()
|
| 476 |
with torch.no_grad():
|
|
|
|
| 488 |
# 'ema_unet_state_dict': self.ema_model.state_dict(),
|
| 489 |
}
|
| 490 |
torch.save(model_state, self.config.save_name+f"-N{self.config.num_image}")
|
| 491 |
+
print(f'device {torch.cuda.current_device()} saved model at ' + self.config.save_name+f"-N{self.config.num_image}")
|
| 492 |
# print('saved model at ' + config.save_dir + f"model_epoch_{ep}_test_{config.run_name}.pth")
|
| 493 |
|
| 494 |
# def rescale(self, value, type='params', to_ranges=[0,1]):
|