| | using System; |
| | using System.Collections.Generic; |
| | using UnityEngine; |
| | using Object = UnityEngine.Object; |
| |
|
| | namespace Unity.MLAgents.Extensions.Sensors |
| | { |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | public abstract class PoseExtractor |
| | { |
| | int[] m_ParentIndices; |
| | Pose[] m_ModelSpacePoses; |
| | Pose[] m_LocalSpacePoses; |
| |
|
| | Vector3[] m_ModelSpaceLinearVelocities; |
| | Vector3[] m_LocalSpaceLinearVelocities; |
| |
|
| | bool[] m_PoseEnabled; |
| |
|
| |
|
| | |
| | |
| | |
| | public IEnumerable<Pose> GetEnabledModelSpacePoses() |
| | { |
| | if (m_ModelSpacePoses == null) |
| | { |
| | yield break; |
| | } |
| |
|
| | for (var i = 0; i < m_ModelSpacePoses.Length; i++) |
| | { |
| | if (m_PoseEnabled[i]) |
| | { |
| | yield return m_ModelSpacePoses[i]; |
| | } |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | public IEnumerable<Pose> GetEnabledLocalSpacePoses() |
| | { |
| | if (m_LocalSpacePoses == null) |
| | { |
| | yield break; |
| | } |
| |
|
| | for (var i = 0; i < m_LocalSpacePoses.Length; i++) |
| | { |
| | if (m_PoseEnabled[i]) |
| | { |
| | yield return m_LocalSpacePoses[i]; |
| | } |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | public IEnumerable<Vector3> GetEnabledModelSpaceVelocities() |
| | { |
| | if (m_ModelSpaceLinearVelocities == null) |
| | { |
| | yield break; |
| | } |
| |
|
| | for (var i = 0; i < m_ModelSpaceLinearVelocities.Length; i++) |
| | { |
| | if (m_PoseEnabled[i]) |
| | { |
| | yield return m_ModelSpaceLinearVelocities[i]; |
| | } |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | public IEnumerable<Vector3> GetEnabledLocalSpaceVelocities() |
| | { |
| | if (m_LocalSpaceLinearVelocities == null) |
| | { |
| | yield break; |
| | } |
| |
|
| | for (var i = 0; i < m_LocalSpaceLinearVelocities.Length; i++) |
| | { |
| | if (m_PoseEnabled[i]) |
| | { |
| | yield return m_LocalSpaceLinearVelocities[i]; |
| | } |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | public int NumEnabledPoses |
| | { |
| | get |
| | { |
| | if (m_PoseEnabled == null) |
| | { |
| | return 0; |
| | } |
| |
|
| | var numEnabled = 0; |
| | for (var i = 0; i < m_PoseEnabled.Length; i++) |
| | { |
| | numEnabled += m_PoseEnabled[i] ? 1 : 0; |
| | } |
| |
|
| | return numEnabled; |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | public int NumPoses |
| | { |
| | get { return m_ModelSpacePoses?.Length ?? 0; } |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | public int GetParentIndex(int index) |
| | { |
| | if (m_ParentIndices == null) |
| | { |
| | throw new NullReferenceException("No parent indices set"); |
| | } |
| |
|
| | return m_ParentIndices[index]; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | public void SetPoseEnabled(int index, bool val) |
| | { |
| | m_PoseEnabled[index] = val; |
| | } |
| |
|
| | public bool IsPoseEnabled(int index) |
| | { |
| | return m_PoseEnabled[index]; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | protected void Setup(int[] parentIndices) |
| | { |
| | #if DEBUG |
| | if (parentIndices[0] != -1) |
| | { |
| | throw new UnityAgentsException($"Expected parentIndices[0] to be -1, got {parentIndices[0]}"); |
| | } |
| | #endif |
| | m_ParentIndices = parentIndices; |
| | var numPoses = parentIndices.Length; |
| | m_ModelSpacePoses = new Pose[numPoses]; |
| | m_LocalSpacePoses = new Pose[numPoses]; |
| |
|
| | m_ModelSpaceLinearVelocities = new Vector3[numPoses]; |
| | m_LocalSpaceLinearVelocities = new Vector3[numPoses]; |
| |
|
| | m_PoseEnabled = new bool[numPoses]; |
| | |
| | for (var i = 0; i < numPoses; i++) |
| | { |
| | m_PoseEnabled[i] = true; |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | protected internal abstract Pose GetPoseAt(int index); |
| |
|
| | |
| | |
| | |
| | |
| | |
| | protected internal abstract Vector3 GetLinearVelocityAt(int index); |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | protected internal virtual Object GetObjectAt(int index) |
| | { |
| | return null; |
| | } |
| |
|
| | |
| | |
| | |
| | public void UpdateModelSpacePoses() |
| | { |
| | using (TimerStack.Instance.Scoped("UpdateModelSpacePoses")) |
| | { |
| | if (m_ModelSpacePoses == null) |
| | { |
| | return; |
| | } |
| |
|
| | var rootWorldTransform = GetPoseAt(0); |
| | var worldToModel = rootWorldTransform.Inverse(); |
| | var rootLinearVel = GetLinearVelocityAt(0); |
| |
|
| | for (var i = 0; i < m_ModelSpacePoses.Length; i++) |
| | { |
| | var currentWorldSpacePose = GetPoseAt(i); |
| | var currentModelSpacePose = worldToModel.Multiply(currentWorldSpacePose); |
| | m_ModelSpacePoses[i] = currentModelSpacePose; |
| |
|
| | var currentBodyLinearVel = GetLinearVelocityAt(i); |
| | var relativeVelocity = currentBodyLinearVel - rootLinearVel; |
| | m_ModelSpaceLinearVelocities[i] = worldToModel.rotation * relativeVelocity; |
| | } |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | public void UpdateLocalSpacePoses() |
| | { |
| | using (TimerStack.Instance.Scoped("UpdateLocalSpacePoses")) |
| | { |
| | if (m_LocalSpacePoses == null) |
| | { |
| | return; |
| | } |
| |
|
| | for (var i = 0; i < m_LocalSpacePoses.Length; i++) |
| | { |
| | if (m_ParentIndices[i] != -1) |
| | { |
| | var parentTransform = GetPoseAt(m_ParentIndices[i]); |
| | |
| | |
| | var invParent = parentTransform.Inverse(); |
| | var currentTransform = GetPoseAt(i); |
| | m_LocalSpacePoses[i] = invParent.Multiply(currentTransform); |
| |
|
| | var parentLinearVel = GetLinearVelocityAt(m_ParentIndices[i]); |
| | var currentLinearVel = GetLinearVelocityAt(i); |
| | m_LocalSpaceLinearVelocities[i] = invParent.rotation * (currentLinearVel - parentLinearVel); |
| | } |
| | else |
| | { |
| | m_LocalSpacePoses[i] = Pose.identity; |
| | m_LocalSpaceLinearVelocities[i] = Vector3.zero; |
| | } |
| | } |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | public int GetNumPoseObservations(PhysicsSensorSettings settings) |
| | { |
| | int obsPerPose = 0; |
| | obsPerPose += settings.UseModelSpaceTranslations ? 3 : 0; |
| | obsPerPose += settings.UseModelSpaceRotations ? 4 : 0; |
| | obsPerPose += settings.UseLocalSpaceTranslations ? 3 : 0; |
| | obsPerPose += settings.UseLocalSpaceRotations ? 4 : 0; |
| |
|
| | obsPerPose += settings.UseModelSpaceLinearVelocity ? 3 : 0; |
| | obsPerPose += settings.UseLocalSpaceLinearVelocity ? 3 : 0; |
| |
|
| | return NumEnabledPoses * obsPerPose; |
| | } |
| |
|
| | internal void DrawModelSpace(Vector3 offset) |
| | { |
| | UpdateLocalSpacePoses(); |
| | UpdateModelSpacePoses(); |
| |
|
| | var pose = m_ModelSpacePoses; |
| | var localPose = m_LocalSpacePoses; |
| | for (var i = 0; i < pose.Length; i++) |
| | { |
| | var current = pose[i]; |
| | if (m_ParentIndices[i] == -1) |
| | { |
| | continue; |
| | } |
| |
|
| | var parent = pose[m_ParentIndices[i]]; |
| | Debug.DrawLine(current.position + offset, parent.position + offset, Color.cyan); |
| | var localUp = localPose[i].rotation * Vector3.up; |
| | var localFwd = localPose[i].rotation * Vector3.forward; |
| | var localRight = localPose[i].rotation * Vector3.right; |
| | Debug.DrawLine(current.position + offset, current.position + offset + .1f * localUp, Color.red); |
| | Debug.DrawLine(current.position + offset, current.position + offset + .1f * localFwd, Color.green); |
| | Debug.DrawLine(current.position + offset, current.position + offset + .1f * localRight, Color.blue); |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | internal struct DisplayNode |
| | { |
| | |
| | |
| | |
| | public Object NodeObject; |
| |
|
| | |
| | |
| | |
| | public bool Enabled; |
| |
|
| | |
| | |
| | |
| | public int Depth; |
| |
|
| | |
| | |
| | |
| | public int OriginalIndex; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | internal IList<DisplayNode> GetDisplayNodes() |
| | { |
| | if (NumPoses == 0) |
| | { |
| | return Array.Empty<DisplayNode>(); |
| | } |
| | var nodesOut = new List<DisplayNode>(NumPoses); |
| |
|
| | |
| | var tree = new Dictionary<int, List<int>>(); |
| | for (var i = 0; i < NumPoses; i++) |
| | { |
| | var parent = GetParentIndex(i); |
| | if (i == -1) |
| | { |
| | continue; |
| | } |
| |
|
| | if (!tree.ContainsKey(parent)) |
| | { |
| | tree[parent] = new List<int>(); |
| | } |
| | tree[parent].Add(i); |
| | } |
| |
|
| | |
| | var stack = new Stack<(int, int)>(); |
| | stack.Push((0, 0)); |
| |
|
| | while (stack.Count != 0) |
| | { |
| | var (current, depth) = stack.Pop(); |
| | var obj = GetObjectAt(current); |
| |
|
| | var node = new DisplayNode |
| | { |
| | NodeObject = obj, |
| | Enabled = IsPoseEnabled(current), |
| | OriginalIndex = current, |
| | Depth = depth |
| | }; |
| | nodesOut.Add(node); |
| |
|
| | |
| | if (tree.ContainsKey(current)) |
| | { |
| | |
| | var children = tree[current]; |
| | for (var childIdx = children.Count - 1; childIdx >= 0; childIdx--) |
| | { |
| | stack.Push((children[childIdx], depth + 1)); |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | if (nodesOut.Count > NumPoses) |
| | { |
| | return nodesOut; |
| | } |
| | } |
| |
|
| | return nodesOut; |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | public static class PoseExtensions |
| | { |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | public static Pose Inverse(this Pose pose) |
| | { |
| | var rotationInverse = Quaternion.Inverse(pose.rotation); |
| | var translationInverse = -(rotationInverse * pose.position); |
| | return new Pose { rotation = rotationInverse, position = translationInverse }; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | public static Pose Multiply(this Pose pose, Pose rhs) |
| | { |
| | return rhs.GetTransformedBy(pose); |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | public static Vector3 Multiply(this Pose pose, Vector3 rhs) |
| | { |
| | return pose.rotation * rhs + pose.position; |
| | } |
| |
|
| | |
| | } |
| | } |
| |
|