| using System.Collections.Generic; | |
| using UnityEngine; | |
| using Unity.MLAgents.Sensors; | |
| namespace Unity.MLAgents.Tests | |
| { | |
| public static class TestGridSensorConfig | |
| { | |
| public static int ObservationSize; | |
| public static bool IsNormalized; | |
| public static bool ParseAllColliders; | |
| public static void SetParameters(int observationSize, bool isNormalized, bool parseAllColliders) | |
| { | |
| ObservationSize = observationSize; | |
| IsNormalized = isNormalized; | |
| ParseAllColliders = parseAllColliders; | |
| } | |
| public static void Reset() | |
| { | |
| ObservationSize = 0; | |
| IsNormalized = false; | |
| ParseAllColliders = false; | |
| } | |
| } | |
| public class SimpleTestGridSensor : GridSensorBase | |
| { | |
| public float[] DummyData; | |
| public SimpleTestGridSensor( | |
| string name, | |
| Vector3 cellScale, | |
| Vector3Int gridSize, | |
| string[] detectableTags, | |
| SensorCompressionType compression | |
| ) : base( | |
| name, | |
| cellScale, | |
| gridSize, | |
| detectableTags, | |
| compression) | |
| { } | |
| protected override int GetCellObservationSize() | |
| { | |
| return TestGridSensorConfig.ObservationSize; | |
| } | |
| protected override bool IsDataNormalized() | |
| { | |
| return TestGridSensorConfig.IsNormalized; | |
| } | |
| protected internal override ProcessCollidersMethod GetProcessCollidersMethod() | |
| { | |
| return TestGridSensorConfig.ParseAllColliders ? ProcessCollidersMethod.ProcessAllColliders : ProcessCollidersMethod.ProcessClosestColliders; | |
| } | |
| protected override void GetObjectData(GameObject detectedObject, int typeIndex, float[] dataBuffer) | |
| { | |
| for (var i = 0; i < DummyData.Length; i++) | |
| { | |
| dataBuffer[i] = DummyData[i]; | |
| } | |
| } | |
| } | |
| public class SimpleTestGridSensorComponent : GridSensorComponent | |
| { | |
| bool m_UseOneHotTag; | |
| bool m_UseTestingGridSensor; | |
| bool m_UseGridSensorBase; | |
| protected override GridSensorBase[] GetGridSensors() | |
| { | |
| List<GridSensorBase> sensorList = new List<GridSensorBase>(); | |
| if (m_UseOneHotTag) | |
| { | |
| var testSensor = new OneHotGridSensor( | |
| SensorName, | |
| CellScale, | |
| GridSize, | |
| DetectableTags, | |
| CompressionType | |
| ); | |
| sensorList.Add(testSensor); | |
| } | |
| if (m_UseGridSensorBase) | |
| { | |
| var testSensor = new GridSensorBase( | |
| SensorName, | |
| CellScale, | |
| GridSize, | |
| DetectableTags, | |
| CompressionType | |
| ); | |
| sensorList.Add(testSensor); | |
| } | |
| if (m_UseTestingGridSensor) | |
| { | |
| var testSensor = new SimpleTestGridSensor( | |
| SensorName, | |
| CellScale, | |
| GridSize, | |
| DetectableTags, | |
| CompressionType | |
| ); | |
| sensorList.Add(testSensor); | |
| } | |
| return sensorList.ToArray(); | |
| } | |
| public void SetComponentParameters( | |
| string[] detectableTags = null, | |
| float cellScaleX = 1f, | |
| float cellScaleZ = 1f, | |
| int gridSizeX = 10, | |
| int gridSizeY = 1, | |
| int gridSizeZ = 10, | |
| int colliderMaskInt = -1, | |
| SensorCompressionType compression = SensorCompressionType.None, | |
| bool rotateWithAgent = false, | |
| bool useOneHotTag = false, | |
| bool useTestingGridSensor = false, | |
| bool useGridSensorBase = false | |
| ) | |
| { | |
| DetectableTags = detectableTags; | |
| CellScale = new Vector3(cellScaleX, 0.01f, cellScaleZ); | |
| GridSize = new Vector3Int(gridSizeX, gridSizeY, gridSizeZ); | |
| ColliderMask = colliderMaskInt < 0 ? LayerMask.GetMask("Default") : colliderMaskInt; | |
| RotateWithAgent = rotateWithAgent; | |
| CompressionType = compression; | |
| m_UseOneHotTag = useOneHotTag; | |
| m_UseGridSensorBase = useGridSensorBase; | |
| m_UseTestingGridSensor = useTestingGridSensor; | |
| } | |
| } | |
| } | |