| | using System; |
| | using System.Collections.Generic; |
| | using System.Linq; |
| | using Unity.Barracuda; |
| | using Unity.MLAgents.Actuators; |
| | using Unity.MLAgents.Sensors; |
| | using Unity.MLAgents.Policies; |
| |
|
| | namespace Unity.MLAgents.Inference |
| | { |
| | |
| | |
| | |
| | |
| | internal class BarracudaModelParamLoader |
| | { |
| | internal enum ModelApiVersion |
| | { |
| | |
| | |
| | |
| | |
| | |
| | |
| | MLAgents1_0 = 2, |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | MLAgents2_0 = 3, |
| | MinSupportedVersion = MLAgents1_0, |
| | MaxSupportedVersion = MLAgents2_0 |
| | } |
| |
|
| | internal class FailedCheck |
| | { |
| | public enum CheckTypeEnum |
| | { |
| | Info = 0, |
| | Warning = 1, |
| | Error = 2 |
| | } |
| | public CheckTypeEnum CheckType; |
| | public string Message; |
| | public static FailedCheck Info(string message) |
| | { |
| | return new FailedCheck { CheckType = CheckTypeEnum.Info, Message = message }; |
| | } |
| |
|
| | public static FailedCheck Warning(string message) |
| | { |
| | return new FailedCheck { CheckType = CheckTypeEnum.Warning, Message = message }; |
| | } |
| |
|
| | public static FailedCheck Error(string message) |
| | { |
| | return new FailedCheck { CheckType = CheckTypeEnum.Error, Message = message }; |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | public static FailedCheck CheckModelVersion(Model model) |
| | { |
| | var modelApiVersion = model.GetVersion(); |
| | if (modelApiVersion < (int)ModelApiVersion.MinSupportedVersion) |
| | { |
| | return FailedCheck.Error( |
| | "Model was trained with a older version of the trainer than is supported. " + |
| | "Either retrain with an newer trainer, or use an older version of com.unity.ml-agents.\n" + |
| | $"Model version: {modelApiVersion} Minimum supported version: {(int)ModelApiVersion.MinSupportedVersion}" |
| | ); |
| | } |
| |
|
| | if (modelApiVersion > (int)ModelApiVersion.MaxSupportedVersion) |
| | { |
| | return FailedCheck.Error( |
| | "Model was trained with a newer version of the trainer than is supported. " + |
| | "Either retrain with an older trainer, or update to a newer version of com.unity.ml-agents.\n" + |
| | $"Model version: {modelApiVersion} Maximum supported version: {(int)ModelApiVersion.MaxSupportedVersion}" |
| | ); |
| | } |
| |
|
| | var memorySize = (int)model.GetTensorByName(TensorNames.MemorySize)[0]; |
| |
|
| | if (modelApiVersion == (int)ModelApiVersion.MLAgents1_0 && memorySize > 0) |
| | { |
| | |
| | |
| | |
| | |
| | return FailedCheck.Error( |
| | "Models from com.unity.ml-agents 1.x that use recurrent neural networks are not supported in newer versions. " + |
| | "Either retrain with an newer trainer, or use an older version of com.unity.ml-agents.\n" |
| | ); |
| | } |
| | return null; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | public static IEnumerable<FailedCheck> CheckModel( |
| | Model model, |
| | BrainParameters brainParameters, |
| | ISensor[] sensors, |
| | ActuatorComponent[] actuatorComponents, |
| | int observableAttributeTotalSize = 0, |
| | BehaviorType behaviorType = BehaviorType.Default, |
| | bool deterministicInference = false |
| | ) |
| | { |
| | List<FailedCheck> failedModelChecks = new List<FailedCheck>(); |
| | if (model == null) |
| | { |
| | var errorMsg = "There is no model for this Brain; cannot run inference. "; |
| | if (behaviorType == BehaviorType.InferenceOnly) |
| | { |
| | errorMsg += "Either assign a model, or change to a different Behavior Type."; |
| | } |
| | else |
| | { |
| | errorMsg += "(But can still train)"; |
| | } |
| | failedModelChecks.Add(FailedCheck.Info(errorMsg)); |
| | return failedModelChecks; |
| | } |
| |
|
| | var hasExpectedTensors = model.CheckExpectedTensors(failedModelChecks, deterministicInference); |
| | if (!hasExpectedTensors) |
| | { |
| | return failedModelChecks; |
| | } |
| |
|
| | var modelApiVersion = model.GetVersion(); |
| | var versionCheck = CheckModelVersion(model); |
| | if (versionCheck != null) |
| | { |
| | failedModelChecks.Add(versionCheck); |
| | } |
| |
|
| | var memorySize = (int)model.GetTensorByName(TensorNames.MemorySize)[0]; |
| | if (memorySize == -1) |
| | { |
| | failedModelChecks.Add(FailedCheck.Warning($"Missing node in the model provided : {TensorNames.MemorySize}" |
| | )); |
| | return failedModelChecks; |
| | } |
| |
|
| | if (modelApiVersion == (int)ModelApiVersion.MLAgents1_0) |
| | { |
| | failedModelChecks.AddRange( |
| | CheckInputTensorPresenceLegacy(model, brainParameters, memorySize, sensors) |
| | ); |
| | failedModelChecks.AddRange( |
| | CheckInputTensorShapeLegacy(model, brainParameters, sensors, observableAttributeTotalSize) |
| | ); |
| | } |
| | else if (modelApiVersion == (int)ModelApiVersion.MLAgents2_0) |
| | { |
| | failedModelChecks.AddRange( |
| | CheckInputTensorPresence(model, brainParameters, memorySize, sensors, deterministicInference) |
| | ); |
| | failedModelChecks.AddRange( |
| | CheckInputTensorShape(model, brainParameters, sensors, observableAttributeTotalSize) |
| | ); |
| | } |
| |
|
| |
|
| | failedModelChecks.AddRange( |
| | CheckOutputTensorShape(model, brainParameters, actuatorComponents) |
| | ); |
| |
|
| | failedModelChecks.AddRange( |
| | CheckOutputTensorPresence(model, memorySize, deterministicInference) |
| | ); |
| | return failedModelChecks; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | static IEnumerable<FailedCheck> CheckInputTensorPresenceLegacy( |
| | Model model, |
| | BrainParameters brainParameters, |
| | int memory, |
| | ISensor[] sensors |
| | ) |
| | { |
| | var failedModelChecks = new List<FailedCheck>(); |
| | var tensorsNames = model.GetInputNames(); |
| |
|
| | |
| | if ((brainParameters.VectorObservationSize != 0) && |
| | (!tensorsNames.Contains(TensorNames.VectorObservationPlaceholder))) |
| | { |
| | failedModelChecks.Add( |
| | FailedCheck.Warning("The model does not contain a Vector Observation Placeholder Input. " + |
| | "You must set the Vector Observation Space Size to 0.") |
| | ); |
| | } |
| |
|
| | |
| | |
| | var visObsIndex = 0; |
| | for (var sensorIndex = 0; sensorIndex < sensors.Length; sensorIndex++) |
| | { |
| | var sensor = sensors[sensorIndex]; |
| | if (sensor.GetObservationSpec().Shape.Length == 3) |
| | { |
| | if (!tensorsNames.Contains( |
| | TensorNames.GetVisualObservationName(visObsIndex))) |
| | { |
| | failedModelChecks.Add( |
| | FailedCheck.Warning("The model does not contain a Visual Observation Placeholder Input " + |
| | $"for sensor component {visObsIndex} ({sensor.GetType().Name}).") |
| | ); |
| | } |
| | visObsIndex++; |
| | } |
| | if (sensor.GetObservationSpec().Shape.Length == 2) |
| | { |
| | if (!tensorsNames.Contains( |
| | TensorNames.GetObservationName(sensorIndex))) |
| | { |
| | failedModelChecks.Add( |
| | FailedCheck.Warning("The model does not contain an Observation Placeholder Input " + |
| | $"for sensor component {sensorIndex} ({sensor.GetType().Name}).") |
| | ); |
| | } |
| | } |
| | } |
| |
|
| | var expectedVisualObs = model.GetNumVisualInputs(); |
| | |
| | if (expectedVisualObs > visObsIndex) |
| | { |
| | failedModelChecks.Add( |
| | FailedCheck.Warning($"The model expects {expectedVisualObs} visual inputs," + |
| | $" but only found {visObsIndex} visual sensors.") |
| | ); |
| | } |
| |
|
| | |
| | if (memory > 0) |
| | { |
| | if (!tensorsNames.Any(x => x.EndsWith("_h")) || |
| | !tensorsNames.Any(x => x.EndsWith("_c"))) |
| | { |
| | failedModelChecks.Add( |
| | FailedCheck.Warning("The model does not contain a Recurrent Input Node but has memory_size.") |
| | ); |
| | } |
| | } |
| |
|
| | |
| | if (model.HasDiscreteOutputs()) |
| | { |
| | if (!tensorsNames.Contains(TensorNames.ActionMaskPlaceholder)) |
| | { |
| | failedModelChecks.Add( |
| | FailedCheck.Warning("The model does not contain an Action Mask but is using Discrete Control.") |
| | ); |
| | } |
| | } |
| | return failedModelChecks; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | static IEnumerable<FailedCheck> CheckInputTensorPresence( |
| | Model model, |
| | BrainParameters brainParameters, |
| | int memory, |
| | ISensor[] sensors, |
| | bool deterministicInference = false |
| | ) |
| | { |
| | var failedModelChecks = new List<FailedCheck>(); |
| | var tensorsNames = model.GetInputNames(); |
| | for (var sensorIndex = 0; sensorIndex < sensors.Length; sensorIndex++) |
| | { |
| | if (!tensorsNames.Contains( |
| | TensorNames.GetObservationName(sensorIndex))) |
| | { |
| | var sensor = sensors[sensorIndex]; |
| | failedModelChecks.Add( |
| | FailedCheck.Warning("The model does not contain an Observation Placeholder Input " + |
| | $"for sensor component {sensorIndex} ({sensor.GetType().Name}).") |
| | ); |
| | } |
| | } |
| |
|
| | |
| | if (memory > 0) |
| | { |
| | var modelVersion = model.GetVersion(); |
| | if (!tensorsNames.Any(x => x == TensorNames.RecurrentInPlaceholder)) |
| | { |
| | failedModelChecks.Add( |
| | FailedCheck.Warning("The model does not contain a Recurrent Input Node but has memory_size.") |
| | ); |
| | } |
| | } |
| |
|
| | |
| | if (model.HasDiscreteOutputs(deterministicInference)) |
| | { |
| | if (!tensorsNames.Contains(TensorNames.ActionMaskPlaceholder)) |
| | { |
| | failedModelChecks.Add( |
| | FailedCheck.Warning("The model does not contain an Action Mask but is using Discrete Control.") |
| | ); |
| | } |
| | } |
| | return failedModelChecks; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | static IEnumerable<FailedCheck> CheckOutputTensorPresence(Model model, int memory, bool deterministicInference = false) |
| | { |
| | var failedModelChecks = new List<FailedCheck>(); |
| |
|
| | |
| | if (memory > 0) |
| | { |
| | var allOutputs = model.GetOutputNames(deterministicInference).ToList(); |
| | if (!allOutputs.Any(x => x == TensorNames.RecurrentOutput)) |
| | { |
| | failedModelChecks.Add( |
| | FailedCheck.Warning("The model does not contain a Recurrent Output Node but has memory_size.") |
| | ); |
| | } |
| | } |
| | return failedModelChecks; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | static FailedCheck CheckVisualObsShape( |
| | TensorProxy tensorProxy, ISensor sensor) |
| | { |
| | var shape = sensor.GetObservationSpec().Shape; |
| | var heightBp = shape[0]; |
| | var widthBp = shape[1]; |
| | var pixelBp = shape[2]; |
| | var heightT = tensorProxy.Height; |
| | var widthT = tensorProxy.Width; |
| | var pixelT = tensorProxy.Channels; |
| | if ((widthBp != widthT) || (heightBp != heightT) || (pixelBp != pixelT)) |
| | { |
| | return FailedCheck.Warning($"The visual Observation of the model does not match. " + |
| | $"Received TensorProxy of shape [?x{widthBp}x{heightBp}x{pixelBp}] but " + |
| | $"was expecting [?x{widthT}x{heightT}x{pixelT}] for the {sensor.GetName()} Sensor." |
| | ); |
| | } |
| | return null; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | static FailedCheck CheckRankTwoObsShape( |
| | TensorProxy tensorProxy, ISensor sensor) |
| | { |
| | var shape = sensor.GetObservationSpec().Shape; |
| | var dim1Bp = shape[0]; |
| | var dim2Bp = shape[1]; |
| | var dim1T = tensorProxy.Channels; |
| | var dim2T = tensorProxy.Width; |
| | var dim3T = tensorProxy.Height; |
| | if ((dim1Bp != dim1T) || (dim2Bp != dim2T)) |
| | { |
| | var proxyDimStr = $"[?x{dim1T}x{dim2T}]"; |
| | if (dim3T > 1) |
| | { |
| | proxyDimStr = $"[?x{dim3T}x{dim2T}x{dim1T}]"; |
| | } |
| | return FailedCheck.Warning($"An Observation of the model does not match. " + |
| | $"Received TensorProxy of shape [?x{dim1Bp}x{dim2Bp}] but " + |
| | $"was expecting {proxyDimStr} for the {sensor.GetName()} Sensor." |
| | ); |
| | } |
| | return null; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | static FailedCheck CheckRankOneObsShape( |
| | TensorProxy tensorProxy, ISensor sensor) |
| | { |
| | var shape = sensor.GetObservationSpec().Shape; |
| | var dim1Bp = shape[0]; |
| | var dim1T = tensorProxy.Channels; |
| | var dim2T = tensorProxy.Width; |
| | var dim3T = tensorProxy.Height; |
| | if ((dim1Bp != dim1T)) |
| | { |
| | var proxyDimStr = $"[?x{dim1T}]"; |
| | if (dim2T > 1) |
| | { |
| | proxyDimStr = $"[?x{dim1T}x{dim2T}]"; |
| | } |
| | if (dim3T > 1) |
| | { |
| | proxyDimStr = $"[?x{dim3T}x{dim2T}x{dim1T}]"; |
| | } |
| | return FailedCheck.Warning($"An Observation of the model does not match. " + |
| | $"Received TensorProxy of shape [?x{dim1Bp}] but " + |
| | $"was expecting {proxyDimStr} for the {sensor.GetName()} Sensor." |
| | ); |
| | } |
| | return null; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | static IEnumerable<FailedCheck> CheckInputTensorShapeLegacy( |
| | Model model, BrainParameters brainParameters, ISensor[] sensors, |
| | int observableAttributeTotalSize) |
| | { |
| | var failedModelChecks = new List<FailedCheck>(); |
| | var tensorTester = |
| | new Dictionary<string, Func<BrainParameters, TensorProxy, ISensor[], int, FailedCheck>>() |
| | { |
| | {TensorNames.VectorObservationPlaceholder, CheckVectorObsShapeLegacy}, |
| | {TensorNames.PreviousActionPlaceholder, CheckPreviousActionShape}, |
| | {TensorNames.RandomNormalEpsilonPlaceholder, ((bp, tensor, scs, i) => null)}, |
| | {TensorNames.ActionMaskPlaceholder, ((bp, tensor, scs, i) => null)}, |
| | {TensorNames.SequenceLengthPlaceholder, ((bp, tensor, scs, i) => null)}, |
| | {TensorNames.RecurrentInPlaceholder, ((bp, tensor, scs, i) => null)}, |
| | }; |
| |
|
| | foreach (var mem in model.memories) |
| | { |
| | tensorTester[mem.input] = ((bp, tensor, scs, i) => null); |
| | } |
| |
|
| | var visObsIndex = 0; |
| | for (var sensorIndex = 0; sensorIndex < sensors.Length; sensorIndex++) |
| | { |
| | var sens = sensors[sensorIndex]; |
| | if (sens.GetObservationSpec().Shape.Length == 3) |
| | { |
| | tensorTester[TensorNames.GetVisualObservationName(visObsIndex)] = |
| | (bp, tensor, scs, i) => CheckVisualObsShape(tensor, sens); |
| | visObsIndex++; |
| | } |
| | if (sens.GetObservationSpec().Shape.Length == 2) |
| | { |
| | tensorTester[TensorNames.GetObservationName(sensorIndex)] = |
| | (bp, tensor, scs, i) => CheckRankTwoObsShape(tensor, sens); |
| | } |
| | } |
| |
|
| | |
| | foreach (var tensor in model.GetInputTensors()) |
| | { |
| | if (!tensorTester.ContainsKey(tensor.name)) |
| | { |
| | if (!tensor.name.Contains("visual_observation")) |
| | { |
| | failedModelChecks.Add( |
| | FailedCheck.Warning("Model contains an unexpected input named : " + tensor.name) |
| | ); |
| | } |
| | } |
| | else |
| | { |
| | var tester = tensorTester[tensor.name]; |
| | var error = tester.Invoke(brainParameters, tensor, sensors, observableAttributeTotalSize); |
| | if (error != null) |
| | { |
| | failedModelChecks.Add(error); |
| | } |
| | } |
| | } |
| | return failedModelChecks; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | static FailedCheck CheckVectorObsShapeLegacy( |
| | BrainParameters brainParameters, TensorProxy tensorProxy, ISensor[] sensors, |
| | int observableAttributeTotalSize) |
| | { |
| | var vecObsSizeBp = brainParameters.VectorObservationSize; |
| | var numStackedVector = brainParameters.NumStackedVectorObservations; |
| | var totalVecObsSizeT = tensorProxy.shape[tensorProxy.shape.Length - 1]; |
| |
|
| | var totalVectorSensorSize = 0; |
| | foreach (var sens in sensors) |
| | { |
| | if ((sens.GetObservationSpec().Shape.Length == 1)) |
| | { |
| | totalVectorSensorSize += sens.GetObservationSpec().Shape[0]; |
| | } |
| | } |
| |
|
| | if (totalVectorSensorSize != totalVecObsSizeT) |
| | { |
| | var sensorSizes = ""; |
| | foreach (var sensorComp in sensors) |
| | { |
| | if (sensorComp.GetObservationSpec().Shape.Length == 1) |
| | { |
| | var vecSize = sensorComp.GetObservationSpec().Shape[0]; |
| | if (sensorSizes.Length == 0) |
| | { |
| | sensorSizes = $"[{vecSize}"; |
| | } |
| | else |
| | { |
| | sensorSizes += $", {vecSize}"; |
| | } |
| | } |
| | } |
| |
|
| | sensorSizes += "]"; |
| | return FailedCheck.Warning( |
| | $"Vector Observation Size of the model does not match. Was expecting {totalVecObsSizeT} " + |
| | $"but received: \n" + |
| | $"Vector observations: {vecObsSizeBp} x {numStackedVector}\n" + |
| | $"Total [Observable] attributes: {observableAttributeTotalSize}\n" + |
| | $"Sensor sizes: {sensorSizes}." |
| | ); |
| | } |
| | return null; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | static IEnumerable<FailedCheck> CheckInputTensorShape( |
| | Model model, BrainParameters brainParameters, ISensor[] sensors, |
| | int observableAttributeTotalSize) |
| | { |
| | var failedModelChecks = new List<FailedCheck>(); |
| | var tensorTester = |
| | new Dictionary<string, Func<BrainParameters, TensorProxy, ISensor[], int, FailedCheck>>() |
| | { |
| | {TensorNames.PreviousActionPlaceholder, CheckPreviousActionShape}, |
| | {TensorNames.RandomNormalEpsilonPlaceholder, ((bp, tensor, scs, i) => null)}, |
| | {TensorNames.ActionMaskPlaceholder, ((bp, tensor, scs, i) => null)}, |
| | {TensorNames.SequenceLengthPlaceholder, ((bp, tensor, scs, i) => null)}, |
| | {TensorNames.RecurrentInPlaceholder, ((bp, tensor, scs, i) => null)}, |
| | }; |
| |
|
| | foreach (var mem in model.memories) |
| | { |
| | tensorTester[mem.input] = ((bp, tensor, scs, i) => null); |
| | } |
| |
|
| | for (var sensorIndex = 0; sensorIndex < sensors.Length; sensorIndex++) |
| | { |
| | var sens = sensors[sensorIndex]; |
| | if (sens.GetObservationSpec().Rank == 3) |
| | { |
| | tensorTester[TensorNames.GetObservationName(sensorIndex)] = |
| | (bp, tensor, scs, i) => CheckVisualObsShape(tensor, sens); |
| | } |
| | if (sens.GetObservationSpec().Rank == 2) |
| | { |
| | tensorTester[TensorNames.GetObservationName(sensorIndex)] = |
| | (bp, tensor, scs, i) => CheckRankTwoObsShape(tensor, sens); |
| | } |
| | if (sens.GetObservationSpec().Rank == 1) |
| | { |
| | tensorTester[TensorNames.GetObservationName(sensorIndex)] = |
| | (bp, tensor, scs, i) => CheckRankOneObsShape(tensor, sens); |
| | } |
| | } |
| |
|
| | |
| | foreach (var tensor in model.GetInputTensors()) |
| | { |
| | if (!tensorTester.ContainsKey(tensor.name)) |
| | { |
| | failedModelChecks.Add(FailedCheck.Warning("Model contains an unexpected input named : " + tensor.name |
| | )); |
| | } |
| | else |
| | { |
| | var tester = tensorTester[tensor.name]; |
| | var error = tester.Invoke(brainParameters, tensor, sensors, observableAttributeTotalSize); |
| | if (error != null) |
| | { |
| | failedModelChecks.Add(error); |
| | } |
| | } |
| | } |
| | return failedModelChecks; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | static FailedCheck CheckPreviousActionShape( |
| | BrainParameters brainParameters, TensorProxy tensorProxy, |
| | ISensor[] sensors, int observableAttributeTotalSize) |
| | { |
| | var numberActionsBp = brainParameters.ActionSpec.NumDiscreteActions; |
| | var numberActionsT = tensorProxy.shape[tensorProxy.shape.Length - 1]; |
| | if (numberActionsBp != numberActionsT) |
| | { |
| | return FailedCheck.Warning("Previous Action Size of the model does not match. " + |
| | $"Received {numberActionsBp} but was expecting {numberActionsT}." |
| | ); |
| | } |
| | return null; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | static IEnumerable<FailedCheck> CheckOutputTensorShape( |
| | Model model, |
| | BrainParameters brainParameters, |
| | ActuatorComponent[] actuatorComponents) |
| | { |
| | var failedModelChecks = new List<FailedCheck>(); |
| |
|
| | |
| | var modelContinuousActionSize = model.ContinuousOutputSize(); |
| | var continuousError = CheckContinuousActionOutputShape(brainParameters, actuatorComponents, modelContinuousActionSize); |
| | if (continuousError != null) |
| | { |
| | failedModelChecks.Add(continuousError); |
| | } |
| | FailedCheck discreteError = null; |
| | var modelApiVersion = model.GetVersion(); |
| | if (modelApiVersion == (int)ModelApiVersion.MLAgents1_0) |
| | { |
| | var modelSumDiscreteBranchSizes = model.DiscreteOutputSize(); |
| | discreteError = CheckDiscreteActionOutputShapeLegacy(brainParameters, actuatorComponents, modelSumDiscreteBranchSizes); |
| | } |
| | if (modelApiVersion == (int)ModelApiVersion.MLAgents2_0) |
| | { |
| | var modelDiscreteBranches = model.GetTensorByName(TensorNames.DiscreteActionOutputShape); |
| | discreteError = CheckDiscreteActionOutputShape(brainParameters, actuatorComponents, modelDiscreteBranches); |
| | } |
| |
|
| | if (discreteError != null) |
| | { |
| | failedModelChecks.Add(discreteError); |
| | } |
| | return failedModelChecks; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | static FailedCheck CheckDiscreteActionOutputShape( |
| | BrainParameters brainParameters, ActuatorComponent[] actuatorComponents, Tensor modelDiscreteBranches) |
| | { |
| | var discreteActionBranches = brainParameters.ActionSpec.BranchSizes.ToList(); |
| | foreach (var actuatorComponent in actuatorComponents) |
| | { |
| | var actionSpec = actuatorComponent.ActionSpec; |
| | discreteActionBranches.AddRange(actionSpec.BranchSizes); |
| | } |
| |
|
| | int modelDiscreteBranchesLength = modelDiscreteBranches?.length ?? 0; |
| | if (modelDiscreteBranchesLength != discreteActionBranches.Count) |
| | { |
| | return FailedCheck.Warning("Discrete Action Size of the model does not match. The BrainParameters expect " + |
| | $"{discreteActionBranches.Count} branches but the model contains {modelDiscreteBranchesLength}." |
| | ); |
| | } |
| |
|
| | for (int i = 0; i < modelDiscreteBranchesLength; i++) |
| | { |
| | if (modelDiscreteBranches != null && modelDiscreteBranches[i] != discreteActionBranches[i]) |
| | { |
| | return FailedCheck.Warning($"The number of Discrete Actions of branch {i} does not match. " + |
| | $"Was expecting {discreteActionBranches[i]} but the model contains {modelDiscreteBranches[i]} " |
| | ); |
| | } |
| | } |
| | return null; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | static FailedCheck CheckDiscreteActionOutputShapeLegacy( |
| | BrainParameters brainParameters, ActuatorComponent[] actuatorComponents, int modelSumDiscreteBranchSizes) |
| | { |
| | |
| | var sumOfDiscreteBranchSizes = brainParameters.ActionSpec.SumOfDiscreteBranchSizes; |
| |
|
| | foreach (var actuatorComponent in actuatorComponents) |
| | { |
| | var actionSpec = actuatorComponent.ActionSpec; |
| | sumOfDiscreteBranchSizes += actionSpec.SumOfDiscreteBranchSizes; |
| | } |
| |
|
| | if (modelSumDiscreteBranchSizes != sumOfDiscreteBranchSizes) |
| | { |
| | return FailedCheck.Warning("Discrete Action Size of the model does not match. The BrainParameters expect " + |
| | $"{sumOfDiscreteBranchSizes} but the model contains {modelSumDiscreteBranchSizes}." |
| | ); |
| | } |
| | return null; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | static FailedCheck CheckContinuousActionOutputShape( |
| | BrainParameters brainParameters, ActuatorComponent[] actuatorComponents, int modelContinuousActionSize) |
| | { |
| | var numContinuousActions = brainParameters.ActionSpec.NumContinuousActions; |
| |
|
| | foreach (var actuatorComponent in actuatorComponents) |
| | { |
| | var actionSpec = actuatorComponent.ActionSpec; |
| | numContinuousActions += actionSpec.NumContinuousActions; |
| | } |
| |
|
| | if (modelContinuousActionSize != numContinuousActions) |
| | { |
| | return FailedCheck.Warning( |
| | "Continuous Action Size of the model does not match. The BrainParameters and ActuatorComponents expect " + |
| | $"{numContinuousActions} but the model contains {modelContinuousActionSize}." |
| | ); |
| | } |
| | return null; |
| | } |
| | } |
| | } |
| |
|