File size: 509 Bytes
ce5618e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 |
from typing_extensions import override
from openpi_client import base_policy as _base_policy
from openpi_client.runtime import agent as _agent
class PolicyAgent(_agent.Agent):
"""An agent that uses a policy to determine actions."""
def __init__(self, policy: _base_policy.BasePolicy) -> None:
self._policy = policy
@override
def get_action(self, observation: dict) -> dict:
return self._policy.infer(observation)
def reset(self) -> None:
self._policy.reset()
|