05205908
Browse files- diffusion.py +3 -0
diffusion.py
CHANGED
|
@@ -467,6 +467,9 @@ class DDPM21CM:
|
|
| 467 |
persistent_workers=True,
|
| 468 |
# sampler=DistributedSampler(dataset),
|
| 469 |
)
|
|
|
|
|
|
|
|
|
|
| 470 |
dataloader_end = time()
|
| 471 |
print(f"cuda:{torch.cuda.current_device()}/{self.config.global_rank} dataloader costs {dataloader_end-dataloader_start:.3f}s")
|
| 472 |
|
|
|
|
| 467 |
persistent_workers=True,
|
| 468 |
# sampler=DistributedSampler(dataset),
|
| 469 |
)
|
| 470 |
+
if len(self.dataloader) % self.config.gradient_accumulation_steps != 0:
|
| 471 |
+
raise ValueError(f"len(self.dataloader) % self.config.gradient_accumulation_steps = {len(self.dataloader) % self.config.gradient_accumulation_steps} instead of 0")
|
| 472 |
+
|
| 473 |
dataloader_end = time()
|
| 474 |
print(f"cuda:{torch.cuda.current_device()}/{self.config.global_rank} dataloader costs {dataloader_end-dataloader_start:.3f}s")
|
| 475 |
|