| | using System; |
| | using System.Collections.Generic; |
| | using Unity.MLAgents.Inference.Utils; |
| | using Random = System.Random; |
| |
|
| | namespace Unity.MLAgents |
| | { |
| | |
| | |
| | |
| | internal static class SamplerFactory |
| | { |
| | public static Func<float> CreateUniformSampler(float min, float max, int seed) |
| | { |
| | Random distr = new Random(seed); |
| | return () => min + (float)distr.NextDouble() * (max - min); |
| | } |
| |
|
| | public static Func<float> CreateGaussianSampler(float mean, float stddev, int seed) |
| | { |
| | RandomNormal distr = new RandomNormal(seed, mean, stddev); |
| | return () => (float)distr.NextDouble(); |
| | } |
| |
|
| | public static Func<float> CreateMultiRangeUniformSampler(IList<float> intervals, int seed) |
| | { |
| | |
| | Random distr = new Random(seed); |
| | |
| | float sumIntervalSizes = 0; |
| | |
| | int numIntervals = (intervals.Count / 2); |
| | |
| | float[] intervalSizes = new float[numIntervals]; |
| | |
| | IList<Func<float>> intervalFuncs = new Func<float>[numIntervals]; |
| | |
| | |
| | for (int i = 0; i < numIntervals; i++) |
| | { |
| | var min = intervals[2 * i]; |
| | var max = intervals[2 * i + 1]; |
| | var intervalSize = max - min; |
| | sumIntervalSizes += intervalSize; |
| | intervalSizes[i] = intervalSize; |
| | intervalFuncs[i] = () => min + (float)distr.NextDouble() * intervalSize; |
| | } |
| | |
| | for (int i = 0; i < numIntervals; i++) |
| | { |
| | intervalSizes[i] = intervalSizes[i] / sumIntervalSizes; |
| | } |
| | |
| | for (int i = 1; i < numIntervals; i++) |
| | { |
| | intervalSizes[i] += intervalSizes[i - 1]; |
| | } |
| | Multinomial intervalDistr = new Multinomial(seed + 1); |
| | float MultiRange() |
| | { |
| | int sampledInterval = intervalDistr.Sample(intervalSizes); |
| | return intervalFuncs[sampledInterval].Invoke(); |
| | } |
| |
|
| | return MultiRange; |
| | } |
| | } |
| | } |
| |
|