| | using System; |
| | using NUnit.Framework; |
| | using Unity.Barracuda; |
| | using Unity.MLAgents.Inference; |
| | using Unity.MLAgents.Inference.Utils; |
| |
|
| | namespace Unity.MLAgents.Tests |
| | { |
| | public class TensorUtilsTest |
| | { |
| | [TestCase(4, TestName = "TestResizeTensor_4D")] |
| | [TestCase(8, TestName = "TestResizeTensor_8D")] |
| | public void TestResizeTensor(int dimension) |
| | { |
| | var alloc = new TensorCachingAllocator(); |
| | var height = 64; |
| | var width = 84; |
| | var channels = 3; |
| |
|
| | |
| | |
| | var shape = new long[dimension]; |
| | for (var i = 0; i < dimension; i++) |
| | { |
| | shape[i] = 1; |
| | } |
| |
|
| | shape[dimension - 3] = height; |
| | shape[dimension - 2] = width; |
| | shape[dimension - 1] = channels; |
| |
|
| | var intShape = new int[dimension]; |
| | for (var i = 0; i < dimension; i++) |
| | { |
| | intShape[i] = (int)shape[i]; |
| | } |
| |
|
| | var tensorProxy = new TensorProxy |
| | { |
| | valueType = TensorProxy.TensorType.Integer, |
| | data = new Tensor(intShape), |
| | shape = shape, |
| | }; |
| |
|
| | |
| | Assert.AreEqual(height, tensorProxy.data.shape.height); |
| | Assert.AreEqual(width, tensorProxy.data.shape.width); |
| | Assert.AreEqual(channels, tensorProxy.data.shape.channels); |
| |
|
| | TensorUtils.ResizeTensor(tensorProxy, 42, alloc); |
| |
|
| | Assert.AreEqual(height, tensorProxy.shape[dimension - 3]); |
| | Assert.AreEqual(width, tensorProxy.shape[dimension - 2]); |
| | Assert.AreEqual(channels, tensorProxy.shape[dimension - 1]); |
| |
|
| | Assert.AreEqual(height, tensorProxy.data.shape.height); |
| | Assert.AreEqual(width, tensorProxy.data.shape.width); |
| | Assert.AreEqual(channels, tensorProxy.data.shape.channels); |
| |
|
| | alloc.Dispose(); |
| | } |
| |
|
| | [Test] |
| | public void RandomNormalTestTensorInt() |
| | { |
| | var rn = new RandomNormal(1982); |
| | var t = new TensorProxy |
| | { |
| | valueType = TensorProxy.TensorType.Integer |
| | }; |
| |
|
| | Assert.Throws<NotImplementedException>( |
| | () => TensorUtils.FillTensorWithRandomNormal(t, rn)); |
| | } |
| |
|
| | [Test] |
| | public void RandomNormalTestDataNull() |
| | { |
| | var rn = new RandomNormal(1982); |
| | var t = new TensorProxy |
| | { |
| | valueType = TensorProxy.TensorType.FloatingPoint |
| | }; |
| |
|
| | Assert.Throws<ArgumentNullException>( |
| | () => TensorUtils.FillTensorWithRandomNormal(t, rn)); |
| | } |
| |
|
| | [Test] |
| | public void RandomNormalTestTensor() |
| | { |
| | var rn = new RandomNormal(1982); |
| | var t = new TensorProxy |
| | { |
| | valueType = TensorProxy.TensorType.FloatingPoint, |
| | data = new Tensor(1, 3, 4, 2) |
| | }; |
| |
|
| | TensorUtils.FillTensorWithRandomNormal(t, rn); |
| |
|
| | var reference = new[] |
| | { |
| | -0.4315872f, |
| | -1.11074f, |
| | 0.3414804f, |
| | -1.130287f, |
| | 0.1413168f, |
| | -0.5105762f, |
| | -0.3027347f, |
| | -0.2645015f, |
| | 1.225356f, |
| | -0.02921959f, |
| | 0.3716498f, |
| | -1.092338f, |
| | 0.9561074f, |
| | -0.5018106f, |
| | 1.167787f, |
| | -0.7763879f, |
| | -0.07491868f, |
| | 0.5396146f, |
| | -0.1377991f, |
| | 0.3331701f, |
| | 0.06144788f, |
| | 0.9520947f, |
| | 1.088157f, |
| | -1.177194f, |
| | }; |
| |
|
| | for (var i = 0; i < t.data.length; i++) |
| | { |
| | Assert.AreEqual(t.data[i], reference[i], 0.0001); |
| | } |
| | } |
| | } |
| | } |
| |
|