File size: 10,333 Bytes
05c9ac2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 | using System;
using System.Collections.Generic;
using UnityEngine;
using System.IO;
namespace Unity.MLAgents.SideChannels
{
/// <summary>
/// Collection of static utilities for managing the registering/unregistering of
/// <see cref="SideChannels"/> and the sending/receiving of messages for all the channels.
/// </summary>
public static class SideChannelManager
{
static Dictionary<Guid, SideChannel> s_RegisteredChannels = new Dictionary<Guid, SideChannel>();
struct CachedSideChannelMessage
{
public Guid ChannelId;
public byte[] Message;
}
static readonly Queue<CachedSideChannelMessage> s_CachedMessages =
new Queue<CachedSideChannelMessage>();
/// <summary>
/// Register a side channel to begin sending and receiving messages. This method is
/// available for environments that have custom side channels. All built-in side
/// channels within the ML-Agents Toolkit are managed internally and do not need to
/// be explicitly registered/unregistered. A side channel may only be registered once.
/// </summary>
/// <param name="sideChannel">The side channel to register.</param>
public static void RegisterSideChannel(SideChannel sideChannel)
{
var channelId = sideChannel.ChannelId;
if (s_RegisteredChannels.ContainsKey(channelId))
{
throw new UnityAgentsException(
$"A side channel with id {channelId} is already registered. " +
"You cannot register multiple side channels of the same id.");
}
// Process any messages that we've already received for this channel ID.
var numMessages = s_CachedMessages.Count;
for (var i = 0; i < numMessages; i++)
{
var cachedMessage = s_CachedMessages.Dequeue();
if (channelId == cachedMessage.ChannelId)
{
sideChannel.ProcessMessage(cachedMessage.Message);
}
else
{
s_CachedMessages.Enqueue(cachedMessage);
}
}
s_RegisteredChannels.Add(channelId, sideChannel);
}
/// <summary>
/// Unregister a side channel to stop sending and receiving messages. This method is
/// available for environments that have custom side channels. All built-in side
/// channels within the ML-Agents Toolkit are managed internally and do not need to
/// be explicitly registered/unregistered. Unregistering a side channel that has already
/// been unregistered (or never registered in the first place) has no negative side effects.
/// Note that unregistering a side channel may not stop the Python side
/// from sending messages, but it does mean that sent messages with not result in a call
/// to <see cref="SideChannel.OnMessageReceived(IncomingMessage)"/>. Furthermore,
/// those messages will not be buffered and will, in essence, be lost.
/// </summary>
/// <param name="sideChannel">The side channel to unregister.</param>
public static void UnregisterSideChannel(SideChannel sideChannel)
{
if (s_RegisteredChannels.ContainsKey(sideChannel.ChannelId))
{
s_RegisteredChannels.Remove(sideChannel.ChannelId);
}
}
/// <summary>
/// Unregisters all the side channels from the communicator.
/// </summary>
internal static void UnregisterAllSideChannels()
{
s_RegisteredChannels = new Dictionary<Guid, SideChannel>();
}
/// <summary>
/// Returns the SideChannel of Type T if there is one registered, or null if it doesn't.
/// If there are multiple SideChannels of the same type registered, the returned instance is arbitrary.
/// </summary>
/// <typeparam name="T"></typeparam>
/// <returns></returns>
internal static T GetSideChannel<T>() where T : SideChannel
{
foreach (var sc in s_RegisteredChannels.Values)
{
if (sc.GetType() == typeof(T))
{
return (T)sc;
}
}
return null;
}
/// <summary>
/// Grabs the messages that the registered side channels will send to Python at the current step
/// into a singe byte array.
/// </summary>
/// <returns></returns>
internal static byte[] GetSideChannelMessage()
{
return GetSideChannelMessage(s_RegisteredChannels);
}
/// <summary>
/// Grabs the messages that the registered side channels will send to Python at the current step
/// into a singe byte array.
/// </summary>
/// <param name="sideChannels"> A dictionary of channel type to channel.</param>
/// <returns></returns>
internal static byte[] GetSideChannelMessage(Dictionary<Guid, SideChannel> sideChannels)
{
if (!HasOutgoingMessages(sideChannels))
{
// Early out so that we don't create the MemoryStream or BinaryWriter.
// This is the most common case.
return Array.Empty<byte>();
}
using (var memStream = new MemoryStream())
{
using (var binaryWriter = new BinaryWriter(memStream))
{
foreach (var sideChannel in sideChannels.Values)
{
var messageList = sideChannel.MessageQueue;
foreach (var message in messageList)
{
binaryWriter.Write(sideChannel.ChannelId.ToByteArray());
binaryWriter.Write(message.Length);
binaryWriter.Write(message);
}
sideChannel.MessageQueue.Clear();
}
return memStream.ToArray();
}
}
}
/// <summary>
/// Check whether any of the sidechannels have queued messages.
/// </summary>
/// <param name="sideChannels"></param>
/// <returns></returns>
static bool HasOutgoingMessages(Dictionary<Guid, SideChannel> sideChannels)
{
foreach (var sideChannel in sideChannels.Values)
{
var messageList = sideChannel.MessageQueue;
if (messageList.Count > 0)
{
return true;
}
}
return false;
}
/// <summary>
/// Separates the data received from Python into individual messages for each registered side channel.
/// </summary>
/// <param name="dataReceived">The byte array of data received from Python.</param>
internal static void ProcessSideChannelData(byte[] dataReceived)
{
ProcessSideChannelData(s_RegisteredChannels, dataReceived);
}
/// <summary>
/// Separates the data received from Python into individual messages for each registered side channel.
/// </summary>
/// <param name="sideChannels">A dictionary of channel type to channel.</param>
/// <param name="dataReceived">The byte array of data received from Python.</param>
internal static void ProcessSideChannelData(Dictionary<Guid, SideChannel> sideChannels, byte[] dataReceived)
{
while (s_CachedMessages.Count != 0)
{
var cachedMessage = s_CachedMessages.Dequeue();
if (sideChannels.ContainsKey(cachedMessage.ChannelId))
{
sideChannels[cachedMessage.ChannelId].ProcessMessage(cachedMessage.Message);
}
else
{
Debug.Log(string.Format(
"Unknown side channel data received. Channel Id is "
+ ": {0}", cachedMessage.ChannelId));
}
}
if (dataReceived.Length == 0)
{
return;
}
using (var memStream = new MemoryStream(dataReceived))
{
using (var binaryReader = new BinaryReader(memStream))
{
while (memStream.Position < memStream.Length)
{
Guid channelId = Guid.Empty;
byte[] message = null;
try
{
channelId = new Guid(binaryReader.ReadBytes(16));
var messageLength = binaryReader.ReadInt32();
message = binaryReader.ReadBytes(messageLength);
}
catch (Exception ex)
{
throw new UnityAgentsException(
"There was a problem reading a message in a SideChannel. Please make sure the " +
"version of MLAgents in Unity is compatible with the Python version. Original error : "
+ ex.Message);
}
if (sideChannels.ContainsKey(channelId))
{
sideChannels[channelId].ProcessMessage(message);
}
else
{
// Don't recognize this ID, but cache it in case the SideChannel that can handle
// it is registered before the next call to ProcessSideChannelData.
s_CachedMessages.Enqueue(new CachedSideChannelMessage
{
ChannelId = channelId,
Message = message
});
}
}
}
}
}
}
}
|