Naungth commited on
Commit ·
b69c6e3
1
Parent(s): cfcecd9
kept the tree.map_structure call and perform the check inside its handler
Browse files
packages/openpi-client/src/openpi_client/action_chunk_broker.py
CHANGED
|
@@ -18,7 +18,6 @@ class ActionChunkBroker(_base_policy.BasePolicy):
|
|
| 18 |
|
| 19 |
def __init__(self, policy: _base_policy.BasePolicy, action_horizon: int):
|
| 20 |
self._policy = policy
|
| 21 |
-
|
| 22 |
self._action_horizon = action_horizon
|
| 23 |
self._cur_step: int = 0
|
| 24 |
|
|
@@ -30,13 +29,13 @@ class ActionChunkBroker(_base_policy.BasePolicy):
|
|
| 30 |
self._last_results = self._policy.infer(obs)
|
| 31 |
self._cur_step = 0
|
| 32 |
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
results[key] = val[self._cur_step, ...]
|
| 37 |
else:
|
| 38 |
-
|
| 39 |
|
|
|
|
| 40 |
self._cur_step += 1
|
| 41 |
|
| 42 |
if self._cur_step >= self._action_horizon:
|
|
|
|
| 18 |
|
| 19 |
def __init__(self, policy: _base_policy.BasePolicy, action_horizon: int):
|
| 20 |
self._policy = policy
|
|
|
|
| 21 |
self._action_horizon = action_horizon
|
| 22 |
self._cur_step: int = 0
|
| 23 |
|
|
|
|
| 29 |
self._last_results = self._policy.infer(obs)
|
| 30 |
self._cur_step = 0
|
| 31 |
|
| 32 |
+
def slicer(x):
|
| 33 |
+
if isinstance(x, np.ndarray):
|
| 34 |
+
return x[self._cur_step, ...]
|
|
|
|
| 35 |
else:
|
| 36 |
+
return x
|
| 37 |
|
| 38 |
+
results = tree.map_structure(slicer, self._last_results)
|
| 39 |
self._cur_step += 1
|
| 40 |
|
| 41 |
if self._cur_step >= self._action_horizon:
|