File size: 4,981 Bytes
05c9ac2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 | using System.Collections.Generic;
using System;
using UnityEngine;
namespace Unity.MLAgents.SideChannels
{
/// <summary>
/// Lists the different data types supported.
/// </summary>
internal enum EnvironmentDataTypes
{
Float = 0,
Sampler = 1
}
/// <summary>
/// The types of distributions from which to sample reset parameters.
/// </summary>
internal enum SamplerType
{
/// <summary>
/// Samples a reset parameter from a uniform distribution.
/// </summary>
Uniform = 0,
/// <summary>
/// Samples a reset parameter from a Gaussian distribution.
/// </summary>
Gaussian = 1,
/// <summary>
/// Samples a reset parameter from a MultiRangeUniform distribution.
/// </summary>
MultiRangeUniform = 2
}
/// <summary>
/// A side channel that manages the environment parameter values from Python. Currently
/// limited to parameters of type float.
/// </summary>
internal class EnvironmentParametersChannel : SideChannel
{
Dictionary<string, Func<float>> m_Parameters = new Dictionary<string, Func<float>>();
Dictionary<string, Action<float>> m_RegisteredActions =
new Dictionary<string, Action<float>>();
const string k_EnvParamsId = "534c891e-810f-11ea-a9d0-822485860400";
/// <summary>
/// Initializes the side channel. The constructor is internal because only one instance is
/// supported at a time, and is created by the Academy.
/// </summary>
internal EnvironmentParametersChannel()
{
ChannelId = new Guid(k_EnvParamsId);
}
/// <inheritdoc/>
protected override void OnMessageReceived(IncomingMessage msg)
{
var key = msg.ReadString();
var type = msg.ReadInt32();
if ((int)EnvironmentDataTypes.Float == type)
{
var value = msg.ReadFloat32();
m_Parameters[key] = () => value;
Action<float> action;
m_RegisteredActions.TryGetValue(key, out action);
action?.Invoke(value);
}
else if ((int)EnvironmentDataTypes.Sampler == type)
{
int seed = msg.ReadInt32();
int samplerType = msg.ReadInt32();
Func<float> sampler = () => 0.0f;
if ((int)SamplerType.Uniform == samplerType)
{
float min = msg.ReadFloat32();
float max = msg.ReadFloat32();
sampler = SamplerFactory.CreateUniformSampler(min, max, seed);
}
else if ((int)SamplerType.Gaussian == samplerType)
{
float mean = msg.ReadFloat32();
float stddev = msg.ReadFloat32();
sampler = SamplerFactory.CreateGaussianSampler(mean, stddev, seed);
}
else if ((int)SamplerType.MultiRangeUniform == samplerType)
{
IList<float> intervals = msg.ReadFloatList();
sampler = SamplerFactory.CreateMultiRangeUniformSampler(intervals, seed);
}
else
{
Debug.LogWarning("EnvironmentParametersChannel received an unknown data type.");
}
m_Parameters[key] = sampler;
}
else
{
Debug.LogWarning("EnvironmentParametersChannel received an unknown data type.");
}
}
/// <summary>
/// Returns the parameter value associated with the provided key. Returns the default
/// value if one doesn't exist.
/// </summary>
/// <param name="key">Parameter key.</param>
/// <param name="defaultValue">Default value to return.</param>
/// <returns></returns>
public float GetWithDefault(string key, float defaultValue)
{
Func<float> valueOut;
bool hasKey = m_Parameters.TryGetValue(key, out valueOut);
return hasKey ? valueOut.Invoke() : defaultValue;
}
/// <summary>
/// Registers a callback for the associated parameter key. Will overwrite any existing
/// actions for this parameter key.
/// </summary>
/// <param name="key">The parameter key.</param>
/// <param name="action">The callback.</param>
public void RegisterCallback(string key, Action<float> action)
{
m_RegisteredActions[key] = action;
}
/// <summary>
/// Returns all parameter keys that have a registered value.
/// </summary>
/// <returns></returns>
public IList<string> ListParameters()
{
return new List<string>(m_Parameters.Keys);
}
}
}
|