Spaces:
Runtime error
Runtime error
feat(data): support accumulation in non-streaming
Browse files- src/dalle_mini/data.py +10 -2
src/dalle_mini/data.py
CHANGED
|
@@ -161,13 +161,16 @@ class Dataset:
|
|
| 161 |
def _dataloader_datasets_non_streaming(
|
| 162 |
dataset: Dataset,
|
| 163 |
per_device_batch_size: int,
|
|
|
|
| 164 |
rng: jax.random.PRNGKey = None,
|
| 165 |
):
|
| 166 |
"""
|
| 167 |
Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
|
| 168 |
Shuffle batches if rng is set.
|
| 169 |
"""
|
| 170 |
-
batch_size =
|
|
|
|
|
|
|
| 171 |
steps_per_epoch = len(dataset) // batch_size
|
| 172 |
|
| 173 |
if rng is not None:
|
|
@@ -183,6 +186,11 @@ class Dataset:
|
|
| 183 |
for idx in batch_idx:
|
| 184 |
batch = dataset[idx]
|
| 185 |
batch = {k: jnp.array(v) for k, v in batch.items()}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
batch = shard(batch)
|
| 187 |
yield batch
|
| 188 |
|
|
@@ -244,7 +252,7 @@ class Dataset:
|
|
| 244 |
if split == "train":
|
| 245 |
self.rng_dataset, input_rng = jax.random.split(self.rng_dataset)
|
| 246 |
return _dataloader_datasets_non_streaming(
|
| 247 |
-
ds, per_device_batch_size, input_rng
|
| 248 |
)
|
| 249 |
|
| 250 |
@property
|
|
|
|
| 161 |
def _dataloader_datasets_non_streaming(
|
| 162 |
dataset: Dataset,
|
| 163 |
per_device_batch_size: int,
|
| 164 |
+
gradient_accumulation_steps: int,
|
| 165 |
rng: jax.random.PRNGKey = None,
|
| 166 |
):
|
| 167 |
"""
|
| 168 |
Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
|
| 169 |
Shuffle batches if rng is set.
|
| 170 |
"""
|
| 171 |
+
batch_size = (
|
| 172 |
+
per_device_batch_size * num_devices * gradient_accumulation_steps
|
| 173 |
+
)
|
| 174 |
steps_per_epoch = len(dataset) // batch_size
|
| 175 |
|
| 176 |
if rng is not None:
|
|
|
|
| 186 |
for idx in batch_idx:
|
| 187 |
batch = dataset[idx]
|
| 188 |
batch = {k: jnp.array(v) for k, v in batch.items()}
|
| 189 |
+
if gradient_accumulation_steps is not None:
|
| 190 |
+
batch = jax.tree_map(
|
| 191 |
+
lambda x: x.reshape((-1, per_device_batch_size) + x.shape[1:]),
|
| 192 |
+
batch,
|
| 193 |
+
)
|
| 194 |
batch = shard(batch)
|
| 195 |
yield batch
|
| 196 |
|
|
|
|
| 252 |
if split == "train":
|
| 253 |
self.rng_dataset, input_rng = jax.random.split(self.rng_dataset)
|
| 254 |
return _dataloader_datasets_non_streaming(
|
| 255 |
+
ds, per_device_batch_size, gradient_accumulation_steps, input_rng
|
| 256 |
)
|
| 257 |
|
| 258 |
@property
|