| | |
| | from mlagents import torch_utils |
| | import yaml |
| |
|
| | import os |
| | import numpy as np |
| | import json |
| |
|
| | from typing import Callable, Optional, List |
| |
|
| | import mlagents.trainers |
| | import mlagents_envs |
| | from mlagents.trainers.trainer_controller import TrainerController |
| | from mlagents.trainers.environment_parameter_manager import EnvironmentParameterManager |
| | from mlagents.trainers.trainer import TrainerFactory |
| | from mlagents.trainers.directory_utils import ( |
| | validate_existing_directories, |
| | setup_init_path, |
| | ) |
| | from mlagents.trainers.stats import StatsReporter |
| | from mlagents.trainers.cli_utils import parser |
| | from mlagents_envs.environment import UnityEnvironment |
| | from mlagents.trainers.settings import RunOptions |
| |
|
| | from mlagents.trainers.training_status import GlobalTrainingStatus |
| | from mlagents_envs.base_env import BaseEnv |
| | from mlagents.trainers.subprocess_env_manager import SubprocessEnvManager |
| | from mlagents_envs.side_channel.side_channel import SideChannel |
| | from mlagents_envs.timers import ( |
| | hierarchical_timer, |
| | get_timer_tree, |
| | add_metadata as add_timer_metadata, |
| | ) |
| | from mlagents_envs import logging_util |
| | from mlagents.plugins.stats_writer import register_stats_writer_plugins |
| | from mlagents.plugins.trainer_type import register_trainer_plugins |
| |
|
| | logger = logging_util.get_logger(__name__) |
| |
|
| | TRAINING_STATUS_FILE_NAME = "training_status.json" |
| |
|
| |
|
| | def get_version_string() -> str: |
| | return f""" Version information: |
| | ml-agents: {mlagents.trainers.__version__}, |
| | ml-agents-envs: {mlagents_envs.__version__}, |
| | Communicator API: {UnityEnvironment.API_VERSION}, |
| | PyTorch: {torch_utils.torch.__version__}""" |
| |
|
| |
|
| | def parse_command_line( |
| | argv: Optional[List[str]] = None, |
| | ) -> RunOptions: |
| | _, _ = register_trainer_plugins() |
| | args = parser.parse_args(argv) |
| | return RunOptions.from_argparse(args) |
| |
|
| |
|
| | def run_training(run_seed: int, options: RunOptions, num_areas: int) -> None: |
| | """ |
| | Launches training session. |
| | :param run_seed: Random seed used for training. |
| | :param num_areas: Number of training areas to instantiate |
| | :param options: parsed command line arguments |
| | """ |
| | with hierarchical_timer("run_training.setup"): |
| | torch_utils.set_torch_config(options.torch_settings) |
| | checkpoint_settings = options.checkpoint_settings |
| | env_settings = options.env_settings |
| | engine_settings = options.engine_settings |
| |
|
| | run_logs_dir = checkpoint_settings.run_logs_dir |
| | port: Optional[int] = env_settings.base_port |
| | |
| | validate_existing_directories( |
| | checkpoint_settings.write_path, |
| | checkpoint_settings.resume, |
| | checkpoint_settings.force, |
| | checkpoint_settings.maybe_init_path, |
| | ) |
| | |
| | os.makedirs(run_logs_dir, exist_ok=True) |
| | |
| | if checkpoint_settings.resume: |
| | GlobalTrainingStatus.load_state( |
| | os.path.join(run_logs_dir, "training_status.json") |
| | ) |
| | |
| | elif checkpoint_settings.maybe_init_path is not None: |
| | setup_init_path(options.behaviors, checkpoint_settings.maybe_init_path) |
| |
|
| | |
| | stats_writers = register_stats_writer_plugins(options) |
| | for sw in stats_writers: |
| | StatsReporter.add_writer(sw) |
| |
|
| | if env_settings.env_path is None: |
| | port = None |
| | env_factory = create_environment_factory( |
| | env_settings.env_path, |
| | engine_settings.no_graphics, |
| | run_seed, |
| | num_areas, |
| | port, |
| | env_settings.env_args, |
| | os.path.abspath(run_logs_dir), |
| | ) |
| |
|
| | env_manager = SubprocessEnvManager(env_factory, options, env_settings.num_envs) |
| | env_parameter_manager = EnvironmentParameterManager( |
| | options.environment_parameters, run_seed, restore=checkpoint_settings.resume |
| | ) |
| |
|
| | trainer_factory = TrainerFactory( |
| | trainer_config=options.behaviors, |
| | output_path=checkpoint_settings.write_path, |
| | train_model=not checkpoint_settings.inference, |
| | load_model=checkpoint_settings.resume, |
| | seed=run_seed, |
| | param_manager=env_parameter_manager, |
| | init_path=checkpoint_settings.maybe_init_path, |
| | multi_gpu=False, |
| | ) |
| | |
| | tc = TrainerController( |
| | trainer_factory, |
| | checkpoint_settings.write_path, |
| | checkpoint_settings.run_id, |
| | env_parameter_manager, |
| | not checkpoint_settings.inference, |
| | run_seed, |
| | ) |
| |
|
| | |
| | try: |
| | tc.start_learning(env_manager) |
| | finally: |
| | env_manager.close() |
| | write_run_options(checkpoint_settings.write_path, options) |
| | write_timing_tree(run_logs_dir) |
| | write_training_status(run_logs_dir) |
| |
|
| |
|
| | def write_run_options(output_dir: str, run_options: RunOptions) -> None: |
| | run_options_path = os.path.join(output_dir, "configuration.yaml") |
| | try: |
| | with open(run_options_path, "w") as f: |
| | try: |
| | yaml.dump(run_options.as_dict(), f, sort_keys=False) |
| | except TypeError: |
| | yaml.dump(run_options.as_dict(), f) |
| | except FileNotFoundError: |
| | logger.warning( |
| | f"Unable to save configuration to {run_options_path}. Make sure the directory exists" |
| | ) |
| |
|
| |
|
| | def write_training_status(output_dir: str) -> None: |
| | GlobalTrainingStatus.save_state(os.path.join(output_dir, TRAINING_STATUS_FILE_NAME)) |
| |
|
| |
|
| | def write_timing_tree(output_dir: str) -> None: |
| | timing_path = os.path.join(output_dir, "timers.json") |
| | try: |
| | with open(timing_path, "w") as f: |
| | json.dump(get_timer_tree(), f, indent=4) |
| | except FileNotFoundError: |
| | logger.warning( |
| | f"Unable to save to {timing_path}. Make sure the directory exists" |
| | ) |
| |
|
| |
|
| | def create_environment_factory( |
| | env_path: Optional[str], |
| | no_graphics: bool, |
| | seed: int, |
| | num_areas: int, |
| | start_port: Optional[int], |
| | env_args: Optional[List[str]], |
| | log_folder: str, |
| | ) -> Callable[[int, List[SideChannel]], BaseEnv]: |
| | def create_unity_environment( |
| | worker_id: int, side_channels: List[SideChannel] |
| | ) -> UnityEnvironment: |
| | |
| | env_seed = seed + worker_id |
| | return UnityEnvironment( |
| | file_name=env_path, |
| | worker_id=worker_id, |
| | seed=env_seed, |
| | num_areas=num_areas, |
| | no_graphics=no_graphics, |
| | base_port=start_port, |
| | additional_args=env_args, |
| | side_channels=side_channels, |
| | log_folder=log_folder, |
| | ) |
| |
|
| | return create_unity_environment |
| |
|
| |
|
| | def run_cli(options: RunOptions) -> None: |
| | try: |
| | print( |
| | """ |
| | β β |
| | βββ¬ββ‘ βββ¬ββ |
| | βββ¬ββββββ β¬ββββββ¬β |
| | ββ¬ββββββ¬β ββ¬βββββββ βββ |
| | β¬β¬β¬β¬ββββ¦β ββ¬ββββ£β£β£β¬ ββ£β£β¬ ββ£β£β£ βββ ββ£β£ |
| | β¬β¬β¬β¬β¬β¬β¬β¬βββ¬ββββ¬βͺβββ£β£β£β£β£β£β£β¬ ββ£β£β¬ ββ£β£β£ ββ£β£βββ£β£β£β β£β£β£ β£β£β£β£β£β£ ββ£β£β β£β£β£ |
| | β¬β¬β¬β¬β ββ¬β¬β¬β¬βββ£β£β£ββ β«β£β£β£β¬ ββ£β£β¬ ββ£β£β£ ββ£β£β£β ββ£β£β£ β£β£β£ βββ£β£ββ β«β£β£ ββ£β£ |
| | β¬β¬β¬β¬β ββ¬β¬β£β£ β«β£β£β£β¬ ββ£β£β¬ ββ£β£β£ ββ£β£β¬ β£β£β£ β£β£β£ ββ£β£ β£β£β£ββ£β£β |
| | β¬β¬β¬β β¬β¬β£β£ βββ£β£β¬ ββ£β£β£βββββ£β£β£β ββ£β£β¬ β£β£β£ β£β£β£ ββ£β£β¦β β£β£β£β£β£ |
| | β ββ¦β β¬β¬β£β£ ββββ βββ£β£β£β£ββ ββββ βββ βββ ββ£β£β£ ββ£β£β£ |
| | β©β¬β¬β¬β¬β¬β¬β¦β¦β¬β¬β£β£ββ£β£β£β£β£β£β£β β«β£β£β£β£ |
| | ββ¬β¬β¬β¬β¬β¬β¬β£β£β£β£β£β£ββ |
| | ββ¬β¬β¬β£β£β£β |
| | β |
| | """ |
| | ) |
| | except Exception: |
| | print("\n\n\tUnity Technologies\n") |
| | print(get_version_string()) |
| |
|
| | if options.debug: |
| | log_level = logging_util.DEBUG |
| | else: |
| | log_level = logging_util.INFO |
| |
|
| | logging_util.set_log_level(log_level) |
| |
|
| | logger.debug("Configuration for this run:") |
| | logger.debug(json.dumps(options.as_dict(), indent=4)) |
| |
|
| | |
| | if options.checkpoint_settings.load_model: |
| | logger.warning( |
| | "The --load option has been deprecated. Please use the --resume option instead." |
| | ) |
| | if options.checkpoint_settings.train_model: |
| | logger.warning( |
| | "The --train option has been deprecated. Train mode is now the default. Use " |
| | "--inference to run in inference mode." |
| | ) |
| |
|
| | run_seed = options.env_settings.seed |
| | num_areas = options.env_settings.num_areas |
| |
|
| | |
| | add_timer_metadata("mlagents_version", mlagents.trainers.__version__) |
| | add_timer_metadata("mlagents_envs_version", mlagents_envs.__version__) |
| | add_timer_metadata("communication_protocol_version", UnityEnvironment.API_VERSION) |
| | add_timer_metadata("pytorch_version", torch_utils.torch.__version__) |
| | add_timer_metadata("numpy_version", np.__version__) |
| |
|
| | if options.env_settings.seed == -1: |
| | run_seed = np.random.randint(0, 10000) |
| | logger.debug(f"run_seed set to {run_seed}") |
| | run_training(run_seed, options, num_areas) |
| |
|
| |
|
| | def main(): |
| | run_cli(parse_command_line()) |
| |
|
| |
|
| | |
| | if __name__ == "__main__": |
| | main() |
| |
|