| | using NUnit.Framework; |
| | using Unity.MLAgents.Sensors; |
| |
|
| | namespace Unity.MLAgents.Tests |
| | { |
| | public class Float2DSensor : ISensor |
| | { |
| | public int Width { get; } |
| | public int Height { get; } |
| | string m_Name; |
| | private ObservationSpec m_ObservationSpec; |
| | public float[,] floatData; |
| |
|
| | public Float2DSensor(int width, int height, string name) |
| | { |
| | Width = width; |
| | Height = height; |
| | m_Name = name; |
| |
|
| | m_ObservationSpec = ObservationSpec.Visual(height, width, 1); |
| | floatData = new float[Height, Width]; |
| | } |
| |
|
| | public Float2DSensor(float[,] floatData, string name) |
| | { |
| | this.floatData = floatData; |
| | Height = floatData.GetLength(0); |
| | Width = floatData.GetLength(1); |
| | m_Name = name; |
| | m_ObservationSpec = ObservationSpec.Visual(Height, Width, 1); |
| | } |
| |
|
| | public string GetName() |
| | { |
| | return m_Name; |
| | } |
| |
|
| | public ObservationSpec GetObservationSpec() |
| | { |
| | return m_ObservationSpec; |
| | } |
| |
|
| | public byte[] GetCompressedObservation() |
| | { |
| | return null; |
| | } |
| |
|
| | public int Write(ObservationWriter writer) |
| | { |
| | using (TimerStack.Instance.Scoped("Float2DSensor.Write")) |
| | { |
| | for (var h = 0; h < Height; h++) |
| | { |
| | for (var w = 0; w < Width; w++) |
| | { |
| | writer[h, w, 0] = floatData[h, w]; |
| | } |
| | } |
| | var numWritten = Height * Width; |
| | return numWritten; |
| | } |
| | } |
| |
|
| | public void Update() { } |
| | public void Reset() { } |
| |
|
| | public CompressionSpec GetCompressionSpec() |
| | { |
| | return CompressionSpec.Default(); |
| | } |
| | } |
| |
|
| | public class FloatVisualSensorTests |
| | { |
| | [Test] |
| | public void TestFloat2DSensorWrite() |
| | { |
| | var sensor = new Float2DSensor(3, 4, "floatsensor"); |
| | for (var h = 0; h < 4; h++) |
| | { |
| | for (var w = 0; w < 3; w++) |
| | { |
| | sensor.floatData[h, w] = 3 * h + w; |
| | } |
| | } |
| |
|
| | var output = new float[12]; |
| | var writer = new ObservationWriter(); |
| | writer.SetTarget(output, sensor.GetObservationSpec(), 0); |
| | sensor.Write(writer); |
| | for (var i = 0; i < 9; i++) |
| | { |
| | Assert.AreEqual(i, output[i]); |
| | } |
| | } |
| |
|
| | [Test] |
| | public void TestFloat2DSensorExternalData() |
| | { |
| | var data = new float[4, 3]; |
| | var sensor = new Float2DSensor(data, "floatsensor"); |
| | Assert.AreEqual(sensor.Height, 4); |
| | Assert.AreEqual(sensor.Width, 3); |
| | } |
| | } |
| | } |
| |
|