0710-2338
Browse files- diffusion.py +4 -2
diffusion.py
CHANGED
|
@@ -230,7 +230,7 @@ class TrainConfig:
|
|
| 230 |
hub_private_repo = False
|
| 231 |
dataset_name = "/storage/home/hcoda1/3/bxia34/scratch/LEN128-DIM64-CUB8.h5"
|
| 232 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 233 |
-
|
| 234 |
# repeat = 2
|
| 235 |
|
| 236 |
# dim = 2
|
|
@@ -351,7 +351,7 @@ class DDPM21CM:
|
|
| 351 |
self.lr_scheduler = get_cosine_schedule_with_warmup(
|
| 352 |
optimizer=self.optimizer,
|
| 353 |
num_warmup_steps=config.lr_warmup_steps,
|
| 354 |
-
num_training_steps=
|
| 355 |
# num_training_steps=(len(self.dataloader) * config.n_epoch),
|
| 356 |
)
|
| 357 |
|
|
@@ -558,6 +558,8 @@ class DDPM21CM:
|
|
| 558 |
# %%
|
| 559 |
def main(rank, world_size):
|
| 560 |
config = TrainConfig()
|
|
|
|
|
|
|
| 561 |
ddp_setup(rank, world_size)
|
| 562 |
|
| 563 |
num_image_list = [5000]#[200]#[1600,3200,6400,12800,25600]
|
|
|
|
| 230 |
hub_private_repo = False
|
| 231 |
dataset_name = "/storage/home/hcoda1/3/bxia34/scratch/LEN128-DIM64-CUB8.h5"
|
| 232 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 233 |
+
world_size = torch.cuda.device_count()
|
| 234 |
# repeat = 2
|
| 235 |
|
| 236 |
# dim = 2
|
|
|
|
| 351 |
self.lr_scheduler = get_cosine_schedule_with_warmup(
|
| 352 |
optimizer=self.optimizer,
|
| 353 |
num_warmup_steps=config.lr_warmup_steps,
|
| 354 |
+
num_training_steps=int(config.num_image / config.world_size / config.batch_size * config.n_epoch),
|
| 355 |
# num_training_steps=(len(self.dataloader) * config.n_epoch),
|
| 356 |
)
|
| 357 |
|
|
|
|
| 558 |
# %%
|
| 559 |
def main(rank, world_size):
|
| 560 |
config = TrainConfig()
|
| 561 |
+
config.world_size = world_size
|
| 562 |
+
|
| 563 |
ddp_setup(rank, world_size)
|
| 564 |
|
| 565 |
num_image_list = [5000]#[200]#[1600,3200,6400,12800,25600]
|