| | using System; |
| | using System.Collections.Generic; |
| | using NUnit.Framework; |
| | using UnityEngine; |
| | using Unity.MLAgents.Sensors; |
| | using Unity.MLAgents.Sensors.Reflection; |
| |
|
| | namespace Unity.MLAgents.Tests |
| | { |
| | [TestFixture] |
| | public class ObservableAttributeTests |
| | { |
| | public enum TestEnum |
| | { |
| | ValueA = -100, |
| | ValueB = 1, |
| | ValueC = 42, |
| | } |
| |
|
| | [Flags] |
| | public enum TestFlags |
| | { |
| | FlagA = 1, |
| | FlagB = 2, |
| | FlagC = 4 |
| | } |
| |
|
| | class TestClass |
| | { |
| | |
| | int m_NonObservableInt; |
| | float m_NonObservableFloat; |
| |
|
| | |
| | |
| | |
| | [Observable] |
| | public int m_IntMember; |
| |
|
| | int m_IntProperty; |
| |
|
| | [Observable] |
| | public int IntProperty |
| | { |
| | get => m_IntProperty; |
| | set => m_IntProperty = value; |
| | } |
| |
|
| | |
| | |
| | |
| | [Observable("floatMember")] |
| | public float m_FloatMember; |
| |
|
| | float m_FloatProperty; |
| | [Observable("floatProperty")] |
| | public float FloatProperty |
| | { |
| | get => m_FloatProperty; |
| | set => m_FloatProperty = value; |
| | } |
| |
|
| | |
| | |
| | |
| | [Observable("boolMember")] |
| | public bool m_BoolMember; |
| |
|
| | bool m_BoolProperty; |
| | [Observable("boolProperty")] |
| | public bool BoolProperty |
| | { |
| | get => m_BoolProperty; |
| | set => m_BoolProperty = value; |
| | } |
| |
|
| | |
| | |
| | |
| |
|
| | [Observable("vector2Member")] |
| | public Vector2 m_Vector2Member; |
| |
|
| | Vector2 m_Vector2Property; |
| |
|
| | [Observable("vector2Property")] |
| | public Vector2 Vector2Property |
| | { |
| | get => m_Vector2Property; |
| | set => m_Vector2Property = value; |
| | } |
| |
|
| | |
| | |
| | |
| | [Observable("vector3Member")] |
| | public Vector3 m_Vector3Member; |
| |
|
| | Vector3 m_Vector3Property; |
| |
|
| | [Observable("vector3Property")] |
| | public Vector3 Vector3Property |
| | { |
| | get => m_Vector3Property; |
| | set => m_Vector3Property = value; |
| | } |
| |
|
| | |
| | |
| | |
| |
|
| | [Observable("vector4Member")] |
| | public Vector4 m_Vector4Member; |
| |
|
| | Vector4 m_Vector4Property; |
| |
|
| | [Observable("vector4Property")] |
| | public Vector4 Vector4Property |
| | { |
| | get => m_Vector4Property; |
| | set => m_Vector4Property = value; |
| | } |
| |
|
| | |
| | |
| | |
| | [Observable("quaternionMember")] |
| | public Quaternion m_QuaternionMember; |
| |
|
| | Quaternion m_QuaternionProperty; |
| |
|
| | [Observable("quaternionProperty")] |
| | public Quaternion QuaternionProperty |
| | { |
| | get => m_QuaternionProperty; |
| | set => m_QuaternionProperty = value; |
| | } |
| |
|
| | |
| | |
| | |
| |
|
| | [Observable("enumMember")] |
| | public TestEnum m_EnumMember = TestEnum.ValueA; |
| |
|
| | TestEnum m_EnumProperty = TestEnum.ValueC; |
| |
|
| | [Observable("enumProperty")] |
| | public TestEnum EnumProperty |
| | { |
| | get => m_EnumProperty; |
| | set => m_EnumProperty = value; |
| | } |
| |
|
| | [Observable("badEnumMember")] |
| | public TestEnum m_BadEnumMember = (TestEnum)1337; |
| |
|
| | |
| | |
| | |
| | [Observable("flagMember")] |
| | public TestFlags m_FlagMember = TestFlags.FlagA; |
| |
|
| | TestFlags m_FlagProperty = TestFlags.FlagB | TestFlags.FlagC; |
| |
|
| | [Observable("flagProperty")] |
| | public TestFlags FlagProperty |
| | { |
| | get => m_FlagProperty; |
| | set => m_FlagProperty = value; |
| | } |
| | } |
| |
|
| | [Test] |
| | public void TestGetObservableSensors() |
| | { |
| | var testClass = new TestClass(); |
| | testClass.m_IntMember = 1; |
| | testClass.IntProperty = 2; |
| |
|
| | testClass.m_FloatMember = 1.1f; |
| | testClass.FloatProperty = 1.2f; |
| |
|
| | testClass.m_BoolMember = true; |
| | testClass.BoolProperty = true; |
| |
|
| | testClass.m_Vector2Member = new Vector2(2.0f, 2.1f); |
| | testClass.Vector2Property = new Vector2(2.2f, 2.3f); |
| |
|
| | testClass.m_Vector3Member = new Vector3(3.0f, 3.1f, 3.2f); |
| | testClass.Vector3Property = new Vector3(3.3f, 3.4f, 3.5f); |
| |
|
| | testClass.m_Vector4Member = new Vector4(4.0f, 4.1f, 4.2f, 4.3f); |
| | testClass.Vector4Property = new Vector4(4.4f, 4.5f, 4.5f, 4.7f); |
| |
|
| | testClass.m_Vector4Member = new Vector4(4.0f, 4.1f, 4.2f, 4.3f); |
| | testClass.Vector4Property = new Vector4(4.4f, 4.5f, 4.5f, 4.7f); |
| |
|
| | testClass.m_QuaternionMember = new Quaternion(5.0f, 5.1f, 5.2f, 5.3f); |
| | testClass.QuaternionProperty = new Quaternion(5.4f, 5.5f, 5.5f, 5.7f); |
| |
|
| | var sensors = ObservableAttribute.CreateObservableSensors(testClass, false); |
| |
|
| | var sensorsByName = new Dictionary<string, ISensor>(); |
| | foreach (var sensor in sensors) |
| | { |
| | sensorsByName[sensor.GetName()] = sensor; |
| | } |
| |
|
| | SensorTestHelper.CompareObservation(sensorsByName["ObservableAttribute:TestClass.m_IntMember"], new[] { 1.0f }); |
| | SensorTestHelper.CompareObservation(sensorsByName["ObservableAttribute:TestClass.IntProperty"], new[] { 2.0f }); |
| |
|
| | SensorTestHelper.CompareObservation(sensorsByName["floatMember"], new[] { 1.1f }); |
| | SensorTestHelper.CompareObservation(sensorsByName["floatProperty"], new[] { 1.2f }); |
| |
|
| | SensorTestHelper.CompareObservation(sensorsByName["boolMember"], new[] { 1.0f }); |
| | SensorTestHelper.CompareObservation(sensorsByName["boolProperty"], new[] { 1.0f }); |
| |
|
| | SensorTestHelper.CompareObservation(sensorsByName["vector2Member"], new[] { 2.0f, 2.1f }); |
| | SensorTestHelper.CompareObservation(sensorsByName["vector2Property"], new[] { 2.2f, 2.3f }); |
| |
|
| | SensorTestHelper.CompareObservation(sensorsByName["vector3Member"], new[] { 3.0f, 3.1f, 3.2f }); |
| | SensorTestHelper.CompareObservation(sensorsByName["vector3Property"], new[] { 3.3f, 3.4f, 3.5f }); |
| |
|
| | SensorTestHelper.CompareObservation(sensorsByName["vector4Member"], new[] { 4.0f, 4.1f, 4.2f, 4.3f }); |
| | SensorTestHelper.CompareObservation(sensorsByName["vector4Property"], new[] { 4.4f, 4.5f, 4.5f, 4.7f }); |
| |
|
| | SensorTestHelper.CompareObservation(sensorsByName["quaternionMember"], new[] { 5.0f, 5.1f, 5.2f, 5.3f }); |
| | SensorTestHelper.CompareObservation(sensorsByName["quaternionProperty"], new[] { 5.4f, 5.5f, 5.5f, 5.7f }); |
| |
|
| | |
| | SensorTestHelper.CompareObservation(sensorsByName["enumMember"], new[] { 0.0f, 0.0f, 1.0f }); |
| | SensorTestHelper.CompareObservation(sensorsByName["enumProperty"], new[] { 0.0f, 1.0f, 0.0f }); |
| | SensorTestHelper.CompareObservation(sensorsByName["badEnumMember"], new[] { 0.0f, 0.0f, 0.0f }); |
| |
|
| | SensorTestHelper.CompareObservation(sensorsByName["flagMember"], new[] { 1.0f, 0.0f, 0.0f }); |
| | SensorTestHelper.CompareObservation(sensorsByName["flagProperty"], new[] { 0.0f, 1.0f, 1.0f }); |
| | } |
| |
|
| | [Test] |
| | public void TestGetTotalObservationSize() |
| | { |
| | var testClass = new TestClass(); |
| | var errors = new List<string>(); |
| | var expectedObsSize = 2 * ( |
| | 1 |
| | + 1 |
| | + 1 |
| | + 2 |
| | + 3 |
| | + 4 |
| | + 4 |
| | + 3 |
| | + 3 |
| | ) |
| | + 3; |
| | Assert.AreEqual(expectedObsSize, ObservableAttribute.GetTotalObservationSize(testClass, false, errors)); |
| | Assert.AreEqual(0, errors.Count); |
| | } |
| |
|
| | class BadClass |
| | { |
| | [Observable] |
| | double m_Double; |
| |
|
| | [Observable] |
| | double DoubleProperty |
| | { |
| | get => m_Double; |
| | set => m_Double = value; |
| | } |
| |
|
| | float m_WriteOnlyProperty; |
| |
|
| | [Observable] |
| | |
| | public float WriteOnlyProperty |
| | { |
| | set => m_WriteOnlyProperty = value; |
| | } |
| | } |
| |
|
| | [Test] |
| | public void TestInvalidObservables() |
| | { |
| | var bad = new BadClass(); |
| | bad.WriteOnlyProperty = 1.0f; |
| | var errors = new List<string>(); |
| | Assert.AreEqual(0, ObservableAttribute.GetTotalObservationSize(bad, false, errors)); |
| | Assert.AreEqual(3, errors.Count); |
| |
|
| | |
| | var sensors = ObservableAttribute.CreateObservableSensors(bad, false); |
| | Assert.AreEqual(0, sensors.Count); |
| | } |
| |
|
| | class StackingClass |
| | { |
| | [Observable(numStackedObservations: 2)] |
| | public float FloatVal; |
| | } |
| |
|
| | [Test] |
| | public void TestObservableAttributeStacking() |
| | { |
| | var c = new StackingClass(); |
| | c.FloatVal = 1.0f; |
| | var sensors = ObservableAttribute.CreateObservableSensors(c, false); |
| | var sensor = sensors[0]; |
| | Assert.AreEqual(typeof(StackingSensor), sensor.GetType()); |
| | SensorTestHelper.CompareObservation(sensor, new[] { 0.0f, 1.0f }); |
| |
|
| | sensor.Update(); |
| | c.FloatVal = 3.0f; |
| | SensorTestHelper.CompareObservation(sensor, new[] { 1.0f, 3.0f }); |
| |
|
| | var errors = new List<string>(); |
| | Assert.AreEqual(2, ObservableAttribute.GetTotalObservationSize(c, false, errors)); |
| | Assert.AreEqual(0, errors.Count); |
| | } |
| |
|
| | class BaseClass |
| | { |
| | [Observable("base")] |
| | public float m_BaseField; |
| |
|
| | [Observable("private")] |
| | float m_PrivateField; |
| | } |
| |
|
| | class DerivedClass : BaseClass |
| | { |
| | [Observable("derived")] |
| | float m_DerivedField; |
| | } |
| |
|
| | [Test] |
| | public void TestObservableAttributeExcludeInherited() |
| | { |
| | var d = new DerivedClass(); |
| | d.m_BaseField = 1.0f; |
| |
|
| | |
| | var sensorAll = ObservableAttribute.CreateObservableSensors(d, false); |
| | Assert.AreEqual(2, sensorAll.Count); |
| | |
| | Assert.AreEqual("derived", sensorAll[0].GetName()); |
| | Assert.AreEqual("base", sensorAll[1].GetName()); |
| |
|
| | |
| | var sensorsDerivedOnly = ObservableAttribute.CreateObservableSensors(d, true); |
| | Assert.AreEqual(1, sensorsDerivedOnly.Count); |
| | Assert.AreEqual("derived", sensorsDerivedOnly[0].GetName()); |
| |
|
| | var b = new BaseClass(); |
| | var baseSensors = ObservableAttribute.CreateObservableSensors(b, false); |
| | Assert.AreEqual(2, baseSensors.Count); |
| | } |
| | } |
| | } |
| |
|