Spaces:
Runtime error
Runtime error
feat(train): start pjit support
Browse files- tools/train/train.py +144 -106
tools/train/train.py
CHANGED
|
@@ -30,21 +30,30 @@ from typing import Callable, Optional
|
|
| 30 |
import datasets
|
| 31 |
import jax
|
| 32 |
import jax.numpy as jnp
|
|
|
|
| 33 |
import optax
|
| 34 |
import transformers
|
| 35 |
-
import wandb
|
| 36 |
from datasets import Dataset
|
| 37 |
-
from distributed_shampoo import GraftingType, distributed_shampoo
|
| 38 |
from flax import jax_utils, traverse_util
|
|
|
|
| 39 |
from flax.jax_utils import unreplicate
|
| 40 |
from flax.serialization import from_bytes, to_bytes
|
| 41 |
from flax.training import train_state
|
| 42 |
from flax.training.common_utils import get_metrics, onehot, shard_prng_key
|
|
|
|
|
|
|
| 43 |
from tqdm import tqdm
|
| 44 |
from transformers import AutoTokenizer, HfArgumentParser
|
| 45 |
|
|
|
|
| 46 |
from dalle_mini.data import Dataset
|
| 47 |
-
from dalle_mini.model import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
logger = logging.getLogger(__name__)
|
| 50 |
|
|
@@ -223,7 +232,6 @@ class TrainingArguments:
|
|
| 223 |
"help": 'The optimizer to use. Can be "distributed_shampoo" (default), "adam" or "adafactor"'
|
| 224 |
},
|
| 225 |
)
|
| 226 |
-
weight_decay: float = field(default=None, metadata={"help": "Weight decay."})
|
| 227 |
beta1: float = field(
|
| 228 |
default=0.9,
|
| 229 |
metadata={"help": "Beta1 for Adam & Distributed Shampoo."},
|
|
@@ -332,6 +340,13 @@ class TrainingArguments:
|
|
| 332 |
metadata={"help": "Verify that TPU is not in use."},
|
| 333 |
)
|
| 334 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 335 |
def __post_init__(self):
|
| 336 |
assert self.optim in [
|
| 337 |
"distributed_shampoo",
|
|
@@ -340,9 +355,6 @@ class TrainingArguments:
|
|
| 340 |
], f"Selected optimizer not supported: {self.optim}"
|
| 341 |
if self.per_device_eval_batch_size is None:
|
| 342 |
self.per_device_eval_batch_size = self.per_device_train_batch_size
|
| 343 |
-
if self.weight_decay is None:
|
| 344 |
-
if self.optim in ["distributed_shampoo", "adam"]:
|
| 345 |
-
self.weight_decay = 0.0
|
| 346 |
if (
|
| 347 |
os.path.exists(self.output_dir)
|
| 348 |
and os.listdir(self.output_dir)
|
|
@@ -353,6 +365,10 @@ class TrainingArguments:
|
|
| 353 |
f"Output directory ({self.output_dir}) already exists and is not empty."
|
| 354 |
"Use --overwrite_output_dir to overcome."
|
| 355 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 356 |
|
| 357 |
|
| 358 |
class TrainState(train_state.TrainState):
|
|
@@ -361,11 +377,6 @@ class TrainState(train_state.TrainState):
|
|
| 361 |
train_time: float = 0.0 # total time the model trained
|
| 362 |
train_samples: int = 0 # number of samples seen
|
| 363 |
|
| 364 |
-
def replicate(self):
|
| 365 |
-
return jax_utils.replicate(self).replace(
|
| 366 |
-
dropout_rng=shard_prng_key(self.dropout_rng)
|
| 367 |
-
)
|
| 368 |
-
|
| 369 |
def restore_state(self, artifact_dir):
|
| 370 |
# restore optimizer state
|
| 371 |
with (Path(artifact_dir) / "opt_state.msgpack").open("rb") as f:
|
|
@@ -487,8 +498,6 @@ def main():
|
|
| 487 |
dtype=getattr(jnp, model_args.dtype),
|
| 488 |
abstract_init=True,
|
| 489 |
)
|
| 490 |
-
# avoid OOM on TPU: see https://github.com/google/flax/issues/1658
|
| 491 |
-
print(model.params)
|
| 492 |
|
| 493 |
# load tokenizer
|
| 494 |
tokenizer = DalleBartTokenizer.from_pretrained(
|
|
@@ -512,8 +521,6 @@ def main():
|
|
| 512 |
dtype=getattr(jnp, model_args.dtype),
|
| 513 |
abstract_init=True,
|
| 514 |
)
|
| 515 |
-
# avoid OOM on TPU: see https://github.com/google/flax/issues/1658
|
| 516 |
-
print(model.params)
|
| 517 |
else:
|
| 518 |
model = DalleBart(
|
| 519 |
config,
|
|
@@ -544,7 +551,7 @@ def main():
|
|
| 544 |
|
| 545 |
# Initialize our training
|
| 546 |
rng = jax.random.PRNGKey(training_args.seed_model)
|
| 547 |
-
rng, dropout_rng = jax.random.split(rng)
|
| 548 |
|
| 549 |
# Store some constant
|
| 550 |
num_epochs = training_args.num_train_epochs
|
|
@@ -601,32 +608,9 @@ def main():
|
|
| 601 |
|
| 602 |
learning_rate_fn = create_learning_rate_fn()
|
| 603 |
|
| 604 |
-
# We use Optax's "masking" functionality to not apply weight decay
|
| 605 |
-
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
|
| 606 |
-
# mask boolean with the same structure as the parameters.
|
| 607 |
-
# The mask is True for parameters that should be decayed.
|
| 608 |
-
# Note that this mask is specifically adapted for FlaxBart.
|
| 609 |
-
def decay_mask_fn(params):
|
| 610 |
-
flat_params = traverse_util.flatten_dict(params)
|
| 611 |
-
layer_norm_params = [
|
| 612 |
-
(name, "scale")
|
| 613 |
-
for name in [
|
| 614 |
-
"self_attn_layer_norm",
|
| 615 |
-
"layernorm_embedding",
|
| 616 |
-
"final_layer_norm",
|
| 617 |
-
]
|
| 618 |
-
]
|
| 619 |
-
flat_mask = {
|
| 620 |
-
path: (path[-1] != "bias" and path[-2:] not in layer_norm_params)
|
| 621 |
-
for path in flat_params
|
| 622 |
-
}
|
| 623 |
-
return traverse_util.unflatten_dict(flat_mask)
|
| 624 |
-
|
| 625 |
# create adam optimizer
|
| 626 |
if training_args.optim == "distributed_shampoo":
|
| 627 |
# parameters from https://github.com/tensorflow/lingvo/blob/03ee9d7cd50764b0424c7c863733c91fc0b053ec/lingvo/jax/optimizers.py#L729
|
| 628 |
-
# Notes:
|
| 629 |
-
# - mask for weight decay is not implemented
|
| 630 |
optimizer = distributed_shampoo(
|
| 631 |
learning_rate_fn,
|
| 632 |
block_size=training_args.block_size,
|
|
@@ -634,7 +618,6 @@ def main():
|
|
| 634 |
beta2=training_args.beta2,
|
| 635 |
diagonal_epsilon=1e-10,
|
| 636 |
matrix_epsilon=1e-8,
|
| 637 |
-
weight_decay=training_args.weight_decay,
|
| 638 |
start_preconditioning_step=training_args.warmup_steps,
|
| 639 |
preconditioning_compute_steps=training_args.preconditioning_compute_steps,
|
| 640 |
statistics_compute_steps=1,
|
|
@@ -657,26 +640,76 @@ def main():
|
|
| 657 |
b1=training_args.beta1,
|
| 658 |
b2=training_args.beta2,
|
| 659 |
eps=training_args.adam_epsilon,
|
| 660 |
-
weight_decay=training_args.weight_decay,
|
| 661 |
-
mask=decay_mask_fn,
|
| 662 |
)
|
| 663 |
elif training_args.optim == "adafactor":
|
| 664 |
# We use the default parameters here to initialize adafactor,
|
| 665 |
# For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
|
| 666 |
optimizer = optax.adafactor(
|
| 667 |
learning_rate=learning_rate_fn,
|
| 668 |
-
weight_decay_rate=training_args.weight_decay,
|
| 669 |
-
weight_decay_mask=decay_mask_fn,
|
| 670 |
clipping_threshold=training_args.max_grad_norm,
|
| 671 |
)
|
| 672 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 673 |
# Setup train state
|
| 674 |
-
state = TrainState
|
| 675 |
apply_fn=model.__call__,
|
| 676 |
-
params=
|
|
|
|
| 677 |
tx=optimizer,
|
| 678 |
dropout_rng=dropout_rng,
|
|
|
|
| 679 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 680 |
if training_args.resume_from_checkpoint is not None:
|
| 681 |
# restore optimizer state and other parameters
|
| 682 |
# we currently ignore partial epoch training: see https://github.com/borisdayma/dalle-mini/issues/105
|
|
@@ -756,8 +789,17 @@ def main():
|
|
| 756 |
return metrics
|
| 757 |
|
| 758 |
# Create parallel version of the train and eval step
|
| 759 |
-
p_train_step =
|
| 760 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 761 |
|
| 762 |
logger.info("***** Running training *****")
|
| 763 |
logger.info(f" Num examples = {len_train_dataset}")
|
|
@@ -792,9 +834,6 @@ def main():
|
|
| 792 |
}
|
| 793 |
)
|
| 794 |
|
| 795 |
-
# replicate state on each device
|
| 796 |
-
state = state.replicate()
|
| 797 |
-
|
| 798 |
def run_evaluation():
|
| 799 |
# ======================== Evaluating ==============================
|
| 800 |
eval_metrics = []
|
|
@@ -823,9 +862,7 @@ def main():
|
|
| 823 |
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
| 824 |
|
| 825 |
# log metrics
|
| 826 |
-
metrics_logger.log(
|
| 827 |
-
eval_metrics, step=unreplicate(state.step), prefix="eval"
|
| 828 |
-
)
|
| 829 |
|
| 830 |
# Print metrics and update progress bar
|
| 831 |
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
|
|
@@ -836,7 +873,7 @@ def main():
|
|
| 836 |
|
| 837 |
def run_save_model(state, eval_metrics=None):
|
| 838 |
if jax.process_index() == 0:
|
| 839 |
-
params = jax.device_get(
|
| 840 |
# save model locally
|
| 841 |
model.save_pretrained(
|
| 842 |
training_args.output_dir,
|
|
@@ -847,11 +884,11 @@ def main():
|
|
| 847 |
tokenizer.save_pretrained(training_args.output_dir)
|
| 848 |
|
| 849 |
# save state
|
| 850 |
-
opt_state =
|
| 851 |
with (Path(training_args.output_dir) / "opt_state.msgpack").open("wb") as f:
|
| 852 |
f.write(to_bytes(opt_state))
|
| 853 |
state_dict = {
|
| 854 |
-
k: jax.device_get(
|
| 855 |
for k in ["step", "epoch", "train_time", "train_samples"]
|
| 856 |
}
|
| 857 |
with (Path(training_args.output_dir) / "training_state.json").open(
|
|
@@ -912,63 +949,64 @@ def main():
|
|
| 912 |
last_time = time.perf_counter()
|
| 913 |
train_metrics = None
|
| 914 |
|
| 915 |
-
|
| 916 |
-
|
| 917 |
-
|
| 918 |
-
|
| 919 |
-
|
| 920 |
-
|
| 921 |
-
|
| 922 |
-
|
| 923 |
-
|
| 924 |
-
|
| 925 |
-
|
| 926 |
-
|
| 927 |
-
|
| 928 |
-
|
| 929 |
-
|
| 930 |
-
|
| 931 |
-
|
| 932 |
-
|
| 933 |
-
|
| 934 |
-
|
|
|
|
| 935 |
|
| 936 |
-
|
| 937 |
-
|
| 938 |
-
|
| 939 |
-
|
| 940 |
|
| 941 |
-
|
| 942 |
-
|
| 943 |
-
|
| 944 |
-
)
|
| 945 |
-
step = unreplicate(state.step)
|
| 946 |
|
| 947 |
-
|
| 948 |
-
|
| 949 |
-
|
|
|
|
|
|
|
| 950 |
|
| 951 |
-
|
| 952 |
-
|
| 953 |
-
|
| 954 |
|
| 955 |
-
|
| 956 |
-
|
| 957 |
|
| 958 |
-
|
| 959 |
-
|
| 960 |
-
|
| 961 |
-
|
| 962 |
|
| 963 |
-
|
| 964 |
-
|
| 965 |
-
|
| 966 |
|
| 967 |
-
|
| 968 |
-
|
| 969 |
|
| 970 |
-
|
| 971 |
-
|
| 972 |
|
| 973 |
|
| 974 |
if __name__ == "__main__":
|
|
|
|
| 30 |
import datasets
|
| 31 |
import jax
|
| 32 |
import jax.numpy as jnp
|
| 33 |
+
import numpy as np
|
| 34 |
import optax
|
| 35 |
import transformers
|
|
|
|
| 36 |
from datasets import Dataset
|
| 37 |
+
from distributed_shampoo import GraftingType, distributed_shampoo, pad_matrix
|
| 38 |
from flax import jax_utils, traverse_util
|
| 39 |
+
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
| 40 |
from flax.jax_utils import unreplicate
|
| 41 |
from flax.serialization import from_bytes, to_bytes
|
| 42 |
from flax.training import train_state
|
| 43 |
from flax.training.common_utils import get_metrics, onehot, shard_prng_key
|
| 44 |
+
from jax.experimental import PartitionSpec, maps
|
| 45 |
+
from jax.experimental.pjit import pjit
|
| 46 |
from tqdm import tqdm
|
| 47 |
from transformers import AutoTokenizer, HfArgumentParser
|
| 48 |
|
| 49 |
+
import wandb
|
| 50 |
from dalle_mini.data import Dataset
|
| 51 |
+
from dalle_mini.model import (
|
| 52 |
+
DalleBart,
|
| 53 |
+
DalleBartConfig,
|
| 54 |
+
DalleBartTokenizer,
|
| 55 |
+
set_partitions,
|
| 56 |
+
)
|
| 57 |
|
| 58 |
logger = logging.getLogger(__name__)
|
| 59 |
|
|
|
|
| 232 |
"help": 'The optimizer to use. Can be "distributed_shampoo" (default), "adam" or "adafactor"'
|
| 233 |
},
|
| 234 |
)
|
|
|
|
| 235 |
beta1: float = field(
|
| 236 |
default=0.9,
|
| 237 |
metadata={"help": "Beta1 for Adam & Distributed Shampoo."},
|
|
|
|
| 340 |
metadata={"help": "Verify that TPU is not in use."},
|
| 341 |
)
|
| 342 |
|
| 343 |
+
mp_devices: Optional[int] = field(
|
| 344 |
+
default=1,
|
| 345 |
+
metadata={
|
| 346 |
+
"help": "Number of devices required for model parallelism. The other dimension of available devices is used for data parallelism."
|
| 347 |
+
},
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
def __post_init__(self):
|
| 351 |
assert self.optim in [
|
| 352 |
"distributed_shampoo",
|
|
|
|
| 355 |
], f"Selected optimizer not supported: {self.optim}"
|
| 356 |
if self.per_device_eval_batch_size is None:
|
| 357 |
self.per_device_eval_batch_size = self.per_device_train_batch_size
|
|
|
|
|
|
|
|
|
|
| 358 |
if (
|
| 359 |
os.path.exists(self.output_dir)
|
| 360 |
and os.listdir(self.output_dir)
|
|
|
|
| 365 |
f"Output directory ({self.output_dir}) already exists and is not empty."
|
| 366 |
"Use --overwrite_output_dir to overcome."
|
| 367 |
)
|
| 368 |
+
assert (
|
| 369 |
+
jax.device_count() % self.mp_devices == 0
|
| 370 |
+
), f"Number of available devices ({jax.device_count()} must be divisible by number of devices used for model parallelism ({self.mp_devices})."
|
| 371 |
+
self.dp_devices = jax.device_count() // self.mp_devices
|
| 372 |
|
| 373 |
|
| 374 |
class TrainState(train_state.TrainState):
|
|
|
|
| 377 |
train_time: float = 0.0 # total time the model trained
|
| 378 |
train_samples: int = 0 # number of samples seen
|
| 379 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 380 |
def restore_state(self, artifact_dir):
|
| 381 |
# restore optimizer state
|
| 382 |
with (Path(artifact_dir) / "opt_state.msgpack").open("rb") as f:
|
|
|
|
| 498 |
dtype=getattr(jnp, model_args.dtype),
|
| 499 |
abstract_init=True,
|
| 500 |
)
|
|
|
|
|
|
|
| 501 |
|
| 502 |
# load tokenizer
|
| 503 |
tokenizer = DalleBartTokenizer.from_pretrained(
|
|
|
|
| 521 |
dtype=getattr(jnp, model_args.dtype),
|
| 522 |
abstract_init=True,
|
| 523 |
)
|
|
|
|
|
|
|
| 524 |
else:
|
| 525 |
model = DalleBart(
|
| 526 |
config,
|
|
|
|
| 551 |
|
| 552 |
# Initialize our training
|
| 553 |
rng = jax.random.PRNGKey(training_args.seed_model)
|
| 554 |
+
rng, *dropout_rng = jax.random.split(rng, num=training_args.dp_devices + 1)
|
| 555 |
|
| 556 |
# Store some constant
|
| 557 |
num_epochs = training_args.num_train_epochs
|
|
|
|
| 608 |
|
| 609 |
learning_rate_fn = create_learning_rate_fn()
|
| 610 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 611 |
# create adam optimizer
|
| 612 |
if training_args.optim == "distributed_shampoo":
|
| 613 |
# parameters from https://github.com/tensorflow/lingvo/blob/03ee9d7cd50764b0424c7c863733c91fc0b053ec/lingvo/jax/optimizers.py#L729
|
|
|
|
|
|
|
| 614 |
optimizer = distributed_shampoo(
|
| 615 |
learning_rate_fn,
|
| 616 |
block_size=training_args.block_size,
|
|
|
|
| 618 |
beta2=training_args.beta2,
|
| 619 |
diagonal_epsilon=1e-10,
|
| 620 |
matrix_epsilon=1e-8,
|
|
|
|
| 621 |
start_preconditioning_step=training_args.warmup_steps,
|
| 622 |
preconditioning_compute_steps=training_args.preconditioning_compute_steps,
|
| 623 |
statistics_compute_steps=1,
|
|
|
|
| 640 |
b1=training_args.beta1,
|
| 641 |
b2=training_args.beta2,
|
| 642 |
eps=training_args.adam_epsilon,
|
|
|
|
|
|
|
| 643 |
)
|
| 644 |
elif training_args.optim == "adafactor":
|
| 645 |
# We use the default parameters here to initialize adafactor,
|
| 646 |
# For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
|
| 647 |
optimizer = optax.adafactor(
|
| 648 |
learning_rate=learning_rate_fn,
|
|
|
|
|
|
|
| 649 |
clipping_threshold=training_args.max_grad_norm,
|
| 650 |
)
|
| 651 |
|
| 652 |
+
# get opt_state shape without actual init
|
| 653 |
+
param_shape = jax.tree_map(lambda x: x.shape, model.params)
|
| 654 |
+
opt_state_shape = jax.eval_shape(lambda x: optimizer.init(x), param_shape)
|
| 655 |
+
|
| 656 |
+
# get PartitionSpec for model params
|
| 657 |
+
param_spec = set_partitions(model.params)
|
| 658 |
+
|
| 659 |
+
# create PartitionSpec for opt_state
|
| 660 |
+
def opt_state_spec_per_leaf(x):
|
| 661 |
+
if training_args.optim in ["adam", "adafactor"]:
|
| 662 |
+
if isinstance(x, dict):
|
| 663 |
+
# variables with same structure as params
|
| 664 |
+
return param_spec
|
| 665 |
+
else:
|
| 666 |
+
# other variables such as count
|
| 667 |
+
return None
|
| 668 |
+
else:
|
| 669 |
+
# TODO: create spec for Distributed Shampoo
|
| 670 |
+
raise NotImplementedError
|
| 671 |
+
|
| 672 |
+
opt_state_spec = jax.tree_map(
|
| 673 |
+
opt_state_spec_per_leaf,
|
| 674 |
+
opt_state_shape,
|
| 675 |
+
# return None spec for empty elements
|
| 676 |
+
is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState)),
|
| 677 |
+
)
|
| 678 |
+
|
| 679 |
+
# create a mesh
|
| 680 |
+
mesh_shape = (training_args.dp_devices, training_args.mp_devices)
|
| 681 |
+
devices = np.asarray(jax.devices()).reshape(*mesh_shape)
|
| 682 |
+
mesh = maps.Mesh(devices, ("batch", "mp"))
|
| 683 |
+
|
| 684 |
+
# move params & init opt_state over specified devices
|
| 685 |
+
with maps.mesh(mesh.devices, mesh.axis_names):
|
| 686 |
+
params, opt_state = pjit(
|
| 687 |
+
lambda x: (x, optimizer.init(x)),
|
| 688 |
+
in_axis_resources=None,
|
| 689 |
+
out_axis_resources=(param_spec, opt_state_spec),
|
| 690 |
+
)(freeze(model.params))
|
| 691 |
+
|
| 692 |
# Setup train state
|
| 693 |
+
state = TrainState(
|
| 694 |
apply_fn=model.__call__,
|
| 695 |
+
params=params,
|
| 696 |
+
opt_state=opt_state,
|
| 697 |
tx=optimizer,
|
| 698 |
dropout_rng=dropout_rng,
|
| 699 |
+
step=0,
|
| 700 |
)
|
| 701 |
+
|
| 702 |
+
# create PartitionSpec for state
|
| 703 |
+
state_spec = {
|
| 704 |
+
"params": param_spec,
|
| 705 |
+
"opt_state": opt_state_spec,
|
| 706 |
+
"dropout_rng": PartitionSpec("batch", None),
|
| 707 |
+
"epoch": None,
|
| 708 |
+
"step": None,
|
| 709 |
+
"train_samples": None,
|
| 710 |
+
"train_time": None,
|
| 711 |
+
}
|
| 712 |
+
|
| 713 |
if training_args.resume_from_checkpoint is not None:
|
| 714 |
# restore optimizer state and other parameters
|
| 715 |
# we currently ignore partial epoch training: see https://github.com/borisdayma/dalle-mini/issues/105
|
|
|
|
| 789 |
return metrics
|
| 790 |
|
| 791 |
# Create parallel version of the train and eval step
|
| 792 |
+
p_train_step = pjit(
|
| 793 |
+
train_step,
|
| 794 |
+
in_axis_resources=(state_spec, None, None),
|
| 795 |
+
out_axis_resources=(state_spec, None),
|
| 796 |
+
donate_argnums=(0,),
|
| 797 |
+
)
|
| 798 |
+
p_eval_step = pjit(
|
| 799 |
+
eval_step,
|
| 800 |
+
in_axis_resources=(param_spec, PartitionSpec("batch", None)),
|
| 801 |
+
out_axis_resources=None,
|
| 802 |
+
)
|
| 803 |
|
| 804 |
logger.info("***** Running training *****")
|
| 805 |
logger.info(f" Num examples = {len_train_dataset}")
|
|
|
|
| 834 |
}
|
| 835 |
)
|
| 836 |
|
|
|
|
|
|
|
|
|
|
| 837 |
def run_evaluation():
|
| 838 |
# ======================== Evaluating ==============================
|
| 839 |
eval_metrics = []
|
|
|
|
| 862 |
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
| 863 |
|
| 864 |
# log metrics
|
| 865 |
+
metrics_logger.log(eval_metrics, step=state.step, prefix="eval")
|
|
|
|
|
|
|
| 866 |
|
| 867 |
# Print metrics and update progress bar
|
| 868 |
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
|
|
|
|
| 873 |
|
| 874 |
def run_save_model(state, eval_metrics=None):
|
| 875 |
if jax.process_index() == 0:
|
| 876 |
+
params = jax.device_get(state.params)
|
| 877 |
# save model locally
|
| 878 |
model.save_pretrained(
|
| 879 |
training_args.output_dir,
|
|
|
|
| 884 |
tokenizer.save_pretrained(training_args.output_dir)
|
| 885 |
|
| 886 |
# save state
|
| 887 |
+
opt_state = jax.device_get(state.opt_state)
|
| 888 |
with (Path(training_args.output_dir) / "opt_state.msgpack").open("wb") as f:
|
| 889 |
f.write(to_bytes(opt_state))
|
| 890 |
state_dict = {
|
| 891 |
+
k: jax.device_get(getattr(state, k)).item()
|
| 892 |
for k in ["step", "epoch", "train_time", "train_samples"]
|
| 893 |
}
|
| 894 |
with (Path(training_args.output_dir) / "training_state.json").open(
|
|
|
|
| 949 |
last_time = time.perf_counter()
|
| 950 |
train_metrics = None
|
| 951 |
|
| 952 |
+
with maps.mesh(mesh.devices, mesh.axis_names):
|
| 953 |
+
for epoch in epochs:
|
| 954 |
+
state.replace(epoch=epoch)
|
| 955 |
+
# ======================== Training ================================
|
| 956 |
+
metrics_logger.log({"train/epoch": epoch}, step=state.step)
|
| 957 |
+
|
| 958 |
+
# Generate an epoch by shuffling sampling indices from the train dataset
|
| 959 |
+
train_loader = dataset.dataloader(
|
| 960 |
+
"train",
|
| 961 |
+
training_args.per_device_train_batch_size,
|
| 962 |
+
training_args.gradient_accumulation_steps,
|
| 963 |
+
epoch,
|
| 964 |
+
)
|
| 965 |
+
# train
|
| 966 |
+
for batch in tqdm(
|
| 967 |
+
train_loader,
|
| 968 |
+
desc="Training...",
|
| 969 |
+
position=1,
|
| 970 |
+
leave=False,
|
| 971 |
+
total=steps_per_epoch,
|
| 972 |
+
):
|
| 973 |
|
| 974 |
+
# calculate delta time (we have a lag of one step but it's ok)
|
| 975 |
+
new_time = time.perf_counter()
|
| 976 |
+
delta_time = new_time - last_time
|
| 977 |
+
last_time = new_time
|
| 978 |
|
| 979 |
+
# train step
|
| 980 |
+
state, train_metrics = p_train_step(state, batch, delta_time)
|
| 981 |
+
step = state.step
|
|
|
|
|
|
|
| 982 |
|
| 983 |
+
if step % training_args.logging_steps == 0 and jax.process_index() == 0:
|
| 984 |
+
all_metrics = metrics_logger.get_all_train_metrics(
|
| 985 |
+
train_metrics, state
|
| 986 |
+
)
|
| 987 |
+
metrics_logger.log(all_metrics, step=step, prefix="train")
|
| 988 |
|
| 989 |
+
eval_metrics = None
|
| 990 |
+
if training_args.eval_steps and step % training_args.eval_steps == 0:
|
| 991 |
+
eval_metrics = run_evaluation()
|
| 992 |
|
| 993 |
+
if step % training_args.save_steps == 0:
|
| 994 |
+
run_save_model(state, eval_metrics)
|
| 995 |
|
| 996 |
+
# log final train metrics
|
| 997 |
+
if train_metrics is not None:
|
| 998 |
+
all_metrics = metrics_logger.get_all_train_metrics(train_metrics, state)
|
| 999 |
+
metrics_logger.log(all_metrics, step=step, prefix="train")
|
| 1000 |
|
| 1001 |
+
epochs.write(
|
| 1002 |
+
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metrics['loss']}, Learning Rate: {train_metrics['learning_rate']})"
|
| 1003 |
+
)
|
| 1004 |
|
| 1005 |
+
# Final evaluation
|
| 1006 |
+
eval_metrics = run_evaluation()
|
| 1007 |
|
| 1008 |
+
# save checkpoint after each epoch
|
| 1009 |
+
run_save_model(state, eval_metrics)
|
| 1010 |
|
| 1011 |
|
| 1012 |
if __name__ == "__main__":
|