| | from typing import Set, Dict, Any, TextIO |
| | import os |
| | import yaml |
| | from mlagents.trainers.exception import TrainerConfigError |
| | from mlagents_envs.environment import UnityEnvironment |
| | import argparse |
| | from mlagents_envs import logging_util |
| |
|
| | logger = logging_util.get_logger(__name__) |
| |
|
| |
|
| | class RaiseRemovedWarning(argparse.Action): |
| | """ |
| | Internal custom Action to raise warning when argument is called. |
| | """ |
| |
|
| | def __init__(self, nargs=0, **kwargs): |
| | super().__init__(nargs=nargs, **kwargs) |
| |
|
| | def __call__(self, arg_parser, namespace, values, option_string=None): |
| | logger.warning(f"The command line argument {option_string} was removed.") |
| |
|
| |
|
| | class DetectDefault(argparse.Action): |
| | """ |
| | Internal custom Action to help detect arguments that aren't default. |
| | """ |
| |
|
| | non_default_args: Set[str] = set() |
| |
|
| | def __call__(self, arg_parser, namespace, values, option_string=None): |
| | setattr(namespace, self.dest, values) |
| | DetectDefault.non_default_args.add(self.dest) |
| |
|
| |
|
| | class DetectDefaultStoreTrue(DetectDefault): |
| | """ |
| | Internal class to help detect arguments that aren't default. |
| | Used for store_true arguments. |
| | """ |
| |
|
| | def __init__(self, nargs=0, **kwargs): |
| | super().__init__(nargs=nargs, **kwargs) |
| |
|
| | def __call__(self, arg_parser, namespace, values, option_string=None): |
| | super().__call__(arg_parser, namespace, True, option_string) |
| |
|
| |
|
| | class StoreConfigFile(argparse.Action): |
| | """ |
| | Custom Action to store the config file location not as part of the CLI args. |
| | This is because we want to maintain an equivalence between the config file's |
| | contents and the args themselves. |
| | """ |
| |
|
| | trainer_config_path: str |
| |
|
| | def __call__(self, arg_parser, namespace, values, option_string=None): |
| | delattr(namespace, self.dest) |
| | StoreConfigFile.trainer_config_path = values |
| |
|
| |
|
| | def _create_parser() -> argparse.ArgumentParser: |
| | argparser = argparse.ArgumentParser( |
| | formatter_class=argparse.ArgumentDefaultsHelpFormatter |
| | ) |
| | argparser.add_argument( |
| | "trainer_config_path", action=StoreConfigFile, nargs="?", default=None |
| | ) |
| | argparser.add_argument( |
| | "--env", |
| | default=None, |
| | dest="env_path", |
| | help="Path to the Unity executable to train", |
| | action=DetectDefault, |
| | ) |
| | argparser.add_argument( |
| | "--load", |
| | default=False, |
| | dest="load_model", |
| | action=DetectDefaultStoreTrue, |
| | help=argparse.SUPPRESS, |
| | ) |
| | argparser.add_argument( |
| | "--resume", |
| | default=False, |
| | dest="resume", |
| | action=DetectDefaultStoreTrue, |
| | help="Whether to resume training from a checkpoint. Specify a --run-id to use this option. " |
| | "If set, the training code loads an already trained model to initialize the neural network " |
| | "before resuming training. This option is only valid when the models exist, and have the same " |
| | "behavior names as the current agents in your scene.", |
| | ) |
| | argparser.add_argument( |
| | "--deterministic", |
| | default=False, |
| | dest="deterministic", |
| | action=DetectDefaultStoreTrue, |
| | help="Whether to select actions deterministically in policy. `dist.mean` for continuous action " |
| | "space, and `dist.argmax` for deterministic action space ", |
| | ) |
| | argparser.add_argument( |
| | "--force", |
| | default=False, |
| | dest="force", |
| | action=DetectDefaultStoreTrue, |
| | help="Whether to force-overwrite this run-id's existing summary and model data. (Without " |
| | "this flag, attempting to train a model with a run-id that has been used before will throw " |
| | "an error.", |
| | ) |
| | argparser.add_argument( |
| | "--run-id", |
| | default="ppo", |
| | help="The identifier for the training run. This identifier is used to name the " |
| | "subdirectories in which the trained model and summary statistics are saved as well " |
| | "as the saved model itself. If you use TensorBoard to view the training statistics, " |
| | "always set a unique run-id for each training run. (The statistics for all runs with the " |
| | "same id are combined as if they were produced by a the same session.)", |
| | action=DetectDefault, |
| | ) |
| | argparser.add_argument( |
| | "--initialize-from", |
| | metavar="RUN_ID", |
| | default=None, |
| | help="Specify a previously saved run ID from which to initialize the model from. " |
| | "This can be used, for instance, to fine-tune an existing model on a new environment. " |
| | "Note that the previously saved models must have the same behavior parameters as your " |
| | "current environment.", |
| | action=DetectDefault, |
| | ) |
| | argparser.add_argument( |
| | "--seed", |
| | default=-1, |
| | type=int, |
| | help="A number to use as a seed for the random number generator used by the training code", |
| | action=DetectDefault, |
| | ) |
| | argparser.add_argument( |
| | "--train", |
| | default=False, |
| | dest="train_model", |
| | action=DetectDefaultStoreTrue, |
| | help=argparse.SUPPRESS, |
| | ) |
| | argparser.add_argument( |
| | "--inference", |
| | default=False, |
| | dest="inference", |
| | action=DetectDefaultStoreTrue, |
| | help="Whether to run in Python inference mode (i.e. no training). Use with --resume to load " |
| | "a model trained with an existing run ID.", |
| | ) |
| | argparser.add_argument( |
| | "--base-port", |
| | default=UnityEnvironment.BASE_ENVIRONMENT_PORT, |
| | type=int, |
| | help="The starting port for environment communication. Each concurrent Unity environment " |
| | "instance will get assigned a port sequentially, starting from the base-port. Each instance " |
| | "will use the port (base_port + worker_id), where the worker_id is sequential IDs given to " |
| | "each instance from 0 to (num_envs - 1). Note that when training using the Editor rather " |
| | "than an executable, the base port will be ignored.", |
| | action=DetectDefault, |
| | ) |
| | argparser.add_argument( |
| | "--num-envs", |
| | default=1, |
| | type=int, |
| | help="The number of concurrent Unity environment instances to collect experiences " |
| | "from when training", |
| | action=DetectDefault, |
| | ) |
| |
|
| | argparser.add_argument( |
| | "--num-areas", |
| | default=1, |
| | type=int, |
| | help="The number of parallel training areas in each Unity environment instance.", |
| | action=DetectDefault, |
| | ) |
| |
|
| | argparser.add_argument( |
| | "--debug", |
| | default=False, |
| | action=DetectDefaultStoreTrue, |
| | help="Whether to enable debug-level logging for some parts of the code", |
| | ) |
| | argparser.add_argument( |
| | "--env-args", |
| | default=None, |
| | nargs=argparse.REMAINDER, |
| | help="Arguments passed to the Unity executable. Be aware that the standalone build will also " |
| | "process these as Unity Command Line Arguments. You should choose different argument names if " |
| | "you want to create environment-specific arguments. All arguments after this flag will be " |
| | "passed to the executable.", |
| | action=DetectDefault, |
| | ) |
| | argparser.add_argument( |
| | "--max-lifetime-restarts", |
| | default=10, |
| | help="The max number of times a single Unity executable can crash over its lifetime before ml-agents exits. " |
| | "Can be set to -1 if no limit is desired.", |
| | action=DetectDefault, |
| | ) |
| | argparser.add_argument( |
| | "--restarts-rate-limit-n", |
| | default=1, |
| | help="The maximum number of times a single Unity executable can crash over a period of time (period set in " |
| | "restarts-rate-limit-period-s). Can be set to -1 to not use rate limiting with restarts.", |
| | action=DetectDefault, |
| | ) |
| | argparser.add_argument( |
| | "--restarts-rate-limit-period-s", |
| | default=60, |
| | help="The period of time --restarts-rate-limit-n applies to.", |
| | action=DetectDefault, |
| | ) |
| | argparser.add_argument( |
| | "--torch", |
| | default=False, |
| | action=RaiseRemovedWarning, |
| | help="(Removed) Use the PyTorch framework.", |
| | ) |
| | argparser.add_argument( |
| | "--tensorflow", |
| | default=False, |
| | action=RaiseRemovedWarning, |
| | help="(Removed) Use the TensorFlow framework.", |
| | ) |
| | argparser.add_argument( |
| | "--results-dir", |
| | default="results", |
| | action=DetectDefault, |
| | help="Results base directory", |
| | ) |
| |
|
| | eng_conf = argparser.add_argument_group(title="Engine Configuration") |
| | eng_conf.add_argument( |
| | "--width", |
| | default=84, |
| | type=int, |
| | help="The width of the executable window of the environment(s) in pixels " |
| | "(ignored for editor training).", |
| | action=DetectDefault, |
| | ) |
| | eng_conf.add_argument( |
| | "--height", |
| | default=84, |
| | type=int, |
| | help="The height of the executable window of the environment(s) in pixels " |
| | "(ignored for editor training)", |
| | action=DetectDefault, |
| | ) |
| | eng_conf.add_argument( |
| | "--quality-level", |
| | default=5, |
| | type=int, |
| | help="The quality level of the environment(s). Equivalent to calling " |
| | "QualitySettings.SetQualityLevel in Unity.", |
| | action=DetectDefault, |
| | ) |
| | eng_conf.add_argument( |
| | "--time-scale", |
| | default=20, |
| | type=float, |
| | help="The time scale of the Unity environment(s). Equivalent to setting " |
| | "Time.timeScale in Unity.", |
| | action=DetectDefault, |
| | ) |
| | eng_conf.add_argument( |
| | "--target-frame-rate", |
| | default=-1, |
| | type=int, |
| | help="The target frame rate of the Unity environment(s). Equivalent to setting " |
| | "Application.targetFrameRate in Unity.", |
| | action=DetectDefault, |
| | ) |
| | eng_conf.add_argument( |
| | "--capture-frame-rate", |
| | default=60, |
| | type=int, |
| | help="The capture frame rate of the Unity environment(s). Equivalent to setting " |
| | "Time.captureFramerate in Unity.", |
| | action=DetectDefault, |
| | ) |
| | eng_conf.add_argument( |
| | "--no-graphics", |
| | default=False, |
| | action=DetectDefaultStoreTrue, |
| | help="Whether to run the Unity executable in no-graphics mode (i.e. without initializing " |
| | "the graphics driver. Use this only if your agents don't use visual observations.", |
| | ) |
| |
|
| | torch_conf = argparser.add_argument_group(title="Torch Configuration") |
| | torch_conf.add_argument( |
| | "--torch-device", |
| | default=None, |
| | dest="device", |
| | action=DetectDefault, |
| | help='Settings for the default torch.device used in training, for example, "cpu", "cuda", or "cuda:0"', |
| | ) |
| | return argparser |
| |
|
| |
|
| | def load_config(config_path: str) -> Dict[str, Any]: |
| | try: |
| | with open(config_path) as data_file: |
| | return _load_config(data_file) |
| | except OSError: |
| | abs_path = os.path.abspath(config_path) |
| | raise TrainerConfigError(f"Config file could not be found at {abs_path}.") |
| | except UnicodeDecodeError: |
| | raise TrainerConfigError( |
| | f"There was an error decoding Config file from {config_path}. " |
| | f"Make sure your file is save using UTF-8" |
| | ) |
| |
|
| |
|
| | def _load_config(fp: TextIO) -> Dict[str, Any]: |
| | """ |
| | Load the yaml config from the file-like object. |
| | """ |
| | try: |
| | return yaml.safe_load(fp) |
| | except yaml.parser.ParserError as e: |
| | raise TrainerConfigError( |
| | "Error parsing yaml file. Please check for formatting errors. " |
| | "A tool such as http://www.yamllint.com/ can be helpful with this." |
| | ) from e |
| |
|
| |
|
| | parser = _create_parser() |
| |
|