ppo-Pyramids-Training / com.unity.ml-agents.extensions /Tests /Runtime /Sensors /RigidBodySensorTests.cs
| using UnityEngine; | |
| using NUnit.Framework; | |
| using Unity.MLAgents.Sensors; | |
| using Unity.MLAgents.Extensions.Sensors; | |
| namespace Unity.MLAgents.Extensions.Tests.Sensors | |
| { | |
| public static class SensorTestHelper | |
| { | |
| public static void CompareObservation(ISensor sensor, float[] expected) | |
| { | |
| string errorMessage; | |
| bool isOk = SensorHelper.CompareObservation(sensor, expected, out errorMessage); | |
| Assert.IsTrue(isOk, errorMessage); | |
| } | |
| public static void CompareObservation(ISensor sensor, float[,,] expected) | |
| { | |
| string errorMessage; | |
| bool isOk = SensorHelper.CompareObservation(sensor, expected, out errorMessage); | |
| Assert.IsTrue(isOk, errorMessage); | |
| } | |
| } | |
| public class RigidBodySensorTests | |
| { | |
| [] | |
| public void TestNullRootBody() | |
| { | |
| var gameObj = new GameObject(); | |
| var sensorComponent = gameObj.AddComponent<RigidBodySensorComponent>(); | |
| Assert.IsFalse(sensorComponent.IsTrivial()); | |
| var sensor = sensorComponent.CreateSensors()[0]; | |
| SensorTestHelper.CompareObservation(sensor, new float[0]); | |
| } | |
| [] | |
| public void TestSingleRigidbody() | |
| { | |
| var gameObj = new GameObject(); | |
| var rootRb = gameObj.AddComponent<Rigidbody>(); | |
| var sensorComponent = gameObj.AddComponent<RigidBodySensorComponent>(); | |
| sensorComponent.RootBody = rootRb; | |
| sensorComponent.Settings = new PhysicsSensorSettings | |
| { | |
| UseModelSpaceLinearVelocity = true, | |
| UseLocalSpaceTranslations = true, | |
| UseLocalSpaceRotations = true | |
| }; | |
| Assert.IsTrue(sensorComponent.IsTrivial()); | |
| var sensor = sensorComponent.CreateSensors()[0]; | |
| sensor.Update(); | |
| // The root body is ignored since it always generates identity values | |
| // and there are no other bodies to generate observations. | |
| var expected = new float[0]; | |
| Assert.AreEqual(expected.Length, sensor.GetObservationSpec().Shape[0]); | |
| SensorTestHelper.CompareObservation(sensor, expected); | |
| } | |
| [] | |
| public void TestBodiesWithJoint() | |
| { | |
| var rootObj = new GameObject(); | |
| var rootRb = rootObj.AddComponent<Rigidbody>(); | |
| rootRb.velocity = new Vector3(1f, 0f, 0f); | |
| var middleGamObj = new GameObject(); | |
| var middleRb = middleGamObj.AddComponent<Rigidbody>(); | |
| middleRb.velocity = new Vector3(0f, 1f, 0f); | |
| middleGamObj.transform.SetParent(rootObj.transform); | |
| middleGamObj.transform.localPosition = new Vector3(13.37f, 0f, 0f); | |
| var joint = middleGamObj.AddComponent<ConfigurableJoint>(); | |
| joint.connectedBody = rootRb; | |
| var leafGameObj = new GameObject(); | |
| var leafRb = leafGameObj.AddComponent<Rigidbody>(); | |
| leafRb.velocity = new Vector3(0f, 0f, 1f); | |
| leafGameObj.transform.SetParent(middleGamObj.transform); | |
| leafGameObj.transform.localPosition = new Vector3(4.2f, 0f, 0f); | |
| var joint2 = leafGameObj.AddComponent<ConfigurableJoint>(); | |
| joint2.connectedBody = middleRb; | |
| var virtualRoot = new GameObject(); | |
| var sensorComponent = rootObj.AddComponent<RigidBodySensorComponent>(); | |
| sensorComponent.RootBody = rootRb; | |
| sensorComponent.Settings = new PhysicsSensorSettings | |
| { | |
| UseModelSpaceTranslations = true, | |
| UseLocalSpaceTranslations = true, | |
| UseLocalSpaceLinearVelocity = true | |
| }; | |
| sensorComponent.VirtualRoot = virtualRoot; | |
| Assert.IsFalse(sensorComponent.IsTrivial()); | |
| var sensor = sensorComponent.CreateSensors()[0]; | |
| sensor.Update(); | |
| // Note that the VirtualRoot is ignored from the observations | |
| var expected = new[] | |
| { | |
| // Model space | |
| 0f, 0f, 0f, // Root pos | |
| 13.37f, 0f, 0f, // Middle pos | |
| leafGameObj.transform.position.x, 0f, 0f, // Leaf pos | |
| // Local space | |
| 0f, 0f, 0f, // Root pos | |
| 13.37f, 0f, 0f, // Attached pos | |
| 4.2f, 0f, 0f, // Leaf pos | |
| 1f, 0f, 0f, // Root vel (relative to virtual root) | |
| -1f, 1f, 0f, // Attached vel | |
| 0f, -1f, 1f // Leaf vel | |
| }; | |
| Assert.AreEqual(expected.Length, sensor.GetObservationSpec().Shape[0]); | |
| SensorTestHelper.CompareObservation(sensor, expected); | |
| // Update the settings to only process joint observations | |
| sensorComponent.Settings = new PhysicsSensorSettings | |
| { | |
| UseJointPositionsAndAngles = true, | |
| UseJointForces = true, | |
| }; | |
| sensor = sensorComponent.CreateSensors()[0]; | |
| sensor.Update(); | |
| expected = new[] | |
| { | |
| 0f, 0f, 0f, // joint1.force | |
| 0f, 0f, 0f, // joint1.torque | |
| 0f, 0f, 0f, // joint2.force | |
| 0f, 0f, 0f, // joint2.torque | |
| }; | |
| SensorTestHelper.CompareObservation(sensor, expected); | |
| Assert.AreEqual(expected.Length, sensor.GetObservationSpec().Shape[0]); | |
| } | |
| } | |
| } | |