add support for v3-32
Browse files- run_mlm_flax_stream.py +10 -1
run_mlm_flax_stream.py
CHANGED
|
@@ -551,6 +551,10 @@ if __name__ == "__main__":
|
|
| 551 |
# define number steps per stream epoch
|
| 552 |
num_train_steps = data_args.num_train_steps
|
| 553 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 554 |
# Create learning rate schedule
|
| 555 |
warmup_fn = optax.linear_schedule(
|
| 556 |
init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps
|
|
@@ -714,8 +718,13 @@ if __name__ == "__main__":
|
|
| 714 |
# process input samples
|
| 715 |
model_inputs = data_collator(samples, pad_to_multiple_of=16)
|
| 716 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 717 |
# Model forward
|
| 718 |
-
model_inputs = shard(
|
| 719 |
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
|
| 720 |
|
| 721 |
train_metrics.append(train_metric)
|
|
|
|
| 551 |
# define number steps per stream epoch
|
| 552 |
num_train_steps = data_args.num_train_steps
|
| 553 |
|
| 554 |
+
num_of_hosts = jax.process_count()
|
| 555 |
+
current_host_idx = jax.process_index()
|
| 556 |
+
|
| 557 |
+
|
| 558 |
# Create learning rate schedule
|
| 559 |
warmup_fn = optax.linear_schedule(
|
| 560 |
init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps
|
|
|
|
| 718 |
# process input samples
|
| 719 |
model_inputs = data_collator(samples, pad_to_multiple_of=16)
|
| 720 |
|
| 721 |
+
local_host_model_inputs = {
|
| 722 |
+
key: np.split(model_inputs.data[key], num_of_hosts, axis=0)[current_host_idx]
|
| 723 |
+
for key, value in model_inputs.data.items()
|
| 724 |
+
}
|
| 725 |
+
|
| 726 |
# Model forward
|
| 727 |
+
model_inputs = shard(local_host_model_inputs)
|
| 728 |
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
|
| 729 |
|
| 730 |
train_metrics.append(train_metric)
|