| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """Training script for Streaming DVC models.""" |
|
|
| import functools |
| import time |
| from typing import Any |
|
|
| from absl import logging |
| from clu import metric_writers |
| from clu import periodic_actions |
| from clu import platform |
| from flax import jax_utils |
| from flax.training import checkpoints |
| import jax |
| import jax.numpy as jnp |
| import jax.profiler |
| import ml_collections |
| from scenic.common_lib import debug_utils |
| from scenic.dataset_lib import dataset_utils |
| from scenic.projects.streaming_dvc import evaluate |
| from scenic.projects.streaming_dvc import partition_utils |
| from scenic.projects.streaming_dvc import train_utils as streaming_dvc_train_utils |
| from scenic.train_lib import lr_schedules |
| from scenic.train_lib import train_utils |
|
|
|
|
| def train_and_evaluate( |
| *, |
| rng: jnp.ndarray, |
| config: ml_collections.ConfigDict, |
| model_cls: Any, |
| dataset: dataset_utils.Dataset, |
| workdir: str, |
| writer: metric_writers.MetricWriter): |
| """Main training loop lives in this function. |
| |
| Args: |
| rng: JAX PRNGKey. |
| config: Configurations of the experiment. |
| model_cls: Model class; A model has a flax_module, a loss_fn, and a |
| metrics_fn associated with it. |
| dataset: The dataset that has train_iter, eval_iter, meta_data, and |
| optionally, test_iter. |
| workdir: Directory for checkpointing. |
| writer: CLU metrics writer instance. |
| |
| Returns: |
| train_sate that has the state of training (including current |
| global_step, model_state, rng, and the optimizer), train_summary |
| and eval_summary which are dict of metrics. These outputs are used for |
| regression testing. |
| """ |
| is_host = jax.process_index() == 0 |
| logging.info('Training with config: %s', config) |
| logging.info('Dataset metadata %s', dataset.meta_data) |
|
|
| model = model_cls(config, dataset.meta_data) |
| rng, init_rng = jax.random.split(rng) |
| input_spec = [ |
| (dataset.meta_data['input_shape'], |
| dataset.meta_data.get('input_dtype', jnp.float32))] |
| if config.get('additional_input_spec', []): |
| input_spec.extend(config.additional_input_spec) |
| (params, model_state, num_trainable_params, gflops) = ( |
| train_utils.initialize_model( |
| model_def=model.flax_model, |
| input_spec=input_spec, |
| config=config, |
| rngs=init_rng, |
| ) |
| ) |
|
|
| |
| frozen_mapping = partition_utils.create_frozen_mask_from_regex( |
| params, config.get('frozen_params') |
| ) |
|
|
| lr_fn = lr_schedules.get_learning_rate_fn(config) |
| _, train_rng = jax.random.split(rng) |
| train_state, num_learnable_params, num_frozen_params = ( |
| partition_utils.create_partitioned_train_state( |
| params, frozen_mapping, config, 0, model_state, train_rng, lr_fn)) |
|
|
| |
| |
| train_state = partition_utils.convert_to_train_state(train_state) |
|
|
| |
| |
| |
| |
| train_state, params_axes = streaming_dvc_train_utils.pop_axes_names( |
| train_state, 'params_axes') |
| train_state = checkpoints.restore_checkpoint(workdir, train_state) |
| train_state = streaming_dvc_train_utils.re_add_axis_names( |
| train_state, params_axes, 'params_axes') |
|
|
| start_step = int(train_state.global_step) |
| if start_step == 0: |
| train_state, start_step = streaming_dvc_train_utils.load_weights( |
| train_state, config) |
| step0_log = {'num_params': num_trainable_params, |
| 'num_learnable_params': num_learnable_params, |
| 'num_frozen_params': num_frozen_params} |
| if gflops: |
| step0_log['gflops'] = gflops |
| writer.write_scalars(1, step0_log) |
| train_state = jax_utils.replicate(train_state) |
| del params |
|
|
| |
| train_state = partition_utils.convert_to_partitioned_train_state( |
| train_state, frozen_mapping) |
|
|
| train_step_pmapped = jax.pmap( |
| functools.partial( |
| partition_utils.train_step_partitioned, |
| flax_model=model.flax_model, |
| loss_and_metrics_fn=model.loss_function, |
| learning_rate_fn=lr_fn, |
| debug=config.debug_train, |
| ), |
| axis_name='batch', donate_argnums=(0,), |
| ) |
|
|
| total_steps, steps_per_epoch = train_utils.get_num_training_steps( |
| config, dataset.meta_data) |
| log_eval_steps = config.get('log_eval_steps', steps_per_epoch) |
| log_summary_steps = config.get('log_summary_steps', 20) |
| checkpoint_steps = config.get('checkpoint_steps') or log_eval_steps |
|
|
| train_metrics, extra_training_logs = [], [] |
| train_summary, eval_summary = None, None |
| eval_batch_size = config.get('eval_batch_size', config.batch_size) |
| chrono = train_utils.Chrono() |
| chrono.inform(start_step, total_steps, config.batch_size, steps_per_epoch) |
|
|
| def write_note(note): |
| if is_host: |
| platform.work_unit().set_notes(note) |
| logging.info(note) |
|
|
| logging.info('Starting training loop at step %d.', start_step + 1) |
| report_progress = periodic_actions.ReportProgress( |
| num_train_steps=total_steps, writer=writer) |
| hooks = [] |
| if is_host: |
| hooks.append(report_progress) |
| if config.get('xprof', True) and is_host: |
| hooks.append(periodic_actions.Profile(num_profile_steps=5, logdir=workdir)) |
|
|
| for step in range(start_step + 1, total_steps + 1): |
| with jax.profiler.StepTraceAnnotation('train', step_num=step): |
| train_batch = next(dataset.train_iter) |
| train_state, lr, train_predictions, metrics = train_step_pmapped( |
| train_state, train_batch) |
| train_metrics.append(metrics) |
| extra_training_logs.append({'learning_rate': lr}) |
| for h in hooks: |
| h(step) |
| chrono.pause() |
| del train_predictions |
|
|
| if (step % log_summary_steps == 0) or (step == total_steps - 1): |
| if is_host: |
| chrono.tick(step, writer, write_note) |
| train_summary = train_utils.log_train_summary( |
| step=step, |
| train_metrics=jax.tree_util.tree_map( |
| train_utils.unreplicate_and_get, train_metrics), |
| extra_training_logs=jax.tree_util.tree_map( |
| train_utils.unreplicate_and_get, extra_training_logs), |
| writer=writer) |
| train_metrics, extra_training_logs = [], [] |
|
|
| eval_first_step = config.get('eval_first_step', True) and step == 1 |
| do_eval = not config.get('not_eval', False) |
| if ((step % log_eval_steps == 0) or (step == total_steps) or ( |
| eval_first_step)) and do_eval: |
| logging.info('Starting evaluation ...') |
| |
| |
| train_state = partition_utils.convert_to_train_state(train_state) |
| start_time = time.time() |
| with report_progress.timed('eval'): |
| train_state = train_utils.sync_model_state_across_replicas(train_state) |
| last_eval_results, last_eval_metrics = evaluate.inference_on_dataset( |
| model.flax_model, |
| train_state, dataset, |
| eval_batch_size=eval_batch_size, |
| is_host=is_host, |
| save_dir=workdir, |
| step=step, |
| config=config) |
| last_eval_step = step |
| train_utils.log_eval_summary( |
| step=last_eval_step, eval_metrics=last_eval_metrics, |
| extra_eval_summary=last_eval_results, writer=writer) |
| duration = time.time() - start_time |
| logging.info('Done with evaluation: %.4f sec.', duration) |
| |
| train_state = partition_utils.convert_to_partitioned_train_state( |
| train_state, frozen_mapping) |
| writer.flush() |
|
|
| if ((step % checkpoint_steps == 0 and step > 0) or |
| (step == total_steps)): |
| with report_progress.timed('checkpoint'): |
| train_state = train_utils.sync_model_state_across_replicas(train_state) |
| if is_host: |
| unrep_train_state = jax_utils.unreplicate(train_state) |
| logging.info('Parameter summary after checkpoint:') |
| debug_utils.log_param_shapes( |
| unrep_train_state.params_learned, |
| description='Learned params') |
| if len(unrep_train_state.params_frozen): |
| debug_utils.log_param_shapes( |
| unrep_train_state.params_frozen, |
| description='Frozen params') |
| |
| |
| unrep_train_state = partition_utils.convert_to_train_state( |
| unrep_train_state) |
| train_utils.save_checkpoint( |
| workdir, unrep_train_state, |
| max_to_keep=config.get('checkpoint_max_to_keep', 1)) |
| del unrep_train_state |
| chrono.resume() |
|
|
| train_utils.barrier() |
| return train_state, train_summary, eval_summary |
|
|