0716-2205
Browse files- diffusion.py +7 -7
diffusion.py
CHANGED
|
@@ -236,12 +236,12 @@ class TrainConfig:
|
|
| 236 |
|
| 237 |
# dim = 2
|
| 238 |
dim = 3
|
| 239 |
-
stride = (2,2) if dim == 2 else (2,2,
|
| 240 |
num_image = 1000#32000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560
|
| 241 |
batch_size = 1#2#50#20#2#100 # 10
|
| 242 |
n_epoch = 8#4# 10#50#20#20#2#5#25 # 120
|
| 243 |
HII_DIM = 64
|
| 244 |
-
num_redshift =
|
| 245 |
channel = 1
|
| 246 |
img_shape = (channel, HII_DIM, num_redshift) if dim == 2 else (channel, HII_DIM, HII_DIM, num_redshift)
|
| 247 |
|
|
@@ -586,7 +586,7 @@ def train(rank, world_size):
|
|
| 586 |
|
| 587 |
ddp_setup(rank, world_size)
|
| 588 |
|
| 589 |
-
num_train_image_list = [
|
| 590 |
for i, num_image in enumerate(num_train_image_list):
|
| 591 |
config.num_image = num_image
|
| 592 |
# config.world_size = world_size
|
|
@@ -677,9 +677,9 @@ if __name__ == "__main__":
|
|
| 677 |
world_size = torch.cuda.device_count()
|
| 678 |
print(f" sampling, world_size = {world_size} ".center(100,'-'))
|
| 679 |
# num_train_image_list = [1600,3200,6400,12800,25600]
|
| 680 |
-
num_train_image_list = [
|
| 681 |
-
num_new_img_per_gpu =
|
| 682 |
-
max_num_img_per_gpu =
|
| 683 |
|
| 684 |
params = torch.tensor([4.4, 131.341])
|
| 685 |
|
|
@@ -690,7 +690,7 @@ if __name__ == "__main__":
|
|
| 690 |
|
| 691 |
for num_image in num_train_image_list:
|
| 692 |
config.num_image = num_image
|
| 693 |
-
config.resume = f"./outputs/model_state-N{num_image}-
|
| 694 |
|
| 695 |
# print("ddpm21cm = DDPM21CM(config)")
|
| 696 |
manager = mp.Manager()
|
|
|
|
| 236 |
|
| 237 |
# dim = 2
|
| 238 |
dim = 3
|
| 239 |
+
stride = (2,2) if dim == 2 else (2,2,4)
|
| 240 |
num_image = 1000#32000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560
|
| 241 |
batch_size = 1#2#50#20#2#100 # 10
|
| 242 |
n_epoch = 8#4# 10#50#20#20#2#5#25 # 120
|
| 243 |
HII_DIM = 64
|
| 244 |
+
num_redshift = 512#128#64#512#256#256#64#512#128
|
| 245 |
channel = 1
|
| 246 |
img_shape = (channel, HII_DIM, num_redshift) if dim == 2 else (channel, HII_DIM, HII_DIM, num_redshift)
|
| 247 |
|
|
|
|
| 586 |
|
| 587 |
ddp_setup(rank, world_size)
|
| 588 |
|
| 589 |
+
num_train_image_list = [3200]#[200]#[1600,3200,6400,12800,25600]
|
| 590 |
for i, num_image in enumerate(num_train_image_list):
|
| 591 |
config.num_image = num_image
|
| 592 |
# config.world_size = world_size
|
|
|
|
| 677 |
world_size = torch.cuda.device_count()
|
| 678 |
print(f" sampling, world_size = {world_size} ".center(100,'-'))
|
| 679 |
# num_train_image_list = [1600,3200,6400,12800,25600]
|
| 680 |
+
num_train_image_list = [3200]
|
| 681 |
+
num_new_img_per_gpu = 9
|
| 682 |
+
max_num_img_per_gpu = 1
|
| 683 |
|
| 684 |
params = torch.tensor([4.4, 131.341])
|
| 685 |
|
|
|
|
| 690 |
|
| 691 |
for num_image in num_train_image_list:
|
| 692 |
config.num_image = num_image
|
| 693 |
+
config.resume = f"./outputs/model_state-N{num_image}-epoch7-device0"
|
| 694 |
|
| 695 |
# print("ddpm21cm = DDPM21CM(config)")
|
| 696 |
manager = mp.Manager()
|