| | using NUnit.Framework; |
| | using System.IO; |
| | using Unity.MLAgents.SideChannels; |
| |
|
| | namespace Unity.MLAgents.Tests |
| | { |
| | public class SamplerTests |
| | { |
| | const int k_Seed = 1337; |
| | const double k_Epsilon = 0.0001; |
| | EnvironmentParametersChannel m_Channel; |
| |
|
| | public SamplerTests() |
| | { |
| | m_Channel = SideChannelManager.GetSideChannel<EnvironmentParametersChannel>(); |
| | |
| | if (m_Channel == null) |
| | { |
| | m_Channel = new EnvironmentParametersChannel(); |
| | SideChannelManager.RegisterSideChannel(m_Channel); |
| | } |
| | } |
| |
|
| | [Test] |
| | public void UniformSamplerTest() |
| | { |
| | float min_value = 1.0f; |
| | float max_value = 2.0f; |
| | string parameter = "parameter1"; |
| | using (var outgoingMsg = new OutgoingMessage()) |
| | { |
| | outgoingMsg.WriteString(parameter); |
| | |
| | outgoingMsg.WriteInt32(1); |
| | outgoingMsg.WriteInt32(k_Seed); |
| | outgoingMsg.WriteInt32((int)SamplerType.Uniform); |
| | outgoingMsg.WriteFloat32(min_value); |
| | outgoingMsg.WriteFloat32(max_value); |
| | byte[] message = GetByteMessage(m_Channel, outgoingMsg); |
| | SideChannelManager.ProcessSideChannelData(message); |
| | } |
| | Assert.AreEqual(1.208888f, m_Channel.GetWithDefault(parameter, 1.0f), k_Epsilon); |
| | Assert.AreEqual(1.118017f, m_Channel.GetWithDefault(parameter, 1.0f), k_Epsilon); |
| | } |
| |
|
| | [Test] |
| | public void GaussianSamplerTest() |
| | { |
| | float mean = 3.0f; |
| | float stddev = 0.2f; |
| | string parameter = "parameter2"; |
| | using (var outgoingMsg = new OutgoingMessage()) |
| | { |
| | outgoingMsg.WriteString(parameter); |
| | |
| | outgoingMsg.WriteInt32(1); |
| | outgoingMsg.WriteInt32(k_Seed); |
| | outgoingMsg.WriteInt32((int)SamplerType.Gaussian); |
| | outgoingMsg.WriteFloat32(mean); |
| | outgoingMsg.WriteFloat32(stddev); |
| | byte[] message = GetByteMessage(m_Channel, outgoingMsg); |
| | SideChannelManager.ProcessSideChannelData(message); |
| | } |
| | Assert.AreEqual(2.936162f, m_Channel.GetWithDefault(parameter, 1.0f), k_Epsilon); |
| | Assert.AreEqual(2.951348f, m_Channel.GetWithDefault(parameter, 1.0f), k_Epsilon); |
| | } |
| |
|
| | [Test] |
| | public void MultiRangeUniformSamplerTest() |
| | { |
| | float[] intervals = new float[4]; |
| | intervals[0] = 1.2f; |
| | intervals[1] = 2f; |
| | intervals[2] = 3.2f; |
| | intervals[3] = 4.1f; |
| | string parameter = "parameter3"; |
| | using (var outgoingMsg = new OutgoingMessage()) |
| | { |
| | outgoingMsg.WriteString(parameter); |
| | |
| | outgoingMsg.WriteInt32(1); |
| | outgoingMsg.WriteInt32(k_Seed); |
| | outgoingMsg.WriteInt32((int)SamplerType.MultiRangeUniform); |
| | outgoingMsg.WriteFloatList(intervals); |
| | byte[] message = GetByteMessage(m_Channel, outgoingMsg); |
| | SideChannelManager.ProcessSideChannelData(message); |
| | } |
| | Assert.AreEqual(3.387999f, m_Channel.GetWithDefault(parameter, 1.0f), k_Epsilon); |
| | Assert.AreEqual(1.294413f, m_Channel.GetWithDefault(parameter, 1.0f), k_Epsilon); |
| | } |
| |
|
| | internal static byte[] GetByteMessage(SideChannel sideChannel, OutgoingMessage msg) |
| | { |
| | byte[] message = msg.ToByteArray(); |
| | using (var memStream = new MemoryStream()) |
| | { |
| | using (var binaryWriter = new BinaryWriter(memStream)) |
| | { |
| | binaryWriter.Write(sideChannel.ChannelId.ToByteArray()); |
| | binaryWriter.Write(message.Length); |
| | binaryWriter.Write(message); |
| | } |
| | return memStream.ToArray(); |
| | } |
| | } |
| | } |
| | } |
| |
|