| | using System; |
| | using System.Collections.Generic; |
| | using Unity.Barracuda; |
| | using Unity.MLAgents.Inference; |
| | using UnityEngine; |
| |
|
| | namespace Unity.MLAgents.Sensors |
| | { |
| | |
| | |
| | |
| | public class ObservationWriter |
| | { |
| | IList<float> m_Data; |
| | int m_Offset; |
| |
|
| | TensorProxy m_Proxy; |
| | int m_Batch; |
| |
|
| | TensorShape m_TensorShape; |
| |
|
| | public ObservationWriter() { } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | internal void SetTarget(IList<float> data, ObservationSpec observationSpec, int offset) |
| | { |
| | SetTarget(data, observationSpec.Shape, offset); |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | internal void SetTarget(IList<float> data, InplaceArray<int> shape, int offset) |
| | { |
| | m_Data = data; |
| | m_Offset = offset; |
| | m_Proxy = null; |
| | m_Batch = 0; |
| |
|
| | if (shape.Length == 1) |
| | { |
| | m_TensorShape = new TensorShape(m_Batch, shape[0]); |
| | } |
| | else if (shape.Length == 2) |
| | { |
| | m_TensorShape = new TensorShape(new[] { m_Batch, 1, shape[0], shape[1] }); |
| | } |
| | else |
| | { |
| | m_TensorShape = new TensorShape(m_Batch, shape[0], shape[1], shape[2]); |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | internal void SetTarget(TensorProxy tensorProxy, int batchIndex, int channelOffset) |
| | { |
| | m_Proxy = tensorProxy; |
| | m_Batch = batchIndex; |
| | m_Offset = channelOffset; |
| | m_Data = null; |
| | m_TensorShape = m_Proxy.data.shape; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | public float this[int index] |
| | { |
| | set |
| | { |
| | if (m_Data != null) |
| | { |
| | m_Data[index + m_Offset] = value; |
| | } |
| | else |
| | { |
| | m_Proxy.data[m_Batch, index + m_Offset] = value; |
| | } |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | public float this[int h, int w, int ch] |
| | { |
| | set |
| | { |
| | if (m_Data != null) |
| | { |
| | if (h < 0 || h >= m_TensorShape.height) |
| | { |
| | throw new IndexOutOfRangeException($"height value {h} must be in range [0, {m_TensorShape.height - 1}]"); |
| | } |
| | if (w < 0 || w >= m_TensorShape.width) |
| | { |
| | throw new IndexOutOfRangeException($"width value {w} must be in range [0, {m_TensorShape.width - 1}]"); |
| | } |
| | if (ch < 0 || ch >= m_TensorShape.channels) |
| | { |
| | throw new IndexOutOfRangeException($"channel value {ch} must be in range [0, {m_TensorShape.channels - 1}]"); |
| | } |
| |
|
| | var index = m_TensorShape.Index(m_Batch, h, w, ch + m_Offset); |
| | m_Data[index] = value; |
| | } |
| | else |
| | { |
| | m_Proxy.data[m_Batch, h, w, ch + m_Offset] = value; |
| | } |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | public void AddList(IList<float> data, int writeOffset = 0) |
| | { |
| | if (m_Data != null) |
| | { |
| | for (var index = 0; index < data.Count; index++) |
| | { |
| | var val = data[index]; |
| | m_Data[index + m_Offset + writeOffset] = val; |
| | } |
| | } |
| | else |
| | { |
| | for (var index = 0; index < data.Count; index++) |
| | { |
| | var val = data[index]; |
| | m_Proxy.data[m_Batch, index + m_Offset + writeOffset] = val; |
| | } |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | public void Add(Vector3 vec, int writeOffset = 0) |
| | { |
| | if (m_Data != null) |
| | { |
| | m_Data[m_Offset + writeOffset + 0] = vec.x; |
| | m_Data[m_Offset + writeOffset + 1] = vec.y; |
| | m_Data[m_Offset + writeOffset + 2] = vec.z; |
| | } |
| | else |
| | { |
| | m_Proxy.data[m_Batch, m_Offset + writeOffset + 0] = vec.x; |
| | m_Proxy.data[m_Batch, m_Offset + writeOffset + 1] = vec.y; |
| | m_Proxy.data[m_Batch, m_Offset + writeOffset + 2] = vec.z; |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | public void Add(Vector4 vec, int writeOffset = 0) |
| | { |
| | if (m_Data != null) |
| | { |
| | m_Data[m_Offset + writeOffset + 0] = vec.x; |
| | m_Data[m_Offset + writeOffset + 1] = vec.y; |
| | m_Data[m_Offset + writeOffset + 2] = vec.z; |
| | m_Data[m_Offset + writeOffset + 3] = vec.w; |
| | } |
| | else |
| | { |
| | m_Proxy.data[m_Batch, m_Offset + writeOffset + 0] = vec.x; |
| | m_Proxy.data[m_Batch, m_Offset + writeOffset + 1] = vec.y; |
| | m_Proxy.data[m_Batch, m_Offset + writeOffset + 2] = vec.z; |
| | m_Proxy.data[m_Batch, m_Offset + writeOffset + 3] = vec.w; |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | public void Add(Quaternion quat, int writeOffset = 0) |
| | { |
| | if (m_Data != null) |
| | { |
| | m_Data[m_Offset + writeOffset + 0] = quat.x; |
| | m_Data[m_Offset + writeOffset + 1] = quat.y; |
| | m_Data[m_Offset + writeOffset + 2] = quat.z; |
| | m_Data[m_Offset + writeOffset + 3] = quat.w; |
| | } |
| | else |
| | { |
| | m_Proxy.data[m_Batch, m_Offset + writeOffset + 0] = quat.x; |
| | m_Proxy.data[m_Batch, m_Offset + writeOffset + 1] = quat.y; |
| | m_Proxy.data[m_Batch, m_Offset + writeOffset + 2] = quat.z; |
| | m_Proxy.data[m_Batch, m_Offset + writeOffset + 3] = quat.w; |
| | } |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | public static class ObservationWriterExtension |
| | { |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | public static int WriteTexture( |
| | this ObservationWriter obsWriter, |
| | Texture2D texture, |
| | bool grayScale) |
| | { |
| | if (texture.format == TextureFormat.RGB24) |
| | { |
| | return obsWriter.WriteTextureRGB24(texture, grayScale); |
| | } |
| | var width = texture.width; |
| | var height = texture.height; |
| |
|
| | var texturePixels = texture.GetPixels32(); |
| | |
| | |
| | for (var h = height - 1; h >= 0; h--) |
| | { |
| | for (var w = 0; w < width; w++) |
| | { |
| | var currentPixel = texturePixels[(height - h - 1) * width + w]; |
| |
|
| | if (grayScale) |
| | { |
| | obsWriter[h, w, 0] = |
| | (currentPixel.r + currentPixel.g + currentPixel.b) / 3f / 255.0f; |
| | } |
| | else |
| | { |
| | |
| | obsWriter[h, w, 0] = currentPixel.r / 255.0f; |
| | obsWriter[h, w, 1] = currentPixel.g / 255.0f; |
| | obsWriter[h, w, 2] = currentPixel.b / 255.0f; |
| | } |
| | } |
| | } |
| |
|
| | return height * width * (grayScale ? 1 : 3); |
| | } |
| |
|
| | internal static int WriteTextureRGB24( |
| | this ObservationWriter obsWriter, |
| | Texture2D texture, |
| | bool grayScale |
| | ) |
| | { |
| | var width = texture.width; |
| | var height = texture.height; |
| |
|
| | var rawBytes = texture.GetRawTextureData<byte>(); |
| | |
| | |
| | for (var h = height - 1; h >= 0; h--) |
| | { |
| | for (var w = 0; w < width; w++) |
| | { |
| | var offset = (height - h - 1) * width + w; |
| | var r = rawBytes[3 * offset]; |
| | var g = rawBytes[3 * offset + 1]; |
| | var b = rawBytes[3 * offset + 2]; |
| |
|
| | if (grayScale) |
| | { |
| | obsWriter[h, w, 0] = (r + g + b) / 3f / 255.0f; |
| | } |
| | else |
| | { |
| | |
| | obsWriter[h, w, 0] = r / 255.0f; |
| | obsWriter[h, w, 1] = g / 255.0f; |
| | obsWriter[h, w, 2] = b / 255.0f; |
| | } |
| | } |
| | } |
| |
|
| | return height * width * (grayScale ? 1 : 3); |
| | } |
| | } |
| | } |
| |
|