0719-1205
Browse files- context_unet.py +1 -1
- diffusion.py +5 -3
- quantify_results.ipynb +0 -0
context_unet.py
CHANGED
|
@@ -330,7 +330,7 @@ class ContextUnet(nn.Module):
|
|
| 330 |
elif image_size == 128:
|
| 331 |
channel_mult = (1, 1, 2, 3, 4)
|
| 332 |
elif image_size == 64:
|
| 333 |
-
channel_mult = (1, 2, 4, 6, 8)#(1, 2, 2, 4)#(1, 2, 8, 8, 8)#(1, 2, 4)#(1, 2, 2, 4)#(0.5,1,2,2,4,4)#(1, 1, 2, 2, 4, 4)#
|
| 334 |
elif image_size == 32:
|
| 335 |
channel_mult = (1, 2, 2, 4)
|
| 336 |
elif image_size == 28:
|
|
|
|
| 330 |
elif image_size == 128:
|
| 331 |
channel_mult = (1, 1, 2, 3, 4)
|
| 332 |
elif image_size == 64:
|
| 333 |
+
channel_mult = (1, 2, 3, 4)#(1, 2, 4, 6, 8)#(1, 2, 2, 4)#(1, 2, 8, 8, 8)#(1, 2, 4)#(1, 2, 2, 4)#(0.5,1,2,2,4,4)#(1, 1, 2, 2, 4, 4)#
|
| 334 |
elif image_size == 32:
|
| 335 |
channel_mult = (1, 2, 2, 4)
|
| 336 |
elif image_size == 28:
|
diffusion.py
CHANGED
|
@@ -23,6 +23,8 @@
|
|
| 23 |
# - it takes 62 mins to generated 8 images with shape of (64,64,64), which is even slower than simulation, which takes ~5 mins for each image. Besides, the batch_size during training and num of images to be generated are limited to be 2 and 8, respectively.
|
| 24 |
# - the slowerness can be solved by using multi-GPUs, and the limited-num-of-images can be solved by multi-accuracy, multi-GPUs.
|
| 25 |
# - In addtion, the performance of DDPM can looks better compared to computation-intensive simulations.
|
|
|
|
|
|
|
| 26 |
|
| 27 |
# %%
|
| 28 |
from dataclasses import dataclass
|
|
@@ -581,7 +583,7 @@ class DDPM21CM:
|
|
| 581 |
return x_last
|
| 582 |
# %%
|
| 583 |
|
| 584 |
-
num_train_image_list = [
|
| 585 |
|
| 586 |
def train(rank, world_size):
|
| 587 |
config = TrainConfig()
|
|
@@ -602,7 +604,7 @@ def train(rank, world_size):
|
|
| 602 |
|
| 603 |
|
| 604 |
if __name__ == "__main__":
|
| 605 |
-
world_size = torch.cuda.device_count()
|
| 606 |
print(f" training, world_size = {world_size} ".center(100,'-'))
|
| 607 |
# torch.multiprocessing.set_start_method("spawn")
|
| 608 |
# args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)
|
|
@@ -667,7 +669,7 @@ if __name__ == "__main__":
|
|
| 667 |
print(f" sampling, world_size = {world_size} ".center(100,'-'))
|
| 668 |
# num_train_image_list = [1600,3200,6400,12800,25600]
|
| 669 |
# num_train_image_list = [5000]
|
| 670 |
-
num_new_img_per_gpu =
|
| 671 |
max_num_img_per_gpu = 20
|
| 672 |
|
| 673 |
params = torch.tensor([4.4, 131.341])
|
|
|
|
| 23 |
# - it takes 62 mins to generated 8 images with shape of (64,64,64), which is even slower than simulation, which takes ~5 mins for each image. Besides, the batch_size during training and num of images to be generated are limited to be 2 and 8, respectively.
|
| 24 |
# - the slowerness can be solved by using multi-GPUs, and the limited-num-of-images can be solved by multi-accuracy, multi-GPUs.
|
| 25 |
# - In addtion, the performance of DDPM can looks better compared to computation-intensive simulations.
|
| 26 |
+
# 1 GPU, batch_size = 10, num_image = 3200, 50s for each epoch
|
| 27 |
+
# 4 GPU, batch_size = 10, num_image = 3200,
|
| 28 |
|
| 29 |
# %%
|
| 30 |
from dataclasses import dataclass
|
|
|
|
| 583 |
return x_last
|
| 584 |
# %%
|
| 585 |
|
| 586 |
+
num_train_image_list = [3200]
|
| 587 |
|
| 588 |
def train(rank, world_size):
|
| 589 |
config = TrainConfig()
|
|
|
|
| 604 |
|
| 605 |
|
| 606 |
if __name__ == "__main__":
|
| 607 |
+
world_size = 1#torch.cuda.device_count()
|
| 608 |
print(f" training, world_size = {world_size} ".center(100,'-'))
|
| 609 |
# torch.multiprocessing.set_start_method("spawn")
|
| 610 |
# args = (config, nn_model, ddpm, optimizer, dataloader, lr_scheduler)
|
|
|
|
| 669 |
print(f" sampling, world_size = {world_size} ".center(100,'-'))
|
| 670 |
# num_train_image_list = [1600,3200,6400,12800,25600]
|
| 671 |
# num_train_image_list = [5000]
|
| 672 |
+
num_new_img_per_gpu = 400
|
| 673 |
max_num_img_per_gpu = 20
|
| 674 |
|
| 675 |
params = torch.tensor([4.4, 131.341])
|
quantify_results.ipynb
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|