ppo-Pyramids-Training / com.unity.ml-agents.extensions /Runtime /Sensors /ArticulationBodyJointExtractor.cs
| using System.Collections.Generic; | |
| using UnityEngine; | |
| using Unity.MLAgents.Sensors; | |
| namespace Unity.MLAgents.Extensions.Sensors | |
| { | |
| public class ArticulationBodyJointExtractor : IJointExtractor | |
| { | |
| ArticulationBody m_Body; | |
| public ArticulationBodyJointExtractor(ArticulationBody body) | |
| { | |
| m_Body = body; | |
| } | |
| public int NumObservations(PhysicsSensorSettings settings) | |
| { | |
| return NumObservations(m_Body, settings); | |
| } | |
| public static int NumObservations(ArticulationBody body, PhysicsSensorSettings settings) | |
| { | |
| if (body == null || body.isRoot) | |
| { | |
| return 0; | |
| } | |
| var totalCount = 0; | |
| if (settings.UseJointPositionsAndAngles) | |
| { | |
| switch (body.jointType) | |
| { | |
| case ArticulationJointType.RevoluteJoint: | |
| case ArticulationJointType.SphericalJoint: | |
| // Both RevoluteJoint and SphericalJoint have all angular components. | |
| // We use sine and cosine of the angles for the observations. | |
| totalCount += 2 * body.dofCount; | |
| break; | |
| case ArticulationJointType.FixedJoint: | |
| // Since FixedJoint can't moved, there aren't any interesting observations for it. | |
| break; | |
| case ArticulationJointType.PrismaticJoint: | |
| // One linear component | |
| totalCount += body.dofCount; | |
| break; | |
| } | |
| } | |
| if (settings.UseJointForces) | |
| { | |
| totalCount += body.dofCount; | |
| } | |
| return totalCount; | |
| } | |
| public int Write(PhysicsSensorSettings settings, ObservationWriter writer, int offset) | |
| { | |
| if (m_Body == null || m_Body.isRoot) | |
| { | |
| return 0; | |
| } | |
| var currentOffset = offset; | |
| // Write joint positions | |
| if (settings.UseJointPositionsAndAngles) | |
| { | |
| switch (m_Body.jointType) | |
| { | |
| case ArticulationJointType.RevoluteJoint: | |
| case ArticulationJointType.SphericalJoint: | |
| // All joint positions are angular | |
| for (var dofIndex = 0; dofIndex < m_Body.dofCount; dofIndex++) | |
| { | |
| var jointRotationRads = m_Body.jointPosition[dofIndex]; | |
| writer[currentOffset++] = Mathf.Sin(jointRotationRads); | |
| writer[currentOffset++] = Mathf.Cos(jointRotationRads); | |
| } | |
| break; | |
| case ArticulationJointType.FixedJoint: | |
| // No observations | |
| break; | |
| case ArticulationJointType.PrismaticJoint: | |
| writer[currentOffset++] = GetPrismaticValue(); | |
| break; | |
| } | |
| } | |
| if (settings.UseJointForces) | |
| { | |
| for (var dofIndex = 0; dofIndex < m_Body.dofCount; dofIndex++) | |
| { | |
| // take tanh to keep in [-1, 1] | |
| writer[currentOffset++] = (float)System.Math.Tanh(m_Body.jointForce[dofIndex]); | |
| } | |
| } | |
| return currentOffset - offset; | |
| } | |
| float GetPrismaticValue() | |
| { | |
| // Prismatic joints should have at most one free axis. | |
| bool limited = false; | |
| var drive = m_Body.xDrive; | |
| if (m_Body.linearLockX == ArticulationDofLock.LimitedMotion) | |
| { | |
| drive = m_Body.xDrive; | |
| limited = true; | |
| } | |
| else if (m_Body.linearLockY == ArticulationDofLock.LimitedMotion) | |
| { | |
| drive = m_Body.yDrive; | |
| limited = true; | |
| } | |
| else if (m_Body.linearLockZ == ArticulationDofLock.LimitedMotion) | |
| { | |
| drive = m_Body.zDrive; | |
| limited = true; | |
| } | |
| var jointPos = m_Body.jointPosition[0]; | |
| if (limited) | |
| { | |
| // If locked, interpolate between the limits. | |
| var upperLimit = drive.upperLimit; | |
| var lowerLimit = drive.lowerLimit; | |
| if (upperLimit <= lowerLimit) | |
| { | |
| // Invalid limits (probably equal), so don't try to lerp | |
| return 0; | |
| } | |
| var invLerped = Mathf.InverseLerp(lowerLimit, upperLimit, jointPos); | |
| // Convert [0, 1] -> [-1, 1] | |
| var normalized = 2.0f * invLerped - 1.0f; | |
| return normalized; | |
| } | |
| // take tanh() to keep in [-1, 1] | |
| return (float)System.Math.Tanh(jointPos); | |
| } | |
| } | |
| } | |