Spaces:
Sleeping
Sleeping
| # All supported agents | |
| import os | |
| from MCAgent import MCAgent | |
| from DPAgent import DPAgent | |
| import warnings | |
| AGENTS_MAP = {"MCAgent": MCAgent, "DPAgent": DPAgent} | |
| def load_agent(agent_key, **kwargs): | |
| """ | |
| Loads an agent from a file or from the AGENTS_MAP. | |
| :param agent_key: Which agent to load. Can be a key in AGENTS_MAP or a path to a policy file ending with ".npy". | |
| If a policy file is provided, the agent name, environment name, and other parameters will be parsed from the file name. | |
| :param kwargs: Additional arguments to pass to the agent constructor. If loading from a policy file, any conflicting arguments will be overwritten. | |
| """ | |
| agent_policy_file = agent_key if agent_key.endswith(".npy") else None | |
| # if loading from a policy file, parse the agent key, environment key, and other parameters from the file name | |
| if agent_policy_file is not None: | |
| props = os.path.basename(agent_key).split("_") | |
| try: | |
| # Parsing arguments from file name | |
| agent_key, env_key = props[0], props[1] | |
| parsed_args = {} | |
| for prop in props[2:]: | |
| props_split = prop.split(":") | |
| if len(props_split) == 2: | |
| parsed_args[props_split[0]] = props_split[1] | |
| else: | |
| warnings.warn( | |
| f"Skipping property {prop} as it does not have the format 'key:value'.", | |
| UserWarning, | |
| ) | |
| # Overwrite any conflicting arguments with those from the file name | |
| parsed_args["env"] = env_key | |
| kwargs.update(parsed_args) | |
| print("agent_args:", kwargs) | |
| except IndexError as e: | |
| raise ValueError( | |
| "ERROR: Could not parse agent properties. Must be of the format 'AgentName_EnvName_key:value_key:value...'." | |
| ) from e | |
| # Check if agent key is valid | |
| if agent_key not in AGENTS_MAP: | |
| raise ValueError( | |
| f"ERROR: Agent '{agent_key}' not valid. Must be one of: {AGENTS_MAP.keys()}" | |
| ) | |
| # Load agent based on key and arguments | |
| agent = AGENTS_MAP[agent_key](**kwargs) | |
| # If loading from a policy file, load the policy into the agent | |
| if agent_policy_file is not None: | |
| agent.load_policy(agent_policy_file) | |
| return agent | |