| from typing import Dict |
|
|
| import numpy as np |
| import tree |
| from typing_extensions import override |
|
|
| from openpi_client import base_policy as _base_policy |
|
|
|
|
| class ActionChunkBroker(_base_policy.BasePolicy): |
| """Wraps a policy to return action chunks one-at-a-time. |
| |
| Assumes that the first dimension of all action fields is the chunk size. |
| |
| A new inference call to the inner policy is only made when the current |
| list of chunks is exhausted. |
| """ |
|
|
| def __init__(self, policy: _base_policy.BasePolicy, action_horizon: int): |
| self._policy = policy |
| self._action_horizon = action_horizon |
| self._cur_step: int = 0 |
|
|
| self._last_results: Dict[str, np.ndarray] | None = None |
|
|
| @override |
| def infer(self, obs: Dict) -> Dict: |
| if self._last_results is None: |
| self._last_results = self._policy.infer(obs) |
| self._cur_step = 0 |
|
|
| def slicer(x): |
| if isinstance(x, np.ndarray): |
| return x[self._cur_step, ...] |
| else: |
| return x |
|
|
| results = tree.map_structure(slicer, self._last_results) |
| self._cur_step += 1 |
|
|
| if self._cur_step >= self._action_horizon: |
| self._last_results = None |
|
|
| return results |
|
|
| @override |
| def reset(self) -> None: |
| self._policy.reset() |
| self._last_results = None |
| self._cur_step = 0 |
|
|