0718-1052
Browse files- context_unet.py +1 -1
- diffusion.py +2 -2
context_unet.py
CHANGED
|
@@ -330,7 +330,7 @@ class ContextUnet(nn.Module):
|
|
| 330 |
elif image_size == 128:
|
| 331 |
channel_mult = (1, 1, 2, 3, 4)
|
| 332 |
elif image_size == 64:
|
| 333 |
-
channel_mult = (1, 2, 4,
|
| 334 |
elif image_size == 32:
|
| 335 |
channel_mult = (1, 2, 2, 4)
|
| 336 |
elif image_size == 28:
|
|
|
|
| 330 |
elif image_size == 128:
|
| 331 |
channel_mult = (1, 1, 2, 3, 4)
|
| 332 |
elif image_size == 64:
|
| 333 |
+
channel_mult = (1, 2, 4, 6, 8)#(1, 2, 2, 4)#(1, 2, 8, 8, 8)#(1, 2, 4)#(1, 2, 2, 4)#(0.5,1,2,2,4,4)#(1, 1, 2, 2, 4, 4)#
|
| 334 |
elif image_size == 32:
|
| 335 |
channel_mult = (1, 2, 2, 4)
|
| 336 |
elif image_size == 28:
|
diffusion.py
CHANGED
|
@@ -239,7 +239,7 @@ class TrainConfig:
|
|
| 239 |
stride = (2,2) if dim == 2 else (2,2,4)
|
| 240 |
num_image = 1000#2000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560
|
| 241 |
batch_size = 1#2#50#20#2#100 # 10
|
| 242 |
-
n_epoch =
|
| 243 |
HII_DIM = 64
|
| 244 |
num_redshift = 512#128#64#512#256#256#64#512#128
|
| 245 |
channel = 1
|
|
@@ -586,7 +586,7 @@ def train(rank, world_size):
|
|
| 586 |
|
| 587 |
ddp_setup(rank, world_size)
|
| 588 |
|
| 589 |
-
num_train_image_list = [
|
| 590 |
for i, num_image in enumerate(num_train_image_list):
|
| 591 |
config.num_image = num_image
|
| 592 |
# config.world_size = world_size
|
|
|
|
| 239 |
stride = (2,2) if dim == 2 else (2,2,4)
|
| 240 |
num_image = 1000#2000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560
|
| 241 |
batch_size = 1#2#50#20#2#100 # 10
|
| 242 |
+
n_epoch = 2#4# 10#50#20#20#2#5#25 # 120
|
| 243 |
HII_DIM = 64
|
| 244 |
num_redshift = 512#128#64#512#256#256#64#512#128
|
| 245 |
channel = 1
|
|
|
|
| 586 |
|
| 587 |
ddp_setup(rank, world_size)
|
| 588 |
|
| 589 |
+
num_train_image_list = [3200]#[3200]#[200]#[1600,3200,6400,12800,25600]
|
| 590 |
for i, num_image in enumerate(num_train_image_list):
|
| 591 |
config.num_image = num_image
|
| 592 |
# config.world_size = world_size
|