| | using System.Collections.Generic; |
| | #if UNITY_2020_1_OR_NEWER |
| | using UnityEngine; |
| | #endif |
| | using Unity.MLAgents.Sensors; |
| |
|
| | namespace Unity.MLAgents.Extensions.Sensors |
| | { |
| | |
| | |
| | |
| | public class PhysicsBodySensor : ISensor, IBuiltInSensor |
| | { |
| | ObservationSpec m_ObservationSpec; |
| | string m_SensorName; |
| |
|
| | PoseExtractor m_PoseExtractor; |
| | List<IJointExtractor> m_JointExtractors; |
| | PhysicsSensorSettings m_Settings; |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | public PhysicsBodySensor( |
| | RigidBodyPoseExtractor poseExtractor, |
| | PhysicsSensorSettings settings, |
| | string sensorName |
| | ) |
| | { |
| | m_PoseExtractor = poseExtractor; |
| | m_SensorName = sensorName; |
| | m_Settings = settings; |
| |
|
| | var numJointExtractorObservations = 0; |
| | m_JointExtractors = new List<IJointExtractor>(poseExtractor.NumEnabledPoses); |
| | foreach (var rb in poseExtractor.GetEnabledRigidbodies()) |
| | { |
| | var jointExtractor = new RigidBodyJointExtractor(rb); |
| | numJointExtractorObservations += jointExtractor.NumObservations(settings); |
| | m_JointExtractors.Add(jointExtractor); |
| | } |
| |
|
| | var numTransformObservations = m_PoseExtractor.GetNumPoseObservations(settings); |
| | m_ObservationSpec = ObservationSpec.Vector(numTransformObservations + numJointExtractorObservations); |
| | } |
| |
|
| | #if UNITY_2020_1_OR_NEWER |
| | public PhysicsBodySensor(ArticulationBody rootBody, PhysicsSensorSettings settings, string sensorName = null) |
| | { |
| | var poseExtractor = new ArticulationBodyPoseExtractor(rootBody); |
| | m_PoseExtractor = poseExtractor; |
| | m_SensorName = string.IsNullOrEmpty(sensorName) ? $"ArticulationBodySensor:{rootBody?.name}" : sensorName; |
| | m_Settings = settings; |
| |
|
| | var numJointExtractorObservations = 0; |
| | m_JointExtractors = new List<IJointExtractor>(poseExtractor.NumEnabledPoses); |
| | foreach (var articBody in poseExtractor.GetEnabledArticulationBodies()) |
| | { |
| | var jointExtractor = new ArticulationBodyJointExtractor(articBody); |
| | numJointExtractorObservations += jointExtractor.NumObservations(settings); |
| | m_JointExtractors.Add(jointExtractor); |
| | } |
| |
|
| | var numTransformObservations = m_PoseExtractor.GetNumPoseObservations(settings); |
| | m_ObservationSpec = ObservationSpec.Vector(numTransformObservations + numJointExtractorObservations); |
| | } |
| |
|
| | #endif |
| |
|
| | |
| | public ObservationSpec GetObservationSpec() |
| | { |
| | return m_ObservationSpec; |
| | } |
| |
|
| | |
| | public int Write(ObservationWriter writer) |
| | { |
| | var numWritten = writer.WritePoses(m_Settings, m_PoseExtractor); |
| | foreach (var jointExtractor in m_JointExtractors) |
| | { |
| | numWritten += jointExtractor.Write(m_Settings, writer, numWritten); |
| | } |
| | return numWritten; |
| | } |
| |
|
| | |
| | public byte[] GetCompressedObservation() |
| | { |
| | return null; |
| | } |
| |
|
| | |
| | public void Update() |
| | { |
| | if (m_Settings.UseModelSpace) |
| | { |
| | m_PoseExtractor.UpdateModelSpacePoses(); |
| | } |
| |
|
| | if (m_Settings.UseLocalSpace) |
| | { |
| | m_PoseExtractor.UpdateLocalSpacePoses(); |
| | } |
| | } |
| |
|
| | |
| | public void Reset() { } |
| |
|
| | |
| | public CompressionSpec GetCompressionSpec() |
| | { |
| | return CompressionSpec.Default(); |
| | } |
| |
|
| | |
| | public string GetName() |
| | { |
| | return m_SensorName; |
| | } |
| |
|
| | |
| | public BuiltInSensorType GetBuiltInSensorType() |
| | { |
| | return BuiltInSensorType.PhysicsBodySensor; |
| | } |
| | } |
| | } |
| |
|