| | using System; |
| | using System.Collections.Generic; |
| | using Unity.Barracuda; |
| | using Unity.MLAgents.Inference.Utils; |
| |
|
| | namespace Unity.MLAgents.Inference |
| | { |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | [Serializable] |
| | internal class TensorProxy |
| | { |
| | public enum TensorType |
| | { |
| | Integer, |
| | FloatingPoint |
| | }; |
| |
|
| | static readonly Dictionary<TensorType, Type> k_TypeMap = |
| | new Dictionary<TensorType, Type>() |
| | { |
| | {TensorType.FloatingPoint, typeof(float)}, |
| | {TensorType.Integer, typeof(int)} |
| | }; |
| |
|
| | public string name; |
| | public TensorType valueType; |
| |
|
| | |
| | public Type DataType => k_TypeMap[valueType]; |
| | public long[] shape; |
| | public Tensor data; |
| |
|
| | public long Height |
| | { |
| | get { return shape.Length == 4 ? shape[1] : shape[5]; } |
| | } |
| |
|
| | public long Width |
| | { |
| | get { return shape.Length == 4 ? shape[2] : shape[6]; } |
| | } |
| |
|
| | public long Channels |
| | { |
| | get { return shape.Length == 4 ? shape[3] : shape[7]; } |
| | } |
| | } |
| |
|
| | internal static class TensorUtils |
| | { |
| | public static void ResizeTensor(TensorProxy tensor, int batch, ITensorAllocator allocator) |
| | { |
| | if (tensor.shape[0] == batch && |
| | tensor.data != null && tensor.data.batch == batch) |
| | { |
| | return; |
| | } |
| |
|
| | tensor.data?.Dispose(); |
| | tensor.shape[0] = batch; |
| |
|
| | if (tensor.shape.Length == 4 || tensor.shape.Length == 8) |
| | { |
| | tensor.data = allocator.Alloc( |
| | new TensorShape( |
| | batch, |
| | (int)tensor.Height, |
| | (int)tensor.Width, |
| | (int)tensor.Channels)); |
| | } |
| | else |
| | { |
| | tensor.data = allocator.Alloc( |
| | new TensorShape( |
| | batch, |
| | (int)tensor.shape[tensor.shape.Length - 1])); |
| | } |
| | } |
| |
|
| | internal static long[] TensorShapeFromBarracuda(TensorShape src) |
| | { |
| | if (src.height == 1 && src.width == 1) |
| | { |
| | return new long[] { src.batch, src.channels }; |
| | } |
| |
|
| | return new long[] { src.batch, src.height, src.width, src.channels }; |
| | } |
| |
|
| | public static TensorProxy TensorProxyFromBarracuda(Tensor src, string nameOverride = null) |
| | { |
| | var shape = TensorShapeFromBarracuda(src.shape); |
| | return new TensorProxy |
| | { |
| | name = nameOverride ?? src.name, |
| | valueType = TensorProxy.TensorType.FloatingPoint, |
| | shape = shape, |
| | data = src |
| | }; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | public static void FillTensorBatch(TensorProxy tensorProxy, int batch, float fillValue) |
| | { |
| | var height = tensorProxy.data.height; |
| | var width = tensorProxy.data.width; |
| | var channels = tensorProxy.data.channels; |
| | for (var h = 0; h < height; h++) |
| | { |
| | for (var w = 0; w < width; w++) |
| | { |
| | for (var c = 0; c < channels; c++) |
| | { |
| | tensorProxy.data[batch, h, w, c] = fillValue; |
| | } |
| | } |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | public static void FillTensorWithRandomNormal( |
| | TensorProxy tensorProxy, RandomNormal randomNormal) |
| | { |
| | if (tensorProxy.DataType != typeof(float)) |
| | { |
| | throw new NotImplementedException("Only float data types are currently supported"); |
| | } |
| |
|
| | if (tensorProxy.data == null) |
| | { |
| | throw new ArgumentNullException(); |
| | } |
| |
|
| | for (var i = 0; i < tensorProxy.data.length; i++) |
| | { |
| | tensorProxy.data[i] = (float)randomNormal.NextDouble(); |
| | } |
| | } |
| | } |
| | } |
| |
|