| import math |
| import tempfile |
| import numpy as np |
| from typing import Dict |
| from mlagents.trainers.trainer_controller import TrainerController |
| from mlagents.trainers.trainer import TrainerFactory |
| from mlagents.trainers.simple_env_manager import SimpleEnvManager |
| from mlagents.trainers.stats import StatsReporter, StatsWriter, StatsSummary |
| from mlagents.trainers.environment_parameter_manager import EnvironmentParameterManager |
| from mlagents_envs.side_channel.environment_parameters_channel import ( |
| EnvironmentParametersChannel, |
| ) |
|
|
|
|
| class DebugWriter(StatsWriter): |
| """ |
| Print to stdout so stats can be viewed in pytest |
| """ |
|
|
| def __init__(self): |
| self._last_reward_summary: Dict[str, float] = {} |
|
|
| def get_last_rewards(self): |
| return self._last_reward_summary |
|
|
| def write_stats( |
| self, category: str, values: Dict[str, StatsSummary], step: int |
| ) -> None: |
| for val, stats_summary in values.items(): |
| if ( |
| val == "Environment/Cumulative Reward" |
| or val == "Environment/Group Cumulative Reward" |
| ): |
|
|
| print(step, val, stats_summary.aggregated_value) |
| self._last_reward_summary[category] = stats_summary.aggregated_value |
|
|
|
|
| |
| |
| |
| |
| |
| def default_reward_processor(rewards, last_n_rewards=5): |
| rewards_to_use = rewards[-last_n_rewards:] |
| |
| print(f"Last {last_n_rewards} rewards:", rewards_to_use) |
| return np.array(rewards[-last_n_rewards:], dtype=np.float32).mean() |
|
|
|
|
| def check_environment_trains( |
| env, |
| trainer_config, |
| reward_processor=default_reward_processor, |
| env_parameter_manager=None, |
| success_threshold=0.9, |
| env_manager=None, |
| training_seed=None, |
| ): |
| if env_parameter_manager is None: |
| env_parameter_manager = EnvironmentParameterManager() |
| |
| with tempfile.TemporaryDirectory() as dir: |
| run_id = "id" |
| seed = 1337 if training_seed is None else training_seed |
| StatsReporter.writers.clear() |
| debug_writer = DebugWriter() |
| StatsReporter.add_writer(debug_writer) |
| if env_manager is None: |
| env_manager = SimpleEnvManager(env, EnvironmentParametersChannel()) |
| trainer_factory = TrainerFactory( |
| trainer_config=trainer_config, |
| output_path=dir, |
| train_model=True, |
| load_model=False, |
| seed=seed, |
| param_manager=env_parameter_manager, |
| multi_gpu=False, |
| ) |
|
|
| tc = TrainerController( |
| trainer_factory=trainer_factory, |
| output_path=dir, |
| run_id=run_id, |
| param_manager=env_parameter_manager, |
| train=True, |
| training_seed=seed, |
| ) |
|
|
| |
| tc.start_learning(env_manager) |
| if ( |
| success_threshold is not None |
| ): |
| processed_rewards = [ |
| reward_processor(rewards) for rewards in env.final_rewards.values() |
| ] |
| assert all(not math.isnan(reward) for reward in processed_rewards) |
| assert all(reward > success_threshold for reward in processed_rewards) |
|
|