| | #if UNITY_EDITOR || UNITY_STANDALONE |
| | #define MLA_SUPPORTED_TRAINING_PLATFORM |
| | #endif |
| |
|
| | #if MLA_SUPPORTED_TRAINING_PLATFORM |
| | using Grpc.Core; |
| | #if UNITY_EDITOR |
| | using UnityEditor; |
| | #endif |
| | using System; |
| | using System.Collections.Generic; |
| | using System.Linq; |
| | using UnityEngine; |
| | using Unity.MLAgents.Actuators; |
| | using Unity.MLAgents.CommunicatorObjects; |
| | using Unity.MLAgents.Sensors; |
| | using Unity.MLAgents.SideChannels; |
| | using Google.Protobuf; |
| |
|
| | using Unity.MLAgents.Analytics; |
| |
|
| | namespace Unity.MLAgents |
| | { |
| | |
| | public class RpcCommunicator : ICommunicator |
| | { |
| | public event QuitCommandHandler QuitCommandReceived; |
| | public event ResetCommandHandler ResetCommandReceived; |
| |
|
| | |
| | bool m_IsOpen; |
| |
|
| | List<string> m_BehaviorNames = new List<string>(); |
| | bool m_NeedCommunicateThisStep; |
| | ObservationWriter m_ObservationWriter = new ObservationWriter(); |
| | Dictionary<string, SensorShapeValidator> m_SensorShapeValidators = new Dictionary<string, SensorShapeValidator>(); |
| | Dictionary<string, List<int>> m_OrderedAgentsRequestingDecisions = new Dictionary<string, List<int>>(); |
| |
|
| | |
| | UnityRLOutputProto m_CurrentUnityRlOutput = |
| | new UnityRLOutputProto(); |
| |
|
| | Dictionary<string, Dictionary<int, ActionBuffers>> m_LastActionsReceived = |
| | new Dictionary<string, Dictionary<int, ActionBuffers>>(); |
| |
|
| | |
| | HashSet<string> m_SentBrainKeys = new HashSet<string>(); |
| | Dictionary<string, ActionSpec> m_UnsentBrainKeys = new Dictionary<string, ActionSpec>(); |
| |
|
| |
|
| | |
| | UnityToExternalProto.UnityToExternalProtoClient m_Client; |
| | Channel m_Channel; |
| |
|
| | |
| | |
| | |
| | protected RpcCommunicator() |
| | { |
| | } |
| |
|
| | public static RpcCommunicator Create() |
| | { |
| | #if MLA_SUPPORTED_TRAINING_PLATFORM |
| | return new RpcCommunicator(); |
| | #else |
| | return null; |
| | #endif |
| | } |
| |
|
| | #region Initialization |
| |
|
| | internal static bool CheckCommunicationVersionsAreCompatible( |
| | string unityCommunicationVersion, |
| | string pythonApiVersion |
| | ) |
| | { |
| | var unityVersion = new Version(unityCommunicationVersion); |
| | var pythonVersion = new Version(pythonApiVersion); |
| | if (unityVersion.Major == 0) |
| | { |
| | if (unityVersion.Major != pythonVersion.Major || unityVersion.Minor != pythonVersion.Minor) |
| | { |
| | return false; |
| | } |
| | } |
| | else if (unityVersion.Major != pythonVersion.Major) |
| | { |
| | return false; |
| | } |
| | else if (unityVersion.Minor != pythonVersion.Minor) |
| | { |
| | |
| | |
| | } |
| | return true; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | public bool Initialize(CommunicatorInitParameters initParameters, out UnityRLInitParameters initParametersOut) |
| | { |
| | #if MLA_SUPPORTED_TRAINING_PLATFORM |
| | var academyParameters = new UnityRLInitializationOutputProto |
| | { |
| | Name = initParameters.name, |
| | PackageVersion = initParameters.unityPackageVersion, |
| | CommunicationVersion = initParameters.unityCommunicationVersion, |
| | Capabilities = initParameters.CSharpCapabilities.ToProto() |
| | }; |
| |
|
| | UnityInputProto input; |
| | UnityInputProto initializationInput; |
| | try |
| | { |
| | initializationInput = Initialize( |
| | initParameters.port, |
| | new UnityOutputProto |
| | { |
| | RlInitializationOutput = academyParameters |
| | }, |
| | out input |
| | ); |
| | } |
| | catch (Exception ex) |
| | { |
| | if (ex is RpcException rpcException) |
| | { |
| | switch (rpcException.Status.StatusCode) |
| | { |
| | case StatusCode.Unavailable: |
| | |
| | break; |
| | case StatusCode.DeadlineExceeded: |
| | |
| | break; |
| | default: |
| | Debug.Log($"Unexpected gRPC exception when trying to initialize communication: {rpcException}"); |
| | break; |
| | } |
| | } |
| | else |
| | { |
| | Debug.Log($"Unexpected exception when trying to initialize communication: {ex}"); |
| | } |
| | initParametersOut = new UnityRLInitParameters(); |
| | NotifyQuitAndShutDownChannel(); |
| | return false; |
| | } |
| |
|
| | var pythonPackageVersion = initializationInput.RlInitializationInput.PackageVersion; |
| | var pythonCommunicationVersion = initializationInput.RlInitializationInput.CommunicationVersion; |
| | TrainingAnalytics.SetTrainerInformation(pythonPackageVersion, pythonCommunicationVersion); |
| |
|
| | var communicationIsCompatible = CheckCommunicationVersionsAreCompatible( |
| | initParameters.unityCommunicationVersion, |
| | pythonCommunicationVersion |
| | ); |
| |
|
| | |
| | |
| | if (initializationInput != null && input == null) |
| | { |
| | if (!communicationIsCompatible) |
| | { |
| | Debug.LogWarningFormat( |
| | "Communication protocol between python ({0}) and Unity ({1}) have different " + |
| | "versions which make them incompatible. Python library version: {2}.", |
| | pythonCommunicationVersion, initParameters.unityCommunicationVersion, |
| | pythonPackageVersion |
| | ); |
| | } |
| | else |
| | { |
| | Debug.LogWarningFormat( |
| | "Unknown communication error between Python. Python communication protocol: {0}, " + |
| | "Python library version: {1}.", |
| | pythonCommunicationVersion, |
| | pythonPackageVersion |
| | ); |
| | } |
| |
|
| | initParametersOut = new UnityRLInitParameters(); |
| | return false; |
| | } |
| |
|
| | UpdateEnvironmentWithInput(input.RlInput); |
| | initParametersOut = initializationInput.RlInitializationInput.ToUnityRLInitParameters(); |
| | |
| | Application.quitting += NotifyQuitAndShutDownChannel; |
| | return true; |
| | #else |
| | initParametersOut = new UnityRLInitParameters(); |
| | return false; |
| | #endif |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | public void SubscribeBrain(string brainKey, ActionSpec actionSpec) |
| | { |
| | if (m_BehaviorNames.Contains(brainKey)) |
| | { |
| | return; |
| | } |
| | m_BehaviorNames.Add(brainKey); |
| | m_CurrentUnityRlOutput.AgentInfos.Add( |
| | brainKey, |
| | new UnityRLOutputProto.Types.ListAgentInfoProto() |
| | ); |
| |
|
| | CacheActionSpec(brainKey, actionSpec); |
| | } |
| |
|
| | void UpdateEnvironmentWithInput(UnityRLInputProto rlInput) |
| | { |
| | SideChannelManager.ProcessSideChannelData(rlInput.SideChannel.ToArray()); |
| | SendCommandEvent(rlInput.Command); |
| | } |
| |
|
| | UnityInputProto Initialize(int port, UnityOutputProto unityOutput, out UnityInputProto unityInput) |
| | { |
| | m_IsOpen = true; |
| | m_Channel = new Channel($"localhost:{port}", ChannelCredentials.Insecure); |
| |
|
| | m_Client = new UnityToExternalProto.UnityToExternalProtoClient(m_Channel); |
| | var result = m_Client.Exchange(WrapMessage(unityOutput, 200)); |
| | var inputMessage = m_Client.Exchange(WrapMessage(null, 200)); |
| | unityInput = inputMessage.UnityInput; |
| | #if UNITY_EDITOR |
| | EditorApplication.playModeStateChanged += HandleOnPlayModeChanged; |
| | #endif |
| | if (result.Header.Status != 200 || inputMessage.Header.Status != 200) |
| | { |
| | m_IsOpen = false; |
| | NotifyQuitAndShutDownChannel(); |
| | } |
| | return result.UnityInput; |
| | } |
| |
|
| | void NotifyQuitAndShutDownChannel() |
| | { |
| | QuitCommandReceived?.Invoke(); |
| | try |
| | { |
| | m_Channel.ShutdownAsync().Wait(); |
| | } |
| | catch (Exception) |
| | { |
| | |
| | } |
| | } |
| |
|
| | #endregion |
| |
|
| | #region Destruction |
| |
|
| | |
| | |
| | |
| | public void Dispose() |
| | { |
| | if (!m_IsOpen) |
| | { |
| | return; |
| | } |
| |
|
| | try |
| | { |
| | m_Client.Exchange(WrapMessage(null, 400)); |
| | m_IsOpen = false; |
| | } |
| | catch |
| | { |
| | |
| | } |
| | } |
| |
|
| | #endregion |
| |
|
| | #region Sending Events |
| |
|
| | void SendCommandEvent(CommandProto command) |
| | { |
| | switch (command) |
| | { |
| | case CommandProto.Quit: |
| | { |
| | NotifyQuitAndShutDownChannel(); |
| | return; |
| | } |
| | case CommandProto.Reset: |
| | { |
| | foreach (var brainName in m_OrderedAgentsRequestingDecisions.Keys) |
| | { |
| | m_OrderedAgentsRequestingDecisions[brainName].Clear(); |
| | } |
| | ResetCommandReceived?.Invoke(); |
| | return; |
| | } |
| | default: |
| | { |
| | return; |
| | } |
| | } |
| | } |
| |
|
| | #endregion |
| |
|
| | #region Sending and retreiving data |
| |
|
| | public void DecideBatch() |
| | { |
| | if (!m_NeedCommunicateThisStep) |
| | { |
| | return; |
| | } |
| | m_NeedCommunicateThisStep = false; |
| |
|
| | SendBatchedMessageHelper(); |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | public void PutObservations(string behaviorName, AgentInfo info, List<ISensor> sensors) |
| | { |
| | #if DEBUG |
| | if (!m_SensorShapeValidators.ContainsKey(behaviorName)) |
| | { |
| | m_SensorShapeValidators[behaviorName] = new SensorShapeValidator(); |
| | } |
| | m_SensorShapeValidators[behaviorName].ValidateSensors(sensors); |
| | #endif |
| |
|
| | using (TimerStack.Instance.Scoped("AgentInfo.ToProto")) |
| | { |
| | var agentInfoProto = info.ToAgentInfoProto(); |
| |
|
| | using (TimerStack.Instance.Scoped("GenerateSensorData")) |
| | { |
| | foreach (var sensor in sensors) |
| | { |
| | var obsProto = sensor.GetObservationProto(m_ObservationWriter); |
| | agentInfoProto.Observations.Add(obsProto); |
| | } |
| | } |
| | m_CurrentUnityRlOutput.AgentInfos[behaviorName].Value.Add(agentInfoProto); |
| | } |
| |
|
| | m_NeedCommunicateThisStep = true; |
| | if (!m_OrderedAgentsRequestingDecisions.ContainsKey(behaviorName)) |
| | { |
| | m_OrderedAgentsRequestingDecisions[behaviorName] = new List<int>(); |
| | } |
| | if (!info.done) |
| | { |
| | m_OrderedAgentsRequestingDecisions[behaviorName].Add(info.episodeId); |
| | } |
| | if (!m_LastActionsReceived.ContainsKey(behaviorName)) |
| | { |
| | m_LastActionsReceived[behaviorName] = new Dictionary<int, ActionBuffers>(); |
| | } |
| | m_LastActionsReceived[behaviorName][info.episodeId] = ActionBuffers.Empty; |
| | if (info.done) |
| | { |
| | m_LastActionsReceived[behaviorName].Remove(info.episodeId); |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | void SendBatchedMessageHelper() |
| | { |
| | var message = new UnityOutputProto |
| | { |
| | RlOutput = m_CurrentUnityRlOutput, |
| | }; |
| | var tempUnityRlInitializationOutput = GetTempUnityRlInitializationOutput(); |
| | if (tempUnityRlInitializationOutput != null) |
| | { |
| | message.RlInitializationOutput = tempUnityRlInitializationOutput; |
| | } |
| |
|
| | byte[] messageAggregated = SideChannelManager.GetSideChannelMessage(); |
| | message.RlOutput.SideChannel = ByteString.CopyFrom(messageAggregated); |
| |
|
| | var input = Exchange(message); |
| | UpdateSentActionSpec(tempUnityRlInitializationOutput); |
| |
|
| | foreach (var k in m_CurrentUnityRlOutput.AgentInfos.Keys) |
| | { |
| | m_CurrentUnityRlOutput.AgentInfos[k].Value.Clear(); |
| | } |
| |
|
| | var rlInput = input?.RlInput; |
| |
|
| | if (rlInput?.AgentActions == null) |
| | { |
| | return; |
| | } |
| |
|
| | UpdateEnvironmentWithInput(rlInput); |
| |
|
| | foreach (var brainName in rlInput.AgentActions.Keys) |
| | { |
| | if (!m_OrderedAgentsRequestingDecisions[brainName].Any()) |
| | { |
| | continue; |
| | } |
| |
|
| | if (!rlInput.AgentActions[brainName].Value.Any()) |
| | { |
| | continue; |
| | } |
| |
|
| | var agentActions = rlInput.AgentActions[brainName].ToAgentActionList(); |
| | var numAgents = m_OrderedAgentsRequestingDecisions[brainName].Count; |
| | for (var i = 0; i < numAgents; i++) |
| | { |
| | var agentAction = agentActions[i]; |
| | var agentId = m_OrderedAgentsRequestingDecisions[brainName][i]; |
| | if (m_LastActionsReceived[brainName].ContainsKey(agentId)) |
| | { |
| | m_LastActionsReceived[brainName][agentId] = agentAction; |
| | } |
| | } |
| | } |
| | foreach (var brainName in m_OrderedAgentsRequestingDecisions.Keys) |
| | { |
| | m_OrderedAgentsRequestingDecisions[brainName].Clear(); |
| | } |
| | } |
| |
|
| | public ActionBuffers GetActions(string behaviorName, int agentId) |
| | { |
| | if (m_LastActionsReceived.ContainsKey(behaviorName)) |
| | { |
| | if (m_LastActionsReceived[behaviorName].ContainsKey(agentId)) |
| | { |
| | return m_LastActionsReceived[behaviorName][agentId]; |
| | } |
| | } |
| | return ActionBuffers.Empty; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | UnityInputProto Exchange(UnityOutputProto unityOutput) |
| | { |
| | if (!m_IsOpen) |
| | { |
| | return null; |
| | } |
| |
|
| | try |
| | { |
| | var message = m_Client.Exchange(WrapMessage(unityOutput, 200)); |
| | if (message.Header.Status == 200) |
| | { |
| | return message.UnityInput; |
| | } |
| |
|
| | m_IsOpen = false; |
| | |
| | |
| | |
| | NotifyQuitAndShutDownChannel(); |
| | return message.UnityInput; |
| | } |
| | catch (Exception ex) |
| | { |
| | if (ex is RpcException rpcException) |
| | { |
| | |
| | switch (rpcException.Status.StatusCode) |
| | { |
| | case StatusCode.Unavailable: |
| | |
| | break; |
| | case StatusCode.ResourceExhausted: |
| | |
| | |
| | |
| | Debug.LogError($"GRPC Exception: {rpcException.Message}. Disconnecting from trainer."); |
| | break; |
| | default: |
| | |
| | Debug.Log($"GRPC Exception: {rpcException.Message}. Disconnecting from trainer."); |
| | break; |
| | } |
| | } |
| | else |
| | { |
| | |
| | Debug.LogError($"Communication Exception: {ex.Message}. Disconnecting from trainer."); |
| | } |
| |
|
| | m_IsOpen = false; |
| | NotifyQuitAndShutDownChannel(); |
| | return null; |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | static UnityMessageProto WrapMessage(UnityOutputProto content, int status) |
| | { |
| | return new UnityMessageProto |
| | { |
| | Header = new HeaderProto { Status = status }, |
| | UnityOutput = content |
| | }; |
| | } |
| |
|
| | void CacheActionSpec(string behaviorName, ActionSpec actionSpec) |
| | { |
| | if (m_SentBrainKeys.Contains(behaviorName)) |
| | { |
| | return; |
| | } |
| |
|
| | |
| | m_UnsentBrainKeys[behaviorName] = actionSpec; |
| | } |
| |
|
| | UnityRLInitializationOutputProto GetTempUnityRlInitializationOutput() |
| | { |
| | UnityRLInitializationOutputProto output = null; |
| | foreach (var behaviorName in m_UnsentBrainKeys.Keys) |
| | { |
| | if (m_CurrentUnityRlOutput.AgentInfos.ContainsKey(behaviorName)) |
| | { |
| | if (m_CurrentUnityRlOutput.AgentInfos[behaviorName].CalculateSize() > 0) |
| | { |
| | |
| | |
| | |
| | |
| | if (output == null) |
| | { |
| | output = new UnityRLInitializationOutputProto(); |
| | } |
| |
|
| | var actionSpec = m_UnsentBrainKeys[behaviorName]; |
| | output.BrainParameters.Add(actionSpec.ToBrainParametersProto(behaviorName, true)); |
| | } |
| | } |
| | } |
| |
|
| | return output; |
| | } |
| |
|
| | void UpdateSentActionSpec(UnityRLInitializationOutputProto output) |
| | { |
| | if (output == null) |
| | { |
| | return; |
| | } |
| |
|
| | foreach (var brainProto in output.BrainParameters) |
| | { |
| | m_SentBrainKeys.Add(brainProto.BrainName); |
| | m_UnsentBrainKeys.Remove(brainProto.BrainName); |
| | } |
| | } |
| |
|
| | #endregion |
| |
|
| | #if UNITY_EDITOR |
| | |
| | |
| | |
| | |
| | void HandleOnPlayModeChanged(PlayModeStateChange state) |
| | { |
| | |
| | if (state == PlayModeStateChange.ExitingPlayMode) |
| | { |
| | Dispose(); |
| | } |
| | } |
| |
|
| | #endif |
| | } |
| | } |
| | #endif // UNITY_EDITOR || UNITY_STANDALONE |
| |
|