| from typing import Dict, Any |
| from enum import Enum |
| from collections import defaultdict |
| import json |
| import attr |
| import cattr |
|
|
| from mlagents.torch_utils import torch |
| from mlagents_envs.logging_util import get_logger |
| from mlagents.trainers import __version__ |
| from mlagents.trainers.exception import TrainerError |
|
|
| logger = get_logger(__name__) |
|
|
| STATUS_FORMAT_VERSION = "0.3.0" |
|
|
|
|
| class StatusType(Enum): |
| LESSON_NUM = "lesson_num" |
| STATS_METADATA = "metadata" |
| CHECKPOINTS = "checkpoints" |
| FINAL_CHECKPOINT = "final_checkpoint" |
| ELO = "elo" |
|
|
|
|
| @attr.s(auto_attribs=True) |
| class StatusMetaData: |
| stats_format_version: str = STATUS_FORMAT_VERSION |
| mlagents_version: str = __version__ |
| torch_version: str = torch.__version__ |
|
|
| def to_dict(self) -> Dict[str, str]: |
| return cattr.unstructure(self) |
|
|
| @staticmethod |
| def from_dict(import_dict: Dict[str, str]) -> "StatusMetaData": |
| return cattr.structure(import_dict, StatusMetaData) |
|
|
| def check_compatibility(self, other: "StatusMetaData") -> None: |
| """ |
| Check compatibility with a loaded StatsMetaData and warn the user |
| if versions mismatch. This is used for resuming from old checkpoints. |
| """ |
| |
| if self.mlagents_version != other.mlagents_version: |
| logger.warning( |
| "Checkpoint was loaded from a different version of ML-Agents. Some things may not resume properly." |
| ) |
| if self.torch_version != other.torch_version: |
| logger.warning( |
| "PyTorch checkpoint was saved with a different version of PyTorch. Model may not resume properly." |
| ) |
|
|
|
|
| class GlobalTrainingStatus: |
| """ |
| GlobalTrainingStatus class that contains static methods to save global training status and |
| load it on a resume. These are values that might be needed for the training resume that |
| cannot/should not be captured in a model checkpoint, such as curriclum lesson. |
| """ |
|
|
| saved_state: Dict[str, Dict[str, Any]] = defaultdict(lambda: {}) |
|
|
| @staticmethod |
| def load_state(path: str) -> None: |
| """ |
| Load a JSON file that contains saved state. |
| :param path: Path to the JSON file containing the state. |
| """ |
| try: |
| with open(path) as f: |
| loaded_dict = json.load(f) |
| |
| _metadata = loaded_dict[StatusType.STATS_METADATA.value] |
| StatusMetaData.from_dict(_metadata).check_compatibility(StatusMetaData()) |
| |
| GlobalTrainingStatus.saved_state.update(loaded_dict) |
| except FileNotFoundError: |
| logger.warning( |
| "Training status file not found. Not all functions will resume properly." |
| ) |
| except KeyError: |
| raise TrainerError( |
| "Metadata not found, resuming from an incompatible version of ML-Agents." |
| ) |
|
|
| @staticmethod |
| def save_state(path: str) -> None: |
| """ |
| Save a JSON file that contains saved state. |
| :param path: Path to the JSON file containing the state. |
| """ |
| GlobalTrainingStatus.saved_state[ |
| StatusType.STATS_METADATA.value |
| ] = StatusMetaData().to_dict() |
| with open(path, "w") as f: |
| json.dump(GlobalTrainingStatus.saved_state, f, indent=4) |
|
|
| @staticmethod |
| def set_parameter_state(category: str, key: StatusType, value: Any) -> None: |
| """ |
| Stores an arbitrary-named parameter in the global saved state. |
| :param category: The category (usually behavior name) of the parameter. |
| :param key: The parameter, e.g. lesson number. |
| :param value: The value. |
| """ |
| GlobalTrainingStatus.saved_state[category][key.value] = value |
|
|
| @staticmethod |
| def get_parameter_state(category: str, key: StatusType) -> Any: |
| """ |
| Loads an arbitrary-named parameter from training_status.json. |
| If not found, returns None. |
| :param category: The category (usually behavior name) of the parameter. |
| :param key: The statistic, e.g. lesson number. |
| :param value: The value. |
| """ |
| return GlobalTrainingStatus.saved_state[category].get(key.value, None) |
|
|