| import dataclasses |
| import pprint |
| from functools import partial |
| import re |
|
|
| from tqdm import tqdm, trange |
| import numpy as np |
| import mlxu |
|
|
| import jax |
| import jax.numpy as jnp |
| from jax.experimental.pjit import pjit, with_sharding_constraint |
| from jax.sharding import PartitionSpec as PS |
| from flax.training.train_state import TrainState |
|
|
| from EasyLM.data import DatasetFactory |
| from EasyLM.checkpoint import StreamingCheckpointer |
| from EasyLM.optimizers import OptimizerFactory |
| from EasyLM.jax_utils import ( |
| JaxRNG, JaxDistributedConfig, next_rng, match_partition_rules, get_float_dtype_by_name, |
| cross_entropy_loss_and_accuracy, named_tree_map, global_norm, |
| set_random_seed, average_metrics, get_weight_decay_mask, |
| make_shard_and_gather_fns, tree_apply |
| ) |
| from EasyLM.models.roberta.roberta_model import ( |
| RobertaConfig, FlaxRobertaForMaskedLMModule |
| ) |
|
|
|
|
| FLAGS, FLAGS_DEF = mlxu.define_flags_with_default( |
| seed=42, |
| mesh_dim='-1,1,1', |
| dtype='fp32', |
| mask_token_probability=0.15, |
| total_steps=10000, |
| load_roberta_config='', |
| update_roberta_config='', |
| load_checkpoint='', |
| load_dataset_state='', |
| log_freq=50, |
| save_model_freq=0, |
| save_milestone_freq=0, |
| eval_steps=0, |
| tokenizer=RobertaConfig.get_tokenizer_config(), |
| train_dataset=DatasetFactory.get_default_config(), |
| eval_dataset=DatasetFactory.get_default_config(), |
| optimizer=OptimizerFactory.get_default_config(), |
| checkpointer=StreamingCheckpointer.get_default_config(), |
| roberta=RobertaConfig.get_default_config(), |
| logger=mlxu.WandBLogger.get_default_config(), |
| log_all_worker=False, |
| jax_distributed=JaxDistributedConfig.get_default_config(), |
| ) |
|
|
|
|
| def main(argv): |
| JaxDistributedConfig.initialize(FLAGS.jax_distributed) |
| variant = mlxu.get_user_flags(FLAGS, FLAGS_DEF) |
| flags_config_dict = mlxu.user_flags_to_config_dict(FLAGS, FLAGS_DEF) |
| logger = mlxu.WandBLogger( |
| config=FLAGS.logger, |
| variant=variant, |
| enable=FLAGS.log_all_worker or (jax.process_index() == 0), |
| ) |
| set_random_seed(FLAGS.seed) |
|
|
| tokenizer = RobertaConfig.get_tokenizer(FLAGS.tokenizer) |
| dataset = DatasetFactory.load_dataset(FLAGS.train_dataset, tokenizer) |
| if FLAGS.load_dataset_state != '': |
| dataset.load_state_dict(mlxu.load_pickle(FLAGS.load_dataset_state)) |
|
|
| if FLAGS.eval_steps > 0: |
| eval_dataset = DatasetFactory.load_dataset( |
| FLAGS.eval_dataset, dataset.tokenizer |
| ) |
| eval_iterator = iter(eval_dataset) |
|
|
| seq_length = dataset.seq_length |
|
|
| if FLAGS.load_roberta_config != '': |
| roberta_config = RobertaConfig.load_config(FLAGS.load_roberta_config) |
| else: |
| roberta_config = RobertaConfig(**FLAGS.roberta) |
|
|
| if FLAGS.update_roberta_config != '': |
| roberta_config.update(dict(eval(FLAGS.update_roberta_config))) |
|
|
| roberta_config.update(dict( |
| bos_token_id=dataset.tokenizer.bos_token_id, |
| eos_token_id=dataset.tokenizer.eos_token_id, |
| pad_token_id=dataset.tokenizer.pad_token_id, |
| vocab_size=dataset.vocab_size, |
| )) |
|
|
| model = FlaxRobertaForMaskedLMModule( |
| roberta_config, dtype=get_float_dtype_by_name(FLAGS.dtype) |
| ) |
|
|
| optimizer, optimizer_info = OptimizerFactory.get_optimizer( |
| FLAGS.optimizer, |
| get_weight_decay_mask(RobertaConfig.get_weight_decay_exclusions()), |
| ) |
|
|
| def create_trainstate_from_params(params): |
| return TrainState.create(params=params, tx=optimizer, apply_fn=None) |
|
|
| def init_fn(rng): |
| rng_generator = JaxRNG(rng) |
| params = model.init( |
| input_ids=jnp.zeros((4, seq_length), dtype=jnp.int32), |
| position_ids=jnp.zeros((4, seq_length), dtype=jnp.int32), |
| attention_mask=jnp.ones((4, seq_length), dtype=jnp.int32), |
| token_type_ids=None, |
| head_mask=None, |
| rngs=rng_generator(roberta_config.rng_keys()), |
| ) |
| return TrainState.create(params=params, tx=optimizer, apply_fn=None) |
|
|
| def train_step(train_state, rng, batch): |
| rng_generator = JaxRNG(rng) |
| tokens = with_sharding_constraint(batch['target_tokens'], PS(('dp', 'fsdp'))) |
| def loss_and_accuracy(params): |
| altered_tokens = jax.random.uniform( |
| rng_generator(), shape=tokens.shape |
| ) < FLAGS.mask_token_probability |
| random_uniform = jax.random.uniform(rng_generator(), shape=tokens.shape) |
| altered_by_mask = altered_tokens & (random_uniform < 0.8) |
| altered_by_random = altered_tokens & (random_uniform >= 0.8) & (random_uniform < 0.9) |
| inputs = jnp.where(altered_by_mask, dataset.tokenizer.mask_token_id, tokens) |
| random_tokens = jax.random.randint( |
| rng_generator(), shape=tokens.shape, minval=0, maxval=dataset.vocab_size |
| ) |
| inputs = jnp.where(altered_by_random, random_tokens, inputs) |
| logits = model.apply( |
| params, inputs, |
| attention_mask=jnp.ones_like(inputs), |
| token_type_ids=None, |
| position_ids=None, |
| head_mask=None, |
| deterministic=False, |
| rngs=rng_generator(roberta_config.rng_keys()), |
| ).logits |
| return cross_entropy_loss_and_accuracy(logits, tokens, valid=altered_tokens) |
| grad_fn = jax.value_and_grad(loss_and_accuracy, has_aux=True) |
| (loss, accuracy), grads = grad_fn(train_state.params) |
| train_state = train_state.apply_gradients(grads=grads) |
| metrics = dict( |
| loss=loss, |
| accuracy=accuracy, |
| learning_rate=optimizer_info['learning_rate_schedule'](train_state.step), |
| gradient_norm=global_norm(grads), |
| param_norm=global_norm(train_state.params), |
| ) |
| return train_state, rng_generator(), metrics |
|
|
| def eval_step(train_state, rng, batch): |
| rng_generator = JaxRNG(rng) |
| tokens = with_sharding_constraint(batch['target_tokens'], PS(('dp', 'fsdp'))) |
| altered_tokens = jax.random.uniform( |
| rng_generator(), shape=tokens.shape |
| ) < FLAGS.mask_token_probability |
| random_uniform = jax.random.uniform(rng_generator(), shape=tokens.shape) |
| altered_by_mask = altered_tokens & (random_uniform < 0.8) |
| altered_by_random = altered_tokens & (random_uniform >= 0.8) & (random_uniform < 0.9) |
| inputs = jnp.where(altered_by_mask, dataset.tokenizer.mask_token_id, tokens) |
| random_tokens = jax.random.randint( |
| rng_generator(), shape=tokens.shape, minval=0, maxval=dataset.vocab_size |
| ) |
| inputs = jnp.where(altered_by_random, random_tokens, inputs) |
| logits = model.apply( |
| train_state.params, inputs, |
| attention_mask=jnp.ones_like(inputs), |
| token_type_ids=None, |
| position_ids=None, |
| head_mask=None, |
| deterministic=False, |
| rngs=rng_generator(roberta_config.rng_keys()), |
| ).logits |
| loss, accuracy = cross_entropy_loss_and_accuracy(logits, tokens, valid=altered_tokens) |
| metrics = dict( |
| eval_loss=loss, |
| eval_accuracy=accuracy, |
| ) |
| return rng_generator(), metrics |
|
|
| train_state_shapes = jax.eval_shape(init_fn, next_rng()) |
| train_state_partition = match_partition_rules( |
| RobertaConfig.get_partition_rules(), train_state_shapes |
| ) |
|
|
| shard_fns, gather_fns = make_shard_and_gather_fns( |
| train_state_partition, train_state_shapes |
| ) |
| checkpointer = StreamingCheckpointer( |
| FLAGS.checkpointer, logger.output_dir, |
| enable=jax.process_index() == 0 |
| ) |
|
|
| sharded_init_fn = pjit( |
| init_fn, |
| in_shardings=PS(), |
| out_shardings=train_state_partition |
| ) |
|
|
| sharded_create_trainstate_from_params = pjit( |
| create_trainstate_from_params, |
| in_shardings=(train_state_partition.params, ), |
| out_shardings=train_state_partition, |
| donate_argnums=(0, ), |
| ) |
|
|
| sharded_train_step = pjit( |
| train_step, |
| in_shardings=(train_state_partition, PS(), PS()), |
| out_shardings=(train_state_partition, PS(), PS()), |
| donate_argnums=(0, 1), |
| ) |
|
|
| sharded_eval_step = pjit( |
| eval_step, |
| in_shardings=(train_state_partition, PS(), PS()), |
| out_shardings=(PS(), PS()), |
| donate_argnums=(1,), |
| ) |
|
|
| def save_checkpoint(train_state, milestone=False): |
| step = int(jax.device_get(train_state.step)) |
| metadata = dict( |
| step=step, |
| variant=variant, |
| flags=flags_config_dict, |
| roberta_config=roberta_config.to_dict(), |
| ) |
| checkpointer.save_all( |
| train_state=train_state, |
| gather_fns=gather_fns, |
| metadata=metadata, |
| dataset=dataset.get_state_dict(), |
| milestone=milestone, |
| ) |
|
|
| mesh = RobertaConfig.get_jax_mesh(FLAGS.mesh_dim) |
| with mesh: |
| train_state, restored_params = None, None |
| if FLAGS.load_checkpoint != '': |
| load_type, load_path = FLAGS.load_checkpoint.split('::', 1) |
| if load_type == 'huggingface': |
| restored_params = tree_apply( |
| shard_fns.params, roberta_config.load_pretrained(load_path) |
| ) |
| train_state = None |
| else: |
| train_state, restored_params = checkpointer.load_trainstate_checkpoint( |
| FLAGS.load_checkpoint, train_state_shapes, shard_fns |
| ) |
|
|
| if train_state is None and restored_params is None: |
| |
| train_state = sharded_init_fn(next_rng()) |
| elif train_state is None and restored_params is not None: |
| |
| train_state = sharded_create_trainstate_from_params(restored_params) |
| del restored_params |
|
|
| start_step = int(jax.device_get(train_state.step)) |
|
|
| if FLAGS.save_model_freq > 0: |
| save_checkpoint(train_state) |
|
|
| sharded_rng = next_rng() |
|
|
| step_counter = trange(start_step, FLAGS.total_steps, ncols=0) |
|
|
| for step, (batch, dataset_metrics) in zip(step_counter, dataset): |
| train_state, sharded_rng, metrics = sharded_train_step( |
| train_state, sharded_rng, batch |
| ) |
|
|
| if step % FLAGS.log_freq == 0: |
| if FLAGS.eval_steps > 0: |
| eval_metric_list = [] |
| for _ in range(FLAGS.eval_steps): |
| eval_batch, _ = next(eval_iterator) |
| sharded_rng, eval_metrics = sharded_eval_step( |
| train_state, sharded_rng, eval_batch |
| ) |
| eval_metric_list.append(eval_metrics) |
| metrics.update(average_metrics(eval_metric_list)) |
|
|
| log_metrics = {"step": step} |
| log_metrics.update(metrics) |
| log_metrics.update(dataset_metrics) |
| log_metrics = jax.device_get(log_metrics) |
| logger.log(log_metrics) |
| tqdm.write("\n" + pprint.pformat(log_metrics) + "\n") |
|
|
| if FLAGS.save_milestone_freq > 0 and (step + 1) % FLAGS.save_milestone_freq == 0: |
| save_checkpoint(train_state, milestone=True) |
| elif FLAGS.save_model_freq > 0 and (step + 1) % FLAGS.save_model_freq == 0: |
| save_checkpoint(train_state) |
|
|
| if FLAGS.save_model_freq > 0: |
| save_checkpoint(train_state) |
|
|
|
|
| if __name__ == "__main__": |
| mlxu.run(main) |
|
|