| namespace Unity.MLAgents.Inference.Utils | |
| { | |
| /// <summary> | |
| /// Multinomial - Draws samples from a multinomial distribution given a (potentially unscaled) | |
| /// cumulative mass function (CMF). This means that the CMF need not "end" with probability | |
| /// mass of 1.0. For instance: [0.1, 0.2, 0.5] is a valid (unscaled). What is important is | |
| /// that it is a cumulative function, not a probability function. In other words, | |
| /// entry[i] = P(x \le i), NOT P(i - 1 \le x \lt i). | |
| /// (\le stands for less than or equal to while \lt is strictly less than). | |
| /// </summary> | |
| internal class Multinomial | |
| { | |
| readonly System.Random m_Random; | |
| /// <summary> | |
| /// Constructor. | |
| /// </summary> | |
| /// <param name="seed"> | |
| /// Seed for the random number generator used in the sampling process. | |
| /// </param> | |
| public Multinomial(int seed) | |
| { | |
| m_Random = new System.Random(seed); | |
| } | |
| /// <summary> | |
| /// Samples from the Multinomial distribution defined by the provided cumulative | |
| /// mass function. | |
| /// </summary> | |
| /// <param name="cmf"> | |
| /// Cumulative mass function, which may be unscaled. The entries in this array need | |
| /// to be monotonic (always increasing). If the CMF is scaled, then the last entry in | |
| /// the array will be 1.0. | |
| /// </param> | |
| /// <param name="branchSize">The number of possible branches, i.e. the effective size of the cmf array.</param> | |
| /// <returns>A sampled index from the CMF ranging from 0 to branchSize-1.</returns> | |
| public int Sample(float[] cmf, int branchSize) | |
| { | |
| var p = (float)m_Random.NextDouble() * cmf[branchSize - 1]; | |
| var cls = 0; | |
| while (cmf[cls] < p) | |
| { | |
| ++cls; | |
| } | |
| return cls; | |
| } | |
| /// <summary> | |
| /// Samples from the Multinomial distribution defined by the provided cumulative | |
| /// mass function. | |
| /// </summary> | |
| /// <returns>A sampled index from the CMF ranging from 0 to cmf.Length-1.</returns> | |
| public int Sample(float[] cmf) | |
| { | |
| return Sample(cmf, cmf.Length); | |
| } | |
| } | |
| } | |