| | using System.Collections.Generic; |
| | using Unity.Barracuda; |
| | using Unity.MLAgents.Sensors; |
| |
|
| | namespace Unity.MLAgents.Inference |
| | { |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | internal class TensorGenerator |
| | { |
| | public interface IGenerator |
| | { |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | void Generate( |
| | TensorProxy tensorProxy, int batchSize, IList<AgentInfoSensorsPair> infos); |
| | } |
| |
|
| | readonly Dictionary<string, IGenerator> m_Dict = new Dictionary<string, IGenerator>(); |
| | int m_ApiVersion; |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | public TensorGenerator( |
| | int seed, |
| | ITensorAllocator allocator, |
| | Dictionary<int, List<float>> memories, |
| | object barracudaModel = null, |
| | bool deterministicInference = false) |
| | { |
| | |
| | if (barracudaModel == null) |
| | { |
| | return; |
| | } |
| | var model = (Model)barracudaModel; |
| |
|
| | m_ApiVersion = model.GetVersion(); |
| |
|
| | |
| | m_Dict[TensorNames.BatchSizePlaceholder] = |
| | new BatchSizeGenerator(allocator); |
| | m_Dict[TensorNames.SequenceLengthPlaceholder] = |
| | new SequenceLengthGenerator(allocator); |
| | m_Dict[TensorNames.RecurrentInPlaceholder] = |
| | new RecurrentInputGenerator(allocator, memories); |
| |
|
| | m_Dict[TensorNames.PreviousActionPlaceholder] = |
| | new PreviousActionInputGenerator(allocator); |
| | m_Dict[TensorNames.ActionMaskPlaceholder] = |
| | new ActionMaskInputGenerator(allocator); |
| | m_Dict[TensorNames.RandomNormalEpsilonPlaceholder] = |
| | new RandomNormalInputGenerator(seed, allocator); |
| |
|
| |
|
| | |
| | if (model.HasContinuousOutputs(deterministicInference)) |
| | { |
| | m_Dict[model.ContinuousOutputName(deterministicInference)] = new BiDimensionalOutputGenerator(allocator); |
| | } |
| | if (model.HasDiscreteOutputs(deterministicInference)) |
| | { |
| | m_Dict[model.DiscreteOutputName(deterministicInference)] = new BiDimensionalOutputGenerator(allocator); |
| | } |
| | m_Dict[TensorNames.RecurrentOutput] = new BiDimensionalOutputGenerator(allocator); |
| | m_Dict[TensorNames.ValueEstimateOutput] = new BiDimensionalOutputGenerator(allocator); |
| | } |
| |
|
| | public void InitializeObservations(List<ISensor> sensors, ITensorAllocator allocator) |
| | { |
| | if (m_ApiVersion == (int)BarracudaModelParamLoader.ModelApiVersion.MLAgents1_0) |
| | { |
| | |
| | |
| | |
| | var visIndex = 0; |
| | ObservationGenerator vecObsGen = null; |
| | for (var sensorIndex = 0; sensorIndex < sensors.Count; sensorIndex++) |
| | { |
| | var sensor = sensors[sensorIndex]; |
| | var rank = sensor.GetObservationSpec().Rank; |
| | ObservationGenerator obsGen = null; |
| | string obsGenName = null; |
| | switch (rank) |
| | { |
| | case 1: |
| | if (vecObsGen == null) |
| | { |
| | vecObsGen = new ObservationGenerator(allocator); |
| | } |
| | obsGen = vecObsGen; |
| | obsGenName = TensorNames.VectorObservationPlaceholder; |
| | break; |
| | case 2: |
| | |
| | |
| | obsGen = new ObservationGenerator(allocator); |
| | obsGenName = TensorNames.GetObservationName(sensorIndex); |
| | break; |
| | case 3: |
| | |
| | |
| | obsGen = new ObservationGenerator(allocator); |
| | obsGenName = TensorNames.GetVisualObservationName(visIndex); |
| | visIndex++; |
| | break; |
| | default: |
| | throw new UnityAgentsException( |
| | $"Sensor {sensor.GetName()} have an invalid rank {rank}"); |
| | } |
| | obsGen.AddSensorIndex(sensorIndex); |
| | m_Dict[obsGenName] = obsGen; |
| | } |
| | } |
| |
|
| | if (m_ApiVersion == (int)BarracudaModelParamLoader.ModelApiVersion.MLAgents2_0) |
| | { |
| | for (var sensorIndex = 0; sensorIndex < sensors.Count; sensorIndex++) |
| | { |
| | var obsGen = new ObservationGenerator(allocator); |
| | var obsGenName = TensorNames.GetObservationName(sensorIndex); |
| | obsGen.AddSensorIndex(sensorIndex); |
| | m_Dict[obsGenName] = obsGen; |
| | } |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | public void GenerateTensors( |
| | IReadOnlyList<TensorProxy> tensors, int currentBatchSize, IList<AgentInfoSensorsPair> infos) |
| | { |
| | for (var tensorIndex = 0; tensorIndex < tensors.Count; tensorIndex++) |
| | { |
| | var tensor = tensors[tensorIndex]; |
| | if (!m_Dict.ContainsKey(tensor.name)) |
| | { |
| | throw new UnityAgentsException( |
| | $"Unknown tensorProxy expected as input : {tensor.name}"); |
| | } |
| | m_Dict[tensor.name].Generate(tensor, currentBatchSize, infos); |
| | } |
| | } |
| | } |
| | } |
| |
|