# Copyright 2024 The Scenic Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utility functions for Training.""" import collections.abc as collections import copy import functools import os import re import time from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple, Union from absl import logging from clu import metric_writers import flax from flax import jax_utils from flax import struct import flax.linen as nn from flax.training import checkpoints import jax import jax.numpy as jnp import ml_collections import numpy as np import optax from scenic.common_lib import debug_utils from scenic.dataset_lib import dataset_utils from scenic.dataset_lib import datasets from scenic.train_lib import optimizers from tensorflow.io import gfile # JAX team is working on type annotation for pytree: # https://github.com/jax-ml/jax/issues/1555 PyTree = Any PRNGKey = jnp.ndarray @struct.dataclass class TrainState: """Dataclass to keep track of state of training. The state of training is structured as a struct.dataclass, which enables instances of this class to be passed into jax transformations like tree_map and pmap. """ tx: Optional[optax.GradientTransformation] = struct.field( default=None, pytree_node=False ) opt_state: Optional[optax.OptState] = None params: Optional[Any] = struct.field(default_factory=dict) global_step: Optional[int] = 0 model_state: Optional[Any] = struct.field(default_factory=dict) rng: Optional[jnp.ndarray] = None metadata: Optional[Dict[str, Any]] = None # NOTE: When using the raw TrainState as the target for checkpoint restoration # in Flax, you should provide the pytree structure, otherwise it might just # silenty ignore restoring the checkpoint subtree if you use with an empty # dict when setting `allow_partial_mpa_restoration=True` and if you set it # to None (e.g., for `metadata`` above), Flax replaces it with a state dict. def __getitem__(self, item): """Make TrainState a subscriptable object.""" return getattr(self, item) def get(self, keyname: str, default: Optional[Any] = None) -> Any: """Return the value for key if it exists otherwise the default.""" try: return self[keyname] except KeyError: return default def expand_dims_for_specs(xs, specs): return jax.tree.map( lambda s, x: jax.tree.map( functools.partial(jnp.expand_dims, axis=tuple(range(len(s)))), x, ), specs, xs, ) def squeeze_for_specs(xs, specs): return jax.tree.map( lambda s, x: jax.tree.map( functools.partial(jnp.squeeze, axis=tuple(range(len(s)))), x, ), specs, xs, ) def initialize_model( *, model_def: nn.Module, input_spec: Sequence[ Union[Tuple[Tuple[int, ...], jnp.dtype], Tuple[int, ...], None] ], config: ml_collections.ConfigDict, rngs: Union[jnp.ndarray, Mapping[str, jnp.ndarray]], train: Optional[bool] = False, **model_kwargs, ) -> Tuple[PyTree, PyTree, int, Optional[float]]: """Initializes parameters and model state. Args: model_def: Definition of a model. input_spec: An iterable of (shape, dtype) pairs specifying the shape and dtype of the inputs. If unspecified the dtype is float32. config: Configurations of the initialization. rngs: Jax rng keys. train: If the scenic model should be initialized in the train mode. **model_kwargs: Kwargs passed to flax model initialization. Returns: Initial params, Init model_state, number of trainable_params, and gflops. """ batch_size = ( (config.batch_size // jax.device_count()) if config.get('batch_size') else None ) dummy_input = [] for spec in input_spec: if spec is not None: in_st = debug_utils.input_spec_to_jax_shape_dtype_struct( spec, batch_size=batch_size ) dummy_input.append(jnp.zeros(in_st.shape, in_st.dtype)) else: dummy_input.append(None) # We want all parameters to be created in host RAM, not on any device, they'll # be sent there later as needed, otherwise we already encountered two # situations where we allocate them twice. @functools.partial(jax.jit, backend='cpu') def _initialize_model(rngs): """Initialization function to be jitted.""" init_model_state, init_params = flax.core.pop( flax.core.freeze( model_def.init( rngs, *dummy_input, train=train, debug=False, **model_kwargs ) ), 'params', ) # Set bias in the head to low value, such that loss is small initially. if config.get('init_head_bias', None) is not None: init_params = flax.core.unfreeze(init_params) init_params['output_projection'] = optimizers.tree_map_with_names( lambda p: jnp.full_like(p, config.init_head_bias), init_params['output_projection'], match_name_fn=lambda name: 'bias' in name, ) init_params = flax.core.freeze(init_params) return init_params, init_model_state if not isinstance(rngs, dict): rngs = {'params': rngs} init_params, init_model_state = _initialize_model(rngs) # Pop out params rng: rngs.pop('params') # Count number of trainable parameters: num_trainable_params = debug_utils.log_param_shapes(init_params) # Count gflops: count_flops = config.get( 'count_flops', ml_collections.ConfigDict({'count_flops': True}) ) if count_flops: variables = {'params': init_params, **init_model_state} flops = debug_utils.compute_flops( flax_model_apply_fn=functools.partial( model_def.apply, variables, train=False, debug=False, rngs=rngs, **model_kwargs, ), input_spec=count_flops.get('input_spec', input_spec), fuse_multiply_add=count_flops.get('fuse_multiply_add', True), ) gflops = flops / (10**9) else: gflops = None return init_params, init_model_state, num_trainable_params, gflops def initialize_model_with_pytree( *, model_def: nn.Module, input_spec: PyTree, config: ml_collections.ConfigDict, rngs: Union[jnp.ndarray, Mapping[str, jnp.ndarray]], unpack_input: bool = True, **model_kwargs, ) -> Tuple[PyTree, PyTree, int, Optional[float]]: """Initializes parameters and model state with a pytree input_spec. This is an extension of the above initialize_model function where we can put pytree `input_spec`. We keep the original function for backward compatibility. If the root type of `input_spec` is `Sequence`, each element is fed to the model as position arguments whereas they are fed as keyword arguments if the root type is `dict`. Args: model_def: Definition of a model. input_spec: A PyTree whose leaves are (shape, dtype) pairs specifying the shape and dtype of the inputs. If unspecified the dtype is float32. config: Configurations of the initialization. rngs: Jax rng keys. unpack_input: Unpack the pytree when feeding it to the model. **model_kwargs: Kwargs passed to flax model initialization. Returns: Initial params, Init model_state, number of trainable_params, and gflops. """ batch_size = ( (config.batch_size // jax.device_count()) if config.get('batch_size') else None ) def check_leaf_spec(spec: Sequence[PyTree]) -> bool: return ( len(spec) == 2 and isinstance(spec[0], collections.Sequence) and all(isinstance(i, int) for i in spec[0]) and isinstance(spec[1], jnp.dtype) ) or (all(isinstance(i, int) for i in spec[0])) def create_dummy_input(spec: PyTree) -> PyTree: if isinstance(spec, dict): return {k: create_dummy_input(v) for k, v in spec.items()} elif isinstance(spec, collections.Sequence): if check_leaf_spec(spec): in_st = debug_utils.input_spec_to_jax_shape_dtype_struct( spec, batch_size=batch_size ) return jnp.zeros(in_st.shape, in_st.dtype) else: return tuple(create_dummy_input(child) for child in spec) elif spec is None: return None else: raise NotImplementedError('Unsupported spec type.', type(spec)) dummy_input = create_dummy_input(input_spec) # We want all parameters to be created in host RAM, not on any device, they'll # be sent there later as needed, otherwise we already encountered two # situations where we allocate them twice. @functools.partial(jax.jit, backend='cpu') def _initialize_model(rngs): """Initialization function to be jitted.""" # If dummy_input is a dict, we feed inputs as keyword arguments, otherwise # feed as position arguments. if isinstance(dummy_input, dict) and unpack_input: init_model_state, init_params = flax.core.pop( flax.core.freeze( model_def.init( rngs, **dummy_input, train=False, debug=False, **model_kwargs ) ), 'params', ) elif isinstance(dummy_input, collections.Sequence) and unpack_input: init_model_state, init_params = flax.core.pop( flax.core.freeze( model_def.init( rngs, *dummy_input, train=False, debug=False, **model_kwargs ) ), 'params', ) else: init_model_state, init_params = flax.core.pop( flax.core.freeze( model_def.init( rngs, dummy_input, train=False, debug=False, **model_kwargs ) ), 'params', ) # Set bias in the head to low value, such that loss is small initially. if config.get('init_head_bias', None) is not None: init_params = flax.core.unfreeze(init_params) init_params['output_projection'] = optimizers.tree_map_with_names( lambda p: jnp.full_like(p, config.init_head_bias), init_params['output_projection'], match_name_fn=lambda name: 'bias' in name, ) init_params = flax.core.freeze(init_params) return init_params, init_model_state if not isinstance(rngs, dict): rngs = {'params': rngs} init_params, init_model_state = _initialize_model(rngs) # Pop out params rng: rngs.pop('params') # Count number of trainable parameters: num_trainable_params = debug_utils.log_param_shapes(init_params) # Count gflops: count_flops = config.get( 'count_flops', ml_collections.ConfigDict({'count_flops': True}) ) if count_flops: variables = {'params': init_params, **init_model_state} flops = debug_utils.compute_flops_with_pytree( flax_model_apply_fn=functools.partial( model_def.apply, variables, train=False, debug=False, rngs=rngs, **model_kwargs, ), input_spec=count_flops.get('input_spec', input_spec), unpack_input=unpack_input, fuse_multiply_add=count_flops.get('fuse_multiply_add', True), ) gflops = flops / (10**9) else: gflops = None return init_params, init_model_state, num_trainable_params, gflops def get_dataset( config: ml_collections.ConfigDict, data_rng: PRNGKey, *, num_local_shards: Optional[int] = None, dataset_service_address: Optional[str] = None, dataset_name: Optional[str] = None, dataset_configs: Optional[ml_collections.ConfigDict] = None, **kwargs: Any, ) -> dataset_utils.Dataset: """Creates dataset. By default, the values in the config file are used. However, if the optional `dataset_name` and `dataset_configs` are passed, those are used instead. Args: config: The configuration of the experiment. data_rng: Random number generator key to use for the dataset. num_local_shards: Number of shards for each batch. So (bs, ...) becomes (num_local_shards, bs//num_local_shards, ...). If not specified, it will be number of local devices. dataset_service_address: Used when using the tf.data.experimental.service dataset_name: Name of dataset to load, if not reading from the config. dataset_configs: Configuration of the dataset, if not reading directly from the config. **kwargs: Keyword arguments passed to the dataset builders. Returns: A dataset_utils.Dataset object. """ device_count = jax.device_count() logging.info('device_count: %d', device_count) logging.info('num_hosts : %d', jax.process_count()) logging.info('host_id : %d', jax.process_index()) dataset_name = dataset_name or config.dataset_name dataset_builder = datasets.get_dataset(dataset_name) batch_size = config.batch_size if batch_size % device_count > 0: raise ValueError( f'Batch size ({batch_size}) must be divisible by the ' f'number of devices ({device_count})' ) eval_batch_size = config.get('eval_batch_size', batch_size) if eval_batch_size % device_count > 0: raise ValueError( f'Eval batch size ({eval_batch_size}) must be divisible ' f'by the number of devices ({device_count})' ) local_batch_size = batch_size // jax.process_count() eval_local_batch_size = eval_batch_size // jax.process_count() device_batch_size = batch_size // device_count logging.info('local_batch_size : %d', local_batch_size) logging.info('device_batch_size : %d', device_batch_size) shuffle_seed = config.get('shuffle_seed', None) if dataset_service_address and shuffle_seed is not None: raise ValueError( 'Using dataset service with a random seed causes each ' 'worker to produce exactly the same data. Add ' 'config.shuffle_seed = None to your config if you want ' 'to run with dataset service.' ) dataset_configs = dataset_configs or config.get('dataset_configs', {}) num_local_shards = num_local_shards or jax.local_device_count() dataset = dataset_builder( batch_size=local_batch_size, eval_batch_size=eval_local_batch_size, num_shards=num_local_shards, dtype_str=config.data_dtype_str, rng=data_rng, shuffle_seed=shuffle_seed, dataset_configs=dataset_configs, dataset_service_address=dataset_service_address, **kwargs, ) return dataset def initialize_multitask_model( *, model_def: nn.Module, input_spec: Dict[ Tuple[Tuple[str, Any], ...], Sequence[Union[Tuple[Tuple[int, ...], jnp.dtype], Tuple[int, ...]]], ], config: ml_collections.ConfigDict, rngs: Union[jnp.ndarray, Mapping[str, jnp.ndarray]], ) -> Tuple[PyTree, PyTree, int, Optional[Dict[str, float]]]: """Initializes parameters and model state. Args: model_def: Definition of a model. input_spec: A dictionary from a dict of keyword arguments to an iterable of (shape, dtype) pairs specifying the shape and dtype of the inputs. If unspecified the dtype is float32. config: Configurations of the initialization. rngs: Jax rng keys. Returns: Initial params, Init model_state, and number of trainable_params. """ def init_fn(model_def): for kwargs, in_spec in input_spec.items(): if config.get('batch_sizes') is not None: batch_size = config.batch_sizes.get(dict(kwargs)['dataset']) else: batch_size = config.batch_size batch_size = (batch_size // jax.device_count()) if batch_size else None input_shapetype = [ debug_utils.input_spec_to_jax_shape_dtype_struct( spec, batch_size=batch_size ) for spec in in_spec ] dummy_input = [] for in_st in input_shapetype: dummy_input.append(jnp.zeros(in_st.shape, in_st.dtype)) model_def(*dummy_input, train=False, debug=False, **dict(kwargs)) # We want all parameters to be created in host RAM, not on any device, they'll # be sent there later as needed, otherwise we already encountered two # situations where we allocate them twice. @functools.partial(jax.jit, backend='cpu') def _initialize_model(rngs): """Initialization function to be jitted.""" init_model_state, init_params = flax.core.pop( flax.core.freeze(nn.init(fn=init_fn, module=model_def)(rngs)), 'params' ) # Set bias in the head to low value, such that loss is small initially. if ( config.get('init_head_bias', None) is not None and 'output_projection' in init_params ): init_params = flax.core.unfreeze(init_params) init_params['output_projection'] = optimizers.tree_map_with_names( lambda p: jnp.full_like(p, config.init_head_bias), init_params['output_projection'], match_name_fn=lambda name: 'bias' in name, ) init_params = flax.core.freeze(init_params) return init_params, init_model_state if not isinstance(rngs, dict): rngs = {'params': rngs} init_params, init_model_state = _initialize_model(rngs) # Pop out params rng: rngs.pop('params') # Count number of trainable parameters: num_trainable_params = debug_utils.log_param_shapes(init_params) # Count gflops: count_flops = config.get('count_flops', ml_collections.ConfigDict()) if count_flops: variables = {'params': init_params, **init_model_state} gflops_dict = {} gflops_all = 0 for kwargs, in_spec in input_spec.items(): flops = debug_utils.compute_flops( flax_model_apply_fn=functools.partial( model_def.apply, variables, train=False, debug=False, rngs=rngs, **dict(kwargs), ), input_spec=count_flops.get('input_spec', in_spec), fuse_multiply_add=count_flops.get('fuse_multiply_add', True), ) gflops = flops / (10**9) gflops_key = 'gflops/' + '/'.join(f'{x}={y}' for x, y in kwargs) gflops_dict[gflops_key] = gflops gflops_all += gflops gflops_dict['gflops'] = gflops_all else: gflops_dict = None return init_params, init_model_state, num_trainable_params, gflops_dict def get_num_training_steps( config: ml_collections.ConfigDict, dataset_metadata: Dict[str, Any] ) -> Tuple[int, Optional[int]]: """Calculates the total number of training step and possibly steps_per_epoch. The main raining loop is based on number of training steps. Thus, for datasets that we want to train based on number of epochs, we need to calculate the total number of training steps. This function looks for `num_training_steps` in config, if it exists it returns that as the total step and `None` as `steps_per_epoch`. If num_training_steps doesn't exist, then it looks for `num_training_epochs` and given the size of training data calculates the total steps and steps_per_epoch. In this computation, we assume that drop_remainder=True. Args: config: Configuration of the experiment. dataset_metadata: Meta-data that is generated by the dataset_builder. Returns: total_steps: Total number of training steps. steps_per_epoch: Number of steps in every epoch. """ # We either use num_training_epochs or num_training_steps. steps_per_epoch = ( dataset_metadata.get('num_train_examples', 0) // config.batch_size ) if config.get('num_training_steps') is not None: assert not config.get('num_training_epochs') return config.num_training_steps, steps_per_epoch or None else: assert config.num_training_epochs and not config.get('num_training_steps') assert steps_per_epoch > 0, 'num_train_examples should be defined.' return int(steps_per_epoch * config.num_training_epochs), steps_per_epoch @functools.partial(jax.pmap, axis_name='x') def pmap_mean(x: PyTree) -> PyTree: # An axis_name is passed to pmap which can then be used by pmean. # In this case each device has its own version of the batch statistics and # we average them. return jax.lax.pmean(x, 'x') def sync_model_state_across_replicas(train_state: TrainState) -> TrainState: """Sync the model_state (like batch statistics) across replicas. Args: train_state: TrainState; Current state of training. Returns: Updated state of training in which model_state is synced across replicas. """ # TODO(dehghani): We simply do "mean" here and this doesn't work with # statistics like variance. (check the discussion in Flax for fixing this). if jax.tree_util.tree_leaves(train_state.model_state): # If the model_state is not empty. new_model_state = flax.core.copy( train_state.model_state, {'batch_stats': pmap_mean(train_state.model_state['batch_stats'])}, ) return train_state.replace( # pytype: disable=attribute-error model_state=new_model_state ) else: return train_state def save_checkpoint( workdir: str, train_state: TrainState, max_to_keep: int = 3, overwrite: bool = False, **kwargs, ): """Saves a checkpoint. Args: workdir: Experiment directory for saving the checkpoint. train_state: An instance of TrainState that holds the state of training. max_to_keep: The number of checkpoints to keep. overwrite: Overwrite existing checkpoint if a checkpoint at the current or a later step already exits (default: False). **kwargs: Passed on to flax.training.checkpoints.save_checkpoint. """ if jax.process_index() == 0: # Get train state from the first replica. checkpoint_state = jax.device_get(train_state) checkpoints.save_checkpoint( workdir, checkpoint_state, int(checkpoint_state.global_step), overwrite=overwrite, keep=max_to_keep, **kwargs, ) SIGNED_FLOAT_RE = re.compile(r'([-+]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][-+]?\d+)?)') def checkpoint_path_step(path: str) -> Optional[float]: """Returns the step number of a checkpoint path. Copied from flax/training/checkpoints.PyTree Args: path: The path to the checkpoint. Returns: The global step corresponding to that checkpoint, or None if it can't be determined. """ for s in SIGNED_FLOAT_RE.split(path)[::-1]: if SIGNED_FLOAT_RE.match(s): return float(s) return None def restore_checkpoint( checkpoint_path: str, train_state: Optional[TrainState] = None, assert_exist: bool = False, step: Optional[int] = None, ) -> Tuple[TrainState, int]: """Restores the last checkpoint. First restores the checkpoint, which is an instance of TrainState that holds the state of training. Args: checkpoint_path: Directory or filename to restore the checkpoint from. train_state: An instance of TrainState that holds the state of training. assert_exist: Assert that there is at least one checkpoint in the given path. step: Step number to load or None to load latest. If specified, checkpoint_path must be a directory. Returns: training state and an int which is the current step. """ if assert_exist: if 'checkpoint_' in checkpoint_path.split('/')[-1]: glob_path = checkpoint_path else: glob_path = os.path.join(checkpoint_path, 'checkpoint_*') if not gfile.glob(glob_path): raise ValueError( 'No checkpoint for the pretrained model is found in: ' f'{checkpoint_path}' ) if train_state is None: raise ValueError( 'Please use `restore_pretrained_checkpoint` for loading' 'a checkpoint without providing a Scenic TrainState.' ) train_state = checkpoints.restore_checkpoint( checkpoint_path, train_state, step ) return train_state, int(train_state.global_step) def bind_rng_to_host_device( rng: jnp.ndarray, axis_name: Union[str, Tuple[str, ...]], bind_to: Optional[str] = None, ) -> jnp.ndarray: """Binds a rng to the host/device we are on. Must be called from within a pmapped function. Note that when binding to "device", we also bind the rng to hosts, as we fold_in the rng with axis_index which is unique for devices across all hosts. Args: rng: A jax.random.PRNGKey. axis_name: The axis of the devices we are binding rng across. bind_to: Must be one of the 'host' or 'device'. None means no binding. Returns: jax.random.PRNGKey specialized to host/device. """ if bind_to is None: return rng if bind_to == 'host': return jax.random.fold_in(rng, jax.process_index()) elif bind_to == 'device': return jax.random.fold_in(rng, jax.lax.axis_index(axis_name)) else: raise ValueError( "`bind_to` should be one of the `[None, 'host', 'device']`" ) class TrainingDivergedError(Exception): pass def normalize_metrics_summary( metrics_summary: Dict[str, Tuple[float, int]], split: str ) -> Dict[str, float]: """Normalize the metrics in summary by its normalizer. Args: metrics_summary: A dictionary mapping metric name to (value, normalizer). split: Split for which we normalize the metrics. Used for logging. Returns: Normalized metrics summary. Raises: TrainingDivergedError: Due to observing a NaN in the metrics. """ # TODO(dehghani): Currently we only support metrics of the form 1/N sum # f(x_i). We may need a more general framework for metrics like # precision and recall. Note in particular that while we're normalizing by # the "metric normalization value" that is val[1], this value is previously # summed up and is defined to be an integer. normalized_metrics_summary = {} for key, val in metrics_summary.items(): normalized_metrics_summary[key] = val[0] / (val[1] + 1e-9) if np.isnan(normalized_metrics_summary[key]): msg = f'NaN detected in {split}_{key} (Unnormalized values: {val})' if split == 'train': raise TrainingDivergedError(msg) else: logging.error('WARNING: Split %s %s', split, msg) return normalized_metrics_summary def stack_forest(forest: PyTree) -> PyTree: """Transposes a list of dicts to dict of lists. For example, given [{'a':1,'b':2}, {'a':3,'b':4}], the output is: {'a': ([1, 3]), 'b': ([2, 4])} Args: forest: a list of dicts Returns: a dict of lists. """ if not forest: return {} stack_args = lambda *args: np.stack(args) return jax.tree_util.tree_map(stack_args, *forest) def unreplicate_and_get(x: PyTree) -> PyTree: return jax.device_get(jax_utils.unreplicate(x)) def process_and_fetch_to_host( pred_or_tgt: Union[jnp.ndarray, Dict[str, jnp.ndarray]], batch_mask: jnp.ndarray, ) -> Union[Sequence[jnp.ndarray], Dict[str, jnp.ndarray]]: """Used to collect predictions and targets of the whole valid/test set. Args: pred_or_tgt: A jnp-array or dict of arrays, each of shape `[n_dev, bs, X,...,Y]. batch_mask: A nd-array of shape `[nun_devices, bs]`, where zero values indicate padded examples. Returns: A list of length n_dev*bs of items, where each item is a dictionary with same keys as `pred_or_tgt` & values are normal np-arrays of shape [X,...,Y]. """ def _split_mini_batches(x): # Fetch to host and filter out padded examples. x = jax.device_get(x)[np.array(batch_mask).astype(bool)] # Split minibatch of examples into a list of examples. x_list = jnp.split(x, x.shape[0], axis=0) # Squeeze out the dummy dimension. return jax.tree_util.tree_map(lambda x: jnp.squeeze(x, axis=0), x_list) pred_or_tgt = jax.tree_util.tree_map(_split_mini_batches, pred_or_tgt) if isinstance(pred_or_tgt, list): # Pred_or_tgt was a single array, so just return the list: return pred_or_tgt else: # Pred_or_tgt was dict of arrays, so convert dict of lists to list of dicts: keys, values = zip(*pred_or_tgt.items()) return [dict(zip(keys, v)) for v in zip(*values)] # pytype: disable=bad-return-type # jax-ndarray @functools.partial(jax.pmap, axis_name='i') def _barrier(x): return jax.lax.psum(x, axis_name='i') def barrier(): """MPI-like barrier.""" jax.device_get(_barrier(jnp.ones((jax.local_device_count(),)))) def log_eval_summary( step: int, *, writer: metric_writers.MetricWriter, eval_metrics: Sequence[Dict[str, Tuple[float, int]]], extra_eval_summary: Optional[Mapping[str, float]] = None, metrics_normalizer_fn: Optional[ Callable[[Dict[str, Tuple[float, int]], str], Dict[str, float]] ] = None, prefix: str = 'valid', key_separator: str = '_', flush_writer: bool = True, ) -> Dict[str, float]: """Computes and logs eval metrics. Args: step: Current step. writer: Metric writer object. eval_metrics: List of dictionaries of calculated metrics. Usually the sequence is the concatenation of the per-eval-step metrics, and every dictionary maps a metric name to an array of (value, normalizer) - where the array index is usually the batch index. extra_eval_summary: A dict containing summaries that are already ready to be logged, e.g. global metrics from eval set, like precision/recall. metrics_normalizer_fn: Used for normalizing metrics. The API for this function is: `new_metrics_dict = metrics_normalizer_fn(metrics_dict, split)`. If set to None, we use the `normalize_metrics_summary` which uses the normalizer paired with each metric to normalize it (after summing both metric and normalizer values). prefix: str; Prefix added to the name of the summaries writen by this function. key_separator: Separator added between the prefix and key. flush_writer: If True, flush the writer after logging. Returns: A dictionary of metrics, mapping both `eval_metrics` and `extra_eval_summary` from metric name (incl. `prefix`) to float value. """ eval_metrics = stack_forest(eval_metrics) # Compute the sum over all examples in all batches. eval_metrics_summary = jax.tree_util.tree_map(lambda x: x.sum(), eval_metrics) # Normalize metrics by the total number of examples. metrics_normalizer_fn = metrics_normalizer_fn or normalize_metrics_summary eval_metrics_summary = metrics_normalizer_fn(eval_metrics_summary, 'eval') # If None, set to an empty dictionary. extra_eval_summary = extra_eval_summary or {} # Adds extra_eval_summary to the returned eval_summary. eval_metrics_summary.update(extra_eval_summary) writer.write_scalars( step, { key_separator.join((prefix, key)): val for key, val in eval_metrics_summary.items() }, ) if flush_writer: writer.flush() return eval_metrics_summary def log_train_summary( step: int, *, writer: metric_writers.MetricWriter, train_metrics: Sequence[Dict[str, Tuple[float, int]]], extra_training_logs: Optional[Sequence[Dict[str, Any]]] = None, metrics_normalizer_fn: Optional[ Callable[[Dict[str, Tuple[float, int]], str], Dict[str, float]] ] = None, prefix: str = 'train', key_separator: str = '_', flush_writer: bool = True, ) -> Dict[str, float]: """Computes and logs train metrics. Args: step: Current step. writer: Summary writer. train_metrics: List of dictionaries of calculated metrics. Usually the sequence is the concatenation of the per-eval-step metrics, and every dictionary maps a metric name to an array of (value, normalizer) - where the array index is usually the batch index. extra_training_logs: List of dictionaries, containing additional training logs, from every train step, e.g. learning rate, Time, num parameters, etc. Their mean will be logged. metrics_normalizer_fn: Used for normalizing metrics. The API for this function is: `new_metrics_dict = metrics_normalizer_fn(metrics_dict, split)`. If set to None, we use the normalize_metrics_summary which uses the normalizer paired with each metric to normalize it. prefix: str; Prefix added to the name of the summaries writen by this function. key_separator: Separator added between the prefix and key. flush_writer: If True, flush the writer after logging. Returns: A dictionary of metrics, mapping `train_metrics from metric name (incl. `prefix`) to float value. """ ##### Prepare metrics: # Get metrics from devices: train_metrics = stack_forest(train_metrics) # Compute the sum over all examples in all batches: train_metrics_summary = jax.tree_util.tree_map( lambda x: x.sum(), train_metrics ) # Normalize metrics by the total number of examples: metrics_normalizer_fn = metrics_normalizer_fn or normalize_metrics_summary train_metrics_summary = metrics_normalizer_fn(train_metrics_summary, 'train') ##### Prepare additional training logs: # If None, set to an empty dictionary. extra_training_logs = extra_training_logs or [{}] train_logs = stack_forest(extra_training_logs) # Metrics: writer.write_scalars( step, { key_separator.join((prefix, key)): val for key, val in train_metrics_summary.items() }, ) # Additional logs: writer.write_scalars( step, {key: val.mean() for key, val in train_logs.items()} ) if flush_writer: writer.flush() return train_metrics_summary def accumulate_gradients( compute_gradient_fn: Callable[ [TrainState, Dict[str, jnp.ndarray], jnp.ndarray], Tuple[Any, jnp.ndarray], ], metrics_fn: Callable[ [jnp.ndarray, Dict[str, jnp.ndarray]], Dict[str, Tuple[float, int]] ], train_state: TrainState, batch: Dict[str, jnp.ndarray], dropout_rng: jnp.ndarray, accum_steps: Optional[int], ) -> Tuple[ Optional[jnp.ndarray], jnp.ndarray, jnp.ndarray, Dict[str, Tuple[float, int]], ]: """Accumulate gradients over multiple steps. This enables training with larger effective batch sizes. Note that currently, gradient accumulation is not supported when the `model_state` is used, e.g., for models that have batch normalization and store batch statistics in the `model_state`. Note that if `accum_steps` <= 1 or is None, then the gradient of a single step is simply returned. Args: compute_gradient_fn: Gradient function, e.g., `jax.value_and_grad( training_loss_fn, ...). metrics_fn: A metrics function that given logits and batch of data, calculates the metrics. train_state: An instance of TrainState that has the parameters of the model, state of the model, etc. batch: A single batch of data. The buffer of this argument can be donated to the computation. dropout_rng: JAX rng key used for dropout. accum_steps: Number of accumulating steps (number of micro batches). When set to None or =<1, no accumulation is done. Returns: A tuple of model_state (e.g., batch statistics), computed gradients, training loss, and calculated metrics. """ params = train_state.params if accum_steps and accum_steps > 1: batch_size = next(iter(batch.values())).shape[0] microbatch_size = batch_size // accum_steps if batch_size % accum_steps != 0: raise ValueError( f'Bad accum_steps {accum_steps} for batch size {batch_size}' ) logging.info( 'Using microbatches: %d microbatches, %d size', accum_steps, microbatch_size, ) def get_microbatch( batch: Dict[str, jnp.ndarray], idx: int ) -> Dict[str, jnp.ndarray]: """Fetch microbatch slice from the given batch.""" return jax.tree_util.tree_map( lambda x: x.reshape((-1, microbatch_size) + x.shape[1:])[idx], batch ) def per_microbatch_compute_gradient_fn( loop_cnt: int, loop_state: Tuple[ jnp.ndarray, jnp.ndarray, jnp.ndarray, Dict[str, Tuple[float, int]] ], ) -> Tuple[ jnp.ndarray, jnp.ndarray, Dict[str, Tuple[float, int]], jnp.ndarray ]: dropout_rng, grad_accum, train_loss_acc, metrics_acc = loop_state dropout_rng, sub_dropout_rng = jax.random.split(dropout_rng) mbatch = get_microbatch(batch, loop_cnt) (train_loss, (_, mlogits)), grad = compute_gradient_fn( params, mbatch, sub_dropout_rng ) metrics = metrics_fn(mlogits, mbatch) # Accumulate gradients and metrics. grad = jax.tree_util.tree_map(jnp.add, grad_accum, grad) metrics = jax.tree_util.tree_map(jnp.add, metrics, metrics_acc) train_loss = jax.tree_util.tree_map(jnp.add, train_loss, train_loss_acc) return dropout_rng, grad, train_loss, metrics # Initialize gradient accumulation loop state. dropout_rng, sub_dropout_rng = jax.random.split(dropout_rng) init_mbatch = get_microbatch(batch, 0) (init_train_loss, (model_state, init_logits)), grad_init = ( compute_gradient_fn(params, init_mbatch, sub_dropout_rng) ) if jax.tree_util.tree_leaves(model_state): # If the model_state is not empty. raise ValueError( 'Gradient accumulation is not supported when the ' 'model_state is in used (e.g. models w/ batch norm).' ) metrics_init = metrics_fn(init_logits, init_mbatch) del init_logits, init_mbatch # Run gradient accumulation loop. loop_init = (dropout_rng, grad_init, init_train_loss, metrics_init) _, grad_acc, train_loss, metrics_acc = jax.lax.fori_loop( 1, accum_steps, per_microbatch_compute_gradient_fn, loop_init ) grad_acc = jax.tree_util.tree_map(lambda x: x / accum_steps, grad_acc) train_loss = jax.tree_util.tree_map(lambda x: x / accum_steps, train_loss) return model_state, grad_acc, train_loss, metrics_acc else: (train_loss, (model_state, logits)), grad = compute_gradient_fn( params, batch, dropout_rng ) metrics = metrics_fn(logits, batch) return model_state, grad, train_loss, metrics class Chrono: """Measures time and reports progress. This is a modified fork of Chrono class from big_vision codebase: https://github.com/google-research/big_vision/blob/main/big_vision/utils.py Some concepts: 1. This differentiates between three "types" of time: - training time: the time spent on actual training (fprop/bprop/update) - program time: overall time the program runs, including all overheads - pause time: the chronometer can be paused (eg during evals). 2. This handles a "warmup": the first step is skipped for training time purposes, as it includes significant compilation overheads, which distort estimates. 3. `accumulates` (i.e. integrates) timings, and saves/loads them across restarts. """ def __init__(self, example_type: str = 'img', warmup: int = 2): self.program_start_time = time.monotonic() self.train_start_time = None self.train_start_step = None # When we started timing (after warmup) self.prev_time = None self.prev_step = None self.pause_start = None self.paused_time = 0 self.warmup = warmup # How many calls to `tick` to skip. self.load() # Inits accum integrators. self.note = 'Chrono n/a' self.example_type = example_type def inform( self, first_step: int, total_steps: int, global_bs: int, steps_per_epoch: int, ): """Provide some extra info that's only known later in the program.""" self.prev_step = copy.deepcopy(first_step) self.first_step = copy.deepcopy(first_step) self.total_steps = total_steps self.steps_per_epoch = steps_per_epoch self.global_bs = global_bs if total_steps: self.note = ( f'Steps:{first_step}/{total_steps} [{first_step/total_steps:.1%}]' ) def tick( self, step: int, writer: metric_writers.MetricWriter, write_note: Callable[[str], None], ): """A chronometer tick.""" summary = {} def hms(s): """Format time in hours/minutes/seconds.""" if s < 60: return f'{s:.0f}s' m, s = divmod(s, 60) if m < 60: return f'{m:.0f}m{s:.0f}s' h, m = divmod(m, 60) return f'{h:.0f}h{m:.0f}m' # Seconds intentionally omitted. now = time.monotonic() summary.update({'uptime': now - self.program_start_time}) # We always count examples, regardless of the timing-related warmup that # happens a few lines below. ds = step - self.prev_step # Steps between ticks self.prev_step = step self.accum_examples_seen += ds * self.global_bs summary.update({'examples_seen': self.accum_examples_seen}) if self.steps_per_epoch: summary.update({'epoch': step / self.steps_per_epoch}) # We take the start as the second time `tick` is called, so we avoid # measuring the overhead of compilation and don't include it in time # estimates. if self.warmup > 1: self.warmup -= 1 write_note(self.note) # This can help debugging. return if self.warmup == 1: self.train_start_time = self.prev_time = now self.train_start_step = step self.accum_program_time += now - self.program_start_time self.paused_time = 0 # Drop pauses that happened before timing starts. self.warmup = 0 write_note(self.note) # This can help debugging. return # Measurement with micro-timings of current training steps speed. # Time between ticks (ignoring pause) if self.prev_time is None: raise ValueError('prev_time is None, possible warmup was skipped') dt = now - self.prev_time - self.paused_time ncores = jax.device_count() # Global device count summary.update({ f'{self.example_type}/sec/core': self.global_bs * ds / dt / ncores, f'{self.example_type}/sec': self.global_bs * ds / dt, }) # Accumulate (integrate) times, good for plots. self.accum_train_time += dt self.accum_pause_time += self.paused_time self.accum_program_time += dt + self.paused_time # Convert to, and log as, core hours. core_hours = self.accum_train_time * ncores / 60 / 60 devtype = jax.devices()[0].device_kind summary.update({ f'core_hours_{devtype}': core_hours, 'core_hours': core_hours, # For convenience as x-axis in sweeps. }) # Progress note with "global" full-program average timings # (eg in program-time minus warmup) dt = now - self.train_start_time # Time elapsed since end of warmup. steps_timed = step - self.train_start_step steps_todo = self.total_steps - step self.note = f'Steps:{step}/{self.total_steps} [{step/self.total_steps:.1%}]' self.note += f'\nWalltime:{hms(self.accum_program_time)}' self.note += f' ({hms(self.accum_pause_time)} Not-train)' self.note += f'\nETA:{hms(dt / steps_timed * steps_todo)}' self.note += ( f'\nTotal train time:{hms(dt / steps_timed * self.total_steps)}' ) write_note(self.note) writer.write_scalars(step, summary) self.prev_time = now self.paused_time = 0 def pause(self, wait_for=()): assert self.pause_start is None, "Don't pause twice." jax.block_until_ready(wait_for) self.pause_start = time.monotonic() def resume(self): assert self.pause_start is not None, 'Cannot resume without pausing first.' self.paused_time += time.monotonic() - self.pause_start self.pause_start = None def save(self): return dict( accum_program_time=self.accum_program_time, accum_train_time=self.accum_train_time, accum_pause_time=self.accum_pause_time, accum_examples_seen=self.accum_examples_seen, ) def load(self, ckpt={}): # pylint: disable=dangerous-default-value self.accum_program_time = ckpt.get('accum_program_time', 0.0) self.accum_train_time = ckpt.get('accum_train_time', 0.0) self.accum_pause_time = ckpt.get('accum_pause_time', 0.0) self.accum_examples_seen = ckpt.get('accum_examples_seen', 0) def barrier_across_hosts(): """Ensure all hosts stay up until the end, otherwise the program may hang.""" if jax.process_count() > 1: x = jnp.ones([jax.local_device_count()]) x = jax.device_get(jax.pmap(lambda x: jax.lax.psum(x, 'i'), 'i')(x)) assert x[0] == jax.device_count() def handle_checkpointing( train_state: TrainState, chrono: Chrono, workdir: str, max_checkpoints_to_keep=3, ): """Handles all the bookkeeping around checkpointing. Syncs the training state and unreplicates it, stops & restarts Chrono (and handles its metadata) and writes the actual checkpoint. Args: train_state: A replicated TrainState. chrono: The Chrono object. workdir: the workdir of the process. max_checkpoints_to_keep: how many checkpoints to keep. """ train_state = sync_model_state_across_replicas(train_state) if jax.process_index() == 0: unrep_train_state = jax_utils.unreplicate(train_state) metadata = unrep_train_state.metadata metadata['chrono'] = chrono.save() unrep_train_state = unrep_train_state.replace(metadata=metadata) save_checkpoint( workdir, unrep_train_state, max_to_keep=max_checkpoints_to_keep ) del unrep_train_state