Sichang0621's picture
Upload folder using huggingface_hub
ce5618e verified
from typing import Dict
import numpy as np
import tree
from typing_extensions import override
from openpi_client import base_policy as _base_policy
class ActionChunkBroker(_base_policy.BasePolicy):
"""Wraps a policy to return action chunks one-at-a-time.
Assumes that the first dimension of all action fields is the chunk size.
A new inference call to the inner policy is only made when the current
list of chunks is exhausted.
"""
def __init__(self, policy: _base_policy.BasePolicy, action_horizon: int):
self._policy = policy
self._action_horizon = action_horizon
self._cur_step: int = 0
self._last_results: Dict[str, np.ndarray] | None = None
@override
def infer(self, obs: Dict) -> Dict: # noqa: UP006
if self._last_results is None:
self._last_results = self._policy.infer(obs)
self._cur_step = 0
def slicer(x):
if isinstance(x, np.ndarray):
return x[self._cur_step, ...]
else:
return x
results = tree.map_structure(slicer, self._last_results)
self._cur_step += 1
if self._cur_step >= self._action_horizon:
self._last_results = None
return results
@override
def reset(self) -> None:
self._policy.reset()
self._last_results = None
self._cur_step = 0