ppo-Pyramids-Training / Project /Assets /ML-Agents /Examples /SharedAssets /Scripts /ModelOverrider.cs
| using System; | |
| using System.Collections.Generic; | |
| using UnityEngine; | |
| using Unity.Barracuda; | |
| using System.IO; | |
| using Unity.Barracuda.ONNX; | |
| using Unity.MLAgents; | |
| using Unity.MLAgents.Policies; | |
| using UnityEditor; | |
| namespace Unity.MLAgentsExamples | |
| { | |
| /// <summary> | |
| /// Utility class to allow the NNModel file for an agent to be overriden during inference. | |
| /// This is used internally to validate the file after training is done. | |
| /// The behavior name to override and file path are specified on the commandline, e.g. | |
| /// player.exe --mlagents-override-model-directory /path/to/models | |
| /// | |
| /// Additionally, a number of episodes to run can be specified; after this, the application will quit. | |
| /// Note this will only work with example scenes that have 1:1 Agent:Behaviors. More complicated scenes like WallJump | |
| /// probably won't override correctly. | |
| /// </summary> | |
| public class ModelOverrider : MonoBehaviour | |
| { | |
| HashSet<string> k_SupportedExtensions = new HashSet<string> { "nn", "onnx" }; | |
| const string k_CommandLineModelOverrideDirectoryFlag = "--mlagents-override-model-directory"; | |
| const string k_CommandLineModelOverrideExtensionFlag = "--mlagents-override-model-extension"; | |
| const string k_CommandLineQuitAfterEpisodesFlag = "--mlagents-quit-after-episodes"; | |
| const string k_CommandLineQuitAfterSeconds = "--mlagents-quit-after-seconds"; | |
| const string k_CommandLineQuitOnLoadFailure = "--mlagents-quit-on-load-failure"; | |
| // The attached Agent | |
| Agent m_Agent; | |
| // Whether or not the commandline args have already been processed. | |
| // Used to make sure that HasOverrides doesn't spam the logs if it's called multiple times. | |
| private bool m_HaveProcessedCommandLine; | |
| string m_BehaviorNameOverrideDirectory; | |
| private string m_OriginalBehaviorName; | |
| private List<string> m_OverrideExtensions = new List<string>(); | |
| // Cached loaded NNModels, with the behavior name as the key. | |
| Dictionary<string, NNModel> m_CachedModels = new Dictionary<string, NNModel>(); | |
| // Max episodes to run. Only used if > 0 | |
| // Will default to 1 if override models are specified, otherwise 0. | |
| int m_MaxEpisodes; | |
| // Deadline - exit if the time exceeds this | |
| DateTime m_Deadline = DateTime.MaxValue; | |
| int m_NumSteps; | |
| int m_PreviousNumSteps; | |
| int m_PreviousAgentCompletedEpisodes; | |
| bool m_QuitOnLoadFailure; | |
| [] | |
| public string debugCommandLineOverride; | |
| // Static values to keep track of completed episodes and steps across resets | |
| // These are updated in OnDisable. | |
| static int s_PreviousAgentCompletedEpisodes; | |
| static int s_PreviousNumSteps; | |
| int TotalCompletedEpisodes | |
| { | |
| get { return m_PreviousAgentCompletedEpisodes + (m_Agent == null ? 0 : m_Agent.CompletedEpisodes); } | |
| } | |
| int TotalNumSteps | |
| { | |
| get { return m_PreviousNumSteps + m_NumSteps; } | |
| } | |
| public bool HasOverrides | |
| { | |
| get | |
| { | |
| GetAssetPathFromCommandLine(); | |
| return !string.IsNullOrEmpty(m_BehaviorNameOverrideDirectory); | |
| } | |
| } | |
| /// <summary> | |
| /// The original behavior name of the agent. The actual behavior name will change when it is overridden. | |
| /// </summary> | |
| public string OriginalBehaviorName | |
| { | |
| get | |
| { | |
| if (string.IsNullOrEmpty(m_OriginalBehaviorName)) | |
| { | |
| var bp = m_Agent.GetComponent<BehaviorParameters>(); | |
| m_OriginalBehaviorName = bp.BehaviorName; | |
| } | |
| return m_OriginalBehaviorName; | |
| } | |
| } | |
| public static string GetOverrideBehaviorName(string originalBehaviorName) | |
| { | |
| return $"Override_{originalBehaviorName}"; | |
| } | |
| /// <summary> | |
| /// Get the asset path to use from the commandline arguments. | |
| /// Can be called multiple times - if m_HaveProcessedCommandLine is set, will have no effect. | |
| /// </summary> | |
| /// <returns></returns> | |
| void GetAssetPathFromCommandLine() | |
| { | |
| if (m_HaveProcessedCommandLine) | |
| { | |
| return; | |
| } | |
| var maxEpisodes = 0; | |
| var timeoutSeconds = 0; | |
| string[] commandLineArgsOverride = null; | |
| if (!string.IsNullOrEmpty(debugCommandLineOverride) && Application.isEditor) | |
| { | |
| commandLineArgsOverride = debugCommandLineOverride.Split(' '); | |
| } | |
| var args = commandLineArgsOverride ?? Environment.GetCommandLineArgs(); | |
| for (var i = 0; i < args.Length; i++) | |
| { | |
| if (args[i] == k_CommandLineModelOverrideDirectoryFlag && i < args.Length - 1) | |
| { | |
| m_BehaviorNameOverrideDirectory = args[i + 1].Trim(); | |
| } | |
| else if (args[i] == k_CommandLineModelOverrideExtensionFlag && i < args.Length - 1) | |
| { | |
| var overrideExtension = args[i + 1].Trim().ToLower(); | |
| var isKnownExtension = k_SupportedExtensions.Contains(overrideExtension); | |
| if (!isKnownExtension) | |
| { | |
| Debug.LogError($"loading unsupported format: {overrideExtension}"); | |
| Application.Quit(1); | |
| EditorApplication.isPlaying = false; | |
| } | |
| m_OverrideExtensions.Add(overrideExtension); | |
| } | |
| else if (args[i] == k_CommandLineQuitAfterEpisodesFlag && i < args.Length - 1) | |
| { | |
| Int32.TryParse(args[i + 1], out maxEpisodes); | |
| } | |
| else if (args[i] == k_CommandLineQuitAfterSeconds && i < args.Length - 1) | |
| { | |
| Int32.TryParse(args[i + 1], out timeoutSeconds); | |
| } | |
| else if (args[i] == k_CommandLineQuitOnLoadFailure) | |
| { | |
| m_QuitOnLoadFailure = true; | |
| } | |
| } | |
| if (!string.IsNullOrEmpty(m_BehaviorNameOverrideDirectory)) | |
| { | |
| // If overriding models, set maxEpisodes to 1 or the command line value | |
| m_MaxEpisodes = maxEpisodes > 0 ? maxEpisodes : 1; | |
| Debug.Log($"setting m_MaxEpisodes to {maxEpisodes}"); | |
| } | |
| if (timeoutSeconds > 0) | |
| { | |
| m_Deadline = DateTime.Now + TimeSpan.FromSeconds(timeoutSeconds); | |
| Debug.Log($"setting deadline to {timeoutSeconds} from now."); | |
| } | |
| m_HaveProcessedCommandLine = true; | |
| } | |
| void OnEnable() | |
| { | |
| // Start with these initialized to previous values in the case where we're resetting scenes. | |
| m_PreviousNumSteps = s_PreviousNumSteps; | |
| m_PreviousAgentCompletedEpisodes = s_PreviousAgentCompletedEpisodes; | |
| m_Agent = GetComponent<Agent>(); | |
| GetAssetPathFromCommandLine(); | |
| if (HasOverrides) | |
| { | |
| OverrideModel(); | |
| } | |
| } | |
| void OnDisable() | |
| { | |
| // Update the static episode and step counts. | |
| // For a single agent in the scene, this will be a straightforward increment. | |
| // If there are multiple agents, we'll increment the count by the Agent that completed the most episodes. | |
| s_PreviousAgentCompletedEpisodes = Mathf.Max(s_PreviousAgentCompletedEpisodes, TotalCompletedEpisodes); | |
| s_PreviousNumSteps = Mathf.Max(s_PreviousNumSteps, TotalNumSteps); | |
| } | |
| void FixedUpdate() | |
| { | |
| if (m_MaxEpisodes > 0) | |
| { | |
| // For Agents without maxSteps, exit as soon as we've hit the target number of episodes. | |
| // For Agents that specify MaxStep, also make sure we've gone at least that many steps. | |
| // Since we exit as soon as *any* Agent hits its target, the maxSteps condition keeps us running | |
| // a bit longer in case there's an early failure. | |
| if (TotalCompletedEpisodes >= m_MaxEpisodes && TotalNumSteps > m_MaxEpisodes * m_Agent.MaxStep) | |
| { | |
| Debug.Log($"ModelOverride reached {TotalCompletedEpisodes} episodes and {TotalNumSteps} steps. Exiting."); | |
| Application.Quit(0); | |
| EditorApplication.isPlaying = false; | |
| } | |
| else if (DateTime.Now >= m_Deadline) | |
| { | |
| Debug.Log( | |
| $"Deadline exceeded. " + | |
| $"{TotalCompletedEpisodes}/{m_MaxEpisodes} episodes and " + | |
| $"{TotalNumSteps}/{m_MaxEpisodes * m_Agent.MaxStep} steps completed. Exiting."); | |
| Application.Quit(0); | |
| EditorApplication.isPlaying = false; | |
| } | |
| } | |
| m_NumSteps++; | |
| } | |
| public NNModel GetModelForBehaviorName(string behaviorName) | |
| { | |
| if (m_CachedModels.ContainsKey(behaviorName)) | |
| { | |
| return m_CachedModels[behaviorName]; | |
| } | |
| if (string.IsNullOrEmpty(m_BehaviorNameOverrideDirectory)) | |
| { | |
| Debug.Log($"No override directory set."); | |
| return null; | |
| } | |
| // Try the override extensions in order. If they weren't set, try .nn first, then .onnx. | |
| var overrideExtensions = (m_OverrideExtensions.Count > 0) | |
| ? m_OverrideExtensions.ToArray() | |
| : new[] { "nn", "onnx" }; | |
| byte[] rawModel = null; | |
| bool isOnnx = false; | |
| string assetName = null; | |
| foreach (var overrideExtension in overrideExtensions) | |
| { | |
| var assetPath = Path.Combine(m_BehaviorNameOverrideDirectory, $"{behaviorName}.{overrideExtension}"); | |
| try | |
| { | |
| rawModel = File.ReadAllBytes(assetPath); | |
| isOnnx = overrideExtension.Equals("onnx"); | |
| assetName = "Override - " + Path.GetFileName(assetPath); | |
| break; | |
| } | |
| catch (IOException) | |
| { | |
| // Do nothing - try the next extension, or we'll exit if nothing loaded. | |
| } | |
| } | |
| if (rawModel == null) | |
| { | |
| Debug.Log($"Couldn't load model file(s) for {behaviorName} in {m_BehaviorNameOverrideDirectory} (full path: {Path.GetFullPath(m_BehaviorNameOverrideDirectory)}"); | |
| // Cache the null so we don't repeatedly try to load a missing file | |
| m_CachedModels[behaviorName] = null; | |
| return null; | |
| } | |
| var asset = isOnnx ? LoadOnnxModel(rawModel) : LoadBarracudaModel(rawModel); | |
| asset.name = assetName; | |
| m_CachedModels[behaviorName] = asset; | |
| return asset; | |
| } | |
| NNModel LoadBarracudaModel(byte[] rawModel) | |
| { | |
| var asset = ScriptableObject.CreateInstance<NNModel>(); | |
| asset.modelData = ScriptableObject.CreateInstance<NNModelData>(); | |
| asset.modelData.Value = rawModel; | |
| return asset; | |
| } | |
| NNModel LoadOnnxModel(byte[] rawModel) | |
| { | |
| var converter = new ONNXModelConverter(true); | |
| var onnxModel = converter.Convert(rawModel); | |
| NNModelData assetData = ScriptableObject.CreateInstance<NNModelData>(); | |
| using (var memoryStream = new MemoryStream()) | |
| using (var writer = new BinaryWriter(memoryStream)) | |
| { | |
| ModelWriter.Save(writer, onnxModel); | |
| assetData.Value = memoryStream.ToArray(); | |
| } | |
| assetData.name = "Data"; | |
| assetData.hideFlags = HideFlags.HideInHierarchy; | |
| var asset = ScriptableObject.CreateInstance<NNModel>(); | |
| asset.modelData = assetData; | |
| return asset; | |
| } | |
| /// <summary> | |
| /// Load the NNModel file from the specified path, and give it to the attached agent. | |
| /// </summary> | |
| void OverrideModel() | |
| { | |
| bool overrideOk = false; | |
| string overrideError = null; | |
| m_Agent.LazyInitialize(); | |
| NNModel nnModel = null; | |
| try | |
| { | |
| nnModel = GetModelForBehaviorName(OriginalBehaviorName); | |
| } | |
| catch (Exception e) | |
| { | |
| overrideError = $"Exception calling GetModelForBehaviorName: {e}"; | |
| } | |
| if (nnModel == null) | |
| { | |
| if (string.IsNullOrEmpty(overrideError)) | |
| { | |
| overrideError = | |
| $"Didn't find a model for behaviorName {OriginalBehaviorName}. Make " + | |
| "sure the behaviorName is set correctly in the commandline " + | |
| "and that the model file exists"; | |
| } | |
| } | |
| else | |
| { | |
| var modelName = nnModel != null ? nnModel.name : "<null>"; | |
| Debug.Log($"Overriding behavior {OriginalBehaviorName} for agent with model {modelName}"); | |
| try | |
| { | |
| m_Agent.SetModel(GetOverrideBehaviorName(OriginalBehaviorName), nnModel); | |
| overrideOk = true; | |
| } | |
| catch (Exception e) | |
| { | |
| overrideError = $"Exception calling Agent.SetModel: {e}"; | |
| } | |
| } | |
| if (!overrideOk && m_QuitOnLoadFailure) | |
| { | |
| if (!string.IsNullOrEmpty(overrideError)) | |
| { | |
| Debug.LogWarning(overrideError); | |
| } | |
| Application.Quit(1); | |
| EditorApplication.isPlaying = false; | |
| } | |
| } | |
| } | |
| } | |