0713-1516
Browse files- diffusion.py +15 -12
diffusion.py
CHANGED
|
@@ -274,13 +274,12 @@ class TrainConfig:
|
|
| 274 |
# save_period = 1 #10 # the period of saving model
|
| 275 |
# cond = True # if training using the conditional information
|
| 276 |
# lr_decay = False #True# if using the learning rate decay
|
| 277 |
-
resume =
|
| 278 |
# params_single = torch.tensor([0.2,0.80000023])
|
| 279 |
# params = torch.tile(params_single,(n_sample,1)).to(device)
|
| 280 |
# params = params
|
| 281 |
# data_dir = './data' # data directory
|
| 282 |
|
| 283 |
-
|
| 284 |
mixed_precision = "fp16"
|
| 285 |
gradient_accumulation_steps = 1
|
| 286 |
|
|
@@ -322,10 +321,6 @@ class DDPM21CM:
|
|
| 322 |
# initialize the unet
|
| 323 |
self.nn_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM, dim=config.dim, stride=config.stride)
|
| 324 |
|
| 325 |
-
if config.resume and os.path.exists(config.resume):
|
| 326 |
-
# resume_file = os.path.join(config.output_dir, f"{config.resume}")
|
| 327 |
-
self.nn_model.load_state_dict(torch.load(config.resume)['unet_state_dict'])
|
| 328 |
-
print(f"resumed nn_model from {config.resume}")
|
| 329 |
# nn_model = ContextUnet(n_param=1, image_size=28)
|
| 330 |
self.nn_model.train()
|
| 331 |
# print("self.ddpm.device =", self.ddpm.device)
|
|
@@ -333,6 +328,14 @@ class DDPM21CM:
|
|
| 333 |
self.nn_model = DDP(self.nn_model, device_ids=[self.ddpm.device])
|
| 334 |
# print("nn_model.device =", ddpm.device)
|
| 335 |
# number of parameters to be trained
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 336 |
self.number_of_params = sum(x.numel() for x in self.nn_model.parameters())
|
| 337 |
print(f"Number of parameters for nn_model: {self.number_of_params}")
|
| 338 |
|
|
@@ -508,7 +511,7 @@ class DDPM21CM:
|
|
| 508 |
|
| 509 |
def sample(self, params:torch.tensor=None, num_new_img=192, ema=False, entire=False, save=False):
|
| 510 |
# n_sample = params.shape[0]
|
| 511 |
-
file = self.config.resume
|
| 512 |
|
| 513 |
if params is None:
|
| 514 |
params = torch.tensor([0.20000000000000018, 0.5055875000000001])
|
|
@@ -528,11 +531,11 @@ class DDPM21CM:
|
|
| 528 |
# params = torch.tile(params, (n_sample,1)).to(device)
|
| 529 |
|
| 530 |
# nn_model = ContextUnet(n_param=self.config.n_param, image_size=self.config.HII_DIM, dim=self.config.dim, stride=self.config.stride).to(self.config.device)
|
| 531 |
-
if ema:
|
| 532 |
-
|
| 533 |
-
else:
|
| 534 |
-
|
| 535 |
-
print(f"device {torch.cuda.current_device()} resumed nn_model from {file}")
|
| 536 |
# nn_model = ContextUnet(n_param=1, image_size=28)
|
| 537 |
# nn_model.train()
|
| 538 |
# self.nn_model.to(self.ddpm.device)
|
|
|
|
| 274 |
# save_period = 1 #10 # the period of saving model
|
| 275 |
# cond = True # if training using the conditional information
|
| 276 |
# lr_decay = False #True# if using the learning rate decay
|
| 277 |
+
resume = False # if resume from the trained checkpoints
|
| 278 |
# params_single = torch.tensor([0.2,0.80000023])
|
| 279 |
# params = torch.tile(params_single,(n_sample,1)).to(device)
|
| 280 |
# params = params
|
| 281 |
# data_dir = './data' # data directory
|
| 282 |
|
|
|
|
| 283 |
mixed_precision = "fp16"
|
| 284 |
gradient_accumulation_steps = 1
|
| 285 |
|
|
|
|
| 321 |
# initialize the unet
|
| 322 |
self.nn_model = ContextUnet(n_param=config.n_param, image_size=config.HII_DIM, dim=config.dim, stride=config.stride)
|
| 323 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 324 |
# nn_model = ContextUnet(n_param=1, image_size=28)
|
| 325 |
self.nn_model.train()
|
| 326 |
# print("self.ddpm.device =", self.ddpm.device)
|
|
|
|
| 328 |
self.nn_model = DDP(self.nn_model, device_ids=[self.ddpm.device])
|
| 329 |
# print("nn_model.device =", ddpm.device)
|
| 330 |
# number of parameters to be trained
|
| 331 |
+
|
| 332 |
+
if config.resume and os.path.exists(config.resume):
|
| 333 |
+
# resume_file = os.path.join(config.output_dir, f"{config.resume}")
|
| 334 |
+
# self.nn_model.load_state_dict(torch.load(config.resume)['unet_state_dict'])
|
| 335 |
+
# print(f"resumed nn_model from {config.resume}")
|
| 336 |
+
self.nn_model.module.load_state_dict(torch.load(config.resume)['unet_state_dict'])
|
| 337 |
+
print(f"device {torch.cuda.current_device()} resumed nn_model from {config.resume}")
|
| 338 |
+
|
| 339 |
self.number_of_params = sum(x.numel() for x in self.nn_model.parameters())
|
| 340 |
print(f"Number of parameters for nn_model: {self.number_of_params}")
|
| 341 |
|
|
|
|
| 511 |
|
| 512 |
def sample(self, params:torch.tensor=None, num_new_img=192, ema=False, entire=False, save=False):
|
| 513 |
# n_sample = params.shape[0]
|
| 514 |
+
# file = self.config.resume
|
| 515 |
|
| 516 |
if params is None:
|
| 517 |
params = torch.tensor([0.20000000000000018, 0.5055875000000001])
|
|
|
|
| 531 |
# params = torch.tile(params, (n_sample,1)).to(device)
|
| 532 |
|
| 533 |
# nn_model = ContextUnet(n_param=self.config.n_param, image_size=self.config.HII_DIM, dim=self.config.dim, stride=self.config.stride).to(self.config.device)
|
| 534 |
+
# if ema:
|
| 535 |
+
# self.nn_model.module.load_state_dict(torch.load(file)['ema_unet_state_dict'])
|
| 536 |
+
# else:
|
| 537 |
+
# self.nn_model.module.load_state_dict(torch.load(file)['unet_state_dict'])
|
| 538 |
+
# print(f"device {torch.cuda.current_device()} resumed nn_model from {file}")
|
| 539 |
# nn_model = ContextUnet(n_param=1, image_size=28)
|
| 540 |
# nn_model.train()
|
| 541 |
# self.nn_model.to(self.ddpm.device)
|