File size: 1,404 Bytes
ce5618e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 | from typing import Dict
import numpy as np
import tree
from typing_extensions import override
from openpi_client import base_policy as _base_policy
class ActionChunkBroker(_base_policy.BasePolicy):
"""Wraps a policy to return action chunks one-at-a-time.
Assumes that the first dimension of all action fields is the chunk size.
A new inference call to the inner policy is only made when the current
list of chunks is exhausted.
"""
def __init__(self, policy: _base_policy.BasePolicy, action_horizon: int):
self._policy = policy
self._action_horizon = action_horizon
self._cur_step: int = 0
self._last_results: Dict[str, np.ndarray] | None = None
@override
def infer(self, obs: Dict) -> Dict: # noqa: UP006
if self._last_results is None:
self._last_results = self._policy.infer(obs)
self._cur_step = 0
def slicer(x):
if isinstance(x, np.ndarray):
return x[self._cur_step, ...]
else:
return x
results = tree.map_structure(slicer, self._last_results)
self._cur_step += 1
if self._cur_step >= self._action_horizon:
self._last_results = None
return results
@override
def reset(self) -> None:
self._policy.reset()
self._last_results = None
self._cur_step = 0
|