Spaces:
Sleeping
Sleeping
| # Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """TFM common training driver library.""" | |
| # pytype: disable=attribute-error | |
| import os | |
| import tempfile | |
| from typing import Any, List, Mapping, Optional, Tuple | |
| # Import libraries | |
| from absl import logging | |
| import orbit | |
| import tensorflow as tf, tf_keras | |
| from official.core import actions | |
| from official.core import base_task | |
| from official.core import base_trainer | |
| from official.core import config_definitions | |
| from official.core import train_utils | |
| maybe_create_best_ckpt_exporter = train_utils.maybe_create_best_ckpt_exporter | |
| class OrbitExperimentRunner: | |
| """Runs experiment with Orbit training loop. | |
| The default experiment runner for model garden experiments. User can | |
| customize the experiment pipeline by subclassing this class and replacing | |
| components or functions. | |
| For example, an experiment runner with customized checkpoint manager: | |
| ```python | |
| class MyExpRunnerWithExporter(OrbitExperimentRunner): | |
| def _maybe_build_checkpoint_manager(sefl): | |
| # Replaces the default CheckpointManger with a customized one. | |
| return MyCheckpointManager(*args) | |
| # In user code, instead of the orginal | |
| # `OrbitExperimentRunner(..).run(mode)`, now user can do: | |
| MyExpRunnerWithExporter(**needed_kwargs).run(mode) | |
| ``` | |
| Similar override can be done to other components. | |
| """ | |
| def __init__( | |
| self, | |
| distribution_strategy: tf.distribute.Strategy, | |
| task: base_task.Task, | |
| mode: str, | |
| params: config_definitions.ExperimentConfig, | |
| model_dir: str, | |
| run_post_eval: bool = False, | |
| save_summary: bool = True, | |
| train_actions: Optional[List[orbit.Action]] = None, | |
| eval_actions: Optional[List[orbit.Action]] = None, | |
| trainer: Optional[base_trainer.Trainer] = None, | |
| controller_cls=orbit.Controller, | |
| summary_manager: Optional[orbit.utils.SummaryManager] = None, | |
| eval_summary_manager: Optional[orbit.utils.SummaryManager] = None, | |
| enable_async_checkpointing: bool = False, | |
| ): | |
| """Constructor. | |
| Args: | |
| distribution_strategy: A distribution strategy. | |
| task: A Task instance. | |
| mode: A 'str', specifying the mode. Can be 'train', 'eval', | |
| 'train_and_eval' or 'continuous_eval'. | |
| params: ExperimentConfig instance. | |
| model_dir: A 'str', a path to store model checkpoints and summaries. | |
| run_post_eval: Whether to run post eval once after training, metrics logs | |
| are returned. | |
| save_summary: Whether to save train and validation summary. | |
| train_actions: Optional list of Orbit train actions. | |
| eval_actions: Optional list of Orbit eval actions. | |
| trainer: the base_trainer.Trainer instance. It should be created within | |
| the strategy.scope(). | |
| controller_cls: The controller class to manage the train and eval process. | |
| Must be a orbit.Controller subclass. | |
| summary_manager: Instance of the summary manager to override default | |
| summary manager. | |
| eval_summary_manager: Instance of the eval summary manager to override | |
| default eval summary manager. | |
| enable_async_checkpointing: Optional boolean indicating whether to enable | |
| async checkpoint saving. | |
| """ | |
| self.strategy = distribution_strategy or tf.distribute.get_strategy() | |
| self._params = params | |
| self._model_dir = model_dir | |
| self._mode = mode | |
| self._run_post_eval = run_post_eval | |
| self._trainer = trainer or self._build_trainer( | |
| task, | |
| train='train' in mode, | |
| evaluate=('eval' in mode) or run_post_eval) | |
| assert self.trainer is not None | |
| self._checkpoint_manager = self._maybe_build_checkpoint_manager() | |
| self._summary_manager = summary_manager | |
| self._eval_summary_manager = eval_summary_manager | |
| self._controller = self._build_controller( | |
| trainer=self.trainer if 'train' in mode else None, | |
| evaluator=self.trainer, | |
| save_summary=save_summary, | |
| train_actions=train_actions, | |
| eval_actions=eval_actions, | |
| controller_cls=controller_cls, | |
| enable_async_checkpointing=enable_async_checkpointing) | |
| def params(self) -> config_definitions.ExperimentConfig: | |
| """The whole experiment parameters object.""" | |
| return self._params | |
| def model_dir(self) -> str: | |
| """Path to the model folder, which stores checkpoints, params, log, etc.""" | |
| return self._model_dir | |
| def trainer(self) -> base_trainer.Trainer: | |
| """The underlying Orbit Trainer object.""" | |
| return self._trainer | |
| def checkpoint_manager(self) -> Optional[tf.train.CheckpointManager]: | |
| """The CheckpointManager that stores the checkpoints in a train job.""" | |
| return self._checkpoint_manager | |
| def controller(self) -> orbit.Controller: | |
| """The Orbit controller object.""" | |
| return self._controller | |
| def _build_trainer(self, task: base_task.Task, train: bool, | |
| evaluate: bool) -> base_trainer.Trainer: | |
| """Create trainer.""" | |
| with self.strategy.scope(): | |
| trainer = train_utils.create_trainer( | |
| self.params, | |
| task, | |
| train=train, | |
| evaluate=evaluate, | |
| checkpoint_exporter=self._build_best_checkpoint_exporter()) | |
| return trainer | |
| def _build_best_checkpoint_exporter(self): | |
| return maybe_create_best_ckpt_exporter(self.params, self.model_dir) | |
| def _maybe_build_checkpoint_manager( | |
| self) -> Optional[tf.train.CheckpointManager]: | |
| """Maybe create a CheckpointManager.""" | |
| assert self.trainer is not None | |
| if self.trainer.checkpoint: | |
| if self.model_dir is None: | |
| raise ValueError('model_dir must be specified, but got None') | |
| if (not self.strategy) or self.strategy.extended.should_checkpoint: | |
| ckpt_path = self.model_dir | |
| max_to_keep = self.params.trainer.max_to_keep | |
| else: | |
| # In multi worker training we need every worker to save checkpoint, | |
| # because variables can trigger synchronization on read and | |
| # synchronization needs all workers to participate. To avoid workers | |
| # overriding each other we save to a temporary directory on non-chief | |
| # workers. | |
| ckpt_path = tempfile.mkdtemp() | |
| max_to_keep = 1 | |
| checkpoint_manager = tf.train.CheckpointManager( | |
| self.trainer.checkpoint, | |
| directory=ckpt_path, | |
| max_to_keep=max_to_keep, | |
| step_counter=self.trainer.global_step, | |
| checkpoint_interval=self.params.trainer.checkpoint_interval, | |
| init_fn=self.trainer.initialize) | |
| else: | |
| checkpoint_manager = None | |
| return checkpoint_manager | |
| def _build_controller( | |
| self, | |
| trainer, | |
| evaluator, | |
| save_summary: bool = True, | |
| train_actions: Optional[List[orbit.Action]] = None, | |
| eval_actions: Optional[List[orbit.Action]] = None, | |
| controller_cls=orbit.Controller, | |
| enable_async_checkpointing: bool = False, | |
| ) -> orbit.Controller: | |
| """Builds a Orbit controler.""" | |
| train_actions = [] if not train_actions else train_actions | |
| if trainer: | |
| checkpoint_manager = self.checkpoint_manager | |
| assert checkpoint_manager, 'Checkpoint manager required but undefined.' | |
| train_actions += actions.get_train_actions( | |
| self.params, | |
| trainer, | |
| self.model_dir, | |
| checkpoint_manager=checkpoint_manager, | |
| ) | |
| eval_actions = [] if not eval_actions else eval_actions | |
| if evaluator: | |
| eval_actions += actions.get_eval_actions(self.params, evaluator, | |
| self.model_dir) | |
| if save_summary: | |
| eval_summary_dir = os.path.join( | |
| self.model_dir, self.params.trainer.validation_summary_subdir | |
| ) | |
| else: | |
| eval_summary_dir = None | |
| controller = controller_cls( | |
| strategy=self.strategy, | |
| trainer=trainer, | |
| evaluator=evaluator, | |
| global_step=self.trainer.global_step, | |
| steps_per_loop=self.params.trainer.steps_per_loop, | |
| checkpoint_manager=self.checkpoint_manager, | |
| enable_async_checkpointing=enable_async_checkpointing, | |
| summary_dir=os.path.join(self.model_dir, 'train') | |
| if (save_summary) | |
| else None, | |
| eval_summary_dir=eval_summary_dir, | |
| summary_interval=self.params.trainer.summary_interval | |
| if (save_summary) | |
| else None, | |
| train_actions=train_actions, | |
| eval_actions=eval_actions, | |
| summary_manager=self._summary_manager | |
| if hasattr(self, '_summary_manager') | |
| else None, | |
| eval_summary_manager=self._eval_summary_manager | |
| if hasattr(self, '_eval_summary_manager') | |
| else None, | |
| ) | |
| return controller | |
| def run(self) -> Tuple[tf_keras.Model, Mapping[str, Any]]: | |
| """Run experiments by mode. | |
| Returns: | |
| A 2-tuple of (model, eval_logs). | |
| model: `tf_keras.Model` instance. | |
| eval_logs: returns eval metrics logs when run_post_eval is set to True, | |
| otherwise, returns {}. | |
| """ | |
| mode = self._mode | |
| params = self.params | |
| logging.info('Starts to execute mode: %s', mode) | |
| with self.strategy.scope(): | |
| if mode == 'train' or mode == 'train_and_post_eval': | |
| self.controller.train(steps=params.trainer.train_steps) | |
| elif mode == 'train_and_eval': | |
| self.controller.train_and_evaluate( | |
| train_steps=params.trainer.train_steps, | |
| eval_steps=params.trainer.validation_steps, | |
| eval_interval=params.trainer.validation_interval) | |
| elif mode == 'eval': | |
| self.controller.evaluate(steps=params.trainer.validation_steps) | |
| elif mode == 'continuous_eval': | |
| def timeout_fn(): | |
| if self.trainer.global_step.numpy() >= params.trainer.train_steps: | |
| return True | |
| return False | |
| self.controller.evaluate_continuously( | |
| steps=params.trainer.validation_steps, | |
| timeout=params.trainer.continuous_eval_timeout, | |
| timeout_fn=timeout_fn) | |
| else: | |
| raise NotImplementedError('The mode is not implemented: %s' % mode) | |
| num_params = train_utils.try_count_params(self.trainer.model) | |
| if num_params is not None: | |
| logging.info('Number of trainable params in model: %f Millions.', | |
| num_params / 10.**6) | |
| flops = train_utils.try_count_flops(self.trainer.model) | |
| if flops is not None: | |
| logging.info('FLOPs (multi-adds) in model: %f Billions.', | |
| flops / 10.**9 / 2) | |
| if self._run_post_eval or mode == 'train_and_post_eval': | |
| with self.strategy.scope(): | |
| return self.trainer.model, self.controller.evaluate( | |
| steps=params.trainer.validation_steps) | |
| else: | |
| return self.trainer.model, {} | |
| def run_experiment( | |
| distribution_strategy: tf.distribute.Strategy, | |
| task: base_task.Task, | |
| mode: str, | |
| params: config_definitions.ExperimentConfig, | |
| model_dir: str, | |
| run_post_eval: bool = False, | |
| save_summary: bool = True, | |
| train_actions: Optional[List[orbit.Action]] = None, | |
| eval_actions: Optional[List[orbit.Action]] = None, | |
| trainer: Optional[base_trainer.Trainer] = None, | |
| controller_cls=orbit.Controller, | |
| summary_manager: Optional[orbit.utils.SummaryManager] = None, | |
| eval_summary_manager: Optional[orbit.utils.SummaryManager] = None, | |
| enable_async_checkpointing: bool = False, | |
| ) -> Tuple[tf_keras.Model, Mapping[str, Any]]: | |
| """Runs train/eval configured by the experiment params. | |
| Args: | |
| distribution_strategy: A distribution distribution_strategy. | |
| task: A Task instance. | |
| mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval' | |
| or 'continuous_eval'. | |
| params: ExperimentConfig instance. | |
| model_dir: A 'str', a path to store model checkpoints and summaries. | |
| run_post_eval: Whether to run post eval once after training, metrics logs | |
| are returned. | |
| save_summary: Whether to save train and validation summary. | |
| train_actions: Optional list of Orbit train actions. | |
| eval_actions: Optional list of Orbit eval actions. | |
| trainer: the base_trainer.Trainer instance. It should be created within the | |
| strategy.scope(). | |
| controller_cls: The controller class to manage the train and eval process. | |
| Must be a orbit.Controller subclass. | |
| summary_manager: Instance of the summary manager to override default summary | |
| manager. | |
| eval_summary_manager: Instance of the eval summary manager to override | |
| default eval summary manager. | |
| enable_async_checkpointing: Optional boolean indicating whether to enable | |
| async checkpoint saving. | |
| Returns: | |
| A 2-tuple of (model, eval_logs). | |
| model: `tf_keras.Model` instance. | |
| eval_logs: returns eval metrics logs when run_post_eval is set to True, | |
| otherwise, returns {}. | |
| """ | |
| runner = OrbitExperimentRunner( | |
| distribution_strategy=distribution_strategy, | |
| task=task, | |
| mode=mode, | |
| params=params, | |
| model_dir=model_dir, | |
| run_post_eval=run_post_eval, | |
| save_summary=save_summary, | |
| train_actions=train_actions, | |
| eval_actions=eval_actions, | |
| trainer=trainer, | |
| controller_cls=controller_cls, | |
| summary_manager=summary_manager, | |
| eval_summary_manager=eval_summary_manager, | |
| enable_async_checkpointing=enable_async_checkpointing, | |
| ) | |
| return runner.run() | |