Naungth commited on
Commit
b69c6e3
·
1 Parent(s): cfcecd9

kept the tree.map_structure call and perform the check inside its handler

Browse files
packages/openpi-client/src/openpi_client/action_chunk_broker.py CHANGED
@@ -18,7 +18,6 @@ class ActionChunkBroker(_base_policy.BasePolicy):
18
 
19
  def __init__(self, policy: _base_policy.BasePolicy, action_horizon: int):
20
  self._policy = policy
21
-
22
  self._action_horizon = action_horizon
23
  self._cur_step: int = 0
24
 
@@ -30,13 +29,13 @@ class ActionChunkBroker(_base_policy.BasePolicy):
30
  self._last_results = self._policy.infer(obs)
31
  self._cur_step = 0
32
 
33
- results: Dict = {}
34
- for key, val in self._last_results.items():
35
- if isinstance(val, np.ndarray):
36
- results[key] = val[self._cur_step, ...]
37
  else:
38
- results[key] = val
39
 
 
40
  self._cur_step += 1
41
 
42
  if self._cur_step >= self._action_horizon:
 
18
 
19
  def __init__(self, policy: _base_policy.BasePolicy, action_horizon: int):
20
  self._policy = policy
 
21
  self._action_horizon = action_horizon
22
  self._cur_step: int = 0
23
 
 
29
  self._last_results = self._policy.infer(obs)
30
  self._cur_step = 0
31
 
32
+ def slicer(x):
33
+ if isinstance(x, np.ndarray):
34
+ return x[self._cur_step, ...]
 
35
  else:
36
+ return x
37
 
38
+ results = tree.map_structure(slicer, self._last_results)
39
  self._cur_step += 1
40
 
41
  if self._cur_step >= self._action_horizon: