| | using System; |
| | using System.Collections; |
| | using System.Collections.Generic; |
| | using UnityEngine; |
| | using UnityEngine.Profiling; |
| |
|
| | namespace Unity.MLAgents.Actuators |
| | { |
| | |
| | |
| | |
| | internal class ActuatorManager : IList<IActuator> |
| | { |
| | |
| | List<IActuator> m_Actuators; |
| |
|
| | |
| | ActuatorDiscreteActionMask m_DiscreteActionMask; |
| |
|
| | ActionSpec m_CombinedActionSpec; |
| |
|
| | |
| | |
| | |
| | |
| | bool m_ReadyForExecution; |
| |
|
| | |
| | |
| | |
| | internal int SumOfDiscreteBranchSizes { get; private set; } |
| |
|
| | |
| | |
| | |
| | internal int NumDiscreteActions { get; private set; } |
| |
|
| | |
| | |
| | |
| | internal int NumContinuousActions { get; private set; } |
| |
|
| | |
| | |
| | |
| | public int TotalNumberOfActions => NumContinuousActions + NumDiscreteActions; |
| |
|
| | |
| | |
| | |
| | public ActuatorDiscreteActionMask DiscreteActionMask => m_DiscreteActionMask; |
| |
|
| | |
| | |
| | |
| | public ActionBuffers StoredActions { get; private set; } |
| |
|
| | |
| | |
| | |
| | |
| | public ActuatorManager(int capacity = 0) |
| | { |
| | m_Actuators = new List<IActuator>(capacity); |
| | } |
| |
|
| | |
| | |
| | |
| | void ReadyActuatorsForExecution() |
| | { |
| | ReadyActuatorsForExecution(m_Actuators, NumContinuousActions, SumOfDiscreteBranchSizes, |
| | NumDiscreteActions); |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | internal void ReadyActuatorsForExecution(IList<IActuator> actuators, int numContinuousActions, int sumOfDiscreteBranches, int numDiscreteBranches) |
| | { |
| | if (m_ReadyForExecution) |
| | { |
| | return; |
| | } |
| | #if DEBUG |
| | |
| | ValidateActuators(); |
| | #endif |
| |
|
| | |
| | SortActuators(m_Actuators); |
| | var continuousActions = numContinuousActions == 0 ? ActionSegment<float>.Empty : |
| | new ActionSegment<float>(new float[numContinuousActions]); |
| | var discreteActions = numDiscreteBranches == 0 ? ActionSegment<int>.Empty : new ActionSegment<int>(new int[numDiscreteBranches]); |
| |
|
| | StoredActions = new ActionBuffers(continuousActions, discreteActions); |
| | m_CombinedActionSpec = CombineActionSpecs(actuators); |
| | m_DiscreteActionMask = new ActuatorDiscreteActionMask(actuators, sumOfDiscreteBranches, numDiscreteBranches, m_CombinedActionSpec.BranchSizes); |
| | m_ReadyForExecution = true; |
| | } |
| |
|
| | internal static ActionSpec CombineActionSpecs(IList<IActuator> actuators) |
| | { |
| | int numContinuousActions = 0; |
| | int numDiscreteActions = 0; |
| |
|
| | foreach (var actuator in actuators) |
| | { |
| | numContinuousActions += actuator.ActionSpec.NumContinuousActions; |
| | numDiscreteActions += actuator.ActionSpec.NumDiscreteActions; |
| | } |
| |
|
| | int[] combinedBranchSizes; |
| | if (numDiscreteActions == 0) |
| | { |
| | combinedBranchSizes = Array.Empty<int>(); |
| | } |
| | else |
| | { |
| | combinedBranchSizes = new int[numDiscreteActions]; |
| | var start = 0; |
| | for (var i = 0; i < actuators.Count; i++) |
| | { |
| | var branchSizes = actuators[i].ActionSpec.BranchSizes; |
| | if (branchSizes != null) |
| | { |
| | Array.Copy(branchSizes, 0, combinedBranchSizes, start, branchSizes.Length); |
| | start += branchSizes.Length; |
| | } |
| | } |
| | } |
| |
|
| | return new ActionSpec(numContinuousActions, combinedBranchSizes); |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | public ActionSpec GetCombinedActionSpec() |
| | { |
| | ReadyActuatorsForExecution(); |
| | return m_CombinedActionSpec; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | public void UpdateActions(ActionBuffers actions) |
| | { |
| | Profiler.BeginSample("ActuatorManager.UpdateActions"); |
| | ReadyActuatorsForExecution(); |
| | UpdateActionArray(actions.ContinuousActions, StoredActions.ContinuousActions); |
| | UpdateActionArray(actions.DiscreteActions, StoredActions.DiscreteActions); |
| | Profiler.EndSample(); |
| | } |
| |
|
| | static void UpdateActionArray<T>(ActionSegment<T> sourceActionBuffer, ActionSegment<T> destination) |
| | where T : struct |
| | { |
| | if (sourceActionBuffer.Length <= 0) |
| | { |
| | destination.Clear(); |
| | } |
| | else |
| | { |
| | if (sourceActionBuffer.Length != destination.Length) |
| | { |
| | Debug.AssertFormat(sourceActionBuffer.Length == destination.Length, |
| | "sourceActionBuffer: {0} is a different size than destination: {1}.", |
| | sourceActionBuffer.Length, |
| | destination.Length); |
| | } |
| |
|
| | Array.Copy(sourceActionBuffer.Array, |
| | sourceActionBuffer.Offset, |
| | destination.Array, |
| | destination.Offset, |
| | destination.Length); |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | public void WriteActionMask() |
| | { |
| | ReadyActuatorsForExecution(); |
| | m_DiscreteActionMask.ResetMask(); |
| | var offset = 0; |
| | for (var i = 0; i < m_Actuators.Count; i++) |
| | { |
| | var actuator = m_Actuators[i]; |
| | if (actuator.ActionSpec.NumDiscreteActions > 0) |
| | { |
| | m_DiscreteActionMask.CurrentBranchOffset = offset; |
| | actuator.WriteDiscreteActionMask(m_DiscreteActionMask); |
| | offset += actuator.ActionSpec.NumDiscreteActions; |
| | } |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | public void ApplyHeuristic(in ActionBuffers actionBuffersOut) |
| | { |
| | Profiler.BeginSample("ActuatorManager.ApplyHeuristic"); |
| | var continuousStart = 0; |
| | var discreteStart = 0; |
| | for (var i = 0; i < m_Actuators.Count; i++) |
| | { |
| | var actuator = m_Actuators[i]; |
| | var numContinuousActions = actuator.ActionSpec.NumContinuousActions; |
| | var numDiscreteActions = actuator.ActionSpec.NumDiscreteActions; |
| |
|
| | if (numContinuousActions == 0 && numDiscreteActions == 0) |
| | { |
| | continue; |
| | } |
| |
|
| | var continuousActions = ActionSegment<float>.Empty; |
| | if (numContinuousActions > 0) |
| | { |
| | continuousActions = new ActionSegment<float>(actionBuffersOut.ContinuousActions.Array, |
| | continuousStart, |
| | numContinuousActions); |
| | } |
| |
|
| | var discreteActions = ActionSegment<int>.Empty; |
| | if (numDiscreteActions > 0) |
| | { |
| | discreteActions = new ActionSegment<int>(actionBuffersOut.DiscreteActions.Array, |
| | discreteStart, |
| | numDiscreteActions); |
| | } |
| | actuator.Heuristic(new ActionBuffers(continuousActions, discreteActions)); |
| | continuousStart += numContinuousActions; |
| | discreteStart += numDiscreteActions; |
| | } |
| | Profiler.EndSample(); |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | public void ExecuteActions() |
| | { |
| | Profiler.BeginSample("ActuatorManager.ExecuteActions"); |
| | ReadyActuatorsForExecution(); |
| | var continuousStart = 0; |
| | var discreteStart = 0; |
| | for (var i = 0; i < m_Actuators.Count; i++) |
| | { |
| | var actuator = m_Actuators[i]; |
| | var numContinuousActions = actuator.ActionSpec.NumContinuousActions; |
| | var numDiscreteActions = actuator.ActionSpec.NumDiscreteActions; |
| |
|
| | if (numContinuousActions == 0 && numDiscreteActions == 0) |
| | { |
| | continue; |
| | } |
| |
|
| | var continuousActions = ActionSegment<float>.Empty; |
| | if (numContinuousActions > 0) |
| | { |
| | continuousActions = new ActionSegment<float>(StoredActions.ContinuousActions.Array, |
| | continuousStart, |
| | numContinuousActions); |
| | } |
| |
|
| | var discreteActions = ActionSegment<int>.Empty; |
| | if (numDiscreteActions > 0) |
| | { |
| | discreteActions = new ActionSegment<int>(StoredActions.DiscreteActions.Array, |
| | discreteStart, |
| | numDiscreteActions); |
| | } |
| |
|
| | actuator.OnActionReceived(new ActionBuffers(continuousActions, discreteActions)); |
| | continuousStart += numContinuousActions; |
| | discreteStart += numDiscreteActions; |
| | } |
| | Profiler.EndSample(); |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | public void ResetData() |
| | { |
| | if (!m_ReadyForExecution) |
| | { |
| | return; |
| | } |
| | StoredActions.Clear(); |
| | for (var i = 0; i < m_Actuators.Count; i++) |
| | { |
| | m_Actuators[i].ResetData(); |
| | } |
| | m_DiscreteActionMask.ResetMask(); |
| | } |
| |
|
| | |
| | |
| | |
| | internal static void SortActuators(List<IActuator> actuators) |
| | { |
| | actuators.Sort((x, y) => string.Compare(x.Name, y.Name, StringComparison.InvariantCulture)); |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | void ValidateActuators() |
| | { |
| | for (var i = 0; i < m_Actuators.Count - 1; i++) |
| | { |
| | Debug.Assert( |
| | !m_Actuators[i].Name.Equals(m_Actuators[i + 1].Name), |
| | "Actuator names must be unique."); |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | void AddToBufferSizes(IActuator actuatorItem) |
| | { |
| | if (actuatorItem == null) |
| | { |
| | return; |
| | } |
| |
|
| | NumContinuousActions += actuatorItem.ActionSpec.NumContinuousActions; |
| | NumDiscreteActions += actuatorItem.ActionSpec.NumDiscreteActions; |
| | SumOfDiscreteBranchSizes += actuatorItem.ActionSpec.SumOfDiscreteBranchSizes; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | void SubtractFromBufferSize(IActuator actuatorItem) |
| | { |
| | if (actuatorItem == null) |
| | { |
| | return; |
| | } |
| |
|
| | NumContinuousActions -= actuatorItem.ActionSpec.NumContinuousActions; |
| | NumDiscreteActions -= actuatorItem.ActionSpec.NumDiscreteActions; |
| | SumOfDiscreteBranchSizes -= actuatorItem.ActionSpec.SumOfDiscreteBranchSizes; |
| | } |
| |
|
| | |
| | |
| | |
| | void ClearBufferSizes() |
| | { |
| | NumContinuousActions = NumDiscreteActions = SumOfDiscreteBranchSizes = 0; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | public void AddActuators(IActuator[] actuators) |
| | { |
| | for (var i = 0; i < actuators.Length; i++) |
| | { |
| | Add(actuators[i]); |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | public IEnumerator<IActuator> GetEnumerator() |
| | { |
| | return m_Actuators.GetEnumerator(); |
| | } |
| |
|
| | |
| | IEnumerator IEnumerable.GetEnumerator() |
| | { |
| | return ((IEnumerable)m_Actuators).GetEnumerator(); |
| | } |
| |
|
| | |
| | public void Add(IActuator item) |
| | { |
| | Debug.Assert(m_ReadyForExecution == false, |
| | "Cannot add to the ActuatorManager after its buffers have been initialized"); |
| | m_Actuators.Add(item); |
| | AddToBufferSizes(item); |
| | } |
| |
|
| | |
| | public void Clear() |
| | { |
| | Debug.Assert(m_ReadyForExecution == false, |
| | "Cannot clear the ActuatorManager after its buffers have been initialized"); |
| | m_Actuators.Clear(); |
| | ClearBufferSizes(); |
| | } |
| |
|
| | |
| | public bool Contains(IActuator item) |
| | { |
| | return m_Actuators.Contains(item); |
| | } |
| |
|
| | |
| | public void CopyTo(IActuator[] array, int arrayIndex) |
| | { |
| | m_Actuators.CopyTo(array, arrayIndex); |
| | } |
| |
|
| | |
| | public bool Remove(IActuator item) |
| | { |
| | Debug.Assert(m_ReadyForExecution == false, |
| | "Cannot remove from the ActuatorManager after its buffers have been initialized"); |
| | if (m_Actuators.Remove(item)) |
| | { |
| | SubtractFromBufferSize(item); |
| | return true; |
| | } |
| | return false; |
| | } |
| |
|
| | |
| | public int Count => m_Actuators.Count; |
| |
|
| | |
| | public bool IsReadOnly => false; |
| |
|
| | |
| | public int IndexOf(IActuator item) |
| | { |
| | return m_Actuators.IndexOf(item); |
| | } |
| |
|
| | |
| | public void Insert(int index, IActuator item) |
| | { |
| | Debug.Assert(m_ReadyForExecution == false, |
| | "Cannot insert into the ActuatorManager after its buffers have been initialized"); |
| | m_Actuators.Insert(index, item); |
| | AddToBufferSizes(item); |
| | } |
| |
|
| | |
| | public void RemoveAt(int index) |
| | { |
| | Debug.Assert(m_ReadyForExecution == false, |
| | "Cannot remove from the ActuatorManager after its buffers have been initialized"); |
| | var actuator = m_Actuators[index]; |
| | SubtractFromBufferSize(actuator); |
| | m_Actuators.RemoveAt(index); |
| | } |
| |
|
| | |
| | public IActuator this[int index] |
| | { |
| | get => m_Actuators[index]; |
| | set |
| | { |
| | Debug.Assert(m_ReadyForExecution == false, |
| | "Cannot modify the ActuatorManager after its buffers have been initialized"); |
| | var old = m_Actuators[index]; |
| | SubtractFromBufferSize(old); |
| | m_Actuators[index] = value; |
| | AddToBufferSizes(value); |
| | } |
| | } |
| | } |
| | } |
| |
|