0721-2348
Browse files- diffusion.py +4 -2
- quantify_results.ipynb +0 -0
diffusion.py
CHANGED
|
@@ -341,6 +341,8 @@ class DDPM21CM:
|
|
| 341 |
# print(f"resumed nn_model from {config.resume}")
|
| 342 |
self.nn_model.module.load_state_dict(torch.load(config.resume)['unet_state_dict'])
|
| 343 |
print(f"device {torch.cuda.current_device()} resumed nn_model from {config.resume}")
|
|
|
|
|
|
|
| 344 |
|
| 345 |
self.number_of_params = sum(x.numel() for x in self.nn_model.parameters())
|
| 346 |
print(f" Number of parameters for nn_model: {self.number_of_params} ".center(100,'-'))
|
|
@@ -614,7 +616,7 @@ def train(rank, world_size):
|
|
| 614 |
|
| 615 |
|
| 616 |
if __name__ == "__main__":
|
| 617 |
-
world_size = torch.cuda.device_count()
|
| 618 |
print(f" training, world_size = {world_size} ".center(100,'-'))
|
| 619 |
# torch.multiprocessing.set_start_method("spawn")
|
| 620 |
# args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)
|
|
@@ -691,7 +693,7 @@ if __name__ == "__main__":
|
|
| 691 |
|
| 692 |
for num_image in num_train_image_list:
|
| 693 |
config.num_image = num_image // world_size
|
| 694 |
-
config.resume = f"./outputs/model_state-
|
| 695 |
|
| 696 |
# print("ddpm21cm = DDPM21CM(config)")
|
| 697 |
manager = mp.Manager()
|
|
|
|
| 341 |
# print(f"resumed nn_model from {config.resume}")
|
| 342 |
self.nn_model.module.load_state_dict(torch.load(config.resume)['unet_state_dict'])
|
| 343 |
print(f"device {torch.cuda.current_device()} resumed nn_model from {config.resume}")
|
| 344 |
+
else:
|
| 345 |
+
print(f"device {torch.cuda.current_device()} initialized nn_model randomly")
|
| 346 |
|
| 347 |
self.number_of_params = sum(x.numel() for x in self.nn_model.parameters())
|
| 348 |
print(f" Number of parameters for nn_model: {self.number_of_params} ".center(100,'-'))
|
|
|
|
| 616 |
|
| 617 |
|
| 618 |
if __name__ == "__main__":
|
| 619 |
+
world_size = 2#torch.cuda.device_count()
|
| 620 |
print(f" training, world_size = {world_size} ".center(100,'-'))
|
| 621 |
# torch.multiprocessing.set_start_method("spawn")
|
| 622 |
# args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)
|
|
|
|
| 693 |
|
| 694 |
for num_image in num_train_image_list:
|
| 695 |
config.num_image = num_image // world_size
|
| 696 |
+
config.resume = f"./outputs/model_state-N4000-device_count2-epoch{config.n_epoch-1}"
|
| 697 |
|
| 698 |
# print("ddpm21cm = DDPM21CM(config)")
|
| 699 |
manager = mp.Manager()
|
quantify_results.ipynb
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|