| | using System; |
| | using System.Collections.Generic; |
| | using Unity.Collections; |
| | using Unity.Jobs; |
| | using UnityEngine; |
| |
|
| | namespace Unity.MLAgents.Sensors |
| | { |
| | |
| | |
| | |
| | public enum RayPerceptionCastType |
| | { |
| | |
| | |
| | |
| | Cast2D, |
| |
|
| | |
| | |
| | |
| | Cast3D, |
| | } |
| |
|
| | |
| | |
| | |
| | public struct RayPerceptionInput |
| | { |
| | |
| | |
| | |
| | public float RayLength; |
| |
|
| | |
| | |
| | |
| | public IReadOnlyList<string> DetectableTags; |
| |
|
| | |
| | |
| | |
| | |
| | public IReadOnlyList<float> Angles; |
| |
|
| | |
| | |
| | |
| | public float StartOffset; |
| |
|
| | |
| | |
| | |
| | public float EndOffset; |
| |
|
| | |
| | |
| | |
| | |
| | public float CastRadius; |
| |
|
| | |
| | |
| | |
| | public Transform Transform; |
| |
|
| | |
| | |
| | |
| | public RayPerceptionCastType CastType; |
| |
|
| | |
| | |
| | |
| | public int LayerMask; |
| |
|
| | |
| | |
| | |
| | public bool UseBatchedRaycasts; |
| |
|
| | |
| | |
| | |
| | |
| | public int OutputSize() |
| | { |
| | return ((DetectableTags?.Count ?? 0) + 2) * (Angles?.Count ?? 0); |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | public (Vector3 StartPositionWorld, Vector3 EndPositionWorld) RayExtents(int rayIndex) |
| | { |
| | var angle = Angles[rayIndex]; |
| | Vector3 startPositionLocal, endPositionLocal; |
| | if (CastType == RayPerceptionCastType.Cast3D) |
| | { |
| | startPositionLocal = new Vector3(0, StartOffset, 0); |
| | endPositionLocal = PolarToCartesian3D(RayLength, angle); |
| | endPositionLocal.y += EndOffset; |
| | } |
| | else |
| | { |
| | |
| | startPositionLocal = new Vector2(); |
| | endPositionLocal = PolarToCartesian2D(RayLength, angle); |
| | } |
| |
|
| | var startPositionWorld = Transform.TransformPoint(startPositionLocal); |
| | var endPositionWorld = Transform.TransformPoint(endPositionLocal); |
| |
|
| | return (StartPositionWorld: startPositionWorld, EndPositionWorld: endPositionWorld); |
| | } |
| |
|
| | |
| | |
| | |
| | static internal Vector3 PolarToCartesian3D(float radius, float angleDegrees) |
| | { |
| | var x = radius * Mathf.Cos(Mathf.Deg2Rad * angleDegrees); |
| | var z = radius * Mathf.Sin(Mathf.Deg2Rad * angleDegrees); |
| | return new Vector3(x, 0f, z); |
| | } |
| |
|
| | |
| | |
| | |
| | static internal Vector2 PolarToCartesian2D(float radius, float angleDegrees) |
| | { |
| | var x = radius * Mathf.Cos(Mathf.Deg2Rad * angleDegrees); |
| | var y = radius * Mathf.Sin(Mathf.Deg2Rad * angleDegrees); |
| | return new Vector2(x, y); |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | public class RayPerceptionOutput |
| | { |
| | |
| | |
| | |
| | public struct RayOutput |
| | { |
| | |
| | |
| | |
| | public bool HasHit; |
| |
|
| | |
| | |
| | |
| | public bool HitTaggedObject; |
| |
|
| | |
| | |
| | |
| | |
| | public int HitTagIndex; |
| |
|
| | |
| | |
| | |
| | public float HitFraction; |
| |
|
| | |
| | |
| | |
| | public GameObject HitGameObject; |
| |
|
| | |
| | |
| | |
| | public Vector3 StartPositionWorld; |
| |
|
| | |
| | |
| | |
| | public Vector3 EndPositionWorld; |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | public float ScaledRayLength |
| | { |
| | get |
| | { |
| | var rayDirection = EndPositionWorld - StartPositionWorld; |
| | return rayDirection.magnitude; |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | public float ScaledCastRadius; |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | public void ToFloatArray(int numDetectableTags, int rayIndex, float[] buffer) |
| | { |
| | var bufferOffset = (numDetectableTags + 2) * rayIndex; |
| | if (HitTaggedObject) |
| | { |
| | buffer[bufferOffset + HitTagIndex] = 1f; |
| | } |
| | buffer[bufferOffset + numDetectableTags] = HasHit ? 0f : 1f; |
| | buffer[bufferOffset + numDetectableTags + 1] = HitFraction; |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | public RayOutput[] RayOutputs; |
| | } |
| |
|
| | |
| | |
| | |
| | public class RayPerceptionSensor : ISensor, IBuiltInSensor |
| | { |
| | float[] m_Observations; |
| | ObservationSpec m_ObservationSpec; |
| | string m_Name; |
| |
|
| | RayPerceptionInput m_RayPerceptionInput; |
| | RayPerceptionOutput m_RayPerceptionOutput; |
| |
|
| | bool m_UseBatchedRaycasts; |
| |
|
| | |
| | |
| | |
| | int m_DebugLastFrameCount; |
| |
|
| | internal int DebugLastFrameCount |
| | { |
| | get { return m_DebugLastFrameCount; } |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | public RayPerceptionSensor(string name, RayPerceptionInput rayInput) |
| | { |
| | m_Name = name; |
| | m_RayPerceptionInput = rayInput; |
| | m_UseBatchedRaycasts = rayInput.UseBatchedRaycasts; |
| |
|
| | SetNumObservations(rayInput.OutputSize()); |
| |
|
| | m_DebugLastFrameCount = Time.frameCount; |
| | m_RayPerceptionOutput = new RayPerceptionOutput(); |
| | } |
| |
|
| | |
| | |
| | |
| | public RayPerceptionOutput RayPerceptionOutput |
| | { |
| | get { return m_RayPerceptionOutput; } |
| | } |
| |
|
| | void SetNumObservations(int numObservations) |
| | { |
| | m_ObservationSpec = ObservationSpec.Vector(numObservations); |
| | m_Observations = new float[numObservations]; |
| | } |
| |
|
| | internal void SetRayPerceptionInput(RayPerceptionInput rayInput) |
| | { |
| | |
| | |
| | if (m_RayPerceptionInput.OutputSize() != rayInput.OutputSize()) |
| | { |
| | Debug.Log( |
| | "Changing the number of tags or rays at runtime is not " + |
| | "supported and may cause errors in training or inference." |
| | ); |
| | |
| | |
| | SetNumObservations(rayInput.OutputSize()); |
| | } |
| | m_RayPerceptionInput = rayInput; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | public int Write(ObservationWriter writer) |
| | { |
| | using (TimerStack.Instance.Scoped("RayPerceptionSensor.Perceive")) |
| | { |
| | Array.Clear(m_Observations, 0, m_Observations.Length); |
| | var numRays = m_RayPerceptionInput.Angles.Count; |
| | var numDetectableTags = m_RayPerceptionInput.DetectableTags.Count; |
| |
|
| | |
| | for (var rayIndex = 0; rayIndex < numRays; rayIndex++) |
| | { |
| | m_RayPerceptionOutput.RayOutputs?[rayIndex].ToFloatArray(numDetectableTags, rayIndex, m_Observations); |
| | } |
| |
|
| | |
| | writer.AddList(m_Observations); |
| | } |
| | return m_Observations.Length; |
| | } |
| |
|
| | |
| | public void Update() |
| | { |
| | m_DebugLastFrameCount = Time.frameCount; |
| | var numRays = m_RayPerceptionInput.Angles.Count; |
| |
|
| | if (m_RayPerceptionOutput.RayOutputs == null || m_RayPerceptionOutput.RayOutputs.Length != numRays) |
| | { |
| | m_RayPerceptionOutput.RayOutputs = new RayPerceptionOutput.RayOutput[numRays]; |
| | } |
| |
|
| | if (m_UseBatchedRaycasts && m_RayPerceptionInput.CastType == RayPerceptionCastType.Cast3D) |
| | { |
| | PerceiveBatchedRays(ref m_RayPerceptionOutput.RayOutputs, m_RayPerceptionInput); |
| | } |
| | else |
| | { |
| | |
| | for (var rayIndex = 0; rayIndex < numRays; rayIndex++) |
| | { |
| | m_RayPerceptionOutput.RayOutputs[rayIndex] = PerceiveSingleRay(m_RayPerceptionInput, rayIndex); |
| | } |
| | } |
| | } |
| |
|
| | |
| | public void Reset() { } |
| |
|
| | |
| | public ObservationSpec GetObservationSpec() |
| | { |
| | return m_ObservationSpec; |
| | } |
| |
|
| | |
| | public string GetName() |
| | { |
| | return m_Name; |
| | } |
| |
|
| | |
| | public virtual byte[] GetCompressedObservation() |
| | { |
| | return null; |
| | } |
| |
|
| | |
| | public CompressionSpec GetCompressionSpec() |
| | { |
| | return CompressionSpec.Default(); |
| | } |
| |
|
| | |
| | public BuiltInSensorType GetBuiltInSensorType() |
| | { |
| | return BuiltInSensorType.RayPerceptionSensor; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | public static RayPerceptionOutput Perceive(RayPerceptionInput input, bool batched) |
| | { |
| | RayPerceptionOutput output = new RayPerceptionOutput(); |
| | output.RayOutputs = new RayPerceptionOutput.RayOutput[input.Angles.Count]; |
| |
|
| | if (batched) |
| | { |
| | PerceiveBatchedRays(ref output.RayOutputs, input); |
| | } |
| | else |
| | { |
| | for (var rayIndex = 0; rayIndex < input.Angles.Count; rayIndex++) |
| | { |
| | output.RayOutputs[rayIndex] = PerceiveSingleRay(input, rayIndex); |
| | } |
| | } |
| |
|
| | return output; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | internal static void PerceiveBatchedRays(ref RayPerceptionOutput.RayOutput[] batchedRaycastOutputs, RayPerceptionInput input) |
| | { |
| | var numRays = input.Angles.Count; |
| | var results = new NativeArray<RaycastHit>(numRays, Allocator.TempJob); |
| | var unscaledRayLength = input.RayLength; |
| | var unscaledCastRadius = input.CastRadius; |
| |
|
| | var raycastCommands = new NativeArray<RaycastCommand>(unscaledCastRadius <= 0f ? numRays : 0, Allocator.TempJob); |
| | var spherecastCommands = new NativeArray<SpherecastCommand>(unscaledCastRadius > 0f ? numRays : 0, Allocator.TempJob); |
| |
|
| | |
| |
|
| | for (int i = 0; i < numRays; i++) |
| | { |
| | var extents = input.RayExtents(i); |
| | var startPositionWorld = extents.StartPositionWorld; |
| | var endPositionWorld = extents.EndPositionWorld; |
| |
|
| | var rayDirection = endPositionWorld - startPositionWorld; |
| | |
| | |
| | |
| | var scaledRayLength = rayDirection.magnitude; |
| | |
| | var scaledCastRadius = unscaledRayLength > 0 ? |
| | unscaledCastRadius * scaledRayLength / unscaledRayLength : |
| | unscaledCastRadius; |
| |
|
| | var queryParameters = QueryParameters.Default; |
| | queryParameters.layerMask = input.LayerMask; |
| |
|
| | var rayDirectionNormalized = rayDirection.normalized; |
| |
|
| | if (scaledCastRadius > 0f) |
| | { |
| | spherecastCommands[i] = new SpherecastCommand(startPositionWorld, scaledCastRadius, rayDirectionNormalized, queryParameters, scaledRayLength); |
| | } |
| | else |
| | { |
| | raycastCommands[i] = new RaycastCommand(startPositionWorld, rayDirectionNormalized, queryParameters, scaledRayLength); |
| | } |
| |
|
| | batchedRaycastOutputs[i] = new RayPerceptionOutput.RayOutput |
| | { |
| | HitTaggedObject = false, |
| | HitTagIndex = -1, |
| | StartPositionWorld = startPositionWorld, |
| | EndPositionWorld = endPositionWorld, |
| | ScaledCastRadius = scaledCastRadius |
| | }; |
| |
|
| | } |
| |
|
| | if (unscaledCastRadius > 0f) |
| | { |
| | JobHandle handle = SpherecastCommand.ScheduleBatch(spherecastCommands, results, 1, 1, default(JobHandle)); |
| | handle.Complete(); |
| | } |
| | else |
| | { |
| | JobHandle handle = RaycastCommand.ScheduleBatch(raycastCommands, results, 1, 1, default(JobHandle)); |
| | handle.Complete(); |
| | } |
| |
|
| | for (int i = 0; i < results.Length; i++) |
| | { |
| | var castHit = results[i].collider != null; |
| | var hitFraction = 1.0f; |
| | GameObject hitObject = null; |
| | float scaledRayLength; |
| | float scaledCastRadius = batchedRaycastOutputs[i].ScaledCastRadius; |
| | if (scaledCastRadius > 0f) |
| | { |
| | scaledRayLength = spherecastCommands[i].distance; |
| | } |
| | else |
| | { |
| | scaledRayLength = raycastCommands[i].distance; |
| | } |
| |
|
| | |
| | |
| | hitFraction = castHit ? (scaledRayLength > 0 ? results[i].distance / scaledRayLength : 0.0f) : 1.0f; |
| | hitObject = castHit ? results[i].collider.gameObject : null; |
| |
|
| | if (castHit) |
| | { |
| | var numTags = input.DetectableTags?.Count ?? 0; |
| | for (int j = 0; j < numTags; j++) |
| | { |
| | var tagsEqual = false; |
| | try |
| | { |
| | var tag = input.DetectableTags[j]; |
| | if (!string.IsNullOrEmpty(tag)) |
| | { |
| | tagsEqual = hitObject.CompareTag(tag); |
| | } |
| | } |
| | catch (UnityException) |
| | { |
| | } |
| |
|
| | if (tagsEqual) |
| | { |
| | batchedRaycastOutputs[i].HitTaggedObject = true; |
| | batchedRaycastOutputs[i].HitTagIndex = j; |
| | break; |
| | } |
| | } |
| | } |
| |
|
| | batchedRaycastOutputs[i].HasHit = castHit; |
| | batchedRaycastOutputs[i].HitFraction = hitFraction; |
| | batchedRaycastOutputs[i].HitGameObject = hitObject; |
| |
|
| | } |
| |
|
| | results.Dispose(); |
| | raycastCommands.Dispose(); |
| | spherecastCommands.Dispose(); |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | internal static RayPerceptionOutput.RayOutput PerceiveSingleRay( |
| | RayPerceptionInput input, |
| | int rayIndex |
| | ) |
| | { |
| | var unscaledRayLength = input.RayLength; |
| | var unscaledCastRadius = input.CastRadius; |
| |
|
| | var extents = input.RayExtents(rayIndex); |
| | var startPositionWorld = extents.StartPositionWorld; |
| | var endPositionWorld = extents.EndPositionWorld; |
| |
|
| | var rayDirection = endPositionWorld - startPositionWorld; |
| | |
| | |
| | |
| | var scaledRayLength = rayDirection.magnitude; |
| | |
| | var scaledCastRadius = unscaledRayLength > 0 ? |
| | unscaledCastRadius * scaledRayLength / unscaledRayLength : |
| | unscaledCastRadius; |
| |
|
| | |
| | var castHit = false; |
| | var hitFraction = 1.0f; |
| | GameObject hitObject = null; |
| |
|
| | if (input.CastType == RayPerceptionCastType.Cast3D) |
| | { |
| | #if MLA_UNITY_PHYSICS_MODULE |
| | RaycastHit rayHit; |
| | if (scaledCastRadius > 0f) |
| | { |
| | castHit = Physics.SphereCast(startPositionWorld, scaledCastRadius, rayDirection, out rayHit, |
| | scaledRayLength, input.LayerMask); |
| | } |
| | else |
| | { |
| | castHit = Physics.Raycast(startPositionWorld, rayDirection, out rayHit, |
| | scaledRayLength, input.LayerMask); |
| | } |
| |
|
| | |
| | |
| | hitFraction = castHit ? (scaledRayLength > 0 ? rayHit.distance / scaledRayLength : 0.0f) : 1.0f; |
| | hitObject = castHit ? rayHit.collider.gameObject : null; |
| | #endif |
| | } |
| | else |
| | { |
| | #if MLA_UNITY_PHYSICS2D_MODULE |
| | RaycastHit2D rayHit; |
| | if (scaledCastRadius > 0f) |
| | { |
| | rayHit = Physics2D.CircleCast(startPositionWorld, scaledCastRadius, rayDirection, |
| | scaledRayLength, input.LayerMask); |
| | } |
| | else |
| | { |
| | rayHit = Physics2D.Raycast(startPositionWorld, rayDirection, scaledRayLength, input.LayerMask); |
| | } |
| |
|
| | castHit = rayHit; |
| | hitFraction = castHit ? rayHit.fraction : 1.0f; |
| | hitObject = castHit ? rayHit.collider.gameObject : null; |
| | #endif |
| | } |
| |
|
| | var rayOutput = new RayPerceptionOutput.RayOutput |
| | { |
| | HasHit = castHit, |
| | HitFraction = hitFraction, |
| | HitTaggedObject = false, |
| | HitTagIndex = -1, |
| | HitGameObject = hitObject, |
| | StartPositionWorld = startPositionWorld, |
| | EndPositionWorld = endPositionWorld, |
| | ScaledCastRadius = scaledCastRadius |
| | }; |
| |
|
| | if (castHit) |
| | { |
| | |
| | var numTags = input.DetectableTags?.Count ?? 0; |
| | for (var i = 0; i < numTags; i++) |
| | { |
| | var tagsEqual = false; |
| | try |
| | { |
| | var tag = input.DetectableTags[i]; |
| | if (!string.IsNullOrEmpty(tag)) |
| | { |
| | tagsEqual = hitObject.CompareTag(tag); |
| | } |
| | } |
| | catch (UnityException) |
| | { |
| | |
| | } |
| |
|
| | if (tagsEqual) |
| | { |
| | rayOutput.HitTaggedObject = true; |
| | rayOutput.HitTagIndex = i; |
| | break; |
| | } |
| | } |
| | } |
| |
|
| |
|
| | return rayOutput; |
| | } |
| | } |
| | } |
| |
|