| | from unittest.mock import Mock |
| |
|
| | import pytest |
| | from unittest import mock |
| |
|
| | import grpc |
| |
|
| | import mlagents_envs.rpc_communicator |
| | from mlagents_envs.rpc_communicator import RpcCommunicator |
| | from mlagents_envs.exception import ( |
| | UnityWorkerInUseException, |
| | UnityTimeOutException, |
| | UnityEnvironmentException, |
| | ) |
| | from mlagents_envs.communicator_objects.unity_input_pb2 import UnityInputProto |
| |
|
| |
|
| | @pytest.mark.parametrize("n_ports", [1]) |
| | def test_rpc_communicator_checks_port_on_create(base_port: int) -> None: |
| | first_comm = RpcCommunicator(base_port=base_port) |
| | with pytest.raises(UnityWorkerInUseException): |
| | second_comm = RpcCommunicator(base_port=base_port) |
| | second_comm.close() |
| | first_comm.close() |
| |
|
| |
|
| | @pytest.mark.parametrize("n_ports", [2]) |
| | def test_rpc_communicator_close(base_port: int) -> None: |
| | |
| | |
| | first_comm = RpcCommunicator(base_port=base_port) |
| | first_comm.close() |
| | second_comm = RpcCommunicator(base_port=base_port + 1) |
| | second_comm.close() |
| |
|
| |
|
| | @pytest.mark.parametrize("n_ports", [2]) |
| | def test_rpc_communicator_create_multiple_workers(base_port: int) -> None: |
| | |
| | |
| | first_comm = RpcCommunicator(base_port=base_port) |
| | second_comm = RpcCommunicator(base_port=base_port, worker_id=1) |
| | first_comm.close() |
| | second_comm.close() |
| |
|
| |
|
| | @pytest.mark.parametrize("n_ports", [1]) |
| | @mock.patch.object(grpc, "server") |
| | @mock.patch.object( |
| | mlagents_envs.rpc_communicator, "UnityToExternalServicerImplementation" |
| | ) |
| | def test_rpc_communicator_initialize_OK( |
| | mock_impl: Mock, mock_grpc_server: Mock, base_port: int |
| | ) -> None: |
| | comm = RpcCommunicator(base_port=base_port, timeout_wait=0.25) |
| | comm.unity_to_external.parent_conn.poll.return_value = True |
| | input = UnityInputProto() |
| | comm.initialize(input) |
| | comm.unity_to_external.parent_conn.poll.assert_called() |
| |
|
| |
|
| | @pytest.mark.parametrize("n_ports", [1]) |
| | @mock.patch.object(grpc, "server") |
| | @mock.patch.object( |
| | mlagents_envs.rpc_communicator, "UnityToExternalServicerImplementation" |
| | ) |
| | def test_rpc_communicator_initialize_timeout( |
| | mock_impl: Mock, mock_grpc_server: Mock, base_port: int |
| | ) -> None: |
| | comm = RpcCommunicator(timeout_wait=0.25, base_port=base_port) |
| | comm.unity_to_external.parent_conn.poll.return_value = None |
| | input = UnityInputProto() |
| | |
| | with pytest.raises(UnityTimeOutException): |
| | comm.initialize(input) |
| | comm.unity_to_external.parent_conn.poll.assert_called() |
| |
|
| |
|
| | @pytest.mark.parametrize("n_ports", [1]) |
| | @mock.patch.object(grpc, "server") |
| | @mock.patch.object( |
| | mlagents_envs.rpc_communicator, "UnityToExternalServicerImplementation" |
| | ) |
| | def test_rpc_communicator_initialize_callback( |
| | mock_impl: Mock, mock_grpc_server: Mock, base_port: int |
| | ) -> None: |
| | def callback(): |
| | raise UnityEnvironmentException |
| |
|
| | comm = RpcCommunicator(base_port=base_port, timeout_wait=0.25) |
| | comm.unity_to_external.parent_conn.poll.return_value = None |
| | input = UnityInputProto() |
| | |
| | with pytest.raises(UnityEnvironmentException): |
| | comm.initialize(input, poll_callback=callback) |
| | comm.unity_to_external.parent_conn.poll.assert_called() |
| |
|