| | using System; |
| | using System.Collections.Generic; |
| |
|
| | namespace Unity.MLAgents.Actuators |
| | { |
| | |
| | |
| | |
| | internal class ActuatorDiscreteActionMask : IDiscreteActionMask |
| | { |
| | |
| | |
| | int[] m_StartingActionIndices; |
| |
|
| | int[] m_BranchSizes; |
| |
|
| | bool[] m_CurrentMask; |
| |
|
| | IList<IActuator> m_Actuators; |
| |
|
| | readonly int m_SumOfDiscreteBranchSizes; |
| | readonly int m_NumBranches; |
| |
|
| | |
| | |
| | |
| | public int CurrentBranchOffset { get; set; } |
| |
|
| | internal ActuatorDiscreteActionMask(IList<IActuator> actuators, int sumOfDiscreteBranchSizes, int numBranches, int[] branchSizes = null) |
| | { |
| | m_Actuators = actuators; |
| | m_SumOfDiscreteBranchSizes = sumOfDiscreteBranchSizes; |
| | m_NumBranches = numBranches; |
| | m_BranchSizes = branchSizes; |
| | } |
| |
|
| | |
| | public void SetActionEnabled(int branch, int actionIndex, bool isEnabled) |
| | { |
| | LazyInitialize(); |
| | #if DEBUG |
| | if (branch >= m_NumBranches || actionIndex >= m_BranchSizes[CurrentBranchOffset + branch]) |
| | { |
| | throw new UnityAgentsException( |
| | "Invalid Action Masking: Action Mask is too large for specified branch."); |
| | } |
| | #endif |
| | m_CurrentMask[actionIndex + m_StartingActionIndices[CurrentBranchOffset + branch]] = !isEnabled; |
| | } |
| |
|
| | void LazyInitialize() |
| | { |
| | if (m_BranchSizes == null) |
| | { |
| | m_BranchSizes = new int[m_NumBranches]; |
| | var start = 0; |
| | for (var i = 0; i < m_Actuators.Count; i++) |
| | { |
| | var actuator = m_Actuators[i]; |
| | var branchSizes = actuator.ActionSpec.BranchSizes; |
| | Array.Copy(branchSizes, 0, m_BranchSizes, start, branchSizes.Length); |
| | start += branchSizes.Length; |
| | } |
| | } |
| |
|
| | |
| | |
| | if (m_CurrentMask == null) |
| | { |
| | m_CurrentMask = new bool[m_SumOfDiscreteBranchSizes]; |
| | } |
| |
|
| | |
| | |
| | if (m_StartingActionIndices == null) |
| | { |
| | m_StartingActionIndices = Utilities.CumSum(m_BranchSizes); |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | internal bool[] GetMask() |
| | { |
| | #if DEBUG |
| | if (m_CurrentMask != null) |
| | { |
| | AssertMask(); |
| | } |
| | #endif |
| | return m_CurrentMask; |
| | } |
| |
|
| | |
| | |
| | |
| | void AssertMask() |
| | { |
| | #if DEBUG |
| | for (var branchIndex = 0; branchIndex < m_NumBranches; branchIndex++) |
| | { |
| | if (AreAllActionsMasked(branchIndex)) |
| | { |
| | throw new UnityAgentsException( |
| | "Invalid Action Masking : All the actions of branch " + branchIndex + |
| | " are masked."); |
| | } |
| | } |
| | #endif |
| | } |
| |
|
| | |
| | |
| | |
| | internal void ResetMask() |
| | { |
| | if (m_CurrentMask != null) |
| | { |
| | Array.Clear(m_CurrentMask, 0, m_CurrentMask.Length); |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | bool AreAllActionsMasked(int branch) |
| | { |
| | if (m_CurrentMask == null) |
| | { |
| | return false; |
| | } |
| | var start = m_StartingActionIndices[branch]; |
| | var end = m_StartingActionIndices[branch + 1]; |
| | for (var i = start; i < end; i++) |
| | { |
| | if (!m_CurrentMask[i]) |
| | { |
| | return false; |
| | } |
| | } |
| | return true; |
| | } |
| | } |
| | } |
| |
|