| | using System.Collections.Generic; |
| | using System.Linq; |
| | using UnityEngine; |
| |
|
| | namespace Unity.MLAgents.Sensors |
| | { |
| | |
| | |
| | |
| | [AddComponentMenu("ML Agents/Grid Sensor", (int)MenuGroup.Sensors)] |
| | public class GridSensorComponent : SensorComponent |
| | { |
| | |
| | GridSensorBase m_DebugSensor; |
| | List<GridSensorBase> m_Sensors; |
| | internal IGridPerception m_GridPerception; |
| |
|
| | [HideInInspector, SerializeField] |
| | protected internal string m_SensorName = "GridSensor"; |
| | |
| | |
| | |
| | |
| | public string SensorName |
| | { |
| | get { return m_SensorName; } |
| | set { m_SensorName = value; } |
| | } |
| |
|
| | [HideInInspector, SerializeField] |
| | internal Vector3 m_CellScale = new Vector3(1f, 0.01f, 1f); |
| |
|
| | |
| | |
| | |
| | |
| | public Vector3 CellScale |
| | { |
| | get { return m_CellScale; } |
| | set { m_CellScale = value; } |
| | } |
| |
|
| | [HideInInspector, SerializeField] |
| | internal Vector3Int m_GridSize = new Vector3Int(16, 1, 16); |
| | |
| | |
| | |
| | |
| | public Vector3Int GridSize |
| | { |
| | get { return m_GridSize; } |
| | set |
| | { |
| | if (value.y != 1) |
| | { |
| | m_GridSize = new Vector3Int(value.x, 1, value.z); |
| | } |
| | else |
| | { |
| | m_GridSize = value; |
| | } |
| | } |
| | } |
| |
|
| | [HideInInspector, SerializeField] |
| | internal bool m_RotateWithAgent = true; |
| | |
| | |
| | |
| | public bool RotateWithAgent |
| | { |
| | get { return m_RotateWithAgent; } |
| | set { m_RotateWithAgent = value; } |
| | } |
| |
|
| | [HideInInspector, SerializeField] |
| | internal GameObject m_AgentGameObject; |
| | |
| | |
| | |
| | |
| | public GameObject AgentGameObject |
| | { |
| | get { return (m_AgentGameObject == null ? gameObject : m_AgentGameObject); } |
| | set { m_AgentGameObject = value; } |
| | } |
| |
|
| | [HideInInspector, SerializeField] |
| | internal string[] m_DetectableTags; |
| | |
| | |
| | |
| | |
| | public string[] DetectableTags |
| | { |
| | get { return m_DetectableTags; } |
| | set { m_DetectableTags = value; } |
| | } |
| |
|
| | [HideInInspector, SerializeField] |
| | internal LayerMask m_ColliderMask; |
| | |
| | |
| | |
| | public LayerMask ColliderMask |
| | { |
| | get { return m_ColliderMask; } |
| | set { m_ColliderMask = value; } |
| | } |
| |
|
| | [HideInInspector, SerializeField] |
| | internal int m_MaxColliderBufferSize = 500; |
| | |
| | |
| | |
| | |
| | |
| | public int MaxColliderBufferSize |
| | { |
| | get { return m_MaxColliderBufferSize; } |
| | set { m_MaxColliderBufferSize = value; } |
| | } |
| |
|
| | [HideInInspector, SerializeField] |
| | internal int m_InitialColliderBufferSize = 4; |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | public int InitialColliderBufferSize |
| | { |
| | get { return m_InitialColliderBufferSize; } |
| | set { m_InitialColliderBufferSize = value; } |
| | } |
| |
|
| | [HideInInspector, SerializeField] |
| | internal Color[] m_DebugColors; |
| | |
| | |
| | |
| | public Color[] DebugColors |
| | { |
| | get { return m_DebugColors; } |
| | set { m_DebugColors = value; } |
| | } |
| |
|
| | [HideInInspector, SerializeField] |
| | internal float m_GizmoYOffset = 0f; |
| | |
| | |
| | |
| | public float GizmoYOffset |
| | { |
| | get { return m_GizmoYOffset; } |
| | set { m_GizmoYOffset = value; } |
| | } |
| |
|
| | [HideInInspector, SerializeField] |
| | internal bool m_ShowGizmos = false; |
| | |
| | |
| | |
| | public bool ShowGizmos |
| | { |
| | get { return m_ShowGizmos; } |
| | set { m_ShowGizmos = value; } |
| | } |
| |
|
| | [HideInInspector, SerializeField] |
| | internal SensorCompressionType m_CompressionType = SensorCompressionType.PNG; |
| | |
| | |
| | |
| | public SensorCompressionType CompressionType |
| | { |
| | get { return m_CompressionType; } |
| | set { m_CompressionType = value; UpdateSensor(); } |
| | } |
| |
|
| | [HideInInspector, SerializeField] |
| | [Range(1, 50)] |
| | [Tooltip("Number of frames of observations that will be stacked before being fed to the neural network.")] |
| | internal int m_ObservationStacks = 1; |
| | |
| | |
| | |
| | |
| | public int ObservationStacks |
| | { |
| | get { return m_ObservationStacks; } |
| | set { m_ObservationStacks = value; } |
| | } |
| |
|
| | |
| | public override ISensor[] CreateSensors() |
| | { |
| | m_GridPerception = new BoxOverlapChecker( |
| | m_CellScale, |
| | m_GridSize, |
| | m_RotateWithAgent, |
| | m_ColliderMask, |
| | gameObject, |
| | AgentGameObject, |
| | m_DetectableTags, |
| | m_InitialColliderBufferSize, |
| | m_MaxColliderBufferSize |
| | ); |
| |
|
| | |
| | m_DebugSensor = new GridSensorBase("DebugGridSensor", m_CellScale, m_GridSize, m_DetectableTags, SensorCompressionType.None); |
| | m_GridPerception.RegisterDebugSensor(m_DebugSensor); |
| |
|
| | m_Sensors = GetGridSensors().ToList(); |
| | if (m_Sensors == null || m_Sensors.Count < 1) |
| | { |
| | throw new UnityAgentsException("GridSensorComponent received no sensors. Specify at least one observation type (OneHot/Counting) to use grid sensors." + |
| | "If you're overriding GridSensorComponent.GetGridSensors(), return at least one grid sensor."); |
| | } |
| |
|
| | |
| | m_Sensors[0].m_GridPerception = m_GridPerception; |
| | foreach (var sensor in m_Sensors) |
| | { |
| | m_GridPerception.RegisterSensor(sensor); |
| | } |
| |
|
| | if (ObservationStacks != 1) |
| | { |
| | var sensors = new ISensor[m_Sensors.Count]; |
| | for (var i = 0; i < m_Sensors.Count; i++) |
| | { |
| | sensors[i] = new StackingSensor(m_Sensors[i], ObservationStacks); |
| | } |
| | return sensors; |
| | } |
| | else |
| | { |
| | return m_Sensors.ToArray(); |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | protected virtual GridSensorBase[] GetGridSensors() |
| | { |
| | List<GridSensorBase> sensorList = new List<GridSensorBase>(); |
| | var sensor = new OneHotGridSensor(m_SensorName + "-OneHot", m_CellScale, m_GridSize, m_DetectableTags, m_CompressionType); |
| | sensorList.Add(sensor); |
| | return sensorList.ToArray(); |
| | } |
| |
|
| | |
| | |
| | |
| | internal void UpdateSensor() |
| | { |
| | if (m_Sensors != null) |
| | { |
| | m_GridPerception.RotateWithAgent = m_RotateWithAgent; |
| | m_GridPerception.ColliderMask = m_ColliderMask; |
| | foreach (var sensor in m_Sensors) |
| | { |
| | sensor.CompressionType = m_CompressionType; |
| | } |
| | } |
| | } |
| |
|
| | void OnDrawGizmos() |
| | { |
| | if (m_ShowGizmos) |
| | { |
| | if (m_GridPerception == null || m_DebugSensor == null) |
| | { |
| | return; |
| | } |
| |
|
| | m_DebugSensor.ResetPerceptionBuffer(); |
| | m_GridPerception.UpdateGizmo(); |
| | var cellColors = m_DebugSensor.PerceptionBuffer; |
| | var rotation = m_GridPerception.GetGridRotation(); |
| |
|
| | var scale = new Vector3(m_CellScale.x, m_CellScale.y, m_CellScale.z); |
| | var gizmoYOffset = new Vector3(0, m_GizmoYOffset, 0); |
| | var oldGizmoMatrix = Gizmos.matrix; |
| | for (var i = 0; i < m_DebugSensor.PerceptionBuffer.Length; i++) |
| | { |
| | var cellPosition = m_GridPerception.GetCellGlobalPosition(i); |
| | var cubeTransform = Matrix4x4.TRS(cellPosition + gizmoYOffset, rotation, scale); |
| | Gizmos.matrix = oldGizmoMatrix * cubeTransform; |
| | var colorIndex = cellColors[i] - 1; |
| | var debugRayColor = Color.white; |
| | if (colorIndex > -1 && m_DebugColors.Length > colorIndex) |
| | { |
| | debugRayColor = m_DebugColors[(int)colorIndex]; |
| | } |
| | Gizmos.color = new Color(debugRayColor.r, debugRayColor.g, debugRayColor.b, .5f); |
| | Gizmos.DrawCube(Vector3.zero, Vector3.one); |
| | } |
| |
|
| | Gizmos.matrix = oldGizmoMatrix; |
| | } |
| | } |
| | } |
| | } |
| |
|