| | using System; |
| | using System.Collections.Generic; |
| | using UnityEngine; |
| | using System.IO; |
| |
|
| | namespace Unity.MLAgents.SideChannels |
| | { |
| | |
| | |
| | |
| | |
| | public static class SideChannelManager |
| | { |
| | static Dictionary<Guid, SideChannel> s_RegisteredChannels = new Dictionary<Guid, SideChannel>(); |
| |
|
| | struct CachedSideChannelMessage |
| | { |
| | public Guid ChannelId; |
| | public byte[] Message; |
| | } |
| |
|
| | static readonly Queue<CachedSideChannelMessage> s_CachedMessages = |
| | new Queue<CachedSideChannelMessage>(); |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | public static void RegisterSideChannel(SideChannel sideChannel) |
| | { |
| | var channelId = sideChannel.ChannelId; |
| | if (s_RegisteredChannels.ContainsKey(channelId)) |
| | { |
| | throw new UnityAgentsException( |
| | $"A side channel with id {channelId} is already registered. " + |
| | "You cannot register multiple side channels of the same id."); |
| | } |
| |
|
| | |
| | var numMessages = s_CachedMessages.Count; |
| | for (var i = 0; i < numMessages; i++) |
| | { |
| | var cachedMessage = s_CachedMessages.Dequeue(); |
| | if (channelId == cachedMessage.ChannelId) |
| | { |
| | sideChannel.ProcessMessage(cachedMessage.Message); |
| | } |
| | else |
| | { |
| | s_CachedMessages.Enqueue(cachedMessage); |
| | } |
| | } |
| | s_RegisteredChannels.Add(channelId, sideChannel); |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | public static void UnregisterSideChannel(SideChannel sideChannel) |
| | { |
| | if (s_RegisteredChannels.ContainsKey(sideChannel.ChannelId)) |
| | { |
| | s_RegisteredChannels.Remove(sideChannel.ChannelId); |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | internal static void UnregisterAllSideChannels() |
| | { |
| | s_RegisteredChannels = new Dictionary<Guid, SideChannel>(); |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | internal static T GetSideChannel<T>() where T : SideChannel |
| | { |
| | foreach (var sc in s_RegisteredChannels.Values) |
| | { |
| | if (sc.GetType() == typeof(T)) |
| | { |
| | return (T)sc; |
| | } |
| | } |
| | return null; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | internal static byte[] GetSideChannelMessage() |
| | { |
| | return GetSideChannelMessage(s_RegisteredChannels); |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | internal static byte[] GetSideChannelMessage(Dictionary<Guid, SideChannel> sideChannels) |
| | { |
| | if (!HasOutgoingMessages(sideChannels)) |
| | { |
| | |
| | |
| | return Array.Empty<byte>(); |
| | } |
| |
|
| | using (var memStream = new MemoryStream()) |
| | { |
| | using (var binaryWriter = new BinaryWriter(memStream)) |
| | { |
| | foreach (var sideChannel in sideChannels.Values) |
| | { |
| | var messageList = sideChannel.MessageQueue; |
| | foreach (var message in messageList) |
| | { |
| | binaryWriter.Write(sideChannel.ChannelId.ToByteArray()); |
| | binaryWriter.Write(message.Length); |
| | binaryWriter.Write(message); |
| | } |
| | sideChannel.MessageQueue.Clear(); |
| | } |
| | return memStream.ToArray(); |
| | } |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | static bool HasOutgoingMessages(Dictionary<Guid, SideChannel> sideChannels) |
| | { |
| | foreach (var sideChannel in sideChannels.Values) |
| | { |
| | var messageList = sideChannel.MessageQueue; |
| | if (messageList.Count > 0) |
| | { |
| | return true; |
| | } |
| | } |
| |
|
| | return false; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | internal static void ProcessSideChannelData(byte[] dataReceived) |
| | { |
| | ProcessSideChannelData(s_RegisteredChannels, dataReceived); |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | internal static void ProcessSideChannelData(Dictionary<Guid, SideChannel> sideChannels, byte[] dataReceived) |
| | { |
| | while (s_CachedMessages.Count != 0) |
| | { |
| | var cachedMessage = s_CachedMessages.Dequeue(); |
| | if (sideChannels.ContainsKey(cachedMessage.ChannelId)) |
| | { |
| | sideChannels[cachedMessage.ChannelId].ProcessMessage(cachedMessage.Message); |
| | } |
| | else |
| | { |
| | Debug.Log(string.Format( |
| | "Unknown side channel data received. Channel Id is " |
| | + ": {0}", cachedMessage.ChannelId)); |
| | } |
| | } |
| |
|
| | if (dataReceived.Length == 0) |
| | { |
| | return; |
| | } |
| | using (var memStream = new MemoryStream(dataReceived)) |
| | { |
| | using (var binaryReader = new BinaryReader(memStream)) |
| | { |
| | while (memStream.Position < memStream.Length) |
| | { |
| | Guid channelId = Guid.Empty; |
| | byte[] message = null; |
| | try |
| | { |
| | channelId = new Guid(binaryReader.ReadBytes(16)); |
| | var messageLength = binaryReader.ReadInt32(); |
| | message = binaryReader.ReadBytes(messageLength); |
| | } |
| | catch (Exception ex) |
| | { |
| | throw new UnityAgentsException( |
| | "There was a problem reading a message in a SideChannel. Please make sure the " + |
| | "version of MLAgents in Unity is compatible with the Python version. Original error : " |
| | + ex.Message); |
| | } |
| | if (sideChannels.ContainsKey(channelId)) |
| | { |
| | sideChannels[channelId].ProcessMessage(message); |
| | } |
| | else |
| | { |
| | |
| | |
| | s_CachedMessages.Enqueue(new CachedSideChannelMessage |
| | { |
| | ChannelId = channelId, |
| | Message = message |
| | }); |
| | } |
| | } |
| | } |
| | } |
| | } |
| | } |
| | } |
| |
|