Spaces:
Runtime error
Runtime error
Merge pull request #122 from borisdayma/feat-acccum
Browse files- src/dalle_mini/data.py +41 -6
- tools/train/train.py +93 -46
src/dalle_mini/data.py
CHANGED
|
@@ -153,16 +153,24 @@ class Dataset:
|
|
| 153 |
),
|
| 154 |
)
|
| 155 |
|
| 156 |
-
def dataloader(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
def _dataloader_datasets_non_streaming(
|
| 158 |
dataset: Dataset,
|
| 159 |
-
|
|
|
|
| 160 |
rng: jax.random.PRNGKey = None,
|
| 161 |
):
|
| 162 |
"""
|
| 163 |
Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
|
| 164 |
Shuffle batches if rng is set.
|
| 165 |
"""
|
|
|
|
|
|
|
|
|
|
| 166 |
steps_per_epoch = len(dataset) // batch_size
|
| 167 |
|
| 168 |
if rng is not None:
|
|
@@ -178,11 +186,20 @@ class Dataset:
|
|
| 178 |
for idx in batch_idx:
|
| 179 |
batch = dataset[idx]
|
| 180 |
batch = {k: jnp.array(v) for k, v in batch.items()}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
batch = shard(batch)
|
| 182 |
yield batch
|
| 183 |
|
| 184 |
def _dataloader_datasets_streaming(
|
| 185 |
-
dataset: Dataset,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
):
|
| 187 |
keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"]
|
| 188 |
batch = {k: [] for k in keys}
|
|
@@ -199,8 +216,22 @@ class Dataset:
|
|
| 199 |
for item in dataset:
|
| 200 |
for k, v in item.items():
|
| 201 |
batch[k].append(v)
|
| 202 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
batch = {k: jnp.array(v) for k, v in batch.items()}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
batch = shard(batch)
|
| 205 |
yield batch
|
| 206 |
batch = {k: [] for k in keys}
|
|
@@ -214,11 +245,15 @@ class Dataset:
|
|
| 214 |
raise ValueError(f'split must be "train" or "eval", got {split}')
|
| 215 |
|
| 216 |
if self.streaming:
|
| 217 |
-
return _dataloader_datasets_streaming(
|
|
|
|
|
|
|
| 218 |
else:
|
| 219 |
if split == "train":
|
| 220 |
self.rng_dataset, input_rng = jax.random.split(self.rng_dataset)
|
| 221 |
-
return _dataloader_datasets_non_streaming(
|
|
|
|
|
|
|
| 222 |
|
| 223 |
@property
|
| 224 |
def length(self):
|
|
|
|
| 153 |
),
|
| 154 |
)
|
| 155 |
|
| 156 |
+
def dataloader(
|
| 157 |
+
self, split, per_device_batch_size, gradient_accumulation_steps=None, epoch=None
|
| 158 |
+
):
|
| 159 |
+
num_devices = jax.local_device_count()
|
| 160 |
+
|
| 161 |
def _dataloader_datasets_non_streaming(
|
| 162 |
dataset: Dataset,
|
| 163 |
+
per_device_batch_size: int,
|
| 164 |
+
gradient_accumulation_steps: int,
|
| 165 |
rng: jax.random.PRNGKey = None,
|
| 166 |
):
|
| 167 |
"""
|
| 168 |
Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
|
| 169 |
Shuffle batches if rng is set.
|
| 170 |
"""
|
| 171 |
+
batch_size = (
|
| 172 |
+
per_device_batch_size * num_devices * gradient_accumulation_steps
|
| 173 |
+
)
|
| 174 |
steps_per_epoch = len(dataset) // batch_size
|
| 175 |
|
| 176 |
if rng is not None:
|
|
|
|
| 186 |
for idx in batch_idx:
|
| 187 |
batch = dataset[idx]
|
| 188 |
batch = {k: jnp.array(v) for k, v in batch.items()}
|
| 189 |
+
if gradient_accumulation_steps is not None:
|
| 190 |
+
batch = jax.tree_map(
|
| 191 |
+
lambda x: x.reshape((-1, per_device_batch_size) + x.shape[1:]),
|
| 192 |
+
batch,
|
| 193 |
+
)
|
| 194 |
batch = shard(batch)
|
| 195 |
yield batch
|
| 196 |
|
| 197 |
def _dataloader_datasets_streaming(
|
| 198 |
+
dataset: Dataset,
|
| 199 |
+
split: str,
|
| 200 |
+
per_device_batch_size: int,
|
| 201 |
+
gradient_accumulation_steps: int,
|
| 202 |
+
epoch: int,
|
| 203 |
):
|
| 204 |
keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"]
|
| 205 |
batch = {k: [] for k in keys}
|
|
|
|
| 216 |
for item in dataset:
|
| 217 |
for k, v in item.items():
|
| 218 |
batch[k].append(v)
|
| 219 |
+
# batch = 5, devices = 8, accumulation = 2 / batch_size = 5 x 8
|
| 220 |
+
# (40, 3, 3) -> shard 8 x (5, 3, 3)
|
| 221 |
+
# (16, 5, 3, 3) -> shard 8 x (2, 5, 3, 3)
|
| 222 |
+
if len(batch[keys[0]]) == per_device_batch_size * num_devices * (
|
| 223 |
+
gradient_accumulation_steps
|
| 224 |
+
if gradient_accumulation_steps is not None
|
| 225 |
+
else 1
|
| 226 |
+
):
|
| 227 |
batch = {k: jnp.array(v) for k, v in batch.items()}
|
| 228 |
+
if gradient_accumulation_steps is not None:
|
| 229 |
+
batch = jax.tree_map(
|
| 230 |
+
lambda x: x.reshape(
|
| 231 |
+
(-1, per_device_batch_size) + x.shape[1:]
|
| 232 |
+
),
|
| 233 |
+
batch,
|
| 234 |
+
)
|
| 235 |
batch = shard(batch)
|
| 236 |
yield batch
|
| 237 |
batch = {k: [] for k in keys}
|
|
|
|
| 245 |
raise ValueError(f'split must be "train" or "eval", got {split}')
|
| 246 |
|
| 247 |
if self.streaming:
|
| 248 |
+
return _dataloader_datasets_streaming(
|
| 249 |
+
ds, split, per_device_batch_size, gradient_accumulation_steps, epoch
|
| 250 |
+
)
|
| 251 |
else:
|
| 252 |
if split == "train":
|
| 253 |
self.rng_dataset, input_rng = jax.random.split(self.rng_dataset)
|
| 254 |
+
return _dataloader_datasets_non_streaming(
|
| 255 |
+
ds, per_device_batch_size, gradient_accumulation_steps, input_rng
|
| 256 |
+
)
|
| 257 |
|
| 258 |
@property
|
| 259 |
def length(self):
|
tools/train/train.py
CHANGED
|
@@ -277,8 +277,8 @@ class TrainingArguments:
|
|
| 277 |
},
|
| 278 |
)
|
| 279 |
|
| 280 |
-
num_train_epochs:
|
| 281 |
-
default=3
|
| 282 |
)
|
| 283 |
warmup_steps: int = field(
|
| 284 |
default=0, metadata={"help": "Linear warmup over warmup_steps."}
|
|
@@ -310,12 +310,40 @@ class TrainingArguments:
|
|
| 310 |
metadata={"help": "Reference to a wandb artifact for resuming training."},
|
| 311 |
)
|
| 312 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 313 |
def __post_init__(self):
|
| 314 |
assert self.optim in [
|
| 315 |
"distributed_shampoo",
|
| 316 |
"adam",
|
| 317 |
"adafactor",
|
| 318 |
], f"Selected optimizer not supported: {self.optim}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 319 |
|
| 320 |
|
| 321 |
class TrainState(train_state.TrainState):
|
|
@@ -396,17 +424,6 @@ def main():
|
|
| 396 |
else:
|
| 397 |
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
| 398 |
|
| 399 |
-
if (
|
| 400 |
-
os.path.exists(training_args.output_dir)
|
| 401 |
-
and os.listdir(training_args.output_dir)
|
| 402 |
-
and training_args.do_train
|
| 403 |
-
and not training_args.overwrite_output_dir
|
| 404 |
-
):
|
| 405 |
-
raise ValueError(
|
| 406 |
-
f"Output directory ({training_args.output_dir}) already exists and is not empty."
|
| 407 |
-
"Use --overwrite_output_dir to overcome."
|
| 408 |
-
)
|
| 409 |
-
|
| 410 |
# Make one log on every process with the configuration for debugging.
|
| 411 |
logging.basicConfig(
|
| 412 |
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
|
@@ -433,14 +450,18 @@ def main():
|
|
| 433 |
)
|
| 434 |
|
| 435 |
logger.info(f"Local TPUs: {jax.local_device_count()}")
|
| 436 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 437 |
|
| 438 |
# Set up wandb run
|
| 439 |
if jax.process_index() == 0:
|
| 440 |
wandb.init(
|
| 441 |
-
entity=
|
| 442 |
-
project=
|
| 443 |
-
job_type=
|
| 444 |
config=parser.parse_args(),
|
| 445 |
)
|
| 446 |
|
|
@@ -515,22 +536,19 @@ def main():
|
|
| 515 |
rng, dropout_rng = jax.random.split(rng)
|
| 516 |
|
| 517 |
# Store some constant
|
| 518 |
-
num_epochs =
|
| 519 |
# batch size per node
|
| 520 |
train_batch_size = (
|
| 521 |
-
|
| 522 |
-
)
|
| 523 |
-
batch_size_per_update = (
|
| 524 |
-
train_batch_size
|
| 525 |
-
* training_args.gradient_accumulation_steps
|
| 526 |
-
* jax.process_count()
|
| 527 |
)
|
|
|
|
|
|
|
| 528 |
eval_batch_size = (
|
| 529 |
-
|
| 530 |
)
|
| 531 |
len_train_dataset, len_eval_dataset = dataset.length
|
| 532 |
steps_per_epoch = (
|
| 533 |
-
len_train_dataset //
|
| 534 |
if len_train_dataset is not None
|
| 535 |
else None
|
| 536 |
)
|
|
@@ -645,12 +663,6 @@ def main():
|
|
| 645 |
clipping_threshold=training_args.max_grad_norm,
|
| 646 |
)
|
| 647 |
|
| 648 |
-
# add gradient accumulation
|
| 649 |
-
if training_args.gradient_accumulation_steps > 1:
|
| 650 |
-
optimizer = optax.chain(
|
| 651 |
-
optax.apply_every(training_args.gradient_accumulation_steps), optimizer
|
| 652 |
-
)
|
| 653 |
-
|
| 654 |
# Setup train state
|
| 655 |
state = TrainState.create(
|
| 656 |
apply_fn=model.__call__,
|
|
@@ -673,22 +685,48 @@ def main():
|
|
| 673 |
def train_step(state, batch, delta_time):
|
| 674 |
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
|
| 675 |
|
| 676 |
-
def compute_loss(params,
|
| 677 |
-
labels =
|
| 678 |
logits = state.apply_fn(
|
| 679 |
-
**
|
| 680 |
)[0]
|
| 681 |
-
|
| 682 |
-
return loss
|
| 683 |
|
| 684 |
grad_fn = jax.value_and_grad(compute_loss)
|
| 685 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 686 |
grads = jax.lax.pmean(grads, "batch")
|
| 687 |
state = state.apply_gradients(
|
| 688 |
grads=grads,
|
| 689 |
dropout_rng=new_dropout_rng,
|
| 690 |
train_time=state.train_time + delta_time,
|
| 691 |
-
train_samples=state.train_samples +
|
| 692 |
)
|
| 693 |
|
| 694 |
metrics = {
|
|
@@ -711,19 +749,20 @@ def main():
|
|
| 711 |
return metrics
|
| 712 |
|
| 713 |
# Create parallel version of the train and eval step
|
| 714 |
-
p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
|
| 715 |
-
p_eval_step = jax.pmap(eval_step, "batch")
|
| 716 |
|
| 717 |
logger.info("***** Running training *****")
|
| 718 |
logger.info(f" Num examples = {len_train_dataset}")
|
| 719 |
logger.info(f" Num Epochs = {num_epochs}")
|
| 720 |
logger.info(
|
| 721 |
-
f"
|
| 722 |
)
|
| 723 |
logger.info(f" Number of devices = {jax.device_count()}")
|
| 724 |
logger.info(
|
| 725 |
-
f"
|
| 726 |
)
|
|
|
|
| 727 |
logger.info(f" Model parameters = {num_params:,}")
|
| 728 |
epochs = tqdm(
|
| 729 |
range(state.epoch, num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0
|
|
@@ -740,8 +779,9 @@ def main():
|
|
| 740 |
{
|
| 741 |
"len_train_dataset": len_train_dataset,
|
| 742 |
"len_eval_dataset": len_eval_dataset,
|
| 743 |
-
"
|
| 744 |
"num_params": num_params,
|
|
|
|
| 745 |
}
|
| 746 |
)
|
| 747 |
|
|
@@ -752,7 +792,9 @@ def main():
|
|
| 752 |
# ======================== Evaluating ==============================
|
| 753 |
eval_metrics = []
|
| 754 |
if training_args.do_eval:
|
| 755 |
-
eval_loader = dataset.dataloader(
|
|
|
|
|
|
|
| 756 |
eval_steps = (
|
| 757 |
len_eval_dataset // eval_batch_size
|
| 758 |
if len_eval_dataset is not None
|
|
@@ -869,7 +911,12 @@ def main():
|
|
| 869 |
metrics_logger.log({"train/epoch": epoch}, step=unreplicate(state.step))
|
| 870 |
|
| 871 |
# Generate an epoch by shuffling sampling indices from the train dataset
|
| 872 |
-
train_loader = dataset.dataloader(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 873 |
# train
|
| 874 |
for batch in tqdm(
|
| 875 |
train_loader,
|
|
|
|
| 277 |
},
|
| 278 |
)
|
| 279 |
|
| 280 |
+
num_train_epochs: int = field(
|
| 281 |
+
default=3, metadata={"help": "Total number of training epochs to perform."}
|
| 282 |
)
|
| 283 |
warmup_steps: int = field(
|
| 284 |
default=0, metadata={"help": "Linear warmup over warmup_steps."}
|
|
|
|
| 310 |
metadata={"help": "Reference to a wandb artifact for resuming training."},
|
| 311 |
)
|
| 312 |
|
| 313 |
+
wandb_entity: Optional[str] = field(
|
| 314 |
+
default=None,
|
| 315 |
+
metadata={"help": "The wandb entity to use (for teams)."},
|
| 316 |
+
)
|
| 317 |
+
wandb_project: str = field(
|
| 318 |
+
default="dalle-mini",
|
| 319 |
+
metadata={"help": "The name of the wandb project."},
|
| 320 |
+
)
|
| 321 |
+
wandb_job_type: str = field(
|
| 322 |
+
default="Seq2Seq",
|
| 323 |
+
metadata={"help": "The name of the wandb job type."},
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
assert_TPU_available: bool = field(
|
| 327 |
+
default=False,
|
| 328 |
+
metadata={"help": "Verify that TPU is not in use."},
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
def __post_init__(self):
|
| 332 |
assert self.optim in [
|
| 333 |
"distributed_shampoo",
|
| 334 |
"adam",
|
| 335 |
"adafactor",
|
| 336 |
], f"Selected optimizer not supported: {self.optim}"
|
| 337 |
+
if (
|
| 338 |
+
os.path.exists(self.output_dir)
|
| 339 |
+
and os.listdir(self.output_dir)
|
| 340 |
+
and self.do_train
|
| 341 |
+
and not self.overwrite_output_dir
|
| 342 |
+
):
|
| 343 |
+
raise ValueError(
|
| 344 |
+
f"Output directory ({training_args.output_dir}) already exists and is not empty."
|
| 345 |
+
"Use --overwrite_output_dir to overcome."
|
| 346 |
+
)
|
| 347 |
|
| 348 |
|
| 349 |
class TrainState(train_state.TrainState):
|
|
|
|
| 424 |
else:
|
| 425 |
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
| 426 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 427 |
# Make one log on every process with the configuration for debugging.
|
| 428 |
logging.basicConfig(
|
| 429 |
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
|
|
|
| 450 |
)
|
| 451 |
|
| 452 |
logger.info(f"Local TPUs: {jax.local_device_count()}")
|
| 453 |
+
logger.info(f"Global TPUs: {jax.device_count()}")
|
| 454 |
+
if training_args.assert_TPU_available:
|
| 455 |
+
assert (
|
| 456 |
+
jax.local_device_count() == 8
|
| 457 |
+
), "TPUs in use, please check running processes"
|
| 458 |
|
| 459 |
# Set up wandb run
|
| 460 |
if jax.process_index() == 0:
|
| 461 |
wandb.init(
|
| 462 |
+
entity=training_args.wandb_entity,
|
| 463 |
+
project=training_args.wandb_project,
|
| 464 |
+
job_type=training_args.wandb_job_type,
|
| 465 |
config=parser.parse_args(),
|
| 466 |
)
|
| 467 |
|
|
|
|
| 536 |
rng, dropout_rng = jax.random.split(rng)
|
| 537 |
|
| 538 |
# Store some constant
|
| 539 |
+
num_epochs = training_args.num_train_epochs
|
| 540 |
# batch size per node
|
| 541 |
train_batch_size = (
|
| 542 |
+
training_args.per_device_train_batch_size * jax.local_device_count()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 543 |
)
|
| 544 |
+
batch_size_per_node = train_batch_size * training_args.gradient_accumulation_steps
|
| 545 |
+
batch_size_per_step = batch_size_per_node * jax.process_count()
|
| 546 |
eval_batch_size = (
|
| 547 |
+
training_args.per_device_eval_batch_size * jax.local_device_count()
|
| 548 |
)
|
| 549 |
len_train_dataset, len_eval_dataset = dataset.length
|
| 550 |
steps_per_epoch = (
|
| 551 |
+
len_train_dataset // batch_size_per_node
|
| 552 |
if len_train_dataset is not None
|
| 553 |
else None
|
| 554 |
)
|
|
|
|
| 663 |
clipping_threshold=training_args.max_grad_norm,
|
| 664 |
)
|
| 665 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 666 |
# Setup train state
|
| 667 |
state = TrainState.create(
|
| 668 |
apply_fn=model.__call__,
|
|
|
|
| 685 |
def train_step(state, batch, delta_time):
|
| 686 |
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
|
| 687 |
|
| 688 |
+
def compute_loss(params, minibatch):
|
| 689 |
+
labels = minibatch.pop("labels")
|
| 690 |
logits = state.apply_fn(
|
| 691 |
+
**minibatch, params=params, dropout_rng=dropout_rng, train=True
|
| 692 |
)[0]
|
| 693 |
+
return loss_fn(logits, labels)
|
|
|
|
| 694 |
|
| 695 |
grad_fn = jax.value_and_grad(compute_loss)
|
| 696 |
+
|
| 697 |
+
if training_args.gradient_accumulation_steps == 1:
|
| 698 |
+
minibatch = jax.tree_map(lambda x: x[0], batch)
|
| 699 |
+
loss, grads = grad_fn(state.params, minibatch)
|
| 700 |
+
else:
|
| 701 |
+
|
| 702 |
+
def _cumul_loss_grads(i, cumul_loss_grads):
|
| 703 |
+
minibatch = jax.tree_map(lambda x: x[i], batch)
|
| 704 |
+
return jax.tree_map(
|
| 705 |
+
lambda x, y: x + y,
|
| 706 |
+
cumul_loss_grads,
|
| 707 |
+
grad_fn(state.params, minibatch),
|
| 708 |
+
)
|
| 709 |
+
|
| 710 |
+
init_loss_grads = (
|
| 711 |
+
0.0,
|
| 712 |
+
jax.tree_map(jnp.zeros_like, state.params),
|
| 713 |
+
)
|
| 714 |
+
loss, grads = jax.tree_map(
|
| 715 |
+
lambda x: x / training_args.gradient_accumulation_steps,
|
| 716 |
+
jax.lax.fori_loop(
|
| 717 |
+
0,
|
| 718 |
+
training_args.gradient_accumulation_steps,
|
| 719 |
+
_cumul_loss_grads,
|
| 720 |
+
init_loss_grads,
|
| 721 |
+
),
|
| 722 |
+
)
|
| 723 |
+
|
| 724 |
grads = jax.lax.pmean(grads, "batch")
|
| 725 |
state = state.apply_gradients(
|
| 726 |
grads=grads,
|
| 727 |
dropout_rng=new_dropout_rng,
|
| 728 |
train_time=state.train_time + delta_time,
|
| 729 |
+
train_samples=state.train_samples + batch_size_per_step,
|
| 730 |
)
|
| 731 |
|
| 732 |
metrics = {
|
|
|
|
| 749 |
return metrics
|
| 750 |
|
| 751 |
# Create parallel version of the train and eval step
|
| 752 |
+
p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0, 1))
|
| 753 |
+
p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(1,))
|
| 754 |
|
| 755 |
logger.info("***** Running training *****")
|
| 756 |
logger.info(f" Num examples = {len_train_dataset}")
|
| 757 |
logger.info(f" Num Epochs = {num_epochs}")
|
| 758 |
logger.info(
|
| 759 |
+
f" Batch size per device = {training_args.per_device_train_batch_size}"
|
| 760 |
)
|
| 761 |
logger.info(f" Number of devices = {jax.device_count()}")
|
| 762 |
logger.info(
|
| 763 |
+
f" Gradient accumulation steps = {training_args.gradient_accumulation_steps}"
|
| 764 |
)
|
| 765 |
+
logger.info(f" Batch size per update = {batch_size_per_step}")
|
| 766 |
logger.info(f" Model parameters = {num_params:,}")
|
| 767 |
epochs = tqdm(
|
| 768 |
range(state.epoch, num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0
|
|
|
|
| 779 |
{
|
| 780 |
"len_train_dataset": len_train_dataset,
|
| 781 |
"len_eval_dataset": len_eval_dataset,
|
| 782 |
+
"batch_size_per_step": batch_size_per_step,
|
| 783 |
"num_params": num_params,
|
| 784 |
+
"num_devices": jax.device_count(),
|
| 785 |
}
|
| 786 |
)
|
| 787 |
|
|
|
|
| 792 |
# ======================== Evaluating ==============================
|
| 793 |
eval_metrics = []
|
| 794 |
if training_args.do_eval:
|
| 795 |
+
eval_loader = dataset.dataloader(
|
| 796 |
+
"eval", training_args.per_device_eval_batch_size
|
| 797 |
+
)
|
| 798 |
eval_steps = (
|
| 799 |
len_eval_dataset // eval_batch_size
|
| 800 |
if len_eval_dataset is not None
|
|
|
|
| 911 |
metrics_logger.log({"train/epoch": epoch}, step=unreplicate(state.step))
|
| 912 |
|
| 913 |
# Generate an epoch by shuffling sampling indices from the train dataset
|
| 914 |
+
train_loader = dataset.dataloader(
|
| 915 |
+
"train",
|
| 916 |
+
training_args.per_device_train_batch_size,
|
| 917 |
+
training_args.gradient_accumulation_steps,
|
| 918 |
+
epoch,
|
| 919 |
+
)
|
| 920 |
# train
|
| 921 |
for batch in tqdm(
|
| 922 |
train_loader,
|