| | |
| | |
| |
|
| |
|
| |
|
| |
|
| |
|
| | import os |
| | import logging |
| | from caffe2.python import core, context |
| | from caffe2.python.net_builder import ops |
| | from caffe2.python.task import ( |
| | final_output, |
| | Node, |
| | Task, |
| | TaskGroup, |
| | TaskOutput, |
| | WorkspaceType, |
| | ) |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| |
|
| | class Job(context.Managed): |
| | """ |
| | A Job defines three TaskGroups: the `init_group`, the `epoch_group` and the |
| | `exit_group` which will be run by a JobRunner. |
| | |
| | The `init_group` will be run only once at startup. Its role is to |
| | initialize globally persistent blobs such as model weights, accumulators |
| | and data file lists. |
| | |
| | The `epoch_group` will be run in a loop after init_group. The loop will |
| | exit when any of the stop signals added with `add_stop_condition` is True |
| | at the end of an epoch. |
| | |
| | The download_group will be run only once, after all the executions of |
| | epoch_group finish. Its role is to collect the distribute scattered |
| | parameters back after training. |
| | |
| | The `exit_group` will be run only once at the very end of the job, the |
| | role of this group is to save the results of training in the end of the job. |
| | |
| | Jobs are context-driven, so that Tasks can be added to the active Job |
| | without having to explicitly pass the job object around. |
| | |
| | Example of usage: |
| | |
| | def build_reader(partitions): |
| | with Job.current().init_group: |
| | reader = HiveReader(init_reader, ..., partitions) |
| | Task(step=init_reader) |
| | with Job.current().epoch_group: |
| | limited_reader = ReaderWithLimit(reader, num_iter=10000) |
| | data_queue = pipe(limited_reader, num_threads=8) |
| | Job.current().add_stop_condition(limited_reader.data_finished()) |
| | return data_queue |
| | |
| | def build_hogwild_trainer(reader, model): |
| | with Job.current().init_group: |
| | Task(step=model.param_init_net) |
| | with Job.current().epoch_group: |
| | pipe(reader, processor=model, num_threads=8) |
| | with Job.current().exit_group: |
| | Task(step=model.save_model_net) |
| | |
| | with Job() as job: |
| | reader = build_reader(partitions) |
| | model = build_model(params) |
| | build_hogwild_trainer(reader, model) |
| | """ |
| | def __init__(self, |
| | init_group=None, epoch_group=None, |
| | download_group=None, exit_group=None, |
| | stop_conditions=None, nodes_to_checkpoint=None): |
| | self.init_group = init_group or TaskGroup( |
| | workspace_type=WorkspaceType.GLOBAL) |
| | self.epoch_group = epoch_group or TaskGroup() |
| | self.download_group = download_group or TaskGroup() |
| | self.exit_group = exit_group or TaskGroup() |
| | self.stop_conditions = stop_conditions or [] |
| | self._nodes_to_checkpoint = nodes_to_checkpoint |
| |
|
| | def nodes_to_checkpoint(self): |
| | if self._nodes_to_checkpoint: |
| | return self._nodes_to_checkpoint |
| | else: |
| | return self.init_group.used_nodes() |
| |
|
| | def compile(self, session_class): |
| | self._nodes_to_checkpoint = self.nodes_to_checkpoint() |
| | self.init_group = session_class.compile(self.init_group) |
| | self.epoch_group = session_class.compile(self.epoch_group) |
| | self.download_group = session_class.compile(self.download_group) |
| | self.exit_group = session_class.compile(self.exit_group) |
| |
|
| | def __enter__(self): |
| | super(Job, self).__enter__() |
| | self.epoch_group.__enter__() |
| | return self |
| |
|
| | def __exit__(self, *args): |
| | self.epoch_group.__exit__() |
| | super(Job, self).__exit__(*args) |
| |
|
| | def add_stop_condition(self, output): |
| | if isinstance(output, core.BlobReference): |
| | t = Task(outputs=[output], group=self.epoch_group) |
| | output = t.outputs()[0] |
| | assert isinstance(output, TaskOutput) |
| | self.stop_conditions.append(output) |
| |
|
| |
|
| | def get_ckpt_filename(node_name, epoch): |
| | """Returns the checkpoint filename. |
| | |
| | Args: |
| | node_name: A string. The name of the node. |
| | epoch: An integer. The checkpoint epoch. |
| | |
| | Returns: |
| | ckpt_filename: A string. The filename of the checkpoint. |
| | """ |
| | return node_name + '.' + str(epoch) |
| |
|
| |
|
| | def db_name(epoch, node_name, db_prefix, path_prefix=None): |
| | """Returns the full db name where checkpoint files are saved. |
| | |
| | Args: |
| | epoch: An integer. The checkpoint epoch. |
| | node_name: A string. The name of the node. |
| | db_prefix: A string. The prefix used to construct full db name. |
| | path_prefix: A string. Optional param used to construct db name or path |
| | where checkpoint files are stored. |
| | Returns: |
| | db_name: A string. The absolute path of full_db_name where checkpoint |
| | files are saved |
| | """ |
| | if path_prefix: |
| | db_name = path_prefix + get_ckpt_filename(node_name, epoch) |
| | else: |
| | ckpt_filename = get_ckpt_filename(node_name, epoch) |
| | db_name = os.path.join(db_prefix, ckpt_filename) |
| | return db_name |
| |
|
| |
|
| | class CheckpointManager(object): |
| | """ |
| | Controls saving and loading of workspaces on every epoch boundary of a job. |
| | If a CheckpointManager instance is passed to JobRunner, then JobRunner will |
| | call `init`, `read` and `save` at different moments in between epoch runs. |
| | |
| | Args: |
| | db_prefix: The prefix used to construct full db name. Since `absolute_path` |
| | is set to True, this will be used as db_name in SaveOp. |
| | node_name: Name of the node where this checkpoint_manager is used. |
| | db_type: Type of database to use for storing checkpoint. |
| | metadata_handler: An optional object capable of reading/writing |
| | checkpoint info in storage of choice. |
| | """ |
| |
|
| | BLOB_NAMES = "blob_names" |
| |
|
| | def __init__(self, db_prefix, node_name, db_type, metadata_handler=None): |
| | self._db_prefix = db_prefix |
| | self._node_name = node_name |
| | self._db_type = db_type |
| | self._metadata_handler = metadata_handler |
| | |
| | self._net = core.Net('!!checkpoint_mngr') |
| | self._blob_names = self._net.AddExternalInput(self.BLOB_NAMES) |
| | self._names_output = None |
| | self._path_prefix = None |
| | self._path_type = None |
| | self._current_db_name = None |
| | self._current_checkpoint_duration = None |
| |
|
| | """ |
| | Initialize the checkpoint manager. Determines all blobs that need to be saved |
| | or loads from a checkpoint. |
| | |
| | Args: |
| | nodes: An array of nodes where this checkpoint manager is running. Should |
| | only contain a single node. |
| | retrieve_from_epoch: Set to a number to load blobs from this epoch. |
| | path_prefix: Used to construct db name or path where checkpoint files are |
| | stored. |
| | path_type: Indicate the type of path where checkpoint files are stored. |
| | """ |
| | def init( |
| | self, |
| | nodes=None, |
| | retrieve_from_epoch=None, |
| | path_prefix=None, |
| | path_type=None |
| | ): |
| | """ |
| | Build a Task that will be run once after the job's `init_group` is run. |
| | This task will determine which blobs need to be checkpointed. |
| | If retrieve_from_epoch is not None, then the checkpoint metadata is |
| | retrieved from a previously saved checkpoint. |
| | """ |
| | assert nodes is None or len(nodes) == 1, ( |
| | 'CheckpointManager only supports single node.') |
| |
|
| | with Task(outputs=[self._blob_names]) as task: |
| | if retrieve_from_epoch is None: |
| | ops.GetAllBlobNames( |
| | [], |
| | self._blob_names, |
| | include_shared=False) |
| | else: |
| | full_db_name = db_name(retrieve_from_epoch, |
| | self._node_name, self._db_prefix, path_prefix) |
| | db_type = path_type or self._db_type |
| | logger.info("Initializing checkpoints from = %s" |
| | % full_db_name) |
| | ops.Load( |
| | [], self._blob_names, |
| | db=full_db_name, |
| | db_type=db_type, |
| | absolute_path=True, |
| | keep_device=True, |
| | ) |
| | self._names_output = task.outputs()[0] |
| | return task |
| |
|
| | def blob_list(self): |
| | assert self._names_output |
| | return self._names_output.fetch().tolist() |
| |
|
| | def _timed_task(self, cp_op_name, add_op): |
| | """ |
| | Build a Task that will measure the time span of checkpoint operations, |
| | once operation is done, time can be read from _current_checkpoint_duration. |
| | |
| | Args: |
| | cp_op_name: A string name of the checkpoint operation. |
| | add_op: A functor to add the checkpoint operation. |
| | |
| | Returns: |
| | A task with timer. |
| | """ |
| | with Task(name=cp_op_name) as task: |
| | with ops.task_init(): |
| | timer = ops.TimerBegin([], counter_name=self._node_name) |
| | add_op() |
| | with ops.task_exit(): |
| | time_span_blob = ops.TimerGetAndEnd(timer) |
| | self._current_checkpoint_duration = final_output(time_span_blob) |
| | return task |
| |
|
| | def collect_checkpoint_stats(self, stats): |
| | """ |
| | Add one checkpoint stats into the stats. |
| | |
| | Args: |
| | stats: A dict of checkpoint stats that will be reported. |
| | """ |
| | if self._current_db_name and self._current_checkpoint_duration: |
| | stats[self._current_db_name] = self._current_checkpoint_duration.fetch()[0] |
| | else: |
| | logger.info( |
| | "Failed to collect checkpoint stats: {}".format( |
| | self._current_db_name |
| | ) |
| | ) |
| |
|
| | def load(self, epoch, path_prefix=None, path_type=None): |
| | """ |
| | Build a Task that will be run by JobRunner when the job is to be |
| | resumed from a given epoch. This task will run a Load op that will |
| | load and deserialize all relevant blobs from a persistent storage. |
| | """ |
| | self._current_db_name = db_name( |
| | epoch, self._node_name, self._db_prefix, path_prefix |
| | ) |
| | db_type = path_type or self._db_type |
| | logger.info("Loading checkpoints from = %s" % self._current_db_name) |
| |
|
| | def add_op(): |
| | ops.Load( |
| | [], |
| | self.blob_list(), |
| | db=self._current_db_name, |
| | db_type=db_type, |
| | absolute_path=True, |
| | keep_device=True, |
| | ) |
| |
|
| | return self._timed_task('checkpoint_load', add_op) |
| |
|
| | def load_blobs_from_checkpoint(self, blob_names, epoch): |
| | """ |
| | Builds a Task that loads only the necessary blobs from a checkpoint of |
| | the given epoch. The necessary blobs are given in the blob_names |
| | argument. |
| | |
| | Args: |
| | blob_names: A list of strings. Each string is the name of a |
| | blob. |
| | epoch: The checkpoint epoch to load from. |
| | |
| | Returns: |
| | A Task which loads the specified blobs from the checkpoint of the |
| | given epoch. |
| | """ |
| | self._current_db_name = db_name(epoch, self._node_name, self._db_prefix) |
| | logger.info('Load from %s' % self._current_db_name) |
| |
|
| | def add_op(): |
| | ops.Load( |
| | [], |
| | blob_names, |
| | db=self._current_db_name, |
| | db_type=self._db_type, |
| | absolute_path=True, |
| | allow_incomplete=True) |
| |
|
| | return self._timed_task('checkpoint_partial_load', add_op) |
| |
|
| | def check_db_exists(self, epoch): |
| | logger.info('Check existence of %s' % |
| | db_name(epoch, self._node_name, self._db_prefix)) |
| | with Task() as task: |
| | existence = ops.Const(False) |
| | ops.DBExists( |
| | [], |
| | [existence], |
| | db_name=db_name(epoch, self._node_name, self._db_prefix), |
| | db_type=self._db_type, |
| | absolute_path=True) |
| | task.add_output(existence) |
| | return task |
| |
|
| | def report_checkpoint_stats(self, action_name): |
| | """ |
| | Report checkpoint operation stats for current node. |
| | |
| | Args: |
| | action_name: A string of the name of checkpoint operation. |
| | """ |
| | all_stats = {} |
| | self.collect_checkpoint_stats(all_stats) |
| | if self._metadata_handler: |
| | self._metadata_handler.report(action_name, all_stats) |
| |
|
| | def save(self, epoch): |
| | """ |
| | Build a Task that is run once after `init_group` and after each |
| | epoch is run. This will execute a Save ops to serialize and persist |
| | blobs present in the global workspace. |
| | """ |
| | self._current_db_name = db_name(epoch, self._node_name, self._db_prefix) |
| | logger.info('Saving to %s' % self._current_db_name) |
| |
|
| | def add_op(): |
| | ops.Save( |
| | self.blob_list(), [], |
| | db=self._current_db_name, |
| | db_type=self._db_type, |
| | absolute_path=True) |
| |
|
| | return self._timed_task('checkpoint_save', add_op) |
| |
|
| | def write_checkpoint_metadata(self, epoch): |
| | """ |
| | Write metadata for checkpoint |
| | |
| | Args: |
| | epoch: An integer. The epoch-id for which checkpoint metadata is |
| | written |
| | """ |
| | if self._metadata_handler is not None: |
| | self._metadata_handler.write(epoch=epoch) |
| |
|
| | def get_resume_from_epoch_id(self, user_epoch=None): |
| | """ |
| | Identify the epoch-id from which Job must resume |
| | |
| | Args: |
| | user_epoch: An integer. Optional parameter for user to explicitly |
| | identify the epoch-id to load checkpoint from |
| | Returns: |
| | epoch: the epoch-id to load checkpoints from |
| | or None if no checkpoints were written |
| | """ |
| | last_epoch = user_epoch |
| | if self._metadata_handler is not None: |
| | last_epoch = self._metadata_handler.last_epoch(user_epoch=user_epoch) |
| | return last_epoch |
| |
|
| | def set_params(self, nodes, path_prefix=None, path_type=None): |
| | """Set parameters associated with CP manager |
| | |
| | Args: |
| | nodes: An array of nodes where this checkpoint manager is running. |
| | path_prefix: Used to construct db name or path where checkpoint files are |
| | stored. |
| | path_type: Indicate the type of path where checkpoint files are stored. |
| | """ |
| | if path_prefix: |
| | self._path_prefix = path_prefix |
| | if path_type: |
| | self._path_type = path_type |
| | if self._metadata_handler: |
| | self._metadata_handler.set_params( |
| | db_prefix=self._db_prefix, |
| | db_type=self._db_type, |
| | node_names=[str(self._node_name)], |
| | path_prefix=self._path_prefix, |
| | path_type=self._path_type) |
| |
|
| | def cp_accessible(self, epoch=None): |
| | """Returns True if Checkpoint data is accessible |
| | |
| | Args: |
| | epoch: An integer. The epoch of the checkpoint. If None, |
| | it implies we need to check if checkpoint directory is accessible |
| | |
| | Returns: |
| | is_cp_accessible: A boolean. Returns True if Checkpoint data is accessible |
| | """ |
| | if self._metadata_handler is not None: |
| | return self._metadata_handler.cp_accessible(epoch) |
| | else: |
| | return True |
| |
|
| |
|
| | class MultiNodeCheckpointManager(object): |
| | """ |
| | Coordinates checkpointing and checkpointing across multiple nodes. |
| | Each of `init`, `load` and `save` will build TaskGroups which will |
| | trigger checkpointing on each of the nodes involved in a distributed job. |
| | |
| | Args: |
| | db_prefix: The prefix used to construct full db name. Since `absolute_path` |
| | is set to True, this will be used as db_name in SaveOp. |
| | db_type: Type of database to use for storing checkpoint. |
| | metadata_handler: An optional object capable of reading/writing |
| | checkpoint info in storage of choice. |
| | """ |
| | def __init__(self, db_prefix, db_type, metadata_handler=None): |
| | self._node_managers = None |
| | self._db_prefix = db_prefix |
| | self._db_type = db_type |
| | self._metadata_handler = metadata_handler |
| | self._path_prefix = None |
| | self._path_type = None |
| |
|
| | def _task_group(self, func, *args, **kw): |
| | assert self._node_managers is not None, 'init must be called first.' |
| | with TaskGroup(WorkspaceType.GLOBAL) as task_group: |
| | for node, manager in self._node_managers: |
| | with Node(node): |
| | func(manager, *args, **kw) |
| | return task_group |
| |
|
| | """ |
| | Args: |
| | nodes: An array of nodes where this checkpoint manager is running. |
| | retrieve_from_epoch: Set to a number to load blobs from this epoch. |
| | path_prefix: Used to construct db name or path where checkpoint files are |
| | stored. |
| | path_type: Indicate the type of path where checkpoint files are stored. |
| | """ |
| | def init( |
| | self, nodes, retrieve_from_epoch=None, path_prefix=None, path_type=None |
| | ): |
| | if self._node_managers is not None: |
| | assert [node for node, _ in self._node_managers] == nodes |
| | return TaskGroup(WorkspaceType.GLOBAL) |
| | self._node_managers = [] |
| | for node in nodes: |
| | with Node(node): |
| | manager = CheckpointManager( |
| | db_prefix=self._db_prefix, |
| | node_name=str(node), |
| | db_type=self._db_type) |
| | self._node_managers.append((node, manager)) |
| | return self._task_group( |
| | CheckpointManager.init, |
| | nodes=[node], |
| | retrieve_from_epoch=retrieve_from_epoch, |
| | path_prefix=path_prefix, |
| | path_type=path_type) |
| |
|
| | def load(self, epoch, path_prefix=None, path_type=None): |
| | return self._task_group( |
| | CheckpointManager.load, |
| | epoch, |
| | path_prefix=path_prefix, |
| | path_type=path_type) |
| |
|
| | def load_blobs_locally(self, nodes, blob_names, epoch, session): |
| | """Loads the necessary blobs from the checkpoints to the current node. |
| | |
| | Args: |
| | blob_names: A list of strings. Each string is the name of a |
| | blob. |
| | epoch: An integer. The checkpoint epoch to load from. |
| | session: A Session object to execute the Load ops. |
| | """ |
| | if self._node_managers is not None: |
| | assert [node for node, _ in self._node_managers] == nodes |
| | else: |
| | self._node_managers = [] |
| | for node in nodes: |
| | with Node(node): |
| | manager = CheckpointManager( |
| | db_prefix=self._db_prefix, |
| | node_name=str(node), |
| | db_type=self._db_type) |
| | self._node_managers.append((node, manager)) |
| | assert self._node_managers is not None, 'must initialize node managers' |
| | for _, manager in self._node_managers: |
| | existence_task = manager.check_db_exists(epoch) |
| | session.run(existence_task) |
| | existence = existence_task.outputs()[0].fetch() |
| | if not existence: |
| | logger.info('DB %s does not exist!' % |
| | db_name(epoch, manager._node_name, manager._db_prefix)) |
| | return False |
| | load_task = manager.load_blobs_from_checkpoint(blob_names, epoch) |
| | session.run(load_task) |
| | logger.info('Successfully loaded from checkpoints.') |
| | return True |
| |
|
| | def get_ckpt_db_name(self, node_name, epoch): |
| | """Returns the DB name of the given node and the given epoch. |
| | |
| | The DB name is effectively the checkpoint path of the given node and |
| | the given epoch. |
| | |
| | Args: |
| | node_name: A string. The node name of interest. |
| | epoch: An integer. The epoch of the checkpoint. |
| | |
| | Returns: |
| | checkpoint_db_name: A string. The checkpoint path of the given |
| | node and the given epoch. |
| | """ |
| | for node, manager in self._node_managers: |
| | if str(node) == node_name: |
| | return db_name(epoch, manager._node_name, manager._db_prefix) |
| |
|
| | def report_checkpoint_stats(self, action_name): |
| | """ |
| | Report the checkpoint stats for all the nodes, we need to aggregate all |
| | the node's stats together so that we know which node's checkpoint |
| | operation dominates. |
| | |
| | Args: |
| | action_name: A string of the name of checkpoint operation. |
| | """ |
| | all_stats = {} |
| | for _, manager in self._node_managers: |
| | manager.collect_checkpoint_stats(all_stats) |
| | logger.debug("checkpoint stats: {}".format(all_stats)) |
| | if self._metadata_handler: |
| | self._metadata_handler.report(action_name, all_stats) |
| |
|
| | def save(self, epoch): |
| | """ |
| | Build a Task that will execute a Save ops to serialize and persist |
| | blobs present in the global workspace. |
| | """ |
| | return self._task_group(CheckpointManager.save, epoch) |
| |
|
| | def write_checkpoint_metadata(self, epoch): |
| | """ |
| | Write metadata for checkpoint |
| | |
| | Args: |
| | epoch: An integer. The epoch-id for which checkpoint metadata is |
| | written |
| | """ |
| | if self._metadata_handler is not None: |
| | self._metadata_handler.write(epoch=epoch) |
| |
|
| | def get_resume_from_epoch_id(self, user_epoch=None): |
| | """ |
| | Identify the epoch-id from which Job must resume |
| | |
| | Args: |
| | user_epoch: An integer. Optional parameter for user to explicitly |
| | identify the epoch-id to load checkpoint from |
| | Returns: |
| | epoch: the epoch-id to load checkpoints from |
| | or None if no checkpoints were written |
| | """ |
| | last_epoch = user_epoch |
| | if self._metadata_handler is not None: |
| | last_epoch = self._metadata_handler.last_epoch(user_epoch=user_epoch) |
| | return last_epoch |
| |
|
| | def set_params(self, nodes, path_prefix=None, path_type=None): |
| | """Set parameters associated with CP manager |
| | |
| | Args: |
| | nodes: An array of nodes where this checkpoint manager is running. |
| | path_prefix: Used to construct db name or path where checkpoint files are |
| | stored. |
| | path_type: Indicate the type of path where checkpoint files are stored. |
| | """ |
| | self._node_names = [str(node) for node in nodes] |
| | if path_prefix: |
| | self._path_prefix = path_prefix |
| | if path_type: |
| | self._path_type = path_type |
| | if self._metadata_handler: |
| | self._metadata_handler.set_params( |
| | db_prefix=self._db_prefix, |
| | db_type=self._db_type, |
| | node_names=self._node_names, |
| | path_prefix=self._path_prefix, |
| | path_type=self._path_type) |
| |
|
| | def cp_accessible(self, epoch=None): |
| | """Returns True if Checkpoint data is accessible |
| | |
| | Args: |
| | epoch: An integer. The epoch of the checkpoint. If None, |
| | it implies we need to check if checkpoint directory is accessible |
| | |
| | Returns: |
| | is_cp_accessible: A boolean. Returns True if Checkpoint data is accessible |
| | """ |
| | if self._metadata_handler is not None: |
| | return self._metadata_handler.cp_accessible(epoch) |
| | else: |
| | return True |
| |
|
| |
|
| | class UploadTaskGroupBuilder(object): |
| | """A simple class to upload checkpoints.""" |
| | def build(self, epoch, checkpoint_manager): |
| | """Builds the task group to upload checkpoints. |
| | |
| | Args: |
| | epoch: An integer. The checkpoint epoch to be uploaded. |
| | checkpoint_manager: Can be a CheckpointManager for single machine |
| | or a MultiNodeCheckpointManager for multi-machine. The manager |
| | that initializes/saves/loads checkpoints. |
| | |
| | Raises: |
| | NotImplementedError: This base class only has the interface, |
| | the implementation will be in the subclasses. |
| | """ |
| | raise NotImplementedError() |
| |
|
| |
|
| | class JobRunner(object): |
| | """ |
| | Implement the runtime logic for jobs with checkpointing at the level of |
| | epoch. Can be used to run either single-host or distributed jobs. Job |
| | runner is a callable to be called once from the master, passing a session |
| | as an argument. This call will block until the Job execution is complete. |
| | |
| | If a checkpoint_manager is passed, checkpoints will be taken after |
| | initialization and after each epoch execution. If, in addition, |
| | `resume_from_epoch` is an epoch number, the corresponding checkpoint will |
| | be loaded and job execution will continue from the given epoch. In |
| | this case, the job's init_group will not be run. |
| | |
| | Refer to checkpoint_test.py for an example. |
| | """ |
| | def __init__(self, job, checkpoint_manager=None, resume_from_epoch=None, |
| | upload_task_group_builder=None): |
| | """Initializes the JobRunner. |
| | |
| | Args: |
| | job: A Job object. The job to be executed. |
| | checkpoint_manager: Can be a CheckpointManager for single machine |
| | or a MultiNodeCheckpointManager for multi-machine. The manager |
| | that initializes/saves/loads checkpoints. |
| | resume_from_epoch: An integer. The epoch to resume from. |
| | upload_task_group_builder: A subclass of the |
| | UploadTaskGroupBuilder. Creates a task group to upload |
| | checkpoints. |
| | """ |
| | self.resume_from_epoch = resume_from_epoch |
| | self.checkpoint_manager = checkpoint_manager |
| | self.job = job |
| | self.upload_task_group_builder = upload_task_group_builder |
| |
|
| | def train(self, session): |
| | """Runs the training flow. |
| | |
| | Args: |
| | session: A Session object. Valid choises are: LocalSession, |
| | LocalHostScheduler, and DistributedSession. It is used to |
| | execute one TaskGroup a time. |
| | """ |
| | |
| | if self.checkpoint_manager: |
| | self.checkpoint_manager.set_params(nodes=self.job.nodes_to_checkpoint()) |
| | self.resume_from_epoch = self.checkpoint_manager.\ |
| | get_resume_from_epoch_id(self.resume_from_epoch) |
| | if self.resume_from_epoch is not None: |
| | logger.info('Resuming from epoch {}'.format(self.resume_from_epoch)) |
| |
|
| | |
| | from_scratch = self.resume_from_epoch is None |
| | if from_scratch: |
| | session.run(self.job.init_group) |
| |
|
| | if self.checkpoint_manager: |
| | logger.info('Preparing checkpoints ...') |
| | session.run(self.checkpoint_manager.init( |
| | self.job.nodes_to_checkpoint(), |
| | retrieve_from_epoch=self.resume_from_epoch)) |
| | |
| | |
| | if from_scratch: |
| | self.save_checkpoints(0, session) |
| | else: |
| | logger.info('Loading checkpoints for epoch {} ...'.format( |
| | self.resume_from_epoch)) |
| | session.run( |
| | self.checkpoint_manager.load(self.resume_from_epoch)) |
| | self.checkpoint_manager.report_checkpoint_stats('checkpoint_load') |
| | logger.info('Checkpoint loaded') |
| |
|
| | logger.info("Finished initializing") |
| |
|
| | |
| | epoch = 1 if from_scratch else self.resume_from_epoch + 1 |
| | while True: |
| | logger.info('Starting epoch %d' % epoch) |
| | session.run(self.job.epoch_group) |
| | logger.info('Finished epoch %d' % epoch) |
| | stop_conditions = [o.fetch() for o in self.job.stop_conditions] |
| |
|
| | if self.checkpoint_manager: |
| | self.save_checkpoints(epoch, session) |
| |
|
| | if any(stop_conditions): |
| | logger.info('Stopping') |
| | break |
| | epoch += 1 |
| | logger.info('Finished training') |
| | |
| | if (self.upload_task_group_builder): |
| | upload_task_group = self.upload_task_group_builder.build( |
| | epoch, self.checkpoint_manager) |
| | session.run(upload_task_group) |
| | logger.info('Finished uploading the checkpoints') |
| |
|
| | |
| | session.run(self.job.download_group) |
| | logger.info('Finished downloading the parameters') |
| |
|
| | |
| | session.run(self.job.exit_group) |
| | logger.info('Finished running the exit group') |
| | return epoch |
| |
|
| | def load_blobs_from_checkpoints(self, blob_names, epoch, session): |
| | """Loads the necessary blobs from the checkpoints. |
| | |
| | Checkpoints store the snapshots of the workspace in each node. |
| | Sometimes we only need to load a subset of the blobs from the |
| | checkpoints. One common scenario is to load only the model blobs from |
| | the checkpoints for evaluation purpose. Given the names of the |
| | necessary blobs, this function goes over all the checkpoints of all the |
| | nodes, but only loads the blobs specified in the blob_names to the |
| | current workspace. |
| | |
| | Args: |
| | blob_names: A list of strings. Each string is the name of a |
| | blob. |
| | epoch: An integer. The checkpoint epoch to load from. |
| | session: A Session object to execute the load ops. |
| | |
| | Raises: |
| | ValueError: When the checkpoint manager is invalid. |
| | """ |
| | if not self.checkpoint_manager: |
| | raise ValueError('Checkpoint manager is None') |
| | logger.info('Loading checkpoint for epoch {} ...'.format(epoch)) |
| | result = self.checkpoint_manager.load_blobs_locally( |
| | self.job.nodes_to_checkpoint(), blob_names, epoch, session) |
| | self.checkpoint_manager.report_checkpoint_stats('checkpoint_partial_load') |
| | return result |
| |
|
| | def save_checkpoints(self, epoch, session): |
| | """Triggers operation to save checkpoints |
| | |
| | This method will trigger the Save ops to serialize and persist the |
| | blobs present in the global workspaace. |
| | |
| | Args: |
| | epoch: An integer. The checkpoint epoch-id that we are saving. |
| | session: A Session object to execute the save ops. |
| | |
| | Raises: |
| | ValueError: When the checkpoint manager is invalid. |
| | """ |
| | if not self.checkpoint_manager: |
| | raise ValueError('Checkpoint manager is None') |
| | try: |
| | is_accessible = self.checkpoint_manager.cp_accessible(epoch=None) |
| | if is_accessible: |
| | logger.info('Saving checkpoints for epoch {}'.format(epoch)) |
| | session.run(self.checkpoint_manager.save(epoch)) |
| | self.checkpoint_manager.write_checkpoint_metadata(epoch) |
| | logger.info('Checkpoints saved') |
| | self.checkpoint_manager.report_checkpoint_stats('checkpoint_save') |
| | else: |
| | logger.warning("Checkpoint files cannot be accessed!") |
| | except Exception as ex: |
| | logger.warning("Unable to write checkpoint for epoch {}. Error={}". |
| | format(epoch, ex)) |
| |
|
| |
|
| | def epoch_limiter(job, num_epochs): |
| | """ |
| | Creates a task that will output True when a given |
| | number of epochs has finished. |
| | """ |
| | with job.init_group: |
| | init_net = core.Net('epoch_counter_init') |
| | counter = init_net.CreateCounter([], init_count=num_epochs - 1) |
| | Task(step=init_net) |
| |
|
| | with job.epoch_group: |
| | epoch_net = core.Net('epoch_countdown') |
| | finished = epoch_net.CountDown(counter) |
| | output = Task(step=epoch_net, outputs=finished).outputs()[0] |
| | job.add_stop_condition(output) |
| |
|