| | using System.Collections.Generic; |
| | using System.Text.RegularExpressions; |
| | using NUnit.Framework; |
| | using UnityEngine; |
| | using UnityEngine.TestTools; |
| | using Unity.MLAgents.Sensors; |
| |
|
| | namespace Unity.MLAgents.Tests |
| | { |
| | public class DummySensor : ISensor |
| | { |
| | string m_Name = "DummySensor"; |
| | ObservationSpec m_ObservationSpec; |
| |
|
| | public DummySensor(int dim1) |
| | { |
| | m_ObservationSpec = ObservationSpec.Vector(dim1); |
| | } |
| |
|
| | public DummySensor(int dim1, int dim2) |
| | { |
| | m_ObservationSpec = ObservationSpec.VariableLength(dim1, dim2); |
| | } |
| |
|
| | public DummySensor(int dim1, int dim2, int dim3) |
| | { |
| | m_ObservationSpec = ObservationSpec.Visual(dim1, dim2, dim3); |
| | } |
| |
|
| | public string GetName() |
| | { |
| | return m_Name; |
| | } |
| |
|
| | public ObservationSpec GetObservationSpec() |
| | { |
| | return m_ObservationSpec; |
| | } |
| |
|
| | public byte[] GetCompressedObservation() |
| | { |
| | return null; |
| | } |
| |
|
| | public int Write(ObservationWriter writer) |
| | { |
| | return this.ObservationSize(); |
| | } |
| |
|
| | public void Update() { } |
| | public void Reset() { } |
| |
|
| | public CompressionSpec GetCompressionSpec() |
| | { |
| | return CompressionSpec.Default(); |
| | } |
| | } |
| |
|
| | public class SensorShapeValidatorTests |
| | { |
| | [Test] |
| | public void TestShapesAgree() |
| | { |
| | var validator = new SensorShapeValidator(); |
| | var sensorList1 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5, 6) }; |
| | validator.ValidateSensors(sensorList1); |
| |
|
| | var sensorList2 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5, 6) }; |
| | validator.ValidateSensors(sensorList2); |
| | } |
| |
|
| | [Test] |
| | public void TestNumSensorMismatch() |
| | { |
| | var validator = new SensorShapeValidator(); |
| | var sensorList1 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5, 6) }; |
| | validator.ValidateSensors(sensorList1); |
| |
|
| | var sensorList2 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), }; |
| | LogAssert.Expect(LogType.Assert, "Number of Sensors must match. 3 != 2"); |
| | validator.ValidateSensors(sensorList2); |
| |
|
| | |
| | validator = new SensorShapeValidator(); |
| | validator.ValidateSensors(sensorList2); |
| | LogAssert.Expect(LogType.Assert, "Number of Sensors must match. 2 != 3"); |
| | validator.ValidateSensors(sensorList1); |
| | } |
| |
|
| | [Test] |
| | public void TestDimensionMismatch() |
| | { |
| | var validator = new SensorShapeValidator(); |
| | var sensorList1 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5, 6) }; |
| | validator.ValidateSensors(sensorList1); |
| |
|
| | var sensorList2 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5) }; |
| | LogAssert.Expect(LogType.Assert, new Regex("Sensor shapes must match.*")); |
| | validator.ValidateSensors(sensorList2); |
| |
|
| | |
| | validator = new SensorShapeValidator(); |
| | validator.ValidateSensors(sensorList2); |
| | LogAssert.Expect(LogType.Assert, new Regex("Sensor shapes must match.*")); |
| | validator.ValidateSensors(sensorList1); |
| | } |
| |
|
| | [Test] |
| | public void TestSizeMismatch() |
| | { |
| | var validator = new SensorShapeValidator(); |
| | var sensorList1 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5, 6) }; |
| | validator.ValidateSensors(sensorList1); |
| |
|
| | var sensorList2 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5, 7) }; |
| | LogAssert.Expect(LogType.Assert, new Regex("Sensor shapes must match.*")); |
| | validator.ValidateSensors(sensorList2); |
| |
|
| | |
| | validator = new SensorShapeValidator(); |
| | validator.ValidateSensors(sensorList2); |
| | LogAssert.Expect(LogType.Assert, new Regex("Sensor shapes must match.*")); |
| | validator.ValidateSensors(sensorList1); |
| | } |
| |
|
| | [Test] |
| | public void TestEverythingMismatch() |
| | { |
| | var validator = new SensorShapeValidator(); |
| | var sensorList1 = new List<ISensor>() { new DummySensor(1), new DummySensor(2, 3), new DummySensor(4, 5, 6) }; |
| | validator.ValidateSensors(sensorList1); |
| |
|
| | var sensorList2 = new List<ISensor>() { new DummySensor(1), new DummySensor(9) }; |
| | LogAssert.Expect(LogType.Assert, "Number of Sensors must match. 3 != 2"); |
| | LogAssert.Expect(LogType.Assert, new Regex("Sensor shapes must match.*")); |
| | validator.ValidateSensors(sensorList2); |
| |
|
| | |
| | validator = new SensorShapeValidator(); |
| | validator.ValidateSensors(sensorList2); |
| | LogAssert.Expect(LogType.Assert, "Number of Sensors must match. 2 != 3"); |
| | LogAssert.Expect(LogType.Assert, new Regex("Sensor shapes must match.*")); |
| | validator.ValidateSensors(sensorList1); |
| | } |
| | } |
| | } |
| |
|