| |
| from typing import Dict, Any, Optional, List |
| import os |
| import attr |
| from mlagents.trainers.training_status import GlobalTrainingStatus, StatusType |
| from mlagents_envs.logging_util import get_logger |
|
|
| logger = get_logger(__name__) |
|
|
|
|
| @attr.s(auto_attribs=True) |
| class ModelCheckpoint: |
| steps: int |
| file_path: str |
| reward: Optional[float] |
| creation_time: float |
| auxillary_file_paths: List[str] = attr.ib(factory=list) |
|
|
|
|
| class ModelCheckpointManager: |
| @staticmethod |
| def get_checkpoints(behavior_name: str) -> List[Dict[str, Any]]: |
| checkpoint_list = GlobalTrainingStatus.get_parameter_state( |
| behavior_name, StatusType.CHECKPOINTS |
| ) |
| if not checkpoint_list: |
| checkpoint_list = [] |
| GlobalTrainingStatus.set_parameter_state( |
| behavior_name, StatusType.CHECKPOINTS, checkpoint_list |
| ) |
| return checkpoint_list |
|
|
| @staticmethod |
| def remove_checkpoint(checkpoint: Dict[str, Any]) -> None: |
| """ |
| Removes a checkpoint stored in checkpoint_list. |
| If checkpoint cannot be found, no action is done. |
| |
| :param checkpoint: A checkpoint stored in checkpoint_list |
| """ |
| file_paths: List[str] = [checkpoint["file_path"]] |
| file_paths.extend(checkpoint["auxillary_file_paths"]) |
| for file_path in file_paths: |
| if os.path.exists(file_path): |
| os.remove(file_path) |
| logger.debug(f"Removed checkpoint model {file_path}.") |
| else: |
| logger.debug(f"Checkpoint at {file_path} could not be found.") |
| return |
|
|
| @classmethod |
| def _cleanup_extra_checkpoints( |
| cls, checkpoints: List[Dict], keep_checkpoints: int |
| ) -> List[Dict]: |
| """ |
| Ensures that the number of checkpoints stored are within the number |
| of checkpoints the user defines. If the limit is hit, checkpoints are |
| removed to create room for the next checkpoint to be inserted. |
| |
| :param behavior_name: The behavior name whose checkpoints we will mange. |
| :param keep_checkpoints: Number of checkpoints to record (user-defined). |
| """ |
| while len(checkpoints) > keep_checkpoints: |
| if keep_checkpoints <= 0 or len(checkpoints) == 0: |
| break |
| ModelCheckpointManager.remove_checkpoint(checkpoints.pop(0)) |
| return checkpoints |
|
|
| @classmethod |
| def add_checkpoint( |
| cls, behavior_name: str, new_checkpoint: ModelCheckpoint, keep_checkpoints: int |
| ) -> None: |
| """ |
| Make room for new checkpoint if needed and insert new checkpoint information. |
| :param behavior_name: Behavior name for the checkpoint. |
| :param new_checkpoint: The new checkpoint to be recorded. |
| :param keep_checkpoints: Number of checkpoints to record (user-defined). |
| """ |
| new_checkpoint_dict = attr.asdict(new_checkpoint) |
| checkpoints = cls.get_checkpoints(behavior_name) |
| checkpoints.append(new_checkpoint_dict) |
| cls._cleanup_extra_checkpoints(checkpoints, keep_checkpoints) |
| GlobalTrainingStatus.set_parameter_state( |
| behavior_name, StatusType.CHECKPOINTS, checkpoints |
| ) |
|
|
| @classmethod |
| def track_final_checkpoint( |
| cls, behavior_name: str, final_checkpoint: ModelCheckpoint |
| ) -> None: |
| """ |
| Ensures number of checkpoints stored is within the max number of checkpoints |
| defined by the user and finally stores the information about the final |
| model (or intermediate model if training is interrupted). |
| :param behavior_name: Behavior name of the model. |
| :param final_checkpoint: Checkpoint information for the final model. |
| """ |
| final_model_dict = attr.asdict(final_checkpoint) |
| GlobalTrainingStatus.set_parameter_state( |
| behavior_name, StatusType.FINAL_CHECKPOINT, final_model_dict |
| ) |
|
|