| | using NUnit.Framework; |
| | using UnityEngine; |
| | using Unity.MLAgents.Sensors; |
| |
|
| | namespace Unity.MLAgents.Tests |
| | { |
| | [TestFixture] |
| | public class BufferSensorTest |
| | { |
| | [Test] |
| | public void TestBufferSensor() |
| | { |
| | var bufferSensor = new BufferSensor(20, 4, "testName"); |
| | var shape = bufferSensor.GetObservationSpec().Shape; |
| | var dimProp = bufferSensor.GetObservationSpec().DimensionProperties; |
| | Assert.AreEqual(shape[0], 20); |
| | Assert.AreEqual(shape[1], 4); |
| | Assert.AreEqual(shape.Length, 2); |
| | Assert.AreEqual(dimProp[0], DimensionProperty.VariableSize); |
| | Assert.AreEqual(dimProp[1], DimensionProperty.None); |
| | Assert.AreEqual(dimProp.Length, 2); |
| |
|
| | bufferSensor.AppendObservation(new float[] { 1, 2, 3, 4 }); |
| | bufferSensor.AppendObservation(new float[] { 5, 6, 7, 8 }); |
| | var obsWriter = new ObservationWriter(); |
| | var obs = bufferSensor.GetObservationProto(obsWriter); |
| |
|
| | Assert.AreEqual(shape, InplaceArray<int>.FromList(obs.Shape)); |
| | Assert.AreEqual(obs.DimensionProperties.Count, 2); |
| | Assert.AreEqual((int)dimProp[0], obs.DimensionProperties[0]); |
| | Assert.AreEqual((int)dimProp[1], obs.DimensionProperties[1]); |
| |
|
| | for (int i = 0; i < 8; i++) |
| | { |
| | Assert.AreEqual(obs.FloatData.Data[i], i + 1); |
| | } |
| | for (int i = 8; i < 80; i++) |
| | { |
| | Assert.AreEqual(obs.FloatData.Data[i], 0); |
| | } |
| | } |
| |
|
| | [Test] |
| | public void TestBufferSensorComponent() |
| | { |
| | var agentGameObj = new GameObject("agent"); |
| | var bufferComponent = agentGameObj.AddComponent<BufferSensorComponent>(); |
| | bufferComponent.MaxNumObservables = 20; |
| | bufferComponent.ObservableSize = 4; |
| | bufferComponent.SensorName = "TestName"; |
| |
|
| | var sensor = bufferComponent.CreateSensors()[0]; |
| | var shape = sensor.GetObservationSpec().Shape; |
| |
|
| | Assert.AreEqual(shape[0], 20); |
| | Assert.AreEqual(shape[1], 4); |
| | Assert.AreEqual(shape.Length, 2); |
| |
|
| | bufferComponent.AppendObservation(new float[] { 1, 2, 3, 4 }); |
| | bufferComponent.AppendObservation(new float[] { 5, 6, 7, 8 }); |
| |
|
| | var obsWriter = new ObservationWriter(); |
| | var obs = sensor.GetObservationProto(obsWriter); |
| |
|
| | Assert.AreEqual(shape, InplaceArray<int>.FromList(obs.Shape)); |
| | Assert.AreEqual(obs.DimensionProperties.Count, 2); |
| |
|
| | Assert.AreEqual(sensor.GetName(), "TestName"); |
| |
|
| | for (int i = 0; i < 8; i++) |
| | { |
| | Assert.AreEqual(obs.FloatData.Data[i], i + 1); |
| | } |
| | for (int i = 8; i < 80; i++) |
| | { |
| | Assert.AreEqual(obs.FloatData.Data[i], 0); |
| | } |
| | } |
| | } |
| | } |
| |
|