| | using System.IO; |
| | using Google.Protobuf; |
| | using System.Collections.Generic; |
| | using Unity.MLAgents.Sensors; |
| | using Unity.MLAgents.Policies; |
| |
|
| | namespace Unity.MLAgents.Demonstrations |
| | { |
| | |
| | |
| | |
| | |
| | public class DemonstrationWriter |
| | { |
| | |
| | |
| | |
| | internal const int MetaDataBytes = 32; |
| |
|
| | DemonstrationMetaData m_MetaData; |
| | Stream m_Writer; |
| | float m_CumulativeReward; |
| | ObservationWriter m_ObservationWriter = new ObservationWriter(); |
| |
|
| | |
| | |
| | |
| | |
| | |
| | public DemonstrationWriter(Stream stream) |
| | { |
| | m_Writer = stream; |
| | } |
| |
|
| | |
| | |
| | |
| | internal int NumSteps |
| | { |
| | get { return m_MetaData.numberSteps; } |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | internal void Initialize( |
| | string demonstrationName, BrainParameters brainParameters, string brainName) |
| | { |
| | if (m_Writer == null) |
| | { |
| | |
| | return; |
| | } |
| |
|
| | m_MetaData = new DemonstrationMetaData { demonstrationName = demonstrationName }; |
| | var metaProto = m_MetaData.ToProto(); |
| | metaProto.WriteDelimitedTo(m_Writer); |
| |
|
| | WriteBrainParameters(brainName, brainParameters); |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | void WriteMetadata() |
| | { |
| | if (m_Writer == null) |
| | { |
| | |
| | return; |
| | } |
| |
|
| | var metaProto = m_MetaData.ToProto(); |
| | var metaProtoBytes = metaProto.ToByteArray(); |
| | m_Writer.Write(metaProtoBytes, 0, metaProtoBytes.Length); |
| | m_Writer.Seek(0, 0); |
| | metaProto.WriteDelimitedTo(m_Writer); |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | void WriteBrainParameters(string brainName, BrainParameters brainParameters) |
| | { |
| | if (m_Writer == null) |
| | { |
| | |
| | return; |
| | } |
| |
|
| | |
| | m_Writer.Seek(MetaDataBytes + 1, 0); |
| | var brainProto = brainParameters.ToProto(brainName, false); |
| | brainProto.WriteDelimitedTo(m_Writer); |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | internal void Record(AgentInfo info, List<ISensor> sensors) |
| | { |
| | if (m_Writer == null) |
| | { |
| | |
| | return; |
| | } |
| |
|
| | |
| | m_MetaData.numberSteps++; |
| | m_CumulativeReward += info.reward; |
| | if (info.done) |
| | { |
| | EndEpisode(); |
| | } |
| |
|
| | |
| | var agentProto = info.ToInfoActionPairProto(); |
| | foreach (var sensor in sensors) |
| | { |
| | agentProto.AgentInfo.Observations.Add(sensor.GetObservationProto(m_ObservationWriter)); |
| | } |
| |
|
| | agentProto.WriteDelimitedTo(m_Writer); |
| | } |
| |
|
| | |
| | |
| | |
| | public void Close() |
| | { |
| | if (m_Writer == null) |
| | { |
| | |
| | return; |
| | } |
| |
|
| | EndEpisode(); |
| | m_MetaData.meanReward = m_CumulativeReward / m_MetaData.numberEpisodes; |
| | WriteMetadata(); |
| | m_Writer.Close(); |
| | m_Writer = null; |
| | } |
| |
|
| | |
| | |
| | |
| | void EndEpisode() |
| | { |
| | m_MetaData.numberEpisodes += 1; |
| | } |
| | } |
| | } |
| |
|