ppo-Pyramids-Training / Project /Assets /ML-Agents /Examples /FoodCollector /Scripts /FoodCollectorAgent.cs
| using UnityEngine; | |
| using Unity.MLAgents; | |
| using Unity.MLAgents.Actuators; | |
| using Unity.MLAgents.Sensors; | |
| using Random = UnityEngine.Random; | |
| public class FoodCollectorAgent : Agent | |
| { | |
| FoodCollectorSettings m_FoodCollecterSettings; | |
| public GameObject area; | |
| FoodCollectorArea m_MyArea; | |
| bool m_Frozen; | |
| bool m_Poisoned; | |
| bool m_Satiated; | |
| bool m_Shoot; | |
| float m_FrozenTime; | |
| float m_EffectTime; | |
| Rigidbody m_AgentRb; | |
| float m_LaserLength; | |
| // Speed of agent rotation. | |
| public float turnSpeed = 300; | |
| // Speed of agent movement. | |
| public float moveSpeed = 2; | |
| public Material normalMaterial; | |
| public Material badMaterial; | |
| public Material goodMaterial; | |
| public Material frozenMaterial; | |
| public GameObject myLaser; | |
| public bool contribute; | |
| public bool useVectorObs; | |
| [ | |
| ] | |
| public bool useVectorFrozenFlag; | |
| EnvironmentParameters m_ResetParams; | |
| public override void Initialize() | |
| { | |
| m_AgentRb = GetComponent<Rigidbody>(); | |
| m_MyArea = area.GetComponent<FoodCollectorArea>(); | |
| m_FoodCollecterSettings = FindObjectOfType<FoodCollectorSettings>(); | |
| m_ResetParams = Academy.Instance.EnvironmentParameters; | |
| SetResetParameters(); | |
| } | |
| public override void CollectObservations(VectorSensor sensor) | |
| { | |
| if (useVectorObs) | |
| { | |
| var localVelocity = transform.InverseTransformDirection(m_AgentRb.velocity); | |
| sensor.AddObservation(localVelocity.x); | |
| sensor.AddObservation(localVelocity.z); | |
| sensor.AddObservation(m_Frozen); | |
| sensor.AddObservation(m_Shoot); | |
| } | |
| else if (useVectorFrozenFlag) | |
| { | |
| sensor.AddObservation(m_Frozen); | |
| } | |
| } | |
| public Color32 ToColor(int hexVal) | |
| { | |
| var r = (byte)((hexVal >> 16) & 0xFF); | |
| var g = (byte)((hexVal >> 8) & 0xFF); | |
| var b = (byte)(hexVal & 0xFF); | |
| return new Color32(r, g, b, 255); | |
| } | |
| public void MoveAgent(ActionBuffers actionBuffers) | |
| { | |
| m_Shoot = false; | |
| if (Time.time > m_FrozenTime + 4f && m_Frozen) | |
| { | |
| Unfreeze(); | |
| } | |
| if (Time.time > m_EffectTime + 0.5f) | |
| { | |
| if (m_Poisoned) | |
| { | |
| Unpoison(); | |
| } | |
| if (m_Satiated) | |
| { | |
| Unsatiate(); | |
| } | |
| } | |
| var dirToGo = Vector3.zero; | |
| var rotateDir = Vector3.zero; | |
| var continuousActions = actionBuffers.ContinuousActions; | |
| var discreteActions = actionBuffers.DiscreteActions; | |
| if (!m_Frozen) | |
| { | |
| var forward = Mathf.Clamp(continuousActions[0], -1f, 1f); | |
| var right = Mathf.Clamp(continuousActions[1], -1f, 1f); | |
| var rotate = Mathf.Clamp(continuousActions[2], -1f, 1f); | |
| dirToGo = transform.forward * forward; | |
| dirToGo += transform.right * right; | |
| rotateDir = -transform.up * rotate; | |
| var shootCommand = discreteActions[0] > 0; | |
| if (shootCommand) | |
| { | |
| m_Shoot = true; | |
| dirToGo *= 0.5f; | |
| m_AgentRb.velocity *= 0.75f; | |
| } | |
| m_AgentRb.AddForce(dirToGo * moveSpeed, ForceMode.VelocityChange); | |
| transform.Rotate(rotateDir, Time.fixedDeltaTime * turnSpeed); | |
| } | |
| if (m_AgentRb.velocity.sqrMagnitude > 25f) // slow it down | |
| { | |
| m_AgentRb.velocity *= 0.95f; | |
| } | |
| if (m_Shoot) | |
| { | |
| var myTransform = transform; | |
| myLaser.transform.localScale = new Vector3(1f, 1f, m_LaserLength); | |
| var rayDir = 25.0f * myTransform.forward; | |
| Debug.DrawRay(myTransform.position, rayDir, Color.red, 0f, true); | |
| RaycastHit hit; | |
| if (Physics.SphereCast(transform.position, 2f, rayDir, out hit, 25f)) | |
| { | |
| if (hit.collider.gameObject.CompareTag("agent")) | |
| { | |
| hit.collider.gameObject.GetComponent<FoodCollectorAgent>().Freeze(); | |
| } | |
| } | |
| } | |
| else | |
| { | |
| myLaser.transform.localScale = new Vector3(0f, 0f, 0f); | |
| } | |
| } | |
| void Freeze() | |
| { | |
| gameObject.tag = "frozenAgent"; | |
| m_Frozen = true; | |
| m_FrozenTime = Time.time; | |
| gameObject.GetComponentInChildren<Renderer>().material = frozenMaterial; | |
| } | |
| void Unfreeze() | |
| { | |
| m_Frozen = false; | |
| gameObject.tag = "agent"; | |
| gameObject.GetComponentInChildren<Renderer>().material = normalMaterial; | |
| } | |
| void Poison() | |
| { | |
| m_Poisoned = true; | |
| m_EffectTime = Time.time; | |
| gameObject.GetComponentInChildren<Renderer>().material = badMaterial; | |
| } | |
| void Unpoison() | |
| { | |
| m_Poisoned = false; | |
| gameObject.GetComponentInChildren<Renderer>().material = normalMaterial; | |
| } | |
| void Satiate() | |
| { | |
| m_Satiated = true; | |
| m_EffectTime = Time.time; | |
| gameObject.GetComponentInChildren<Renderer>().material = goodMaterial; | |
| } | |
| void Unsatiate() | |
| { | |
| m_Satiated = false; | |
| gameObject.GetComponentInChildren<Renderer>().material = normalMaterial; | |
| } | |
| public override void OnActionReceived(ActionBuffers actionBuffers) | |
| { | |
| MoveAgent(actionBuffers); | |
| } | |
| public override void Heuristic(in ActionBuffers actionsOut) | |
| { | |
| var continuousActionsOut = actionsOut.ContinuousActions; | |
| if (Input.GetKey(KeyCode.D)) | |
| { | |
| continuousActionsOut[2] = 1; | |
| } | |
| if (Input.GetKey(KeyCode.W)) | |
| { | |
| continuousActionsOut[0] = 1; | |
| } | |
| if (Input.GetKey(KeyCode.A)) | |
| { | |
| continuousActionsOut[2] = -1; | |
| } | |
| if (Input.GetKey(KeyCode.S)) | |
| { | |
| continuousActionsOut[0] = -1; | |
| } | |
| var discreteActionsOut = actionsOut.DiscreteActions; | |
| discreteActionsOut[0] = Input.GetKey(KeyCode.Space) ? 1 : 0; | |
| } | |
| public override void OnEpisodeBegin() | |
| { | |
| Unfreeze(); | |
| Unpoison(); | |
| Unsatiate(); | |
| m_Shoot = false; | |
| m_AgentRb.velocity = Vector3.zero; | |
| myLaser.transform.localScale = new Vector3(0f, 0f, 0f); | |
| transform.position = new Vector3(Random.Range(-m_MyArea.range, m_MyArea.range), | |
| 2f, Random.Range(-m_MyArea.range, m_MyArea.range)) | |
| + area.transform.position; | |
| transform.rotation = Quaternion.Euler(new Vector3(0f, Random.Range(0, 360))); | |
| SetResetParameters(); | |
| } | |
| void OnCollisionEnter(Collision collision) | |
| { | |
| if (collision.gameObject.CompareTag("food")) | |
| { | |
| Satiate(); | |
| collision.gameObject.GetComponent<FoodLogic>().OnEaten(); | |
| AddReward(1f); | |
| if (contribute) | |
| { | |
| m_FoodCollecterSettings.totalScore += 1; | |
| } | |
| } | |
| if (collision.gameObject.CompareTag("badFood")) | |
| { | |
| Poison(); | |
| collision.gameObject.GetComponent<FoodLogic>().OnEaten(); | |
| AddReward(-1f); | |
| if (contribute) | |
| { | |
| m_FoodCollecterSettings.totalScore -= 1; | |
| } | |
| } | |
| } | |
| public void SetLaserLengths() | |
| { | |
| m_LaserLength = m_ResetParams.GetWithDefault("laser_length", 1.0f); | |
| } | |
| public void SetAgentScale() | |
| { | |
| float agentScale = m_ResetParams.GetWithDefault("agent_scale", 1.0f); | |
| gameObject.transform.localScale = new Vector3(agentScale, agentScale, agentScale); | |
| } | |
| public void SetResetParameters() | |
| { | |
| SetLaserLengths(); | |
| SetAgentScale(); | |
| } | |
| } | |