| | using System; |
| | using System.Collections.Generic; |
| | using System.Reflection; |
| | using UnityEngine; |
| |
|
| | namespace Unity.MLAgents.Sensors.Reflection |
| | { |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | [AttributeUsage(AttributeTargets.Field | AttributeTargets.Property)] |
| | public class ObservableAttribute : Attribute |
| | { |
| | string m_Name; |
| | int m_NumStackedObservations; |
| |
|
| | |
| | |
| | |
| | const BindingFlags k_BindingFlags = BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic; |
| |
|
| | |
| | |
| | |
| | static Dictionary<Type, (int, Type)> s_TypeToSensorInfo = new Dictionary<Type, (int, Type)>() |
| | { |
| | {typeof(int), (1, typeof(IntReflectionSensor))}, |
| | {typeof(bool), (1, typeof(BoolReflectionSensor))}, |
| | {typeof(float), (1, typeof(FloatReflectionSensor))}, |
| |
|
| | {typeof(Vector2), (2, typeof(Vector2ReflectionSensor))}, |
| | {typeof(Vector3), (3, typeof(Vector3ReflectionSensor))}, |
| | {typeof(Vector4), (4, typeof(Vector4ReflectionSensor))}, |
| | {typeof(Quaternion), (4, typeof(QuaternionReflectionSensor))}, |
| | }; |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | public ObservableAttribute(string name = null, int numStackedObservations = 1) |
| | { |
| | m_Name = name; |
| | m_NumStackedObservations = numStackedObservations; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | static IEnumerable<(FieldInfo, ObservableAttribute)> GetObservableFields(object o, bool excludeInherited) |
| | { |
| | |
| | var bindingFlags = k_BindingFlags | (excludeInherited ? BindingFlags.DeclaredOnly : 0); |
| | var fields = o.GetType().GetFields(bindingFlags); |
| | foreach (var field in fields) |
| | { |
| | var attr = (ObservableAttribute)GetCustomAttribute(field, typeof(ObservableAttribute)); |
| | if (attr != null) |
| | { |
| | yield return (field, attr); |
| | } |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | static IEnumerable<(PropertyInfo, ObservableAttribute)> GetObservableProperties(object o, bool excludeInherited) |
| | { |
| | var bindingFlags = k_BindingFlags | (excludeInherited ? BindingFlags.DeclaredOnly : 0); |
| | var properties = o.GetType().GetProperties(bindingFlags); |
| | foreach (var prop in properties) |
| | { |
| | var attr = (ObservableAttribute)GetCustomAttribute(prop, typeof(ObservableAttribute)); |
| | if (attr != null) |
| | { |
| | yield return (prop, attr); |
| | } |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | internal static List<ISensor> CreateObservableSensors(object o, bool excludeInherited) |
| | { |
| | var sensorsOut = new List<ISensor>(); |
| | foreach (var (field, attr) in GetObservableFields(o, excludeInherited)) |
| | { |
| | var sensor = CreateReflectionSensor(o, field, null, attr); |
| | if (sensor != null) |
| | { |
| | sensorsOut.Add(sensor); |
| | } |
| | } |
| |
|
| | foreach (var (prop, attr) in GetObservableProperties(o, excludeInherited)) |
| | { |
| | if (!prop.CanRead) |
| | { |
| | |
| | continue; |
| | } |
| | var sensor = CreateReflectionSensor(o, null, prop, attr); |
| | if (sensor != null) |
| | { |
| | sensorsOut.Add(sensor); |
| | } |
| | } |
| |
|
| | return sensorsOut; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | static ISensor CreateReflectionSensor(object o, FieldInfo fieldInfo, PropertyInfo propertyInfo, ObservableAttribute observableAttribute) |
| | { |
| | string memberName; |
| | string declaringTypeName; |
| | Type memberType; |
| | if (fieldInfo != null) |
| | { |
| | declaringTypeName = fieldInfo.DeclaringType.Name; |
| | memberName = fieldInfo.Name; |
| | memberType = fieldInfo.FieldType; |
| | } |
| | else |
| | { |
| | declaringTypeName = propertyInfo.DeclaringType.Name; |
| | memberName = propertyInfo.Name; |
| | memberType = propertyInfo.PropertyType; |
| | } |
| |
|
| | if (!s_TypeToSensorInfo.ContainsKey(memberType) && !memberType.IsEnum) |
| | { |
| | |
| | return null; |
| | } |
| |
|
| | string sensorName; |
| | if (string.IsNullOrEmpty(observableAttribute.m_Name)) |
| | { |
| | sensorName = $"ObservableAttribute:{declaringTypeName}.{memberName}"; |
| | } |
| | else |
| | { |
| | sensorName = observableAttribute.m_Name; |
| | } |
| |
|
| | var reflectionSensorInfo = new ReflectionSensorInfo |
| | { |
| | Object = o, |
| | FieldInfo = fieldInfo, |
| | PropertyInfo = propertyInfo, |
| | ObservableAttribute = observableAttribute, |
| | SensorName = sensorName |
| | }; |
| |
|
| | ISensor sensor = null; |
| | if (memberType.IsEnum) |
| | { |
| | sensor = new EnumReflectionSensor(reflectionSensorInfo); |
| | } |
| | else |
| | { |
| | var (_, sensorType) = s_TypeToSensorInfo[memberType]; |
| | sensor = (ISensor)Activator.CreateInstance(sensorType, reflectionSensorInfo); |
| | } |
| |
|
| | |
| | if (observableAttribute.m_NumStackedObservations > 1) |
| | { |
| | return new StackingSensor(sensor, observableAttribute.m_NumStackedObservations); |
| | } |
| |
|
| | return sensor; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | internal static int GetTotalObservationSize(object o, bool excludeInherited, List<string> errorsOut) |
| | { |
| | int sizeOut = 0; |
| | foreach (var (field, attr) in GetObservableFields(o, excludeInherited)) |
| | { |
| | if (s_TypeToSensorInfo.ContainsKey(field.FieldType)) |
| | { |
| | var (obsSize, _) = s_TypeToSensorInfo[field.FieldType]; |
| | sizeOut += obsSize * attr.m_NumStackedObservations; |
| | } |
| | else if (field.FieldType.IsEnum) |
| | { |
| | sizeOut += EnumReflectionSensor.GetEnumObservationSize(field.FieldType); |
| | } |
| | else |
| | { |
| | errorsOut.Add($"Unsupported Observable type {field.FieldType.Name} on field {field.Name}"); |
| | } |
| | } |
| |
|
| | foreach (var (prop, attr) in GetObservableProperties(o, excludeInherited)) |
| | { |
| | if (!prop.CanRead) |
| | { |
| | errorsOut.Add($"Observable property {prop.Name} is write-only."); |
| | } |
| | else if (s_TypeToSensorInfo.ContainsKey(prop.PropertyType)) |
| | { |
| | var (obsSize, _) = s_TypeToSensorInfo[prop.PropertyType]; |
| | sizeOut += obsSize * attr.m_NumStackedObservations; |
| | } |
| | else if (prop.PropertyType.IsEnum) |
| | { |
| | sizeOut += EnumReflectionSensor.GetEnumObservationSize(prop.PropertyType); |
| | } |
| | else |
| | { |
| | errorsOut.Add($"Unsupported Observable type {prop.PropertyType.Name} on property {prop.Name}"); |
| | } |
| | } |
| |
|
| | return sizeOut; |
| | } |
| | } |
| | } |
| |
|