ppo-Pyramids-Training / com.unity.ml-agents /Tests /Editor /Inference /EditModeTestInternalBrainTensorGenerator.cs
| using System.Collections.Generic; | |
| using Unity.Barracuda; | |
| using NUnit.Framework; | |
| using UnityEngine; | |
| using Unity.MLAgents.Actuators; | |
| using Unity.MLAgents.Inference; | |
| using Unity.MLAgents.Policies; | |
| using Unity.MLAgents.Utils.Tests; | |
| namespace Unity.MLAgents.Tests | |
| { | |
| [] | |
| public class EditModeTestInternalBrainTensorGenerator | |
| { | |
| [] | |
| public void SetUp() | |
| { | |
| if (Academy.IsInitialized) | |
| { | |
| Academy.Instance.Dispose(); | |
| } | |
| } | |
| static List<TestAgent> GetFakeAgents(ObservableAttributeOptions observableAttributeOptions = ObservableAttributeOptions.Ignore) | |
| { | |
| var goA = new GameObject("goA"); | |
| var bpA = goA.AddComponent<BehaviorParameters>(); | |
| bpA.BrainParameters.VectorObservationSize = 3; | |
| bpA.BrainParameters.NumStackedVectorObservations = 1; | |
| bpA.ObservableAttributeHandling = observableAttributeOptions; | |
| var agentA = goA.AddComponent<TestAgent>(); | |
| var goB = new GameObject("goB"); | |
| var bpB = goB.AddComponent<BehaviorParameters>(); | |
| bpB.BrainParameters.VectorObservationSize = 3; | |
| bpB.BrainParameters.NumStackedVectorObservations = 1; | |
| bpB.ObservableAttributeHandling = observableAttributeOptions; | |
| var agentB = goB.AddComponent<TestAgent>(); | |
| var agents = new List<TestAgent> { agentA, agentB }; | |
| foreach (var agent in agents) | |
| { | |
| agent.LazyInitialize(); | |
| } | |
| agentA.collectObservationsSensor.AddObservation(new Vector3(1, 2, 3)); | |
| agentB.collectObservationsSensor.AddObservation(new Vector3(4, 5, 6)); | |
| var infoA = new AgentInfo | |
| { | |
| storedActions = new ActionBuffers(null, new[] { 1, 2 }), | |
| discreteActionMasks = null, | |
| }; | |
| var infoB = new AgentInfo | |
| { | |
| storedActions = new ActionBuffers(null, new[] { 3, 4 }), | |
| discreteActionMasks = new[] { true, false, false, false, false }, | |
| }; | |
| agentA._Info = infoA; | |
| agentB._Info = infoB; | |
| return agents; | |
| } | |
| [] | |
| public void Construction() | |
| { | |
| var alloc = new TensorCachingAllocator(); | |
| var mem = new Dictionary<int, List<float>>(); | |
| var tensorGenerator = new TensorGenerator(0, alloc, mem); | |
| Assert.IsNotNull(tensorGenerator); | |
| alloc.Dispose(); | |
| } | |
| [] | |
| public void GenerateBatchSize() | |
| { | |
| var inputTensor = new TensorProxy(); | |
| var alloc = new TensorCachingAllocator(); | |
| const int batchSize = 4; | |
| var generator = new BatchSizeGenerator(alloc); | |
| generator.Generate(inputTensor, batchSize, null); | |
| Assert.IsNotNull(inputTensor.data); | |
| Assert.AreEqual(inputTensor.data[0], batchSize); | |
| alloc.Dispose(); | |
| } | |
| [] | |
| public void GenerateSequenceLength() | |
| { | |
| var inputTensor = new TensorProxy(); | |
| var alloc = new TensorCachingAllocator(); | |
| const int batchSize = 4; | |
| var generator = new SequenceLengthGenerator(alloc); | |
| generator.Generate(inputTensor, batchSize, null); | |
| Assert.IsNotNull(inputTensor.data); | |
| Assert.AreEqual(inputTensor.data[0], 1); | |
| alloc.Dispose(); | |
| } | |
| [] | |
| public void GenerateVectorObservation() | |
| { | |
| var inputTensor = new TensorProxy | |
| { | |
| shape = new long[] { 2, 4 } | |
| }; | |
| const int batchSize = 4; | |
| var agentInfos = GetFakeAgents(ObservableAttributeOptions.ExamineAll); | |
| var alloc = new TensorCachingAllocator(); | |
| var generator = new ObservationGenerator(alloc); | |
| generator.AddSensorIndex(0); // ObservableAttribute (size 1) | |
| generator.AddSensorIndex(1); // TestSensor (size 0) | |
| generator.AddSensorIndex(2); // TestSensor (size 0) | |
| generator.AddSensorIndex(3); // VectorSensor (size 3) | |
| var agent0 = agentInfos[0]; | |
| var agent1 = agentInfos[1]; | |
| var inputs = new List<AgentInfoSensorsPair> | |
| { | |
| new AgentInfoSensorsPair {agentInfo = agent0._Info, sensors = agent0.sensors}, | |
| new AgentInfoSensorsPair {agentInfo = agent1._Info, sensors = agent1.sensors}, | |
| }; | |
| generator.Generate(inputTensor, batchSize, inputs); | |
| Assert.IsNotNull(inputTensor.data); | |
| Assert.AreEqual(inputTensor.data[0, 1], 1); | |
| Assert.AreEqual(inputTensor.data[0, 3], 3); | |
| Assert.AreEqual(inputTensor.data[1, 1], 4); | |
| Assert.AreEqual(inputTensor.data[1, 3], 6); | |
| alloc.Dispose(); | |
| } | |
| [] | |
| public void GeneratePreviousActionInput() | |
| { | |
| var inputTensor = new TensorProxy | |
| { | |
| shape = new long[] { 2, 2 }, | |
| valueType = TensorProxy.TensorType.Integer | |
| }; | |
| const int batchSize = 4; | |
| var agentInfos = GetFakeAgents(); | |
| var alloc = new TensorCachingAllocator(); | |
| var generator = new PreviousActionInputGenerator(alloc); | |
| var agent0 = agentInfos[0]; | |
| var agent1 = agentInfos[1]; | |
| var inputs = new List<AgentInfoSensorsPair> | |
| { | |
| new AgentInfoSensorsPair {agentInfo = agent0._Info, sensors = agent0.sensors}, | |
| new AgentInfoSensorsPair {agentInfo = agent1._Info, sensors = agent1.sensors}, | |
| }; | |
| generator.Generate(inputTensor, batchSize, inputs); | |
| Assert.IsNotNull(inputTensor.data); | |
| Assert.AreEqual(inputTensor.data[0, 0], 1); | |
| Assert.AreEqual(inputTensor.data[0, 1], 2); | |
| Assert.AreEqual(inputTensor.data[1, 0], 3); | |
| Assert.AreEqual(inputTensor.data[1, 1], 4); | |
| alloc.Dispose(); | |
| } | |
| [] | |
| public void GenerateActionMaskInput() | |
| { | |
| var inputTensor = new TensorProxy | |
| { | |
| shape = new long[] { 2, 5 }, | |
| valueType = TensorProxy.TensorType.FloatingPoint | |
| }; | |
| const int batchSize = 4; | |
| var agentInfos = GetFakeAgents(); | |
| var alloc = new TensorCachingAllocator(); | |
| var generator = new ActionMaskInputGenerator(alloc); | |
| var agent0 = agentInfos[0]; | |
| var agent1 = agentInfos[1]; | |
| var inputs = new List<AgentInfoSensorsPair> | |
| { | |
| new AgentInfoSensorsPair {agentInfo = agent0._Info, sensors = agent0.sensors}, | |
| new AgentInfoSensorsPair {agentInfo = agent1._Info, sensors = agent1.sensors}, | |
| }; | |
| generator.Generate(inputTensor, batchSize, inputs); | |
| Assert.IsNotNull(inputTensor.data); | |
| Assert.AreEqual(inputTensor.data[0, 0], 1); | |
| Assert.AreEqual(inputTensor.data[0, 4], 1); | |
| Assert.AreEqual(inputTensor.data[1, 0], 0); | |
| Assert.AreEqual(inputTensor.data[1, 4], 1); | |
| alloc.Dispose(); | |
| } | |
| } | |
| } | |