| | using System.Collections.Generic; |
| | using Unity.Barracuda; |
| | using UnityEngine.Profiling; |
| | using Unity.MLAgents.Actuators; |
| | using Unity.MLAgents.Policies; |
| | using Unity.MLAgents.Sensors; |
| |
|
| | namespace Unity.MLAgents.Inference |
| | { |
| | internal struct AgentInfoSensorsPair |
| | { |
| | public AgentInfo agentInfo; |
| | public List<ISensor> sensors; |
| | } |
| |
|
| | internal class ModelRunner |
| | { |
| | List<AgentInfoSensorsPair> m_Infos = new List<AgentInfoSensorsPair>(); |
| | Dictionary<int, ActionBuffers> m_LastActionsReceived = new Dictionary<int, ActionBuffers>(); |
| | List<int> m_OrderedAgentsRequestingDecisions = new List<int>(); |
| |
|
| | ITensorAllocator m_TensorAllocator; |
| | TensorGenerator m_TensorGenerator; |
| | TensorApplier m_TensorApplier; |
| |
|
| | NNModel m_Model; |
| | string m_ModelName; |
| | InferenceDevice m_InferenceDevice; |
| | IWorker m_Engine; |
| | bool m_Verbose = false; |
| | bool m_DeterministicInference; |
| | string[] m_OutputNames; |
| | IReadOnlyList<TensorProxy> m_InferenceInputs; |
| | List<TensorProxy> m_InferenceOutputs; |
| | Dictionary<string, Tensor> m_InputsByName; |
| | Dictionary<int, List<float>> m_Memories = new Dictionary<int, List<float>>(); |
| |
|
| | SensorShapeValidator m_SensorShapeValidator = new SensorShapeValidator(); |
| |
|
| | bool m_ObservationsInitialized; |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | public ModelRunner( |
| | NNModel model, |
| | ActionSpec actionSpec, |
| | InferenceDevice inferenceDevice, |
| | int seed = 0, |
| | bool deterministicInference = false) |
| | { |
| | Model barracudaModel; |
| | m_Model = model; |
| | m_ModelName = model?.name; |
| | m_InferenceDevice = inferenceDevice; |
| | m_DeterministicInference = deterministicInference; |
| | m_TensorAllocator = new TensorCachingAllocator(); |
| | if (model != null) |
| | { |
| | #if BARRACUDA_VERBOSE |
| | m_Verbose = true; |
| | #endif |
| |
|
| | D.logEnabled = m_Verbose; |
| |
|
| | barracudaModel = ModelLoader.Load(model); |
| |
|
| | var failedCheck = BarracudaModelParamLoader.CheckModelVersion( |
| | barracudaModel |
| | ); |
| | if (failedCheck != null) |
| | { |
| | if (failedCheck.CheckType == BarracudaModelParamLoader.FailedCheck.CheckTypeEnum.Error) |
| | { |
| | throw new UnityAgentsException(failedCheck.Message); |
| | } |
| | } |
| |
|
| | WorkerFactory.Type executionDevice; |
| | switch (inferenceDevice) |
| | { |
| | case InferenceDevice.CPU: |
| | executionDevice = WorkerFactory.Type.CSharp; |
| | break; |
| | case InferenceDevice.GPU: |
| | executionDevice = WorkerFactory.Type.ComputePrecompiled; |
| | break; |
| | case InferenceDevice.Burst: |
| | executionDevice = WorkerFactory.Type.CSharpBurst; |
| | break; |
| | case InferenceDevice.Default: |
| | default: |
| | executionDevice = WorkerFactory.Type.CSharpBurst; |
| | break; |
| | } |
| | m_Engine = WorkerFactory.CreateWorker(executionDevice, barracudaModel, m_Verbose); |
| | } |
| | else |
| | { |
| | barracudaModel = null; |
| | m_Engine = null; |
| | } |
| |
|
| | m_InferenceInputs = barracudaModel.GetInputTensors(); |
| | m_OutputNames = barracudaModel.GetOutputNames(m_DeterministicInference); |
| |
|
| | m_TensorGenerator = new TensorGenerator( |
| | seed, m_TensorAllocator, m_Memories, barracudaModel, m_DeterministicInference); |
| | m_TensorApplier = new TensorApplier( |
| | actionSpec, seed, m_TensorAllocator, m_Memories, barracudaModel, m_DeterministicInference); |
| | m_InputsByName = new Dictionary<string, Tensor>(); |
| | m_InferenceOutputs = new List<TensorProxy>(); |
| | } |
| |
|
| | public InferenceDevice InferenceDevice |
| | { |
| | get { return m_InferenceDevice; } |
| | } |
| |
|
| | public NNModel Model |
| | { |
| | get { return m_Model; } |
| | } |
| |
|
| | void PrepareBarracudaInputs(IReadOnlyList<TensorProxy> infInputs) |
| | { |
| | m_InputsByName.Clear(); |
| | for (var i = 0; i < infInputs.Count; i++) |
| | { |
| | var inp = infInputs[i]; |
| | m_InputsByName[inp.name] = inp.data; |
| | } |
| | } |
| |
|
| | public void Dispose() |
| | { |
| | if (m_Engine != null) |
| | m_Engine.Dispose(); |
| | m_TensorAllocator?.Reset(false); |
| | } |
| |
|
| | void FetchBarracudaOutputs(string[] names) |
| | { |
| | m_InferenceOutputs.Clear(); |
| | foreach (var n in names) |
| | { |
| | var output = m_Engine.PeekOutput(n); |
| | m_InferenceOutputs.Add(TensorUtils.TensorProxyFromBarracuda(output, n)); |
| | } |
| | } |
| |
|
| | public void PutObservations(AgentInfo info, List<ISensor> sensors) |
| | { |
| | #if DEBUG |
| | m_SensorShapeValidator.ValidateSensors(sensors); |
| | #endif |
| | m_Infos.Add(new AgentInfoSensorsPair |
| | { |
| | agentInfo = info, |
| | sensors = sensors |
| | }); |
| |
|
| | |
| | m_OrderedAgentsRequestingDecisions.Add(info.episodeId); |
| |
|
| | if (!m_LastActionsReceived.ContainsKey(info.episodeId)) |
| | { |
| | m_LastActionsReceived[info.episodeId] = ActionBuffers.Empty; |
| | } |
| | if (info.done) |
| | { |
| | |
| | |
| | m_LastActionsReceived.Remove(info.episodeId); |
| | } |
| | } |
| |
|
| | public void DecideBatch() |
| | { |
| | var currentBatchSize = m_Infos.Count; |
| | if (currentBatchSize == 0) |
| | { |
| | return; |
| | } |
| | if (!m_ObservationsInitialized) |
| | { |
| | |
| | |
| | var firstInfo = m_Infos[0]; |
| | m_TensorGenerator.InitializeObservations(firstInfo.sensors, m_TensorAllocator); |
| | m_ObservationsInitialized = true; |
| | } |
| |
|
| | Profiler.BeginSample("ModelRunner.DecideAction"); |
| | Profiler.BeginSample(m_ModelName); |
| |
|
| | Profiler.BeginSample($"GenerateTensors"); |
| | |
| | m_TensorGenerator.GenerateTensors(m_InferenceInputs, currentBatchSize, m_Infos); |
| | Profiler.EndSample(); |
| |
|
| | Profiler.BeginSample($"PrepareBarracudaInputs"); |
| | PrepareBarracudaInputs(m_InferenceInputs); |
| | Profiler.EndSample(); |
| |
|
| | |
| | Profiler.BeginSample($"ExecuteGraph"); |
| | m_Engine.Execute(m_InputsByName); |
| | Profiler.EndSample(); |
| |
|
| | Profiler.BeginSample($"FetchBarracudaOutputs"); |
| | FetchBarracudaOutputs(m_OutputNames); |
| | Profiler.EndSample(); |
| |
|
| | Profiler.BeginSample($"ApplyTensors"); |
| | |
| | m_TensorApplier.ApplyTensors(m_InferenceOutputs, m_OrderedAgentsRequestingDecisions, m_LastActionsReceived); |
| | Profiler.EndSample(); |
| |
|
| | Profiler.EndSample(); |
| | Profiler.EndSample(); |
| |
|
| | m_Infos.Clear(); |
| |
|
| | m_OrderedAgentsRequestingDecisions.Clear(); |
| | } |
| |
|
| | public bool HasModel(NNModel other, InferenceDevice otherInferenceDevice) |
| | { |
| | return m_Model == other && m_InferenceDevice == otherInferenceDevice; |
| | } |
| |
|
| | public ActionBuffers GetAction(int agentId) |
| | { |
| | if (m_LastActionsReceived.ContainsKey(agentId)) |
| | { |
| | return m_LastActionsReceived[agentId]; |
| | } |
| | return ActionBuffers.Empty; |
| | } |
| | } |
| | } |
| |
|