| | from typing import Dict |
| |
|
| | import numpy as np |
| | import tree |
| | from typing_extensions import override |
| |
|
| | from openpi_client import base_policy as _base_policy |
| |
|
| |
|
| | class ActionChunkBroker(_base_policy.BasePolicy): |
| | """Wraps a policy to return action chunks one-at-a-time. |
| | |
| | Assumes that the first dimension of all action fields is the chunk size. |
| | |
| | A new inference call to the inner policy is only made when the current |
| | list of chunks is exhausted. |
| | """ |
| |
|
| | def __init__(self, policy: _base_policy.BasePolicy, action_horizon: int): |
| | self._policy = policy |
| | self._action_horizon = action_horizon |
| | self._cur_step: int = 0 |
| |
|
| | self._last_results: Dict[str, np.ndarray] | None = None |
| |
|
| | @override |
| | def infer(self, obs: Dict) -> Dict: |
| | if self._last_results is None: |
| | self._last_results = self._policy.infer(obs) |
| | self._cur_step = 0 |
| |
|
| | def slicer(x): |
| | if isinstance(x, np.ndarray): |
| | return x[self._cur_step, ...] |
| | else: |
| | return x |
| |
|
| | results = tree.map_structure(slicer, self._last_results) |
| | self._cur_step += 1 |
| |
|
| | if self._cur_step >= self._action_horizon: |
| | self._last_results = None |
| |
|
| | return results |
| |
|
| | @override |
| | def reset(self) -> None: |
| | self._policy.reset() |
| | self._last_results = None |
| | self._cur_step = 0 |
| |
|