0710-1537
Browse files- diffusion.py +3 -2
diffusion.py
CHANGED
|
@@ -362,12 +362,13 @@ class DDPM21CM:
|
|
| 362 |
# self.shape_loaded = dataset.images.shape
|
| 363 |
# print("shape_loaded =", self.shape_loaded)
|
| 364 |
self.dataloader = DataLoader(
|
| 365 |
-
dataset,
|
| 366 |
batch_size=self.config.batch_size,
|
| 367 |
-
shuffle=
|
| 368 |
num_workers=1,#len(os.sched_getaffinity(0)),
|
| 369 |
pin_memory=True,
|
| 370 |
persistent_workers=True,
|
|
|
|
| 371 |
)
|
| 372 |
# del dataset
|
| 373 |
# self.accelerate(self.config)
|
|
|
|
| 362 |
# self.shape_loaded = dataset.images.shape
|
| 363 |
# print("shape_loaded =", self.shape_loaded)
|
| 364 |
self.dataloader = DataLoader(
|
| 365 |
+
dataset=dataset,
|
| 366 |
batch_size=self.config.batch_size,
|
| 367 |
+
shuffle=False,
|
| 368 |
num_workers=1,#len(os.sched_getaffinity(0)),
|
| 369 |
pin_memory=True,
|
| 370 |
persistent_workers=True,
|
| 371 |
+
sampler=DistributedSampler(dataset),
|
| 372 |
)
|
| 373 |
# del dataset
|
| 374 |
# self.accelerate(self.config)
|