# Copyright 2024 The Scenic Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Training Script.""" import functools from typing import Any, Callable, Dict, Tuple, Optional, Type from absl import logging from clu import metric_writers from clu import periodic_actions from clu import platform import flax from flax import jax_utils import flax.linen as nn import jax from jax.example_libraries.optimizers import clip_grads import jax.numpy as jnp import jax.profiler import ml_collections import numpy as np import optax from scenic.dataset_lib import dataset_utils from scenic.model_lib.base_models import base_model from scenic.train_lib import lr_schedules from scenic.train_lib import optimizers from scenic.train_lib import train_utils # Aliases for custom types: Batch = Dict[str, jnp.ndarray] MetricFn = Callable[[jnp.ndarray, Dict[str, jnp.ndarray]], Dict[str, Tuple[float, int]]] LossFn = Callable[[jnp.ndarray, Batch, Optional[jnp.ndarray]], float] LrFn = Callable[[jnp.ndarray], jnp.ndarray] flax.config.update('flax_use_orbax_checkpointing', False) def train_step( train_state: train_utils.TrainState, batch: Batch, *, flax_model: nn.Module, loss_fn: LossFn, lr_fn: LrFn, metrics_fn: MetricFn, config: ml_collections.ConfigDict, debug: Optional[bool] = False ) -> Tuple[train_utils.TrainState, Dict[str, Tuple[float, int]], Dict[str, Any]]: """Runs a single step of training. Given the state of the training and a batch of data, computes the loss and updates the parameters of the model. Note that in this code, the buffers of the first (train_state) and second (batch) arguments are donated to the computation. Args: train_state: The state of training including the current global_step, model_state, rng, params, and optimizer. The buffer of this argument can be donated to the computation. batch: A single batch of data. The buffer of this argument can be donated to the computation. flax_model: A Flax model. loss_fn: A loss function that given logits, a batch, and parameters of the model calculates the loss. lr_fn: The learning rate fn used for the logging the learning rate. metrics_fn: A metrics function that given logits and batch of data, calculates the metrics as well as the loss. config: Configurations of the experiment. debug: Whether the debug mode is enabled during training. `debug=True` enables model specific logging/storing some values using jax.host_callback. Returns: Updated state of training and computed metrics and some training logs. """ training_logs = {} new_rng, rng = jax.random.split(train_state.rng) if config.get('mixup') and config.mixup.alpha: mixup_rng, rng = jax.random.split(rng, 2) mixup_rng = train_utils.bind_rng_to_host_device( mixup_rng, axis_name='batch', bind_to=config.mixup.get('bind_to', 'device')) batch = dataset_utils.mixup( batch, config.mixup.alpha, config.mixup.get('image_format', 'NHWC'), rng=mixup_rng) # Bind the rng to the host/device we are on. dropout_rng = train_utils.bind_rng_to_host_device( rng, axis_name='batch', bind_to='device') def training_loss_fn(params): variables = {'params': params, **train_state.model_state} logits, new_model_state = flax_model.apply( variables, batch['inputs'], mutable=['batch_stats'], train=True, rngs={'dropout': dropout_rng}, debug=debug) loss = loss_fn(logits, batch, variables['params']) return loss, (new_model_state, logits) compute_gradient_fn = jax.value_and_grad(training_loss_fn, has_aux=True) (train_cost, (new_model_state, logits)), grad = compute_gradient_fn(train_state.params) del train_cost # Re-use same axis_name as in the call to `pmap(...train_step...)` below. grad = jax.lax.pmean(grad, axis_name='batch') if config.get('max_grad_norm') is not None: grad = clip_grads(grad, config.max_grad_norm) tx = train_state.tx if tx is None: raise ValueError('train_state.tx, the Gradient Transformation, is None') updates, new_opt_state = tx.update( grad, train_state.opt_state, train_state.params ) new_params = optax.apply_updates(train_state.params, updates) training_logs['l2_grads'] = jnp.sqrt( sum([jnp.vdot(g, g) for g in jax.tree_util.tree_leaves(grad)]) ) ps = jax.tree_util.tree_leaves(new_params) training_logs['l2_params'] = jnp.sqrt(sum([jnp.vdot(p, p) for p in ps])) us = jax.tree_util.tree_leaves(updates) training_logs['l2_updates'] = jnp.sqrt(sum([jnp.vdot(u, u) for u in us])) # TODO(dehghani): Can we get this from the optimizer instead? training_logs['learning_rate'] = lr_fn(jnp.asarray([train_state.global_step])) metrics = metrics_fn(logits, batch) new_train_state = train_state.replace( # pytype: disable=attribute-error global_step=train_state.global_step + 1, opt_state=new_opt_state, params=new_params, model_state=new_model_state, rng=new_rng) return new_train_state, metrics, training_logs def eval_step( train_state: train_utils.TrainState, batch: Batch, *, flax_model: nn.Module, metrics_fn: MetricFn, debug: Optional[bool] = False ) -> Tuple[Dict[str, Tuple[float, int]], jnp.ndarray]: """Runs a single step of training. Note that in this code, the buffer of the second argument (batch) is donated to the computation. Assumed API of metrics_fn is: ```metrics = metrics_fn(logits, batch) where batch is yielded by the batch iterator, and metrics is a dictionary mapping metric name to a vector of per example measurements. eval_step will aggregate (by summing) all per example measurements and divide by the aggregated normalizers. For each given metric we compute: 1/N sum_{b in batch_iter} metric(b), where N is the sum of normalizer over all batches. Args: train_state: TrainState, the state of training including the current global_step, model_state, rng, params and optimizer state. The buffer of this argument can be donated to the computation. batch: A single batch of data. a metrics function, that given logits and batch of data, calculates the metrics as well as the loss. flax_model: A Flax model. metrics_fn: A metrics function, that given logits and batch of data, calculates the metrics as well as the loss. debug: Whether the debug mode is enabled during evaluation. `debug=True` enables model specific logging/storing some values using jax.host_callback. Returns: Calculated metrics and logits. """ variables = {'params': train_state.params, **train_state.model_state} logits = flax_model.apply( variables, batch['inputs'], train=False, mutable=False, debug=debug) metrics = metrics_fn(logits, batch) return metrics, logits def train( *, rng: jnp.ndarray, config: ml_collections.ConfigDict, model_cls: Type[base_model.BaseModel], dataset: dataset_utils.Dataset, workdir: str, writer: metric_writers.MetricWriter, ) -> Tuple[train_utils.TrainState, Dict[str, Any], Dict[str, Any]]: """Main training loop lives in this function. Given the model class and dataset, it prepares the items needed to run the training, including the TrainState. Args: rng: Jax rng key. config: Configurations of the experiment. model_cls: Model class; A model has a flax_module, a loss_fn, and a metrics_fn associated with it. dataset: The dataset that has train_iter, eval_iter, meta_data, and optionally, test_iter. workdir: Directory for checkpointing. writer: CLU metrics writer instance. Returns: train_state that has the state of training (including current global_step, model_state, rng, and the optimizer), train_summary and eval_summary which are dict of metrics. These outputs are used for regression testing. """ lead_host = jax.process_index() == 0 # Build the loss_fn, metrics, and flax_model. model = model_cls(config, dataset.meta_data) # Initialize model. rng, init_rng = jax.random.split(rng) (params, model_state, num_trainable_params, gflops) = train_utils.initialize_model( model_def=model.flax_model, input_spec=[(dataset.meta_data['input_shape'], dataset.meta_data.get('input_dtype', jnp.float32))], config=config, rngs=init_rng) # Create optimizer. lr_fn = lr_schedules.get_learning_rate_fn(config) optimizer_config = optimizers.get_optax_optimizer_config(config) # If the config is already an optax-compatible config, better call directly: # optimizers.get_optimizer(config.optimizer_configs, lr_fn) tx = optimizers.get_optimizer(optimizer_config, lr_fn, params=params) # We jit this, such that the arrays that are created on the same device as the # input is, in this case the CPU. Else they'd be on device[0]. opt_state = jax.jit(tx.init, backend='cpu')(params) rng, train_rng = jax.random.split(rng) # Create chrono class to track and store training statistics and metadata: chrono = train_utils.Chrono() train_state = train_utils.TrainState( global_step=0, opt_state=opt_state, tx=tx, params=params, model_state=model_state, rng=train_rng, metadata={'chrono': chrono.save()}) start_step = train_state.global_step if config.checkpoint: train_state, start_step = train_utils.restore_checkpoint( workdir, train_state) chrono.load(train_state.metadata['chrono']) train_state = train_state.replace(metadata={}) # Replicate the optimizer, state, and rng. train_state = jax_utils.replicate(train_state) del params # Do not keep a copy of the initial params. # Calculate the total number of training steps. total_steps, steps_per_epoch = train_utils.get_num_training_steps( config, dataset.meta_data) train_step_pmapped = jax.pmap( functools.partial( train_step, flax_model=model.flax_model, loss_fn=model.loss_function, lr_fn=lr_fn, metrics_fn=model.get_metrics_fn('train'), config=config, debug=config.debug_train), axis_name='batch', # We can donate both buffers of train_state and train_batch. donate_argnums=(0, 1), ) eval_step_pmapped = jax.pmap( functools.partial( eval_step, flax_model=model.flax_model, metrics_fn=model.get_metrics_fn('validation'), debug=config.debug_eval), axis_name='batch', # We can donate the eval_batch's buffer. donate_argnums=(1,), ) log_eval_steps = config.get('log_eval_steps') or steps_per_epoch if not log_eval_steps: raise ValueError("'log_eval_steps' should be specified in the config.") checkpoint_steps = config.get('checkpoint_steps') or log_eval_steps max_checkpoint_keep = config.get('max_checkpoint_keep', 3) log_summary_steps = config.get('log_summary_steps') or log_eval_steps # Ceil rounding such that we include the last incomplete batch. eval_batch_size = config.get('eval_batch_size', config.batch_size) total_eval_steps = int( np.ceil(dataset.meta_data['num_eval_examples'] / eval_batch_size)) steps_per_eval = config.get('steps_per_eval') or total_eval_steps train_metrics, extra_training_logs = [], [] train_summary, eval_summary = None, None chrono.inform(start_step, total_steps, config.batch_size, steps_per_epoch) logging.info('Starting training loop at step %d.', start_step + 1) report_progress = periodic_actions.ReportProgress( num_train_steps=total_steps, writer=writer, every_secs=None, every_steps=config.get('report_progress_step', log_summary_steps), ) def write_note(note): if lead_host: platform.work_unit().set_notes(note) hooks = [] if lead_host: hooks.append(report_progress) if config.get('xprof', True) and lead_host: hooks.append(periodic_actions.Profile(num_profile_steps=5, logdir=workdir)) if start_step == 0: step0_log = {'num_trainable_params': num_trainable_params} if gflops: step0_log['gflops'] = gflops writer.write_scalars(1, step0_log) write_note(f'First step compilations...\n{chrono.note}') for step in range(start_step + 1, total_steps + 1): with jax.profiler.StepTraceAnnotation('train', step_num=step): train_batch = next(dataset.train_iter) train_state, t_metrics, t_logs = train_step_pmapped( train_state, train_batch) # This will accumulate metrics in TPU memory up to the point that we log # them. This is no problem for small metrics but may be a problem for # large (e.g. segmentation) metrics. An alternative is to set # `log_summary_steps` to a small number, or to use # `train_utils.unreplicate_and_get` here instead of right before writing # summaries, but that means in each step, we have data transfer between # tpu and host, which might slow down the training. train_metrics.append(t_metrics) # Additional training logs: learning rate: t_logs = jax.tree_util.tree_map(jax_utils.unreplicate, t_logs) extra_training_logs.append(t_logs) for h in hooks: h(step) # Below are once-in-a-while ops -> pause. ###################### LOG TRAIN SUMMARY ######################## if ((step % log_summary_steps == 1) or (step == total_steps) or (lead_host and chrono.warmup)): chrono.pause(wait_for=(train_metrics)) if lead_host: chrono.tick(step, writer, write_note) # train_metrics is list of a dictionaries of metrics, where the shape of # the metrics[key] is [n_local_devices]. However, because metric functions # have a psum, we have already summed across the whole sharded batch, and # what's returned is n_local_devices copies of the same summed metric. # So we do unreplicate and fetch them to host using `unreplicate_and_get`. train_summary = train_utils.log_train_summary( step=step, train_metrics=jax.tree_util.tree_map(train_utils.unreplicate_and_get, train_metrics), extra_training_logs=jax.tree_util.tree_map(jax.device_get, extra_training_logs), writer=writer) # Reset metric accumulation for next evaluation cycle. train_metrics, extra_training_logs = [], [] chrono.resume() ################### EVALUATION ####################### if (step % log_eval_steps == 1) or (step == total_steps): chrono.pause(wait_for=(train_state.params)) with report_progress.timed('eval'): eval_metrics = [] # Sync model state across replicas. train_state = train_utils.sync_model_state_across_replicas(train_state) for _ in range(steps_per_eval): eval_batch = next(dataset.valid_iter) e_metrics, _ = eval_step_pmapped(train_state, eval_batch) eval_metrics.append(train_utils.unreplicate_and_get(e_metrics)) eval_summary = train_utils.log_eval_summary( step=step, eval_metrics=eval_metrics, writer=writer) writer.flush() del eval_metrics chrono.resume() ##################### CHECKPOINTING ################### if ((step % checkpoint_steps == 1 and step > 1) or (step == total_steps)) and config.checkpoint: chrono.pause(wait_for=(train_state.params, train_state.opt_state)) with report_progress.timed('checkpoint'): train_utils.handle_checkpointing( train_state, chrono, workdir, max_checkpoint_keep) chrono.resume() # Wait until computations are done before exiting. train_utils.barrier_across_hosts() # Return the train and eval summary after last step for regression testing. assert train_summary is not None assert eval_summary is not None return train_state, train_summary, eval_summary