| | using System.Collections.Generic; |
| | using System; |
| | using UnityEngine; |
| |
|
| | namespace Unity.MLAgents.SideChannels |
| | { |
| | |
| | |
| | |
| | internal enum EnvironmentDataTypes |
| | { |
| | Float = 0, |
| | Sampler = 1 |
| | } |
| |
|
| | |
| | |
| | |
| | internal enum SamplerType |
| | { |
| | |
| | |
| | |
| | Uniform = 0, |
| |
|
| | |
| | |
| | |
| | Gaussian = 1, |
| |
|
| | |
| | |
| | |
| | MultiRangeUniform = 2 |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | internal class EnvironmentParametersChannel : SideChannel |
| | { |
| | Dictionary<string, Func<float>> m_Parameters = new Dictionary<string, Func<float>>(); |
| | Dictionary<string, Action<float>> m_RegisteredActions = |
| | new Dictionary<string, Action<float>>(); |
| |
|
| | const string k_EnvParamsId = "534c891e-810f-11ea-a9d0-822485860400"; |
| |
|
| | |
| | |
| | |
| | |
| | internal EnvironmentParametersChannel() |
| | { |
| | ChannelId = new Guid(k_EnvParamsId); |
| | } |
| |
|
| | |
| | protected override void OnMessageReceived(IncomingMessage msg) |
| | { |
| | var key = msg.ReadString(); |
| | var type = msg.ReadInt32(); |
| | if ((int)EnvironmentDataTypes.Float == type) |
| | { |
| | var value = msg.ReadFloat32(); |
| |
|
| | m_Parameters[key] = () => value; |
| |
|
| | Action<float> action; |
| | m_RegisteredActions.TryGetValue(key, out action); |
| | action?.Invoke(value); |
| | } |
| | else if ((int)EnvironmentDataTypes.Sampler == type) |
| | { |
| | int seed = msg.ReadInt32(); |
| | int samplerType = msg.ReadInt32(); |
| | Func<float> sampler = () => 0.0f; |
| | if ((int)SamplerType.Uniform == samplerType) |
| | { |
| | float min = msg.ReadFloat32(); |
| | float max = msg.ReadFloat32(); |
| | sampler = SamplerFactory.CreateUniformSampler(min, max, seed); |
| | } |
| | else if ((int)SamplerType.Gaussian == samplerType) |
| | { |
| | float mean = msg.ReadFloat32(); |
| | float stddev = msg.ReadFloat32(); |
| |
|
| | sampler = SamplerFactory.CreateGaussianSampler(mean, stddev, seed); |
| | } |
| | else if ((int)SamplerType.MultiRangeUniform == samplerType) |
| | { |
| | IList<float> intervals = msg.ReadFloatList(); |
| | sampler = SamplerFactory.CreateMultiRangeUniformSampler(intervals, seed); |
| | } |
| | else |
| | { |
| | Debug.LogWarning("EnvironmentParametersChannel received an unknown data type."); |
| | } |
| | m_Parameters[key] = sampler; |
| | } |
| | else |
| | { |
| | Debug.LogWarning("EnvironmentParametersChannel received an unknown data type."); |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | public float GetWithDefault(string key, float defaultValue) |
| | { |
| | Func<float> valueOut; |
| | bool hasKey = m_Parameters.TryGetValue(key, out valueOut); |
| | return hasKey ? valueOut.Invoke() : defaultValue; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | public void RegisterCallback(string key, Action<float> action) |
| | { |
| | m_RegisteredActions[key] = action; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | public IList<string> ListParameters() |
| | { |
| | return new List<string>(m_Parameters.Keys); |
| | } |
| | } |
| | } |
| |
|