| | using System; |
| | using System.Collections.Generic; |
| | using System.Linq; |
| | using Unity.Barracuda; |
| | using FailedCheck = Unity.MLAgents.Inference.BarracudaModelParamLoader.FailedCheck; |
| |
|
| | namespace Unity.MLAgents.Inference |
| | { |
| | |
| | |
| | |
| | internal static class BarracudaModelExtensions |
| | { |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | public static string[] GetInputNames(this Model model) |
| | { |
| | var names = new List<string>(); |
| |
|
| | if (model == null) |
| | return names.ToArray(); |
| |
|
| | foreach (var input in model.inputs) |
| | { |
| | names.Add(input.name); |
| | } |
| |
|
| | foreach (var mem in model.memories) |
| | { |
| | names.Add(mem.input); |
| | } |
| |
|
| | names.Sort(StringComparer.InvariantCulture); |
| |
|
| | return names.ToArray(); |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | public static int GetVersion(this Model model) |
| | { |
| | return (int)model.GetTensorByName(TensorNames.VersionNumber)[0]; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | public static IReadOnlyList<TensorProxy> GetInputTensors(this Model model) |
| | { |
| | var tensors = new List<TensorProxy>(); |
| |
|
| | if (model == null) |
| | return tensors; |
| |
|
| | foreach (var input in model.inputs) |
| | { |
| | tensors.Add(new TensorProxy |
| | { |
| | name = input.name, |
| | valueType = TensorProxy.TensorType.FloatingPoint, |
| | data = null, |
| | shape = input.shape.Select(i => (long)i).ToArray() |
| | }); |
| | } |
| |
|
| | tensors.Sort((el1, el2) => string.Compare(el1.name, el2.name, StringComparison.InvariantCulture)); |
| |
|
| | return tensors; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | public static int GetNumVisualInputs(this Model model) |
| | { |
| | var count = 0; |
| | if (model == null) |
| | return count; |
| |
|
| | foreach (var input in model.inputs) |
| | { |
| | if (input.name.StartsWith(TensorNames.VisualObservationPlaceholderPrefix)) |
| | { |
| | count++; |
| | } |
| | } |
| |
|
| | return count; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | public static string[] GetOutputNames(this Model model, bool deterministicInference = false) |
| | { |
| | var names = new List<string>(); |
| |
|
| | if (model == null) |
| | { |
| | return names.ToArray(); |
| | } |
| |
|
| | if (model.HasContinuousOutputs(deterministicInference)) |
| | { |
| | names.Add(model.ContinuousOutputName(deterministicInference)); |
| | } |
| | if (model.HasDiscreteOutputs(deterministicInference)) |
| | { |
| | names.Add(model.DiscreteOutputName(deterministicInference)); |
| | } |
| |
|
| | var modelVersion = model.GetVersion(); |
| | var memory = (int)model.GetTensorByName(TensorNames.MemorySize)[0]; |
| | if (memory > 0) |
| | { |
| | names.Add(TensorNames.RecurrentOutput); |
| | } |
| |
|
| | names.Sort(StringComparer.InvariantCulture); |
| |
|
| | return names.ToArray(); |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | public static bool HasContinuousOutputs(this Model model, bool deterministicInference = false) |
| | { |
| | if (model == null) |
| | return false; |
| | if (!model.SupportsContinuousAndDiscrete()) |
| | { |
| | return (int)model.GetTensorByName(TensorNames.IsContinuousControlDeprecated)[0] > 0; |
| | } |
| | else |
| | { |
| | bool hasStochasticOutput = !deterministicInference && |
| | model.outputs.Contains(TensorNames.ContinuousActionOutput); |
| | bool hasDeterministicOutput = deterministicInference && |
| | model.outputs.Contains(TensorNames.DeterministicContinuousActionOutput); |
| |
|
| | return (hasStochasticOutput || hasDeterministicOutput) && |
| | (int)model.GetTensorByName(TensorNames.ContinuousActionOutputShape)[0] > 0; |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | public static int ContinuousOutputSize(this Model model) |
| | { |
| | if (model == null) |
| | return 0; |
| | if (!model.SupportsContinuousAndDiscrete()) |
| | { |
| | return (int)model.GetTensorByName(TensorNames.IsContinuousControlDeprecated)[0] > 0 ? |
| | (int)model.GetTensorByName(TensorNames.ActionOutputShapeDeprecated)[0] : 0; |
| | } |
| | else |
| | { |
| | var continuousOutputShape = model.GetTensorByName(TensorNames.ContinuousActionOutputShape); |
| | return continuousOutputShape == null ? 0 : (int)continuousOutputShape[0]; |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | public static string ContinuousOutputName(this Model model, bool deterministicInference = false) |
| | { |
| | if (model == null) |
| | return null; |
| | if (!model.SupportsContinuousAndDiscrete()) |
| | { |
| | return TensorNames.ActionOutputDeprecated; |
| | } |
| | else |
| | { |
| | return deterministicInference ? TensorNames.DeterministicContinuousActionOutput : TensorNames.ContinuousActionOutput; |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | public static bool HasDiscreteOutputs(this Model model, bool deterministicInference = false) |
| | { |
| | if (model == null) |
| | return false; |
| | if (!model.SupportsContinuousAndDiscrete()) |
| | { |
| | return (int)model.GetTensorByName(TensorNames.IsContinuousControlDeprecated)[0] == 0; |
| | } |
| | else |
| | { |
| | bool hasStochasticOutput = !deterministicInference && |
| | model.outputs.Contains(TensorNames.DiscreteActionOutput); |
| | bool hasDeterministicOutput = deterministicInference && |
| | model.outputs.Contains(TensorNames.DeterministicDiscreteActionOutput); |
| | return (hasStochasticOutput || hasDeterministicOutput) && |
| | model.DiscreteOutputSize() > 0; |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | public static int DiscreteOutputSize(this Model model) |
| | { |
| | if (model == null) |
| | return 0; |
| | if (!model.SupportsContinuousAndDiscrete()) |
| | { |
| | return (int)model.GetTensorByName(TensorNames.IsContinuousControlDeprecated)[0] > 0 ? |
| | 0 : (int)model.GetTensorByName(TensorNames.ActionOutputShapeDeprecated)[0]; |
| | } |
| | else |
| | { |
| | var discreteOutputShape = model.GetTensorByName(TensorNames.DiscreteActionOutputShape); |
| | if (discreteOutputShape == null) |
| | { |
| | return 0; |
| | } |
| | else |
| | { |
| | int result = 0; |
| | for (int i = 0; i < discreteOutputShape.length; i++) |
| | { |
| | result += (int)discreteOutputShape[i]; |
| | } |
| | return result; |
| | } |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | public static string DiscreteOutputName(this Model model, bool deterministicInference = false) |
| | { |
| | if (model == null) |
| | return null; |
| | if (!model.SupportsContinuousAndDiscrete()) |
| | { |
| | return TensorNames.ActionOutputDeprecated; |
| | } |
| | else |
| | { |
| | return deterministicInference ? TensorNames.DeterministicDiscreteActionOutput : TensorNames.DiscreteActionOutput; |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | public static bool SupportsContinuousAndDiscrete(this Model model) |
| | { |
| | return model == null || |
| | model.outputs.Contains(TensorNames.ContinuousActionOutput) || |
| | model.outputs.Contains(TensorNames.DiscreteActionOutput); |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | public static bool CheckExpectedTensors(this Model model, List<FailedCheck> failedModelChecks, bool deterministicInference = false) |
| | { |
| | |
| | var modelApiVersionTensor = model.GetTensorByName(TensorNames.VersionNumber); |
| | if (modelApiVersionTensor == null) |
| | { |
| | failedModelChecks.Add( |
| | FailedCheck.Warning($"Required constant \"{TensorNames.VersionNumber}\" was not found in the model file.") |
| | ); |
| | return false; |
| | } |
| |
|
| | |
| | var memorySizeTensor = model.GetTensorByName(TensorNames.MemorySize); |
| | if (memorySizeTensor == null) |
| | { |
| | failedModelChecks.Add( |
| | FailedCheck.Warning($"Required constant \"{TensorNames.MemorySize}\" was not found in the model file.") |
| | ); |
| | return false; |
| | } |
| |
|
| | |
| | if (!model.outputs.Contains(TensorNames.ActionOutputDeprecated) && |
| | !model.outputs.Contains(TensorNames.ContinuousActionOutput) && |
| | !model.outputs.Contains(TensorNames.DiscreteActionOutput) && |
| | !model.outputs.Contains(TensorNames.DeterministicContinuousActionOutput) && |
| | !model.outputs.Contains(TensorNames.DeterministicDiscreteActionOutput)) |
| | { |
| | failedModelChecks.Add( |
| | FailedCheck.Warning("The model does not contain any Action Output Node.") |
| | ); |
| | return false; |
| | } |
| |
|
| | |
| | if (!model.SupportsContinuousAndDiscrete()) |
| | { |
| | if (model.GetTensorByName(TensorNames.ActionOutputShapeDeprecated) == null) |
| | { |
| | failedModelChecks.Add( |
| | FailedCheck.Warning("The model does not contain any Action Output Shape Node.") |
| | ); |
| | return false; |
| | } |
| | if (model.GetTensorByName(TensorNames.IsContinuousControlDeprecated) == null) |
| | { |
| | failedModelChecks.Add( |
| | FailedCheck.Warning($"Required constant \"{TensorNames.IsContinuousControlDeprecated}\" was " + |
| | "not found in the model file. " + |
| | "This is only required for model that uses a deprecated model format.") |
| | ); |
| | return false; |
| | } |
| | } |
| | else |
| | { |
| | if (model.outputs.Contains(TensorNames.ContinuousActionOutput)) |
| | { |
| | if (model.GetTensorByName(TensorNames.ContinuousActionOutputShape) == null) |
| | { |
| | failedModelChecks.Add( |
| | FailedCheck.Warning("The model uses continuous action but does not contain Continuous Action Output Shape Node.") |
| | ); |
| | return false; |
| | } |
| | else if (!model.HasContinuousOutputs(deterministicInference)) |
| | { |
| | var actionType = deterministicInference ? "deterministic" : "stochastic"; |
| | var actionName = deterministicInference ? "Deterministic" : ""; |
| | failedModelChecks.Add( |
| | FailedCheck.Warning($"The model uses {actionType} inference but does not contain {actionName} Continuous Action Output Tensor. Uncheck `Deterministic inference` flag..") |
| | ); |
| | return false; |
| | } |
| | } |
| |
|
| | if (model.outputs.Contains(TensorNames.DiscreteActionOutput)) |
| | { |
| | if (model.GetTensorByName(TensorNames.DiscreteActionOutputShape) == null) |
| | { |
| | failedModelChecks.Add( |
| | FailedCheck.Warning("The model uses discrete action but does not contain Discrete Action Output Shape Node.") |
| | ); |
| | return false; |
| | } |
| | else if (!model.HasDiscreteOutputs(deterministicInference)) |
| | { |
| | var actionType = deterministicInference ? "deterministic" : "stochastic"; |
| | var actionName = deterministicInference ? "Deterministic" : ""; |
| | failedModelChecks.Add( |
| | FailedCheck.Warning($"The model uses {actionType} inference but does not contain {actionName} Discrete Action Output Tensor. Uncheck `Deterministic inference` flag.") |
| | ); |
| | return false; |
| | } |
| | } |
| | } |
| | return true; |
| | } |
| | } |
| | } |
| |
|