| | from unittest import mock |
| | import pytest |
| |
|
| | from mlagents_envs.environment import UnityEnvironment |
| | from mlagents_envs.base_env import DecisionSteps, TerminalSteps, ActionTuple |
| | from mlagents_envs.exception import UnityEnvironmentException, UnityActionException |
| | from mlagents_envs.mock_communicator import MockCommunicator |
| |
|
| |
|
| | @mock.patch("mlagents_envs.environment.UnityEnvironment._get_communicator") |
| | def test_handles_bad_filename(get_communicator): |
| | with pytest.raises(UnityEnvironmentException): |
| | UnityEnvironment(" ") |
| |
|
| |
|
| | @mock.patch("mlagents_envs.env_utils.launch_executable") |
| | @mock.patch("mlagents_envs.environment.UnityEnvironment._get_communicator") |
| | def test_initialization(mock_communicator, mock_launcher): |
| | mock_communicator.return_value = MockCommunicator( |
| | discrete_action=False, visual_inputs=0 |
| | ) |
| | env = UnityEnvironment(" ") |
| | assert list(env.behavior_specs.keys()) == ["RealFakeBrain"] |
| | env.close() |
| |
|
| |
|
| | @pytest.mark.parametrize( |
| | "base_port,file_name,expected", |
| | [ |
| | |
| | (6001, "foo.exe", 6001), |
| | |
| | (None, "foo.exe", UnityEnvironment.BASE_ENVIRONMENT_PORT), |
| | |
| | (None, None, UnityEnvironment.DEFAULT_EDITOR_PORT), |
| | ], |
| | ) |
| | @mock.patch("mlagents_envs.env_utils.launch_executable") |
| | @mock.patch("mlagents_envs.environment.UnityEnvironment._get_communicator") |
| | def test_port_defaults( |
| | mock_communicator, mock_launcher, base_port, file_name, expected |
| | ): |
| | mock_communicator.return_value = MockCommunicator( |
| | discrete_action=False, visual_inputs=0 |
| | ) |
| | env = UnityEnvironment(file_name=file_name, worker_id=0, base_port=base_port) |
| | assert expected == env._port |
| | env.close() |
| |
|
| |
|
| | @mock.patch("mlagents_envs.env_utils.launch_executable") |
| | @mock.patch("mlagents_envs.environment.UnityEnvironment._get_communicator") |
| | def test_log_file_path_is_set(mock_communicator, mock_launcher): |
| | mock_communicator.return_value = MockCommunicator() |
| | env = UnityEnvironment( |
| | file_name="myfile", worker_id=0, log_folder="./some-log-folder-path" |
| | ) |
| | args = env._executable_args() |
| | log_file_index = args.index("-logFile") |
| | assert args[log_file_index + 1] == "./some-log-folder-path/Player-0.log" |
| | env.close() |
| |
|
| |
|
| | @mock.patch("mlagents_envs.env_utils.launch_executable") |
| | @mock.patch("mlagents_envs.environment.UnityEnvironment._get_communicator") |
| | def test_reset(mock_communicator, mock_launcher): |
| | mock_communicator.return_value = MockCommunicator( |
| | discrete_action=False, visual_inputs=0 |
| | ) |
| | env = UnityEnvironment(" ") |
| | spec = env.behavior_specs["RealFakeBrain"] |
| | env.reset() |
| | decision_steps, terminal_steps = env.get_steps("RealFakeBrain") |
| | env.close() |
| | assert isinstance(decision_steps, DecisionSteps) |
| | assert isinstance(terminal_steps, TerminalSteps) |
| | assert len(spec.observation_specs) == len(decision_steps.obs) |
| | assert len(spec.observation_specs) == len(terminal_steps.obs) |
| | n_agents = len(decision_steps) |
| | for sen_spec, obs in zip(spec.observation_specs, decision_steps.obs): |
| | assert (n_agents,) + sen_spec.shape == obs.shape |
| | n_agents = len(terminal_steps) |
| | for sen_spec, obs in zip(spec.observation_specs, terminal_steps.obs): |
| | assert (n_agents,) + sen_spec.shape == obs.shape |
| |
|
| |
|
| | @mock.patch("mlagents_envs.env_utils.launch_executable") |
| | @mock.patch("mlagents_envs.environment.UnityEnvironment._get_communicator") |
| | def test_step(mock_communicator, mock_launcher): |
| | mock_communicator.return_value = MockCommunicator( |
| | discrete_action=False, visual_inputs=0 |
| | ) |
| | env = UnityEnvironment(" ") |
| | spec = env.behavior_specs["RealFakeBrain"] |
| | env.step() |
| | decision_steps, terminal_steps = env.get_steps("RealFakeBrain") |
| | n_agents = len(decision_steps) |
| | env.set_actions("RealFakeBrain", spec.action_spec.empty_action(n_agents)) |
| | env.step() |
| | with pytest.raises(UnityActionException): |
| | env.set_actions("RealFakeBrain", spec.action_spec.empty_action(n_agents - 1)) |
| | decision_steps, terminal_steps = env.get_steps("RealFakeBrain") |
| | n_agents = len(decision_steps) |
| | _empty_act = spec.action_spec.empty_action(n_agents) |
| | next_action = ActionTuple(_empty_act.continuous - 1, _empty_act.discrete - 1) |
| | env.set_actions("RealFakeBrain", next_action) |
| | env.step() |
| |
|
| | env.close() |
| | assert isinstance(decision_steps, DecisionSteps) |
| | assert isinstance(terminal_steps, TerminalSteps) |
| | assert len(spec.observation_specs) == len(decision_steps.obs) |
| | assert len(spec.observation_specs) == len(terminal_steps.obs) |
| | for spec, obs in zip(spec.observation_specs, decision_steps.obs): |
| | assert (n_agents,) + spec.shape == obs.shape |
| | assert 0 in decision_steps |
| | assert 2 in terminal_steps |
| |
|
| |
|
| | @mock.patch("mlagents_envs.env_utils.launch_executable") |
| | @mock.patch("mlagents_envs.environment.UnityEnvironment._get_communicator") |
| | def test_close(mock_communicator, mock_launcher): |
| | comm = MockCommunicator(discrete_action=False, visual_inputs=0) |
| | mock_communicator.return_value = comm |
| | env = UnityEnvironment(" ") |
| | assert env._loaded |
| | env.close() |
| | assert not env._loaded |
| | assert comm.has_been_closed |
| |
|
| |
|
| | def test_check_communication_compatibility(): |
| | unity_ver = "1.0.0" |
| | python_ver = "1.0.0" |
| | unity_package_version = "0.15.0" |
| | assert UnityEnvironment._check_communication_compatibility( |
| | unity_ver, python_ver, unity_package_version |
| | ) |
| | unity_ver = "1.1.0" |
| | assert UnityEnvironment._check_communication_compatibility( |
| | unity_ver, python_ver, unity_package_version |
| | ) |
| | unity_ver = "2.0.0" |
| | assert not UnityEnvironment._check_communication_compatibility( |
| | unity_ver, python_ver, unity_package_version |
| | ) |
| |
|
| | unity_ver = "0.16.0" |
| | python_ver = "0.16.0" |
| | assert UnityEnvironment._check_communication_compatibility( |
| | unity_ver, python_ver, unity_package_version |
| | ) |
| | unity_ver = "0.17.0" |
| | assert not UnityEnvironment._check_communication_compatibility( |
| | unity_ver, python_ver, unity_package_version |
| | ) |
| | unity_ver = "1.16.0" |
| | assert not UnityEnvironment._check_communication_compatibility( |
| | unity_ver, python_ver, unity_package_version |
| | ) |
| |
|
| |
|
| | def test_returncode_to_signal_name(): |
| | assert UnityEnvironment._returncode_to_signal_name(-2) == "SIGINT" |
| | assert UnityEnvironment._returncode_to_signal_name(42) is None |
| | assert UnityEnvironment._returncode_to_signal_name("SIGINT") is None |
| |
|
| |
|
| | if __name__ == "__main__": |
| | pytest.main() |
| |
|