| from typing import Optional |
|
|
| from .communicator import Communicator, PollCallback |
| from .environment import UnityEnvironment |
| from mlagents_envs.communicator_objects.unity_rl_output_pb2 import UnityRLOutputProto |
| from mlagents_envs.communicator_objects.brain_parameters_pb2 import ( |
| BrainParametersProto, |
| ActionSpecProto, |
| ) |
| from mlagents_envs.communicator_objects.unity_rl_initialization_output_pb2 import ( |
| UnityRLInitializationOutputProto, |
| ) |
| from mlagents_envs.communicator_objects.unity_input_pb2 import UnityInputProto |
| from mlagents_envs.communicator_objects.unity_output_pb2 import UnityOutputProto |
| from mlagents_envs.communicator_objects.agent_info_pb2 import AgentInfoProto |
| from mlagents_envs.communicator_objects.observation_pb2 import ( |
| ObservationProto, |
| NONE as COMPRESSION_TYPE_NONE, |
| PNG as COMPRESSION_TYPE_PNG, |
| ) |
|
|
|
|
| class MockCommunicator(Communicator): |
| def __init__( |
| self, |
| discrete_action=False, |
| visual_inputs=0, |
| num_agents=3, |
| brain_name="RealFakeBrain", |
| vec_obs_size=3, |
| ): |
| """ |
| Python side of the grpc communication. Python is the client and Unity the server |
| """ |
| super().__init__() |
| self.is_discrete = discrete_action |
| self.steps = 0 |
| self.visual_inputs = visual_inputs |
| self.has_been_closed = False |
| self.num_agents = num_agents |
| self.brain_name = brain_name |
| self.vec_obs_size = vec_obs_size |
|
|
| def initialize( |
| self, inputs: UnityInputProto, poll_callback: Optional[PollCallback] = None |
| ) -> UnityOutputProto: |
| if self.is_discrete: |
| action_spec = ActionSpecProto( |
| num_discrete_actions=2, discrete_branch_sizes=[3, 2] |
| ) |
| else: |
| action_spec = ActionSpecProto(num_continuous_actions=2) |
| bp = BrainParametersProto( |
| brain_name=self.brain_name, is_training=True, action_spec=action_spec |
| ) |
| rl_init = UnityRLInitializationOutputProto( |
| name="RealFakeAcademy", |
| communication_version=UnityEnvironment.API_VERSION, |
| package_version="mock_package_version", |
| log_path="", |
| brain_parameters=[bp], |
| ) |
| output = UnityRLOutputProto(agentInfos=self._get_agent_infos()) |
| return UnityOutputProto(rl_initialization_output=rl_init, rl_output=output) |
|
|
| def _get_agent_infos(self): |
| dict_agent_info = {} |
| list_agent_info = [] |
| vector_obs = [1, 2, 3] |
|
|
| observations = [ |
| ObservationProto( |
| compressed_data=None, |
| shape=[30, 40, 3], |
| compression_type=COMPRESSION_TYPE_PNG, |
| ) |
| for _ in range(self.visual_inputs) |
| ] |
| vector_obs_proto = ObservationProto( |
| float_data=ObservationProto.FloatData(data=vector_obs), |
| shape=[len(vector_obs)], |
| compression_type=COMPRESSION_TYPE_NONE, |
| ) |
| observations.append(vector_obs_proto) |
|
|
| for i in range(self.num_agents): |
| list_agent_info.append( |
| AgentInfoProto( |
| reward=1, |
| done=(i == 2), |
| max_step_reached=False, |
| id=i, |
| observations=observations, |
| ) |
| ) |
| dict_agent_info["RealFakeBrain"] = UnityRLOutputProto.ListAgentInfoProto( |
| value=list_agent_info |
| ) |
| return dict_agent_info |
|
|
| def exchange( |
| self, inputs: UnityInputProto, poll_callback: Optional[PollCallback] = None |
| ) -> UnityOutputProto: |
| result = UnityRLOutputProto(agentInfos=self._get_agent_infos()) |
| return UnityOutputProto(rl_output=result) |
|
|
| def close(self): |
| """ |
| Sends a shutdown signal to the unity environment, and closes the grpc connection. |
| """ |
| self.has_been_closed = True |
|
|