| | using System; |
| | using System.Collections.Generic; |
| | using System.Linq; |
| | using Google.Protobuf; |
| | using Unity.MLAgents.CommunicatorObjects; |
| | using UnityEngine; |
| | using System.Runtime.CompilerServices; |
| | using Unity.MLAgents.Actuators; |
| | using Unity.MLAgents.Sensors; |
| | using Unity.MLAgents.Demonstrations; |
| | using Unity.MLAgents.Policies; |
| |
|
| | using Unity.MLAgents.Analytics; |
| |
|
| | [assembly: InternalsVisibleTo("Unity.ML-Agents.Editor")] |
| | [assembly: InternalsVisibleTo("Unity.ML-Agents.Editor.Tests")] |
| | [assembly: InternalsVisibleTo("Unity.ML-Agents.Runtime.Utils.Tests")] |
| |
|
| | namespace Unity.MLAgents |
| | { |
| | internal static class GrpcExtensions |
| | { |
| | #region AgentInfo |
| | |
| | |
| | |
| | private static bool s_HaveWarnedTrainerCapabilitiesAgentGroup; |
| |
|
| | |
| | |
| | |
| | |
| | public static AgentInfoActionPairProto ToInfoActionPairProto(this AgentInfo ai) |
| | { |
| | var agentInfoProto = ai.ToAgentInfoProto(); |
| |
|
| | var agentActionProto = new AgentActionProto(); |
| |
|
| | if (!ai.storedActions.IsEmpty()) |
| | { |
| | if (!ai.storedActions.ContinuousActions.IsEmpty()) |
| | { |
| | agentActionProto.ContinuousActions.AddRange(ai.storedActions.ContinuousActions.Array); |
| | } |
| | if (!ai.storedActions.DiscreteActions.IsEmpty()) |
| | { |
| | agentActionProto.DiscreteActions.AddRange(ai.storedActions.DiscreteActions.Array); |
| | } |
| | } |
| |
|
| | return new AgentInfoActionPairProto |
| | { |
| | AgentInfo = agentInfoProto, |
| | ActionInfo = agentActionProto |
| | }; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | public static AgentInfoProto ToAgentInfoProto(this AgentInfo ai) |
| | { |
| | if (ai.groupId > 0) |
| | { |
| | var trainerCanHandle = Academy.Instance.TrainerCapabilities == null || Academy.Instance.TrainerCapabilities.MultiAgentGroups; |
| | if (!trainerCanHandle) |
| | { |
| | if (!s_HaveWarnedTrainerCapabilitiesAgentGroup) |
| | { |
| | Debug.LogWarning( |
| | $"Attached trainer doesn't support Multi Agent Groups; group rewards will be ignored." + |
| | "Please find the versions that work best together from our release page: " + |
| | "https://github.com/Unity-Technologies/ml-agents/releases" |
| | ); |
| | s_HaveWarnedTrainerCapabilitiesAgentGroup = true; |
| | } |
| | } |
| | } |
| | var agentInfoProto = new AgentInfoProto |
| | { |
| | Reward = ai.reward, |
| | GroupReward = ai.groupReward, |
| | MaxStepReached = ai.maxStepReached, |
| | Done = ai.done, |
| | Id = ai.episodeId, |
| | GroupId = ai.groupId, |
| | }; |
| |
|
| | if (ai.discreteActionMasks != null) |
| | { |
| | agentInfoProto.ActionMask.AddRange(ai.discreteActionMasks); |
| | } |
| |
|
| | return agentInfoProto; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | public static List<ObservationSummary> GetObservationSummaries(this AgentInfoActionPairProto infoActionPair) |
| | { |
| | List<ObservationSummary> summariesOut = new List<ObservationSummary>(); |
| | var agentInfo = infoActionPair.AgentInfo; |
| | foreach (var obs in agentInfo.Observations) |
| | { |
| | var summary = new ObservationSummary(); |
| | summary.shape = obs.Shape.ToArray(); |
| | summariesOut.Add(summary); |
| | } |
| |
|
| | return summariesOut; |
| | } |
| |
|
| | #endregion |
| |
|
| | #region BrainParameters |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | public static BrainParametersProto ToProto(this BrainParameters bp, string name, bool isTraining) |
| | { |
| | |
| | #pragma warning disable CS0618 |
| | var brainParametersProto = new BrainParametersProto |
| | { |
| | VectorActionSpaceTypeDeprecated = (SpaceTypeProto)bp.VectorActionSpaceType, |
| | BrainName = name, |
| | IsTraining = isTraining, |
| | ActionSpec = ToActionSpecProto(bp.ActionSpec), |
| | }; |
| | if (bp.VectorActionSize != null) |
| | { |
| | brainParametersProto.VectorActionSizeDeprecated.AddRange(bp.VectorActionSize); |
| | } |
| | if (bp.VectorActionDescriptions != null) |
| | { |
| | brainParametersProto.VectorActionDescriptionsDeprecated.AddRange(bp.VectorActionDescriptions); |
| | } |
| | #pragma warning restore CS0618 |
| | return brainParametersProto; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | public static BrainParametersProto ToBrainParametersProto(this ActionSpec actionSpec, string name, bool isTraining) |
| | { |
| | var brainParametersProto = new BrainParametersProto |
| | { |
| | BrainName = name, |
| | IsTraining = isTraining, |
| | ActionSpec = ToActionSpecProto(actionSpec), |
| | }; |
| |
|
| | var supportHybrid = Academy.Instance.TrainerCapabilities == null || Academy.Instance.TrainerCapabilities.HybridActions; |
| | if (!supportHybrid) |
| | { |
| | actionSpec.CheckAllContinuousOrDiscrete(); |
| | if (actionSpec.NumContinuousActions > 0) |
| | { |
| | brainParametersProto.VectorActionSizeDeprecated.Add(actionSpec.NumContinuousActions); |
| | brainParametersProto.VectorActionSpaceTypeDeprecated = SpaceTypeProto.Continuous; |
| | } |
| | else if (actionSpec.NumDiscreteActions > 0) |
| | { |
| | brainParametersProto.VectorActionSizeDeprecated.AddRange(actionSpec.BranchSizes); |
| | brainParametersProto.VectorActionSpaceTypeDeprecated = SpaceTypeProto.Discrete; |
| | } |
| | } |
| |
|
| | |
| | return brainParametersProto; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | public static BrainParameters ToBrainParameters(this BrainParametersProto bpp) |
| | { |
| | ActionSpec actionSpec; |
| | if (bpp.ActionSpec == null) |
| | { |
| | |
| | #pragma warning disable CS0618 |
| | var spaceType = (SpaceType)bpp.VectorActionSpaceTypeDeprecated; |
| | if (spaceType == SpaceType.Continuous) |
| | { |
| | actionSpec = ActionSpec.MakeContinuous(bpp.VectorActionSizeDeprecated.ToArray()[0]); |
| | } |
| | else |
| | { |
| | actionSpec = ActionSpec.MakeDiscrete(bpp.VectorActionSizeDeprecated.ToArray()); |
| | } |
| | #pragma warning restore CS0618 |
| | } |
| | else |
| | { |
| | actionSpec = ToActionSpec(bpp.ActionSpec); |
| | } |
| | var bp = new BrainParameters |
| | { |
| | VectorActionDescriptions = bpp.VectorActionDescriptionsDeprecated.ToArray(), |
| | ActionSpec = actionSpec, |
| | }; |
| | return bp; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | public static ActionSpec ToActionSpec(this ActionSpecProto actionSpecProto) |
| | { |
| | var actionSpec = new ActionSpec(actionSpecProto.NumContinuousActions); |
| | if (actionSpecProto.DiscreteBranchSizes != null) |
| | { |
| | actionSpec.BranchSizes = actionSpecProto.DiscreteBranchSizes.ToArray(); |
| | } |
| | return actionSpec; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | public static ActionSpecProto ToActionSpecProto(this ActionSpec actionSpec) |
| | { |
| | var actionSpecProto = new ActionSpecProto |
| | { |
| | NumContinuousActions = actionSpec.NumContinuousActions, |
| | NumDiscreteActions = actionSpec.NumDiscreteActions, |
| | }; |
| | if (actionSpec.BranchSizes != null) |
| | { |
| | actionSpecProto.DiscreteBranchSizes.AddRange(actionSpec.BranchSizes); |
| | } |
| | return actionSpecProto; |
| | } |
| |
|
| | #endregion |
| |
|
| | #region DemonstrationMetaData |
| | |
| | |
| | |
| | public static DemonstrationMetaProto ToProto(this DemonstrationMetaData dm) |
| | { |
| | var demonstrationName = dm.demonstrationName ?? ""; |
| | var demoProto = new DemonstrationMetaProto |
| | { |
| | ApiVersion = DemonstrationMetaData.ApiVersion, |
| | MeanReward = dm.meanReward, |
| | NumberSteps = dm.numberSteps, |
| | NumberEpisodes = dm.numberEpisodes, |
| | DemonstrationName = demonstrationName |
| | }; |
| | return demoProto; |
| | } |
| |
|
| | |
| | |
| | |
| | public static DemonstrationMetaData ToDemonstrationMetaData(this DemonstrationMetaProto demoProto) |
| | { |
| | var dm = new DemonstrationMetaData |
| | { |
| | numberEpisodes = demoProto.NumberEpisodes, |
| | numberSteps = demoProto.NumberSteps, |
| | meanReward = demoProto.MeanReward, |
| | demonstrationName = demoProto.DemonstrationName |
| | }; |
| | if (demoProto.ApiVersion != DemonstrationMetaData.ApiVersion) |
| | { |
| | throw new Exception("API versions of demonstration are incompatible."); |
| | } |
| | return dm; |
| | } |
| |
|
| | #endregion |
| |
|
| | public static UnityRLInitParameters ToUnityRLInitParameters(this UnityRLInitializationInputProto inputProto) |
| | { |
| | return new UnityRLInitParameters |
| | { |
| | seed = inputProto.Seed, |
| | numAreas = inputProto.NumAreas, |
| | pythonLibraryVersion = inputProto.PackageVersion, |
| | pythonCommunicationVersion = inputProto.CommunicationVersion, |
| | TrainerCapabilities = inputProto.Capabilities.ToRLCapabilities() |
| | }; |
| | } |
| |
|
| | #region AgentAction |
| | public static List<ActionBuffers> ToAgentActionList(this UnityRLInputProto.Types.ListAgentActionProto proto) |
| | { |
| | var agentActions = new List<ActionBuffers>(proto.Value.Count); |
| | foreach (var ap in proto.Value) |
| | { |
| | agentActions.Add(ap.ToActionBuffers()); |
| | } |
| | return agentActions; |
| | } |
| |
|
| | public static ActionBuffers ToActionBuffers(this AgentActionProto proto) |
| | { |
| | return new ActionBuffers(proto.ContinuousActions.ToArray(), proto.DiscreteActions.ToArray()); |
| | } |
| |
|
| | #endregion |
| |
|
| | #region Observations |
| | |
| | |
| | |
| | private static bool s_HaveWarnedTrainerCapabilitiesMultiPng; |
| | private static bool s_HaveWarnedTrainerCapabilitiesMapping; |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | public static ObservationProto GetObservationProto(this ISensor sensor, ObservationWriter observationWriter) |
| | { |
| | var obsSpec = sensor.GetObservationSpec(); |
| | var shape = obsSpec.Shape; |
| | ObservationProto observationProto = null; |
| | var compressionSpec = sensor.GetCompressionSpec(); |
| | var compressionType = compressionSpec.SensorCompressionType; |
| | |
| | if (compressionType == SensorCompressionType.PNG && shape.Length == 3 && shape[2] > 3) |
| | { |
| | var trainerCanHandle = Academy.Instance.TrainerCapabilities == null || Academy.Instance.TrainerCapabilities.ConcatenatedPngObservations; |
| | if (!trainerCanHandle) |
| | { |
| | if (!s_HaveWarnedTrainerCapabilitiesMultiPng) |
| | { |
| | Debug.LogWarning( |
| | $"Attached trainer doesn't support multiple PNGs. Switching to uncompressed observations for sensor {sensor.GetName()}. " + |
| | "Please find the versions that work best together from our release page: " + |
| | "https://github.com/Unity-Technologies/ml-agents/releases" |
| | ); |
| | s_HaveWarnedTrainerCapabilitiesMultiPng = true; |
| | } |
| | compressionType = SensorCompressionType.None; |
| | } |
| | } |
| | |
| | if (compressionType != SensorCompressionType.None && shape.Length == 3 && shape[2] > 3) |
| | { |
| | var trainerCanHandleMapping = Academy.Instance.TrainerCapabilities == null || Academy.Instance.TrainerCapabilities.CompressedChannelMapping; |
| | var isTrivialMapping = compressionSpec.IsTrivialMapping(); |
| | if (!trainerCanHandleMapping && !isTrivialMapping) |
| | { |
| | if (!s_HaveWarnedTrainerCapabilitiesMapping) |
| | { |
| | Debug.LogWarning( |
| | $"The sensor {sensor.GetName()} is using non-trivial mapping and " + |
| | "the attached trainer doesn't support compression mapping. " + |
| | "Switching to uncompressed observations. " + |
| | "Please find the versions that work best together from our release page: " + |
| | "https://github.com/Unity-Technologies/ml-agents/releases" |
| | ); |
| | s_HaveWarnedTrainerCapabilitiesMapping = true; |
| | } |
| | compressionType = SensorCompressionType.None; |
| | } |
| | } |
| |
|
| | if (compressionType == SensorCompressionType.None) |
| | { |
| | var numFloats = sensor.ObservationSize(); |
| | var floatDataProto = new ObservationProto.Types.FloatData(); |
| | |
| | |
| | for (var i = 0; i < numFloats; i++) |
| | { |
| | floatDataProto.Data.Add(0.0f); |
| | } |
| |
|
| | observationWriter.SetTarget(floatDataProto.Data, sensor.GetObservationSpec(), 0); |
| | sensor.Write(observationWriter); |
| |
|
| | observationProto = new ObservationProto |
| | { |
| | FloatData = floatDataProto, |
| | CompressionType = (CompressionTypeProto)SensorCompressionType.None, |
| | }; |
| | } |
| | else |
| | { |
| | var compressedObs = sensor.GetCompressedObservation(); |
| | if (compressedObs == null) |
| | { |
| | throw new UnityAgentsException( |
| | $"GetCompressedObservation() returned null data for sensor named {sensor.GetName()}. " + |
| | "You must return a byte[]. If you don't want to use compressed observations, " + |
| | "return CompressionSpec.Default() from GetCompressionSpec()." |
| | ); |
| | } |
| | observationProto = new ObservationProto |
| | { |
| | CompressedData = ByteString.CopyFrom(compressedObs), |
| | CompressionType = (CompressionTypeProto)sensor.GetCompressionSpec().SensorCompressionType, |
| | }; |
| | if (compressionSpec.CompressedChannelMapping != null) |
| | { |
| | observationProto.CompressedChannelMapping.AddRange(compressionSpec.CompressedChannelMapping); |
| | } |
| | } |
| |
|
| | |
| | var dimensionProperties = obsSpec.DimensionProperties; |
| | for (int i = 0; i < dimensionProperties.Length; i++) |
| | { |
| | observationProto.DimensionProperties.Add((int)dimensionProperties[i]); |
| | } |
| |
|
| | |
| | if (dimensionProperties == new InplaceArray<DimensionProperty>(DimensionProperty.VariableSize, DimensionProperty.None)) |
| | { |
| | var trainerCanHandleVarLenObs = Academy.Instance.TrainerCapabilities == null || Academy.Instance.TrainerCapabilities.VariableLengthObservation; |
| | if (!trainerCanHandleVarLenObs) |
| | { |
| | throw new UnityAgentsException("Variable Length Observations are not supported by the trainer"); |
| | } |
| | } |
| |
|
| | for (var i = 0; i < shape.Length; i++) |
| | { |
| | observationProto.Shape.Add(shape[i]); |
| | } |
| |
|
| | var sensorName = sensor.GetName(); |
| | if (!string.IsNullOrEmpty(sensorName)) |
| | { |
| | observationProto.Name = sensorName; |
| | } |
| |
|
| | observationProto.ObservationType = (ObservationTypeProto)obsSpec.ObservationType; |
| | return observationProto; |
| | } |
| |
|
| | #endregion |
| |
|
| | public static UnityRLCapabilities ToRLCapabilities(this UnityRLCapabilitiesProto proto) |
| | { |
| | return new UnityRLCapabilities |
| | { |
| | BaseRLCapabilities = proto.BaseRLCapabilities, |
| | ConcatenatedPngObservations = proto.ConcatenatedPngObservations, |
| | CompressedChannelMapping = proto.CompressedChannelMapping, |
| | HybridActions = proto.HybridActions, |
| | TrainingAnalytics = proto.TrainingAnalytics, |
| | VariableLengthObservation = proto.VariableLengthObservation, |
| | MultiAgentGroups = proto.MultiAgentGroups, |
| | }; |
| | } |
| |
|
| | public static UnityRLCapabilitiesProto ToProto(this UnityRLCapabilities rlCaps) |
| | { |
| | return new UnityRLCapabilitiesProto |
| | { |
| | BaseRLCapabilities = rlCaps.BaseRLCapabilities, |
| | ConcatenatedPngObservations = rlCaps.ConcatenatedPngObservations, |
| | CompressedChannelMapping = rlCaps.CompressedChannelMapping, |
| | HybridActions = rlCaps.HybridActions, |
| | TrainingAnalytics = rlCaps.TrainingAnalytics, |
| | VariableLengthObservation = rlCaps.VariableLengthObservation, |
| | MultiAgentGroups = rlCaps.MultiAgentGroups, |
| | }; |
| | } |
| |
|
| | #region Analytics |
| | internal static TrainingEnvironmentInitializedEvent ToTrainingEnvironmentInitializedEvent( |
| | this TrainingEnvironmentInitialized inputProto) |
| | { |
| | return new TrainingEnvironmentInitializedEvent |
| | { |
| | TrainerPythonVersion = inputProto.PythonVersion, |
| | MLAgentsVersion = inputProto.MlagentsVersion, |
| | MLAgentsEnvsVersion = inputProto.MlagentsEnvsVersion, |
| | TorchVersion = inputProto.TorchVersion, |
| | TorchDeviceType = inputProto.TorchDeviceType, |
| | NumEnvironments = inputProto.NumEnvs, |
| | NumEnvironmentParameters = inputProto.NumEnvironmentParameters, |
| | RunOptions = inputProto.RunOptions, |
| | }; |
| | } |
| |
|
| | internal static TrainingBehaviorInitializedEvent ToTrainingBehaviorInitializedEvent( |
| | this TrainingBehaviorInitialized inputProto) |
| | { |
| | RewardSignals rewardSignals = 0; |
| | rewardSignals |= inputProto.ExtrinsicRewardEnabled ? RewardSignals.Extrinsic : 0; |
| | rewardSignals |= inputProto.GailRewardEnabled ? RewardSignals.Gail : 0; |
| | rewardSignals |= inputProto.CuriosityRewardEnabled ? RewardSignals.Curiosity : 0; |
| | rewardSignals |= inputProto.RndRewardEnabled ? RewardSignals.Rnd : 0; |
| |
|
| | TrainingFeatures trainingFeatures = 0; |
| | trainingFeatures |= inputProto.BehavioralCloningEnabled ? TrainingFeatures.BehavioralCloning : 0; |
| | trainingFeatures |= inputProto.RecurrentEnabled ? TrainingFeatures.Recurrent : 0; |
| | trainingFeatures |= inputProto.TrainerThreaded ? TrainingFeatures.Threaded : 0; |
| | trainingFeatures |= inputProto.SelfPlayEnabled ? TrainingFeatures.SelfPlay : 0; |
| | trainingFeatures |= inputProto.CurriculumEnabled ? TrainingFeatures.Curriculum : 0; |
| |
|
| |
|
| | return new TrainingBehaviorInitializedEvent |
| | { |
| | BehaviorName = inputProto.BehaviorName, |
| | TrainerType = inputProto.TrainerType, |
| | RewardSignalFlags = rewardSignals, |
| | TrainingFeatureFlags = trainingFeatures, |
| | VisualEncoder = inputProto.VisualEncoder, |
| | NumNetworkLayers = inputProto.NumNetworkLayers, |
| | NumNetworkHiddenUnits = inputProto.NumNetworkHiddenUnits, |
| | Config = inputProto.Config, |
| | }; |
| | } |
| |
|
| | #endregion |
| | } |
| | } |
| |
|