| | using System; |
| | using NUnit.Framework; |
| | using System.Collections.Generic; |
| | using System.Text; |
| | using Unity.MLAgents.SideChannels; |
| |
|
| | namespace Unity.MLAgents.Tests |
| | { |
| | public class SideChannelTests |
| | { |
| | |
| | public class TestSideChannel : SideChannel |
| | { |
| | public List<int> messagesReceived = new List<int>(); |
| |
|
| | public TestSideChannel() |
| | { |
| | ChannelId = new Guid("6afa2c06-4f82-11ea-b238-784f4387d1f7"); |
| | } |
| |
|
| | protected override void OnMessageReceived(IncomingMessage msg) |
| | { |
| | messagesReceived.Add(msg.ReadInt32()); |
| | } |
| |
|
| | public void SendInt(int value) |
| | { |
| | using (var msg = new OutgoingMessage()) |
| | { |
| | msg.WriteInt32(value); |
| | QueueMessageToSend(msg); |
| | } |
| | } |
| | } |
| |
|
| | [Test] |
| | public void TestIntegerSideChannel() |
| | { |
| | var intSender = new TestSideChannel(); |
| | var intReceiver = new TestSideChannel(); |
| | var dictSender = new Dictionary<Guid, SideChannel> { { intSender.ChannelId, intSender } }; |
| | var dictReceiver = new Dictionary<Guid, SideChannel> { { intReceiver.ChannelId, intReceiver } }; |
| |
|
| | intSender.SendInt(4); |
| | intSender.SendInt(5); |
| | intSender.SendInt(6); |
| |
|
| | byte[] fakeData = SideChannelManager.GetSideChannelMessage(dictSender); |
| | SideChannelManager.ProcessSideChannelData(dictReceiver, fakeData); |
| |
|
| | Assert.AreEqual(intReceiver.messagesReceived[0], 4); |
| | Assert.AreEqual(intReceiver.messagesReceived[1], 5); |
| | Assert.AreEqual(intReceiver.messagesReceived[2], 6); |
| | } |
| |
|
| | [Test] |
| | public void TestRawBytesSideChannel() |
| | { |
| | var str1 = "Test string"; |
| | var str2 = "Test string, second"; |
| |
|
| | var strSender = new RawBytesChannel(new Guid("9a5b8954-4f82-11ea-b238-784f4387d1f7")); |
| | var strReceiver = new RawBytesChannel(new Guid("9a5b8954-4f82-11ea-b238-784f4387d1f7")); |
| | var dictSender = new Dictionary<Guid, SideChannel> { { strSender.ChannelId, strSender } }; |
| | var dictReceiver = new Dictionary<Guid, SideChannel> { { strReceiver.ChannelId, strReceiver } }; |
| |
|
| | strSender.SendRawBytes(Encoding.ASCII.GetBytes(str1)); |
| | strSender.SendRawBytes(Encoding.ASCII.GetBytes(str2)); |
| |
|
| | byte[] fakeData = SideChannelManager.GetSideChannelMessage(dictSender); |
| | SideChannelManager.ProcessSideChannelData(dictReceiver, fakeData); |
| |
|
| | var messages = strReceiver.GetAndClearReceivedMessages(); |
| |
|
| | Assert.AreEqual(messages.Count, 2); |
| | Assert.AreEqual(Encoding.ASCII.GetString(messages[0]), str1); |
| | Assert.AreEqual(Encoding.ASCII.GetString(messages[1]), str2); |
| | } |
| |
|
| | [Test] |
| | public void TestFloatPropertiesSideChannel() |
| | { |
| | var k1 = "gravity"; |
| | var k2 = "length"; |
| | int wasCalled = 0; |
| |
|
| | var propA = new FloatPropertiesChannel(); |
| | var propB = new FloatPropertiesChannel(); |
| | var dictReceiver = new Dictionary<Guid, SideChannel> { { propA.ChannelId, propA } }; |
| | var dictSender = new Dictionary<Guid, SideChannel> { { propB.ChannelId, propB } }; |
| |
|
| | propA.RegisterCallback(k1, f => { wasCalled++; }); |
| | var tmp = propB.GetWithDefault(k2, 3.0f); |
| | Assert.AreEqual(tmp, 3.0f); |
| | propB.Set(k2, 1.0f); |
| | tmp = propB.GetWithDefault(k2, 3.0f); |
| | Assert.AreEqual(tmp, 1.0f); |
| |
|
| | byte[] fakeData = SideChannelManager.GetSideChannelMessage(dictSender); |
| | SideChannelManager.ProcessSideChannelData(dictReceiver, fakeData); |
| |
|
| | tmp = propA.GetWithDefault(k2, 3.0f); |
| | Assert.AreEqual(tmp, 1.0f); |
| |
|
| | Assert.AreEqual(wasCalled, 0); |
| | propB.Set(k1, 1.0f); |
| | Assert.AreEqual(wasCalled, 0); |
| | fakeData = SideChannelManager.GetSideChannelMessage(dictSender); |
| | SideChannelManager.ProcessSideChannelData(dictReceiver, fakeData); |
| | Assert.AreEqual(wasCalled, 1); |
| |
|
| | var keysA = propA.Keys(); |
| | Assert.AreEqual(2, keysA.Count); |
| | Assert.IsTrue(keysA.Contains(k1)); |
| | Assert.IsTrue(keysA.Contains(k2)); |
| |
|
| | var keysB = propA.Keys(); |
| | Assert.AreEqual(2, keysB.Count); |
| | Assert.IsTrue(keysB.Contains(k1)); |
| | Assert.IsTrue(keysB.Contains(k2)); |
| | } |
| |
|
| | [Test] |
| | public void TestOutgoingMessageRawBytes() |
| | { |
| | |
| | |
| | var msg = new OutgoingMessage(); |
| | msg.WriteInt32(42); |
| | msg.WriteFloat32(1.0f); |
| |
|
| | var data = new byte[] { 1, 2, 3, 4 }; |
| | msg.SetRawBytes(data); |
| |
|
| | var result = msg.ToByteArray(); |
| | Assert.AreEqual(data, result); |
| | } |
| |
|
| | [Test] |
| | public void TestMessageReadWrites() |
| | { |
| | var boolVal = true; |
| | var intVal = 1337; |
| | var floatVal = 4.2f; |
| | var floatListVal = new float[] { 1001, 1002 }; |
| | var stringVal = "mlagents!"; |
| |
|
| | IncomingMessage incomingMsg; |
| | using (var outgoingMsg = new OutgoingMessage()) |
| | { |
| | outgoingMsg.WriteBoolean(boolVal); |
| | outgoingMsg.WriteInt32(intVal); |
| | outgoingMsg.WriteFloat32(floatVal); |
| | outgoingMsg.WriteString(stringVal); |
| | outgoingMsg.WriteFloatList(floatListVal); |
| |
|
| | incomingMsg = new IncomingMessage(outgoingMsg.ToByteArray()); |
| | } |
| |
|
| | Assert.AreEqual(boolVal, incomingMsg.ReadBoolean()); |
| | Assert.AreEqual(intVal, incomingMsg.ReadInt32()); |
| | Assert.AreEqual(floatVal, incomingMsg.ReadFloat32()); |
| | Assert.AreEqual(stringVal, incomingMsg.ReadString()); |
| | Assert.AreEqual(floatListVal, incomingMsg.ReadFloatList()); |
| | } |
| |
|
| | [Test] |
| | public void TestMessageReadDefaults() |
| | { |
| | |
| | IncomingMessage incomingMsg; |
| | using (var outgoingMsg = new OutgoingMessage()) |
| | { |
| | incomingMsg = new IncomingMessage(outgoingMsg.ToByteArray()); |
| | } |
| |
|
| | Assert.AreEqual(false, incomingMsg.ReadBoolean()); |
| | Assert.AreEqual(true, incomingMsg.ReadBoolean(defaultValue: true)); |
| |
|
| | Assert.AreEqual(0, incomingMsg.ReadInt32()); |
| | Assert.AreEqual(42, incomingMsg.ReadInt32(defaultValue: 42)); |
| |
|
| | Assert.AreEqual(0.0f, incomingMsg.ReadFloat32()); |
| | Assert.AreEqual(1337.0f, incomingMsg.ReadFloat32(defaultValue: 1337.0f)); |
| |
|
| | Assert.AreEqual(default(string), incomingMsg.ReadString()); |
| | Assert.AreEqual("foo", incomingMsg.ReadString(defaultValue: "foo")); |
| |
|
| | Assert.AreEqual(default(float[]), incomingMsg.ReadFloatList()); |
| | Assert.AreEqual(new float[] { 1001, 1002 }, incomingMsg.ReadFloatList(new float[] { 1001, 1002 })); |
| | } |
| | } |
| | } |
| |
|