ppo-Pyramids-Training / com.unity.ml-agents.extensions /Tests /Runtime /Sensors /RigidBodyPoseExtractorTests.cs
| using UnityEngine; | |
| using NUnit.Framework; | |
| using Unity.MLAgents.Extensions.Sensors; | |
| namespace Unity.MLAgents.Extensions.Tests.Sensors | |
| { | |
| public class RigidBodyPoseExtractorTests | |
| { | |
| [] | |
| public void RemoveGameObjects() | |
| { | |
| var objects = GameObject.FindObjectsOfType<GameObject>(); | |
| foreach (var o in objects) | |
| { | |
| UnityEngine.Object.DestroyImmediate(o); | |
| } | |
| } | |
| [] | |
| public void TestNullRoot() | |
| { | |
| var poseExtractor = new RigidBodyPoseExtractor(null); | |
| // These should be no-ops | |
| poseExtractor.UpdateLocalSpacePoses(); | |
| poseExtractor.UpdateModelSpacePoses(); | |
| Assert.AreEqual(0, poseExtractor.NumPoses); | |
| } | |
| [] | |
| public void TestSingleBody() | |
| { | |
| var go = new GameObject(); | |
| var rootRb = go.AddComponent<Rigidbody>(); | |
| var poseExtractor = new RigidBodyPoseExtractor(rootRb); | |
| Assert.AreEqual(1, poseExtractor.NumPoses); | |
| // Also pass the GameObject | |
| poseExtractor = new RigidBodyPoseExtractor(rootRb, go); | |
| Assert.AreEqual(1, poseExtractor.NumPoses); | |
| } | |
| [] | |
| public void TestNoBodiesFound() | |
| { | |
| // Check that if we can't find any bodies under the game object, we get an empty extractor | |
| var gameObj = new GameObject(); | |
| var rootRb = gameObj.AddComponent<Rigidbody>(); | |
| var otherGameObj = new GameObject(); | |
| var poseExtractor = new RigidBodyPoseExtractor(rootRb, otherGameObj); | |
| Assert.AreEqual(0, poseExtractor.NumPoses); | |
| // Add an RB under the other GameObject. Constructor will find a rigid body, but not the root. | |
| otherGameObj.AddComponent<Rigidbody>(); | |
| poseExtractor = new RigidBodyPoseExtractor(rootRb, otherGameObj); | |
| Assert.AreEqual(0, poseExtractor.NumPoses); | |
| } | |
| [] | |
| public void TestTwoBodies() | |
| { | |
| // * rootObj | |
| // - rb1 | |
| // * go2 | |
| // - rb2 | |
| // - joint | |
| var rootObj = new GameObject(); | |
| var rb1 = rootObj.AddComponent<Rigidbody>(); | |
| var go2 = new GameObject(); | |
| var rb2 = go2.AddComponent<Rigidbody>(); | |
| go2.transform.SetParent(rootObj.transform); | |
| var joint = go2.AddComponent<ConfigurableJoint>(); | |
| joint.connectedBody = rb1; | |
| var poseExtractor = new RigidBodyPoseExtractor(rb1); | |
| Assert.AreEqual(2, poseExtractor.NumPoses); | |
| rb1.position = new Vector3(1, 0, 0); | |
| rb1.rotation = Quaternion.Euler(0, 13.37f, 0); | |
| rb1.velocity = new Vector3(2, 0, 0); | |
| Assert.AreEqual(rb1.position, poseExtractor.GetPoseAt(0).position); | |
| Assert.IsTrue(rb1.rotation == poseExtractor.GetPoseAt(0).rotation); | |
| Assert.AreEqual(rb1.velocity, poseExtractor.GetLinearVelocityAt(0)); | |
| // Check DisplayNodes gives expected results | |
| var displayNodes = poseExtractor.GetDisplayNodes(); | |
| Assert.AreEqual(2, displayNodes.Count); | |
| Assert.AreEqual(rb1, displayNodes[0].NodeObject); | |
| Assert.AreEqual(false, displayNodes[0].Enabled); | |
| Assert.AreEqual(rb2, displayNodes[1].NodeObject); | |
| Assert.AreEqual(true, displayNodes[1].Enabled); | |
| } | |
| [] | |
| public void TestTwoBodiesVirtualRoot() | |
| { | |
| // * virtualRoot | |
| // * rootObj | |
| // - rb1 | |
| // * go2 | |
| // - rb2 | |
| // - joint | |
| var virtualRoot = new GameObject("I am vroot"); | |
| var rootObj = new GameObject(); | |
| var rb1 = rootObj.AddComponent<Rigidbody>(); | |
| var go2 = new GameObject(); | |
| go2.AddComponent<Rigidbody>(); | |
| go2.transform.SetParent(rootObj.transform); | |
| var joint = go2.AddComponent<ConfigurableJoint>(); | |
| joint.connectedBody = rb1; | |
| var poseExtractor = new RigidBodyPoseExtractor(rb1, null, virtualRoot); | |
| Assert.AreEqual(3, poseExtractor.NumPoses); | |
| // "body" 0 has no parent | |
| Assert.AreEqual(-1, poseExtractor.GetParentIndex(0)); | |
| // body 1 has parent 0 | |
| Assert.AreEqual(0, poseExtractor.GetParentIndex(1)); | |
| var virtualRootPos = new Vector3(0, 2, 0); | |
| var virtualRootRot = Quaternion.Euler(0, 42, 0); | |
| virtualRoot.transform.position = virtualRootPos; | |
| virtualRoot.transform.rotation = virtualRootRot; | |
| Assert.AreEqual(virtualRootPos, poseExtractor.GetPoseAt(0).position); | |
| Assert.IsTrue(virtualRootRot == poseExtractor.GetPoseAt(0).rotation); | |
| Assert.AreEqual(Vector3.zero, poseExtractor.GetLinearVelocityAt(0)); | |
| // Same as above test, but using index 1 | |
| rb1.position = new Vector3(1, 0, 0); | |
| rb1.rotation = Quaternion.Euler(0, 13.37f, 0); | |
| rb1.velocity = new Vector3(2, 0, 0); | |
| Assert.AreEqual(rb1.position, poseExtractor.GetPoseAt(1).position); | |
| Assert.IsTrue(rb1.rotation == poseExtractor.GetPoseAt(1).rotation); | |
| Assert.AreEqual(rb1.velocity, poseExtractor.GetLinearVelocityAt(1)); | |
| } | |
| [] | |
| public void TestBodyPosesEnabledDictionary() | |
| { | |
| // * rootObj | |
| // - rb1 | |
| // * go2 | |
| // - rb2 | |
| // - joint | |
| var rootObj = new GameObject(); | |
| var rb1 = rootObj.AddComponent<Rigidbody>(); | |
| var go2 = new GameObject(); | |
| var rb2 = go2.AddComponent<Rigidbody>(); | |
| go2.transform.SetParent(rootObj.transform); | |
| var joint = go2.AddComponent<ConfigurableJoint>(); | |
| joint.connectedBody = rb1; | |
| var poseExtractor = new RigidBodyPoseExtractor(rb1); | |
| // Expect the root body disabled and the attached one enabled. | |
| Assert.IsFalse(poseExtractor.IsPoseEnabled(0)); | |
| Assert.IsTrue(poseExtractor.IsPoseEnabled(1)); | |
| var bodyPosesEnabled = poseExtractor.GetBodyPosesEnabled(); | |
| Assert.IsFalse(bodyPosesEnabled[rb1]); | |
| Assert.IsTrue(bodyPosesEnabled[rb2]); | |
| // Swap the values | |
| bodyPosesEnabled[rb1] = true; | |
| bodyPosesEnabled[rb2] = false; | |
| var poseExtractor2 = new RigidBodyPoseExtractor(rb1, null, null, bodyPosesEnabled); | |
| Assert.IsTrue(poseExtractor2.IsPoseEnabled(0)); | |
| Assert.IsFalse(poseExtractor2.IsPoseEnabled(1)); | |
| } | |
| } | |
| } | |