| | using System.Collections.Generic; |
| | using UnityEngine; |
| |
|
| | namespace Unity.MLAgents.Extensions.Sensors |
| | { |
| | |
| | |
| | |
| | |
| | public class RigidBodyPoseExtractor : PoseExtractor |
| | { |
| | Rigidbody[] m_Bodies; |
| |
|
| | |
| | |
| | |
| | |
| | GameObject m_VirtualRoot; |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | public RigidBodyPoseExtractor(Rigidbody rootBody, GameObject rootGameObject = null, |
| | GameObject virtualRoot = null, Dictionary<Rigidbody, bool> enableBodyPoses = null) |
| | { |
| | if (rootBody == null) |
| | { |
| | return; |
| | } |
| |
|
| | Rigidbody[] rbs; |
| | Joint[] joints; |
| | if (rootGameObject == null) |
| | { |
| | rbs = rootBody.GetComponentsInChildren<Rigidbody>(); |
| | joints = rootBody.GetComponentsInChildren<Joint>(); |
| | } |
| | else |
| | { |
| | rbs = rootGameObject.GetComponentsInChildren<Rigidbody>(); |
| | joints = rootGameObject.GetComponentsInChildren<Joint>(); |
| | } |
| |
|
| | if (rbs == null || rbs.Length == 0) |
| | { |
| | Debug.Log("No rigid bodies found!"); |
| | return; |
| | } |
| |
|
| | if (rbs[0] != rootBody) |
| | { |
| | Debug.Log("Expected root body at index 0"); |
| | return; |
| | } |
| |
|
| | |
| | |
| | if (virtualRoot != null) |
| | { |
| | var extendedRbs = new Rigidbody[rbs.Length + 1]; |
| | for (var i = 0; i < rbs.Length; i++) |
| | { |
| | extendedRbs[i + 1] = rbs[i]; |
| | } |
| |
|
| | rbs = extendedRbs; |
| | } |
| |
|
| | var bodyToIndex = new Dictionary<Rigidbody, int>(rbs.Length); |
| | var parentIndices = new int[rbs.Length]; |
| | parentIndices[0] = -1; |
| |
|
| | for (var i = 0; i < rbs.Length; i++) |
| | { |
| | if (rbs[i] != null) |
| | { |
| | bodyToIndex[rbs[i]] = i; |
| | } |
| | } |
| |
|
| | foreach (var j in joints) |
| | { |
| | var parent = j.connectedBody; |
| | var child = j.GetComponent<Rigidbody>(); |
| |
|
| | var parentIndex = bodyToIndex[parent]; |
| | var childIndex = bodyToIndex[child]; |
| | parentIndices[childIndex] = parentIndex; |
| | } |
| |
|
| | if (virtualRoot != null) |
| | { |
| | |
| | parentIndices[1] = 0; |
| | m_VirtualRoot = virtualRoot; |
| | } |
| |
|
| | m_Bodies = rbs; |
| | Setup(parentIndices); |
| |
|
| | |
| | SetPoseEnabled(0, false); |
| |
|
| | if (enableBodyPoses != null) |
| | { |
| | foreach (var pair in enableBodyPoses) |
| | { |
| | var rb = pair.Key; |
| | if (bodyToIndex.TryGetValue(rb, out var index)) |
| | { |
| | SetPoseEnabled(index, pair.Value); |
| | } |
| | } |
| | } |
| | } |
| |
|
| | |
| | protected internal override Vector3 GetLinearVelocityAt(int index) |
| | { |
| | if (index == 0 && m_VirtualRoot != null) |
| | { |
| | |
| | return Vector3.zero; |
| | } |
| | return m_Bodies[index].velocity; |
| | } |
| |
|
| | |
| | protected internal override Pose GetPoseAt(int index) |
| | { |
| | if (index == 0 && m_VirtualRoot != null) |
| | { |
| | |
| | return new Pose |
| | { |
| | rotation = m_VirtualRoot.transform.rotation, |
| | position = m_VirtualRoot.transform.position |
| | }; |
| | } |
| |
|
| | var body = m_Bodies[index]; |
| | return new Pose { rotation = body.rotation, position = body.position }; |
| | } |
| |
|
| | |
| | protected internal override Object GetObjectAt(int index) |
| | { |
| | if (index == 0 && m_VirtualRoot != null) |
| | { |
| | return m_VirtualRoot; |
| | } |
| | return m_Bodies[index]; |
| | } |
| |
|
| | internal Rigidbody[] Bodies => m_Bodies; |
| |
|
| | |
| | |
| | |
| | |
| | internal Dictionary<Rigidbody, bool> GetBodyPosesEnabled() |
| | { |
| | var bodyPosesEnabled = new Dictionary<Rigidbody, bool>(m_Bodies.Length); |
| | for (var i = 0; i < m_Bodies.Length; i++) |
| | { |
| | var rb = m_Bodies[i]; |
| | if (rb == null) |
| | { |
| | continue; |
| | } |
| |
|
| | bodyPosesEnabled[rb] = IsPoseEnabled(i); |
| | } |
| |
|
| | return bodyPosesEnabled; |
| | } |
| |
|
| | internal IEnumerable<Rigidbody> GetEnabledRigidbodies() |
| | { |
| | if (m_Bodies == null) |
| | { |
| | yield break; |
| | } |
| |
|
| | for (var i = 0; i < m_Bodies.Length; i++) |
| | { |
| | var rb = m_Bodies[i]; |
| | if (rb == null) |
| | { |
| | |
| | continue; |
| | } |
| |
|
| | if (IsPoseEnabled(i)) |
| | { |
| | yield return rb; |
| | } |
| | } |
| | } |
| | } |
| | } |
| |
|