Spaces:
Runtime error
Runtime error
feat(train) - handle multiple nodes (#130)
Browse files- src/dalle_mini/data.py +1 -1
- tools/train/train.py +69 -54
src/dalle_mini/data.py
CHANGED
|
@@ -94,7 +94,7 @@ class Dataset:
|
|
| 94 |
if self.streaming:
|
| 95 |
# we need to shuffle early in streaming mode
|
| 96 |
if hasattr(self, "train_dataset"):
|
| 97 |
-
self.train_dataset = self.train_dataset.shuffle(
|
| 98 |
else:
|
| 99 |
# prepare rng for later shuffling
|
| 100 |
if self.seed_dataset is None:
|
|
|
|
| 94 |
if self.streaming:
|
| 95 |
# we need to shuffle early in streaming mode
|
| 96 |
if hasattr(self, "train_dataset"):
|
| 97 |
+
self.train_dataset = self.train_dataset.shuffle(5000, self.seed_dataset)
|
| 98 |
else:
|
| 99 |
# prepare rng for later shuffling
|
| 100 |
if self.seed_dataset is None:
|
tools/train/train.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
#!/usr/bin/env python
|
| 2 |
# coding=utf-8
|
| 3 |
-
# Copyright 2021 The HuggingFace Team All rights reserved.
|
| 4 |
#
|
| 5 |
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
# you may not use this file except in compliance with the License.
|
|
@@ -14,7 +14,7 @@
|
|
| 14 |
# See the License for the specific language governing permissions and
|
| 15 |
# limitations under the License.
|
| 16 |
"""
|
| 17 |
-
|
| 18 |
Script adapted from run_summarization_flax.py
|
| 19 |
"""
|
| 20 |
|
|
@@ -527,23 +527,29 @@ def main():
|
|
| 527 |
dataset.preprocess(tokenizer=tokenizer, config=model.config)
|
| 528 |
|
| 529 |
# Initialize our training
|
| 530 |
-
|
| 531 |
-
rng, dropout_rng = jax.random.split(rng)
|
| 532 |
|
| 533 |
# Store some constant
|
| 534 |
num_epochs = training_args.num_train_epochs
|
| 535 |
# batch size
|
| 536 |
-
|
| 537 |
-
training_args.per_device_train_batch_size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 538 |
)
|
| 539 |
-
batch_size_per_node = minibatch_size * training_args.gradient_accumulation_steps
|
| 540 |
batch_size_per_step = batch_size_per_node * jax.process_count()
|
| 541 |
-
|
| 542 |
-
training_args.per_device_eval_batch_size
|
|
|
|
|
|
|
| 543 |
)
|
|
|
|
| 544 |
len_train_dataset, len_eval_dataset = dataset.length
|
| 545 |
steps_per_epoch = (
|
| 546 |
-
len_train_dataset //
|
| 547 |
if len_train_dataset is not None
|
| 548 |
else None
|
| 549 |
)
|
|
@@ -763,13 +769,21 @@ def main():
|
|
| 763 |
|
| 764 |
# Define gradient update step fn
|
| 765 |
def train_step(state, batch, delta_time):
|
| 766 |
-
#
|
| 767 |
-
#
|
| 768 |
-
|
| 769 |
-
|
| 770 |
-
|
| 771 |
-
|
| 772 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 773 |
|
| 774 |
# get a minibatch (one gradient accumulation slice)
|
| 775 |
def get_minibatch(batch, grad_idx):
|
|
@@ -791,54 +805,45 @@ def main():
|
|
| 791 |
def loss_and_grad(grad_idx, dropout_rng):
|
| 792 |
# minibatch at grad_idx, shape (dp_devices, per_device_train_batch_size, ...)
|
| 793 |
minibatch = get_minibatch(batch, grad_idx)
|
| 794 |
-
#
|
|
|
|
|
|
|
| 795 |
minibatch = jax.tree_map(
|
| 796 |
-
lambda x: with_sharding_constraint(x, PartitionSpec("batch")),
|
|
|
|
| 797 |
)
|
| 798 |
-
#
|
| 799 |
loss_grads = jax.vmap(grad_fn, in_axes=(None, 0, None), out_axes=(0, 0))(
|
| 800 |
state.params, minibatch, dropout_rng
|
| 801 |
)
|
| 802 |
-
# ensure
|
| 803 |
loss_grads = jax.tree_map(
|
| 804 |
lambda x: with_sharding_constraint(x, PartitionSpec("batch")),
|
| 805 |
loss_grads,
|
| 806 |
)
|
| 807 |
-
|
| 808 |
# average across all devices
|
| 809 |
loss_grads = jax.tree_map(lambda x: jnp.mean(x, axis=0), loss_grads)
|
| 810 |
-
|
| 811 |
# return loss and grads
|
| 812 |
-
return loss_grads
|
| 813 |
-
|
| 814 |
-
# create a new rng
|
| 815 |
-
dropout_rng, _ = jax.random.split(state.dropout_rng)
|
| 816 |
-
# use a different rng per node
|
| 817 |
-
dropout_rng = jax.random.fold_in(dropout_rng, jax.process_index())
|
| 818 |
|
| 819 |
if training_args.gradient_accumulation_steps == 1:
|
| 820 |
-
|
| 821 |
-
def batch_step(dropout_rng):
|
| 822 |
-
dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
|
| 823 |
-
loss_grad = loss_and_grad(0, dropout_rng)
|
| 824 |
-
return loss_grad, new_dropout_rng
|
| 825 |
-
|
| 826 |
-
loss_grad, dropout_rng = batch_step(dropout_rng)
|
| 827 |
else:
|
| 828 |
-
# create initial state for
|
| 829 |
-
|
| 830 |
-
|
| 831 |
-
|
|
|
|
|
|
|
|
|
|
| 832 |
)
|
| 833 |
-
init_minibatch_step = (init_cumul_loss_grad, dropout_rng)
|
| 834 |
|
| 835 |
# accumulate gradients
|
| 836 |
def cumul_minibatch_step(grad_idx, cumul_loss_grad_dropout):
|
| 837 |
cumul_loss_grad, dropout_rng = cumul_loss_grad_dropout
|
| 838 |
-
|
| 839 |
-
loss_grad = loss_and_grad(grad_idx, dropout_rng)
|
| 840 |
cumul_loss_grad = jax.tree_map(jnp.add, cumul_loss_grad, loss_grad)
|
| 841 |
-
return cumul_loss_grad,
|
| 842 |
|
| 843 |
# loop over gradients
|
| 844 |
loss_grad, dropout_rng = jax.lax.fori_loop(
|
|
@@ -870,6 +875,20 @@ def main():
|
|
| 870 |
|
| 871 |
# Define eval fn
|
| 872 |
def eval_step(state, batch):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 873 |
def compute_eval_loss(batch):
|
| 874 |
batch, labels = batch.pop("labels")
|
| 875 |
logits = state.apply_fn(**batch, params=state.params, train=False)[0]
|
|
@@ -936,9 +955,9 @@ def main():
|
|
| 936 |
def run_evaluation():
|
| 937 |
# ======================== Evaluating ==============================
|
| 938 |
if training_args.do_eval:
|
| 939 |
-
eval_loader = dataset.dataloader("eval",
|
| 940 |
eval_steps = (
|
| 941 |
-
len_eval_dataset //
|
| 942 |
if len_eval_dataset is not None
|
| 943 |
else None
|
| 944 |
)
|
|
@@ -950,17 +969,14 @@ def main():
|
|
| 950 |
leave=False,
|
| 951 |
total=eval_steps,
|
| 952 |
):
|
| 953 |
-
#
|
| 954 |
batch = jax.tree_map(
|
| 955 |
lambda x: x.reshape(
|
| 956 |
-
(
|
| 957 |
-
training_args.dp_devices,
|
| 958 |
-
training_args.per_device_eval_batch_size,
|
| 959 |
-
)
|
| 960 |
-
+ x.shape[1:]
|
| 961 |
),
|
| 962 |
batch,
|
| 963 |
)
|
|
|
|
| 964 |
# freeze batch to pass safely to jax transforms
|
| 965 |
batch = freeze(batch)
|
| 966 |
# accumulate losses async
|
|
@@ -1081,8 +1097,7 @@ def main():
|
|
| 1081 |
lambda x: x.reshape(
|
| 1082 |
(
|
| 1083 |
training_args.gradient_accumulation_steps,
|
| 1084 |
-
|
| 1085 |
-
training_args.per_device_train_batch_size,
|
| 1086 |
)
|
| 1087 |
+ x.shape[1:]
|
| 1088 |
),
|
|
|
|
| 1 |
#!/usr/bin/env python
|
| 2 |
# coding=utf-8
|
| 3 |
+
# Copyright 2021-2022 The HuggingFace & DALL·E Mini Team All rights reserved.
|
| 4 |
#
|
| 5 |
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
# you may not use this file except in compliance with the License.
|
|
|
|
| 14 |
# See the License for the specific language governing permissions and
|
| 15 |
# limitations under the License.
|
| 16 |
"""
|
| 17 |
+
Training DALL·E Mini.
|
| 18 |
Script adapted from run_summarization_flax.py
|
| 19 |
"""
|
| 20 |
|
|
|
|
| 527 |
dataset.preprocess(tokenizer=tokenizer, config=model.config)
|
| 528 |
|
| 529 |
# Initialize our training
|
| 530 |
+
dropout_rng = jax.random.PRNGKey(training_args.seed_model)
|
|
|
|
| 531 |
|
| 532 |
# Store some constant
|
| 533 |
num_epochs = training_args.num_train_epochs
|
| 534 |
# batch size
|
| 535 |
+
batch_size_per_node_per_grad_step = (
|
| 536 |
+
training_args.per_device_train_batch_size
|
| 537 |
+
* jax.local_device_count()
|
| 538 |
+
// training_args.mp_devices
|
| 539 |
+
)
|
| 540 |
+
batch_size_per_node = (
|
| 541 |
+
batch_size_per_node_per_grad_step * training_args.gradient_accumulation_steps
|
| 542 |
)
|
|
|
|
| 543 |
batch_size_per_step = batch_size_per_node * jax.process_count()
|
| 544 |
+
eval_batch_size_per_node = (
|
| 545 |
+
training_args.per_device_eval_batch_size
|
| 546 |
+
* jax.local_device_count()
|
| 547 |
+
// training_args.mp_devices
|
| 548 |
)
|
| 549 |
+
eval_batch_size_per_step = eval_batch_size_per_node * jax.process_count()
|
| 550 |
len_train_dataset, len_eval_dataset = dataset.length
|
| 551 |
steps_per_epoch = (
|
| 552 |
+
len_train_dataset // batch_size_per_step
|
| 553 |
if len_train_dataset is not None
|
| 554 |
else None
|
| 555 |
)
|
|
|
|
| 769 |
|
| 770 |
# Define gradient update step fn
|
| 771 |
def train_step(state, batch, delta_time):
|
| 772 |
+
# we reshape to (gradient_accumulation_steps, dp_devices, ...)
|
| 773 |
+
# allows feeding partial batch size per node for full model parallel
|
| 774 |
+
batch = jax.tree_map(
|
| 775 |
+
lambda x: x.reshape(
|
| 776 |
+
(
|
| 777 |
+
training_args.gradient_accumulation_steps,
|
| 778 |
+
training_args.dp_devices,
|
| 779 |
+
training_args.per_device_train_batch_size,
|
| 780 |
+
)
|
| 781 |
+
+ x.shape[2:]
|
| 782 |
+
),
|
| 783 |
+
batch,
|
| 784 |
+
)
|
| 785 |
+
# ensure data is sharded correctly per dp device
|
| 786 |
+
batch = with_sharding_constraint(batch, grad_batch_spec)
|
| 787 |
|
| 788 |
# get a minibatch (one gradient accumulation slice)
|
| 789 |
def get_minibatch(batch, grad_idx):
|
|
|
|
| 805 |
def loss_and_grad(grad_idx, dropout_rng):
|
| 806 |
# minibatch at grad_idx, shape (dp_devices, per_device_train_batch_size, ...)
|
| 807 |
minibatch = get_minibatch(batch, grad_idx)
|
| 808 |
+
# calculate loss and grads independently per dp_device
|
| 809 |
+
dropout_rng, _ = jax.random.split(dropout_rng)
|
| 810 |
+
# ensure inputs are sharded per device
|
| 811 |
minibatch = jax.tree_map(
|
| 812 |
+
lambda x: with_sharding_constraint(x, PartitionSpec("batch")),
|
| 813 |
+
minibatch,
|
| 814 |
)
|
| 815 |
+
# only 1 single rng per grad step, let us handle larger batch size
|
| 816 |
loss_grads = jax.vmap(grad_fn, in_axes=(None, 0, None), out_axes=(0, 0))(
|
| 817 |
state.params, minibatch, dropout_rng
|
| 818 |
)
|
| 819 |
+
# ensure outputs are sharded per device
|
| 820 |
loss_grads = jax.tree_map(
|
| 821 |
lambda x: with_sharding_constraint(x, PartitionSpec("batch")),
|
| 822 |
loss_grads,
|
| 823 |
)
|
|
|
|
| 824 |
# average across all devices
|
| 825 |
loss_grads = jax.tree_map(lambda x: jnp.mean(x, axis=0), loss_grads)
|
|
|
|
| 826 |
# return loss and grads
|
| 827 |
+
return loss_grads, dropout_rng
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 828 |
|
| 829 |
if training_args.gradient_accumulation_steps == 1:
|
| 830 |
+
loss_grad, dropout_rng = loss_and_grad(0, state.dropout_rng)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 831 |
else:
|
| 832 |
+
# create initial state for cumul_minibatch_step loop
|
| 833 |
+
init_minibatch_step = (
|
| 834 |
+
(
|
| 835 |
+
0.0,
|
| 836 |
+
jax.tree_map(jnp.zeros_like, state.params),
|
| 837 |
+
),
|
| 838 |
+
state.dropout_rng,
|
| 839 |
)
|
|
|
|
| 840 |
|
| 841 |
# accumulate gradients
|
| 842 |
def cumul_minibatch_step(grad_idx, cumul_loss_grad_dropout):
|
| 843 |
cumul_loss_grad, dropout_rng = cumul_loss_grad_dropout
|
| 844 |
+
loss_grad, dropout_rng = loss_and_grad(grad_idx, dropout_rng)
|
|
|
|
| 845 |
cumul_loss_grad = jax.tree_map(jnp.add, cumul_loss_grad, loss_grad)
|
| 846 |
+
return cumul_loss_grad, dropout_rng
|
| 847 |
|
| 848 |
# loop over gradients
|
| 849 |
loss_grad, dropout_rng = jax.lax.fori_loop(
|
|
|
|
| 875 |
|
| 876 |
# Define eval fn
|
| 877 |
def eval_step(state, batch):
|
| 878 |
+
# we reshape to (dp_devices, ...)
|
| 879 |
+
batch = jax.tree_map(
|
| 880 |
+
lambda x: x.reshape(
|
| 881 |
+
(
|
| 882 |
+
training_args.dp_devices,
|
| 883 |
+
training_args.per_device_eval_batch_size,
|
| 884 |
+
)
|
| 885 |
+
+ x.shape[1:]
|
| 886 |
+
),
|
| 887 |
+
batch,
|
| 888 |
+
)
|
| 889 |
+
# ensure data is sharded correctly per dp device
|
| 890 |
+
batch = with_sharding_constraint(batch, batch_spec)
|
| 891 |
+
|
| 892 |
def compute_eval_loss(batch):
|
| 893 |
batch, labels = batch.pop("labels")
|
| 894 |
logits = state.apply_fn(**batch, params=state.params, train=False)[0]
|
|
|
|
| 955 |
def run_evaluation():
|
| 956 |
# ======================== Evaluating ==============================
|
| 957 |
if training_args.do_eval:
|
| 958 |
+
eval_loader = dataset.dataloader("eval", eval_batch_size_per_step)
|
| 959 |
eval_steps = (
|
| 960 |
+
len_eval_dataset // eval_batch_size_per_step
|
| 961 |
if len_eval_dataset is not None
|
| 962 |
else None
|
| 963 |
)
|
|
|
|
| 969 |
leave=False,
|
| 970 |
total=eval_steps,
|
| 971 |
):
|
| 972 |
+
# need to keep only eval_batch_size_per_node items relevant to the node
|
| 973 |
batch = jax.tree_map(
|
| 974 |
lambda x: x.reshape(
|
| 975 |
+
(jax.process_count(), eval_batch_size_per_node) + x.shape[1:]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 976 |
),
|
| 977 |
batch,
|
| 978 |
)
|
| 979 |
+
batch = jax.tree_map(lambda x: x[jax.process_index()], batch)
|
| 980 |
# freeze batch to pass safely to jax transforms
|
| 981 |
batch = freeze(batch)
|
| 982 |
# accumulate losses async
|
|
|
|
| 1097 |
lambda x: x.reshape(
|
| 1098 |
(
|
| 1099 |
training_args.gradient_accumulation_steps,
|
| 1100 |
+
batch_size_per_node_per_grad_step,
|
|
|
|
| 1101 |
)
|
| 1102 |
+ x.shape[1:]
|
| 1103 |
),
|