| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """Training Script.""" |
|
|
| import functools |
| from typing import Any, Callable, Dict, Tuple, Optional, Type |
|
|
| from absl import logging |
| from clu import metric_writers |
| from clu import periodic_actions |
| from clu import platform |
| from flax import jax_utils |
| import flax.linen as nn |
| import jax |
| from jax.example_libraries.optimizers import clip_grads |
| import jax.numpy as jnp |
| import jax.profiler |
| import ml_collections |
| import numpy as np |
| import optax |
| from scenic.dataset_lib import dataset_utils |
| from scenic.model_lib.base_models import base_model |
| from scenic.train_lib import lr_schedules |
| from scenic.train_lib import optimizers |
| from scenic.train_lib import train_utils |
|
|
| |
| Batch = Dict[str, jnp.ndarray] |
| MetricFn = Callable[[jnp.ndarray, Dict[str, jnp.ndarray]], |
| Dict[str, Tuple[float, int]]] |
| LossFn = Callable[ |
| [jnp.ndarray, Batch, Optional[jnp.ndarray], Any], float] |
| LrFn = Callable[[jnp.ndarray], jnp.ndarray] |
|
|
|
|
| def train_step( |
| train_state: train_utils.TrainState, |
| batch: Batch, |
| *, |
| flax_model: nn.Module, |
| loss_fn: LossFn, |
| lr_fn: LrFn, |
| metrics_fn: MetricFn, |
| config: ml_collections.ConfigDict, |
| debug: Optional[bool] = False |
| ) -> Tuple[train_utils.TrainState, Dict[str, Tuple[float, int]], Dict[str, |
| Any]]: |
| """Runs a single step of training. |
| |
| Given the state of the training and a batch of data, computes |
| the loss and updates the parameters of the model. |
| |
| Note that in this code, the buffers of the first (train_state) and second |
| (batch) arguments are donated to the computation. |
| |
| Args: |
| train_state: The state of training including the current global_step, |
| model_state, rng, params, and optimizer. The buffer of this argument can |
| be donated to the computation. |
| batch: A single batch of data. The buffer of this argument can be donated to |
| the computation. |
| flax_model: A Flax model. |
| loss_fn: A loss function that given logits, a batch, and parameters of the |
| model calculates the loss. |
| lr_fn: The learning rate fn used for the logging the learning rate. |
| metrics_fn: A metrics function that given logits and batch of data, |
| calculates the metrics as well as the loss. |
| config: Configurations of the experiment. |
| debug: Whether the debug mode is enabled during training. `debug=True` |
| enables model specific logging/storing some values using |
| jax.host_callback. |
| |
| Returns: |
| Updated state of training and computed metrics and some training logs. |
| """ |
| training_logs = {} |
| new_rng, rng = jax.random.split(train_state.rng) |
|
|
| if config.get('mixup') and config.mixup.alpha: |
| mixup_rng, rng = jax.random.split(rng, 2) |
| mixup_rng = train_utils.bind_rng_to_host_device( |
| mixup_rng, |
| axis_name='batch', |
| bind_to=config.mixup.get('bind_to', 'device')) |
| batch = dataset_utils.mixup( |
| batch, |
| config.mixup.alpha, |
| config.mixup.get('image_format', 'NHWC'), |
| rng=mixup_rng) |
|
|
| |
| dropout_rng = train_utils.bind_rng_to_host_device( |
| rng, axis_name='batch', bind_to='device') |
|
|
| def training_loss_fn(params): |
| variables = {'params': params, **train_state.model_state} |
| (logits, auxiliary_outputs), new_model_state = flax_model.apply( |
| variables, |
| batch['inputs'], |
| mutable=['batch_stats'], |
| train=True, |
| rngs={'dropout': dropout_rng}, |
| debug=debug) |
| loss = loss_fn(logits, batch, variables['params'], auxiliary_outputs) |
| return loss, (new_model_state, logits, auxiliary_outputs) |
|
|
| compute_gradient_fn = jax.value_and_grad(training_loss_fn, has_aux=True) |
| (train_cost, compute_outputs), grad = compute_gradient_fn(train_state.params) |
| (new_model_state, logits, auxiliary_outputs) = compute_outputs |
|
|
| del train_cost |
| |
| grad = jax.lax.pmean(grad, axis_name='batch') |
|
|
| if config.get('max_grad_norm') is not None: |
| grad = clip_grads(grad, config.max_grad_norm) |
|
|
| updates, new_opt_state = train_state.tx.update(grad, train_state.opt_state, |
| train_state.params) |
| new_params = optax.apply_updates(train_state.params, updates) |
|
|
| training_logs['l2_grads'] = jnp.sqrt( |
| sum([jnp.vdot(g, g) for g in jax.tree_util.tree_leaves(grad)])) |
| ps = jax.tree_util.tree_leaves(new_params) |
| training_logs['l2_params'] = jnp.sqrt(sum([jnp.vdot(p, p) for p in ps])) |
| us = jax.tree_util.tree_leaves(updates) |
| training_logs['l2_updates'] = jnp.sqrt(sum([jnp.vdot(u, u) for u in us])) |
| |
| training_logs['learning_rate'] = lr_fn(train_state.global_step) |
|
|
| |
| if config.get('model.ac_config.dynamic_tape_length') is not None: |
| |
| input_masks = auxiliary_outputs[0] |
| input_len = jnp.sum(input_masks, axis=-1) |
| avg_input_len = jnp.mean(input_len) |
| training_logs['sequence length'] = avg_input_len |
| |
| |
| training_logs['sequence length var'] = jnp.var(input_len) |
|
|
| metrics = metrics_fn(logits, batch) |
|
|
| new_train_state = train_state.replace( |
| global_step=train_state.global_step + 1, |
| opt_state=new_opt_state, |
| params=new_params, |
| model_state=new_model_state, |
| rng=new_rng) |
|
|
| return new_train_state, metrics, training_logs |
|
|
|
|
| def eval_step( |
| train_state: train_utils.TrainState, |
| batch: Batch, |
| *, |
| flax_model: nn.Module, |
| metrics_fn: MetricFn, |
| debug: Optional[bool] = False |
| ) -> Tuple[Dict[str, Tuple[float, int]], jnp.ndarray]: |
| """Runs a single step of training. |
| |
| Note that in this code, the buffer of the second argument (batch) is donated |
| to the computation. |
| |
| Assumed API of metrics_fn is: |
| ```metrics = metrics_fn(logits, batch) |
| where batch is yielded by the batch iterator, and metrics is a dictionary |
| mapping metric name to a vector of per example measurements. eval_step will |
| aggregate (by summing) all per example measurements and divide by the |
| aggregated normalizers. For each given metric we compute: |
| 1/N sum_{b in batch_iter} metric(b), where N is the sum of normalizer |
| over all batches. |
| |
| Args: |
| train_state: TrainState, the state of training including the current |
| global_step, model_state, rng, params and optimizer state. The buffer of |
| this argument can be donated to the computation. |
| batch: A single batch of data. a metrics function, that given logits and |
| batch of data, calculates the metrics as well as the loss. |
| flax_model: A Flax model. |
| metrics_fn: A metrics function, that given logits and batch of data, |
| calculates the metrics as well as the loss. |
| debug: Whether the debug mode is enabled during evaluation. `debug=True` |
| enables model specific logging/storing some values using |
| jax.host_callback. |
| |
| Returns: |
| Calculated metrics and logits. |
| """ |
| variables = {'params': train_state.params, **train_state.model_state} |
| logits, _ = flax_model.apply( |
| variables, batch['inputs'], train=False, mutable=False, debug=debug) |
| metrics = metrics_fn(logits, batch) |
| return metrics, logits |
|
|
|
|
| def train( |
| *, |
| rng: jnp.ndarray, |
| config: ml_collections.ConfigDict, |
| model_cls: Type[base_model.BaseModel], |
| dataset: dataset_utils.Dataset, |
| workdir: str, |
| writer: metric_writers.MetricWriter, |
| ) -> Tuple[train_utils.TrainState, Dict[str, Any], Dict[str, Any]]: |
| """Main training loop lives in this function. |
| |
| Given the model class and dataset, it prepares the items needed to run the |
| training, including the TrainState. |
| |
| Args: |
| rng: Jax rng key. |
| config: Configurations of the experiment. |
| model_cls: Model class; A model has a flax_module, a loss_fn, and a |
| metrics_fn associated with it. |
| dataset: The dataset that has train_iter, eval_iter, meta_data, and |
| optionally, test_iter. |
| workdir: Directory for checkpointing. |
| writer: CLU metrics writer instance. |
| |
| Returns: |
| train_state that has the state of training (including current |
| global_step, model_state, rng, and the optimizer), train_summary |
| and eval_summary which are dict of metrics. These outputs are used for |
| regression testing. |
| """ |
| lead_host = jax.process_index() == 0 |
| |
| model = model_cls(config, dataset.meta_data) |
|
|
| |
| rng, init_rng = jax.random.split(rng) |
| (params, model_state, num_trainable_params, |
| gflops) = train_utils.initialize_model( |
| model_def=model.flax_model, |
| input_spec=[(dataset.meta_data['input_shape'], |
| dataset.meta_data.get('input_dtype', jnp.float32))], |
| config=config, |
| rngs=init_rng) |
|
|
| |
| lr_fn = lr_schedules.get_learning_rate_fn(config) |
| optimizer_config = optimizers.get_optax_optimizer_config(config) |
| |
| |
| tx = optimizers.get_optimizer(optimizer_config, lr_fn, params=params) |
| |
| |
| opt_state = jax.jit(tx.init, backend='cpu')(params) |
|
|
| rng, train_rng = jax.random.split(rng) |
|
|
| |
| chrono = train_utils.Chrono() |
|
|
| train_state = train_utils.TrainState( |
| global_step=0, |
| opt_state=opt_state, |
| tx=tx, |
| params=params, |
| model_state=model_state, |
| rng=train_rng, |
| metadata={'chrono': chrono.save()}) |
| start_step = train_state.global_step |
| if config.checkpoint: |
| train_state, start_step = train_utils.restore_checkpoint( |
| workdir, train_state) |
| chrono.load(train_state.metadata['chrono']) |
| train_state = train_state.replace(metadata={}) |
| |
| train_state = jax_utils.replicate(train_state) |
| del params |
|
|
| |
| total_steps, steps_per_epoch = train_utils.get_num_training_steps( |
| config, dataset.meta_data) |
|
|
| train_step_pmapped = jax.pmap( |
| functools.partial( |
| train_step, |
| flax_model=model.flax_model, |
| loss_fn=model.loss_function, |
| lr_fn=lr_fn, |
| metrics_fn=model.get_metrics_fn('train'), |
| config=config, |
| debug=config.debug_train), |
| axis_name='batch', |
| |
| donate_argnums=(0, 1), |
| ) |
| eval_step_pmapped = jax.pmap( |
| functools.partial( |
| eval_step, |
| flax_model=model.flax_model, |
| metrics_fn=model.get_metrics_fn('validation'), |
| debug=config.debug_eval), |
| axis_name='batch', |
| |
| donate_argnums=(1,), |
| ) |
| log_eval_steps = config.get('log_eval_steps') or steps_per_epoch |
| if not log_eval_steps: |
| raise ValueError("'log_eval_steps' should be specified in the config.") |
| checkpoint_steps = config.get('checkpoint_steps') or log_eval_steps |
| log_summary_steps = config.get('log_summary_steps') or log_eval_steps |
|
|
| |
| eval_batch_size = config.get('eval_batch_size', config.batch_size) |
| total_eval_steps = int( |
| np.ceil(dataset.meta_data['num_eval_examples'] / eval_batch_size)) |
| steps_per_eval = config.get('steps_per_eval') or total_eval_steps |
|
|
| train_metrics, extra_training_logs = [], [] |
| train_summary, eval_summary = None, None |
|
|
| chrono.inform(start_step, total_steps, config.batch_size, steps_per_epoch) |
| logging.info('Starting training loop at step %d.', start_step + 1) |
| report_progress = periodic_actions.ReportProgress( |
| num_train_steps=total_steps, writer=writer) |
|
|
| def write_note(note): |
| if lead_host: |
| platform.work_unit().set_notes(note) |
|
|
| hooks = [] |
| if lead_host: |
| hooks.append(report_progress) |
| if config.get('xprof', True) and lead_host: |
| hooks.append(periodic_actions.Profile(num_profile_steps=5, logdir=workdir)) |
|
|
| if start_step == 0: |
| step0_log = {'num_trainable_params': num_trainable_params} |
| if gflops: |
| step0_log['gflops'] = gflops |
| writer.write_scalars(1, step0_log) |
|
|
| write_note(f'First step compilations...\n{chrono.note}') |
| for step in range(start_step + 1, total_steps + 1): |
| with jax.profiler.StepTraceAnnotation('train', step_num=step): |
| train_batch = next(dataset.train_iter) |
| train_state, t_metrics, t_logs = train_step_pmapped( |
| train_state, train_batch) |
| |
| |
| |
| |
| |
| |
| |
| train_metrics.append(t_metrics) |
| |
| t_logs = jax.tree_util.tree_map(jax_utils.unreplicate, t_logs) |
| extra_training_logs.append(t_logs) |
| for h in hooks: |
| h(step) |
| |
| |
| if ((step % log_summary_steps == 1) or (step == total_steps) or |
| (lead_host and chrono.warmup)): |
| chrono.pause(wait_for=(train_metrics)) |
| if lead_host: |
| chrono.tick(step, writer, write_note) |
| |
| |
| |
| |
| |
| train_summary = train_utils.log_train_summary( |
| step=step, |
| train_metrics=jax.tree_util.tree_map(train_utils.unreplicate_and_get, |
| train_metrics), |
| extra_training_logs=jax.tree_util.tree_map(jax.device_get, |
| extra_training_logs), |
| writer=writer) |
| |
| train_metrics, extra_training_logs = [], [] |
| chrono.resume() |
| |
| if (step % log_eval_steps == 1) or (step == total_steps): |
| chrono.pause(wait_for=(train_state.params)) |
| with report_progress.timed('eval'): |
| eval_metrics = [] |
| |
| train_state = train_utils.sync_model_state_across_replicas(train_state) |
| for _ in range(steps_per_eval): |
| eval_batch = next(dataset.valid_iter) |
| e_metrics, _ = eval_step_pmapped(train_state, eval_batch) |
| eval_metrics.append(train_utils.unreplicate_and_get(e_metrics)) |
| eval_summary = train_utils.log_eval_summary( |
| step=step, eval_metrics=eval_metrics, writer=writer) |
| writer.flush() |
| del eval_metrics |
| chrono.resume() |
| |
| if ((step % checkpoint_steps == 0 and step > 0) or |
| (step == total_steps)) and config.checkpoint: |
| chrono.pause(wait_for=(train_state.params, train_state.opt_state)) |
| with report_progress.timed('checkpoint'): |
| |
| train_state = train_utils.sync_model_state_across_replicas(train_state) |
| if lead_host: |
| |
| unrep_train_state = jax_utils.unreplicate(train_state) |
| metadata = unrep_train_state.metadata |
| metadata['chrono'] = chrono.save() |
| unrep_train_state.replace(metadata=metadata) |
| train_utils.save_checkpoint(workdir, unrep_train_state) |
| del unrep_train_state |
| chrono.resume() |
| |
| train_utils.barrier_across_hosts() |
| |
| return train_state, train_summary, eval_summary |
|
|