| | using NUnit.Framework; |
| | using Unity.MLAgents.Inference.Utils; |
| |
|
| | namespace Unity.MLAgents.Tests |
| | { |
| | public class MultinomialTest |
| | { |
| | [Test] |
| | public void TestDim1() |
| | { |
| | var m = new Multinomial(2018); |
| | var cdf = new[] { 1f }; |
| |
|
| | Assert.AreEqual(0, m.Sample(cdf)); |
| | Assert.AreEqual(0, m.Sample(cdf)); |
| | Assert.AreEqual(0, m.Sample(cdf)); |
| | } |
| |
|
| | [Test] |
| | public void TestDim1Unscaled() |
| | { |
| | var m = new Multinomial(2018); |
| | var cdf = new[] { 0.1f }; |
| |
|
| | Assert.AreEqual(0, m.Sample(cdf)); |
| | Assert.AreEqual(0, m.Sample(cdf)); |
| | Assert.AreEqual(0, m.Sample(cdf)); |
| | } |
| |
|
| | [Test] |
| | public void TestDim3() |
| | { |
| | var m = new Multinomial(2018); |
| | var cdf = new[] { 0.1f, 0.3f, 1.0f }; |
| |
|
| | Assert.AreEqual(2, m.Sample(cdf)); |
| | Assert.AreEqual(2, m.Sample(cdf)); |
| | Assert.AreEqual(2, m.Sample(cdf)); |
| | Assert.AreEqual(1, m.Sample(cdf)); |
| | } |
| |
|
| | [Test] |
| | public void TestDim3Unscaled() |
| | { |
| | var m = new Multinomial(2018); |
| | var cdf = new[] { 0.05f, 0.15f, 0.5f }; |
| |
|
| | Assert.AreEqual(2, m.Sample(cdf)); |
| | Assert.AreEqual(2, m.Sample(cdf)); |
| | Assert.AreEqual(2, m.Sample(cdf)); |
| | Assert.AreEqual(1, m.Sample(cdf)); |
| | } |
| | } |
| | } |
| |
|