| | from typing import Dict, List, Tuple, Optional |
| | from mlagents.trainers.settings import ( |
| | EnvironmentParameterSettings, |
| | ParameterRandomizationSettings, |
| | ) |
| | from collections import defaultdict |
| | from mlagents.trainers.training_status import GlobalTrainingStatus, StatusType |
| |
|
| | from mlagents_envs.logging_util import get_logger |
| |
|
| | logger = get_logger(__name__) |
| |
|
| |
|
| | class EnvironmentParameterManager: |
| | def __init__( |
| | self, |
| | settings: Optional[Dict[str, EnvironmentParameterSettings]] = None, |
| | run_seed: int = -1, |
| | restore: bool = False, |
| | ): |
| | """ |
| | EnvironmentParameterManager manages all the environment parameters of a training |
| | session. It determines when parameters should change and gives access to the |
| | current sampler of each parameter. |
| | :param settings: A dictionary from environment parameter to |
| | EnvironmentParameterSettings. |
| | :param run_seed: When the seed is not provided for an environment parameter, |
| | this seed will be used instead. |
| | :param restore: If true, the EnvironmentParameterManager will use the |
| | GlobalTrainingStatus to try and reload the lesson status of each environment |
| | parameter. |
| | """ |
| | if settings is None: |
| | settings = {} |
| | self._dict_settings = settings |
| | for parameter_name in self._dict_settings.keys(): |
| | initial_lesson = GlobalTrainingStatus.get_parameter_state( |
| | parameter_name, StatusType.LESSON_NUM |
| | ) |
| | if initial_lesson is None or not restore: |
| | GlobalTrainingStatus.set_parameter_state( |
| | parameter_name, StatusType.LESSON_NUM, 0 |
| | ) |
| | self._smoothed_values: Dict[str, float] = defaultdict(float) |
| | for key in self._dict_settings.keys(): |
| | self._smoothed_values[key] = 0.0 |
| | |
| | self._set_sampler_seeds(run_seed) |
| |
|
| | def _set_sampler_seeds(self, seed): |
| | """ |
| | Sets the seeds for the samplers (if no seed was already present). Note that |
| | using the provided seed. |
| | """ |
| | offset = 0 |
| | for settings in self._dict_settings.values(): |
| | for lesson in settings.curriculum: |
| | if lesson.value.seed == -1: |
| | lesson.value.seed = seed + offset |
| | offset += 1 |
| |
|
| | def get_minimum_reward_buffer_size(self, behavior_name: str) -> int: |
| | """ |
| | Calculates the minimum size of the reward buffer a behavior must use. This |
| | method uses the 'min_lesson_length' sampler_parameter to determine this value. |
| | :param behavior_name: The name of the behavior the minimum reward buffer |
| | size corresponds to. |
| | """ |
| | result = 1 |
| | for settings in self._dict_settings.values(): |
| | for lesson in settings.curriculum: |
| | if lesson.completion_criteria is not None: |
| | if lesson.completion_criteria.behavior == behavior_name: |
| | result = max( |
| | result, lesson.completion_criteria.min_lesson_length |
| | ) |
| | return result |
| |
|
| | def get_current_samplers(self) -> Dict[str, ParameterRandomizationSettings]: |
| | """ |
| | Creates a dictionary from environment parameter name to their corresponding |
| | ParameterRandomizationSettings. If curriculum is used, the |
| | ParameterRandomizationSettings corresponds to the sampler of the current lesson. |
| | """ |
| | samplers: Dict[str, ParameterRandomizationSettings] = {} |
| | for param_name, settings in self._dict_settings.items(): |
| | lesson_num = GlobalTrainingStatus.get_parameter_state( |
| | param_name, StatusType.LESSON_NUM |
| | ) |
| | lesson = settings.curriculum[lesson_num] |
| | samplers[param_name] = lesson.value |
| | return samplers |
| |
|
| | def get_current_lesson_number(self) -> Dict[str, int]: |
| | """ |
| | Creates a dictionary from environment parameter to the current lesson number. |
| | If not using curriculum, this number is always 0 for that environment parameter. |
| | """ |
| | result: Dict[str, int] = {} |
| | for parameter_name in self._dict_settings.keys(): |
| | result[parameter_name] = GlobalTrainingStatus.get_parameter_state( |
| | parameter_name, StatusType.LESSON_NUM |
| | ) |
| | return result |
| |
|
| | def log_current_lesson(self, parameter_name: Optional[str] = None) -> None: |
| | """ |
| | Logs the current lesson number and sampler value of the parameter with name |
| | parameter_name. If no parameter_name is provided, the values and lesson |
| | numbers of all parameters will be displayed. |
| | """ |
| | if parameter_name is not None: |
| | settings = self._dict_settings[parameter_name] |
| | lesson_number = GlobalTrainingStatus.get_parameter_state( |
| | parameter_name, StatusType.LESSON_NUM |
| | ) |
| | lesson_name = settings.curriculum[lesson_number].name |
| | lesson_value = settings.curriculum[lesson_number].value |
| | logger.info( |
| | f"Parameter '{parameter_name}' is in lesson '{lesson_name}' " |
| | f"and has value '{lesson_value}'." |
| | ) |
| | else: |
| | for parameter_name, settings in self._dict_settings.items(): |
| | lesson_number = GlobalTrainingStatus.get_parameter_state( |
| | parameter_name, StatusType.LESSON_NUM |
| | ) |
| | lesson_name = settings.curriculum[lesson_number].name |
| | lesson_value = settings.curriculum[lesson_number].value |
| | logger.info( |
| | f"Parameter '{parameter_name}' is in lesson '{lesson_name}' " |
| | f"and has value '{lesson_value}'." |
| | ) |
| |
|
| | def update_lessons( |
| | self, |
| | trainer_steps: Dict[str, int], |
| | trainer_max_steps: Dict[str, int], |
| | trainer_reward_buffer: Dict[str, List[float]], |
| | ) -> Tuple[bool, bool]: |
| | """ |
| | Given progress metrics, calculates if at least one environment parameter is |
| | in a new lesson and if at least one environment parameter requires the env |
| | to reset. |
| | :param trainer_steps: A dictionary from behavior_name to the number of training |
| | steps this behavior's trainer has performed. |
| | :param trainer_max_steps: A dictionary from behavior_name to the maximum number |
| | of training steps this behavior's trainer has performed. |
| | :param trainer_reward_buffer: A dictionary from behavior_name to the list of |
| | the most recent episode returns for this behavior's trainer. |
| | :returns: A tuple of two booleans : (True if any lesson has changed, True if |
| | environment needs to reset) |
| | """ |
| | must_reset = False |
| | updated = False |
| | for param_name, settings in self._dict_settings.items(): |
| | lesson_num = GlobalTrainingStatus.get_parameter_state( |
| | param_name, StatusType.LESSON_NUM |
| | ) |
| | next_lesson_num = lesson_num + 1 |
| | lesson = settings.curriculum[lesson_num] |
| | if ( |
| | lesson.completion_criteria is not None |
| | and len(settings.curriculum) > next_lesson_num |
| | ): |
| | behavior_to_consider = lesson.completion_criteria.behavior |
| | if behavior_to_consider in trainer_steps: |
| | ( |
| | must_increment, |
| | new_smoothing, |
| | ) = lesson.completion_criteria.need_increment( |
| | float(trainer_steps[behavior_to_consider]) |
| | / float(trainer_max_steps[behavior_to_consider]), |
| | trainer_reward_buffer[behavior_to_consider], |
| | self._smoothed_values[param_name], |
| | ) |
| | self._smoothed_values[param_name] = new_smoothing |
| | if must_increment: |
| | GlobalTrainingStatus.set_parameter_state( |
| | param_name, StatusType.LESSON_NUM, next_lesson_num |
| | ) |
| | self.log_current_lesson(param_name) |
| | updated = True |
| | if lesson.completion_criteria.require_reset: |
| | must_reset = True |
| | return updated, must_reset |
| |
|