| | using System; |
| | using System.Collections.Generic; |
| | using System.Linq; |
| | using NUnit.Framework; |
| | using Unity.MLAgents.Actuators; |
| | using Assert = UnityEngine.Assertions.Assert; |
| |
|
| | namespace Unity.MLAgents.Tests.Actuators |
| | { |
| | [TestFixture] |
| | public class VectorActuatorTests |
| | { |
| | class TestActionReceiver : IActionReceiver, IHeuristicProvider |
| | { |
| | public ActionBuffers LastActionBuffers; |
| | public int Branch; |
| | public IList<int> Mask; |
| | public ActionSpec ActionSpec { get; } |
| | public bool HeuristicCalled; |
| |
|
| | public void OnActionReceived(ActionBuffers actionBuffers) |
| | { |
| | LastActionBuffers = actionBuffers; |
| | } |
| |
|
| | public void WriteDiscreteActionMask(IDiscreteActionMask actionMask) |
| | { |
| | foreach (var actionIndex in Mask) |
| | { |
| | actionMask.SetActionEnabled(Branch, actionIndex, false); |
| | } |
| | } |
| |
|
| | public void Heuristic(in ActionBuffers actionBuffersOut) |
| | { |
| | HeuristicCalled = true; |
| | } |
| | } |
| |
|
| | [Test] |
| | public void TestConstruct() |
| | { |
| | var ar = new TestActionReceiver(); |
| | var va = new VectorActuator(ar, ActionSpec.MakeDiscrete(1, 2, 3), "name"); |
| |
|
| | Assert.IsTrue(va.ActionSpec.NumDiscreteActions == 3); |
| | Assert.IsTrue(va.ActionSpec.SumOfDiscreteBranchSizes == 6); |
| | Assert.IsTrue(va.ActionSpec.NumContinuousActions == 0); |
| |
|
| | var va1 = new VectorActuator(ar, ActionSpec.MakeContinuous(4), "name"); |
| |
|
| | Assert.IsTrue(va1.ActionSpec.NumContinuousActions == 4); |
| | Assert.IsTrue(va1.ActionSpec.SumOfDiscreteBranchSizes == 0); |
| | Assert.AreEqual(va1.Name, "name-Continuous"); |
| | } |
| |
|
| | [Test] |
| | public void TestOnActionReceived() |
| | { |
| | var ar = new TestActionReceiver(); |
| | var va = new VectorActuator(ar, ActionSpec.MakeDiscrete(1, 2, 3), "name"); |
| |
|
| | var discreteActions = new[] { 0, 1, 1 }; |
| | var ab = new ActionBuffers(ActionSegment<float>.Empty, |
| | new ActionSegment<int>(discreteActions, 0, 3)); |
| |
|
| | va.OnActionReceived(ab); |
| |
|
| | Assert.AreEqual(ar.LastActionBuffers, ab); |
| | va.ResetData(); |
| | Assert.AreEqual(va.ActionBuffers.ContinuousActions, ActionSegment<float>.Empty); |
| | Assert.AreEqual(va.ActionBuffers.DiscreteActions, ActionSegment<int>.Empty); |
| | } |
| |
|
| | [Test] |
| | public void TestResetData() |
| | { |
| | var ar = new TestActionReceiver(); |
| | var va = new VectorActuator(ar, ActionSpec.MakeDiscrete(1, 2, 3), "name"); |
| |
|
| | var discreteActions = new[] { 0, 1, 1 }; |
| | var ab = new ActionBuffers(ActionSegment<float>.Empty, |
| | new ActionSegment<int>(discreteActions, 0, 3)); |
| |
|
| | va.OnActionReceived(ab); |
| | } |
| |
|
| | [Test] |
| | public void TestWriteDiscreteActionMask() |
| | { |
| | var ar = new TestActionReceiver(); |
| | var va = new VectorActuator(ar, ActionSpec.MakeDiscrete(1, 2, 3), "name"); |
| | var bdam = new ActuatorDiscreteActionMask(new[] { va }, 6, 3); |
| |
|
| | var groundTruthMask = new[] { false, true, false, false, true, true }; |
| |
|
| | ar.Branch = 1; |
| | ar.Mask = new[] { 0 }; |
| | va.WriteDiscreteActionMask(bdam); |
| | ar.Branch = 2; |
| | ar.Mask = new[] { 1, 2 }; |
| | va.WriteDiscreteActionMask(bdam); |
| |
|
| | Assert.IsTrue(groundTruthMask.SequenceEqual(bdam.GetMask())); |
| | } |
| |
|
| | [Test] |
| | public void TestHeuristic() |
| | { |
| | var ar = new TestActionReceiver(); |
| | var va = new VectorActuator(ar, ActionSpec.MakeDiscrete(1, 2, 3), "name"); |
| |
|
| | va.Heuristic(new ActionBuffers(Array.Empty<float>(), va.ActionSpec.BranchSizes)); |
| | Assert.IsTrue(ar.HeuristicCalled); |
| | } |
| | } |
| | } |
| |
|