| | using NUnit.Framework; |
| | using UnityEngine; |
| | using Unity.MLAgents.Sensors; |
| |
|
| | namespace Unity.MLAgents.Tests |
| | { |
| | public class VectorSensorTests |
| | { |
| | [Test] |
| | public void TestCtor() |
| | { |
| | ISensor sensor = new VectorSensor(4); |
| | Assert.AreEqual("VectorSensor_size4", sensor.GetName()); |
| |
|
| | sensor = new VectorSensor(3, "test_sensor"); |
| | Assert.AreEqual("test_sensor", sensor.GetName()); |
| | } |
| |
|
| | [Test] |
| | public void TestWrite() |
| | { |
| | var sensor = new VectorSensor(4); |
| | sensor.AddObservation(1f); |
| | sensor.AddObservation(2f); |
| | sensor.AddObservation(3f); |
| | sensor.AddObservation(4f); |
| |
|
| | SensorTestHelper.CompareObservation(sensor, new[] { 1f, 2f, 3f, 4f }); |
| | |
| | SensorTestHelper.CompareObservation(sensor, new[] { 1f, 2f, 3f, 4f }); |
| |
|
| | |
| | sensor.Update(); |
| | SensorTestHelper.CompareObservation(sensor, new[] { 0f, 0f, 0f, 0f }); |
| | } |
| |
|
| | [Test] |
| | public void TestAddObservationFloat() |
| | { |
| | var sensor = new VectorSensor(1); |
| | sensor.AddObservation(1.2f); |
| | SensorTestHelper.CompareObservation(sensor, new[] { 1.2f }); |
| | } |
| |
|
| | [Test] |
| | public void TestObservationType() |
| | { |
| | var sensor = new VectorSensor(1); |
| | var spec = sensor.GetObservationSpec(); |
| | Assert.AreEqual((int)spec.ObservationType, (int)ObservationType.Default); |
| | sensor = new VectorSensor(1, observationType: ObservationType.Default); |
| | spec = sensor.GetObservationSpec(); |
| | Assert.AreEqual((int)spec.ObservationType, (int)ObservationType.Default); |
| | sensor = new VectorSensor(1, observationType: ObservationType.GoalSignal); |
| | spec = sensor.GetObservationSpec(); |
| | Assert.AreEqual((int)spec.ObservationType, (int)ObservationType.GoalSignal); |
| | } |
| |
|
| | [Test] |
| | public void TestAddObservationInt() |
| | { |
| | var sensor = new VectorSensor(1); |
| | sensor.AddObservation(42); |
| | SensorTestHelper.CompareObservation(sensor, new[] { 42f }); |
| | } |
| |
|
| | [Test] |
| | public void TestAddObservationVec() |
| | { |
| | var sensor = new VectorSensor(3); |
| | sensor.AddObservation(new Vector3(1, 2, 3)); |
| | SensorTestHelper.CompareObservation(sensor, new[] { 1f, 2f, 3f }); |
| |
|
| | sensor = new VectorSensor(2); |
| | sensor.AddObservation(new Vector2(4, 5)); |
| | SensorTestHelper.CompareObservation(sensor, new[] { 4f, 5f }); |
| | } |
| |
|
| | [Test] |
| | public void TestAddObservationQuaternion() |
| | { |
| | var sensor = new VectorSensor(4); |
| | sensor.AddObservation(Quaternion.identity); |
| | SensorTestHelper.CompareObservation(sensor, new[] { 0f, 0f, 0f, 1f }); |
| | } |
| |
|
| | [Test] |
| | public void TestWriteEnumerable() |
| | { |
| | var sensor = new VectorSensor(4); |
| | sensor.AddObservation(new[] { 1f, 2f, 3f, 4f }); |
| |
|
| | SensorTestHelper.CompareObservation(sensor, new[] { 1f, 2f, 3f, 4f }); |
| | } |
| |
|
| | [Test] |
| | public void TestAddObservationBool() |
| | { |
| | var sensor = new VectorSensor(1); |
| | sensor.AddObservation(true); |
| | SensorTestHelper.CompareObservation(sensor, new[] { 1f }); |
| | } |
| |
|
| | [Test] |
| | public void TestAddObservationOneHot() |
| | { |
| | var sensor = new VectorSensor(4); |
| | sensor.AddOneHotObservation(2, 4); |
| | SensorTestHelper.CompareObservation(sensor, new[] { 0f, 0f, 1f, 0f }); |
| | } |
| |
|
| | [Test] |
| | public void TestWriteTooMany() |
| | { |
| | var sensor = new VectorSensor(2); |
| | sensor.AddObservation(new[] { 1f, 2f, 3f, 4f }); |
| |
|
| | SensorTestHelper.CompareObservation(sensor, new[] { 1f, 2f }); |
| | } |
| |
|
| | [Test] |
| | public void TestWriteNotEnough() |
| | { |
| | var sensor = new VectorSensor(4); |
| | sensor.AddObservation(new[] { 1f, 2f }); |
| |
|
| | |
| | SensorTestHelper.CompareObservation(sensor, new[] { 1f, 2f, 0f, 0f }); |
| | } |
| | } |
| | } |
| |
|