s3y's picture
Upload folder using huggingface_hub
93a69f8 verified
raw
history blame contribute delete
509 Bytes
from typing_extensions import override
from openpi_client import base_policy as _base_policy
from openpi_client.runtime import agent as _agent
class PolicyAgent(_agent.Agent):
"""An agent that uses a policy to determine actions."""
def __init__(self, policy: _base_policy.BasePolicy) -> None:
self._policy = policy
@override
def get_action(self, observation: dict) -> dict:
return self._policy.infer(observation)
def reset(self) -> None:
self._policy.reset()