File size: 15,993 Bytes
06c11b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
import gymnasium as gym
import numpy as np
import torch

from robomme.robomme_env.utils.vqa_options import get_vqa_options
from mani_skill.examples.motionplanning.panda.motionplanner import (
    PandaArmMotionPlanningSolver,
)
from mani_skill.examples.motionplanning.panda.motionplanner_stick import (
    PandaStickMotionPlanningSolver,
)
from ..robomme_env.utils import planner_denseStep
from ..robomme_env.utils.oracle_action_matcher import (
    find_exact_label_option_index,
)
from ..robomme_env.utils.choice_action_mapping import select_target_with_pixel
from ..logging_utils import logger


# -----------------------------------------------------------------------------
# Module: Oracle Planner Demonstration Wrapper
# Connect Robomme Oracle planning logic in Gym environment, support step-by-step observation collection.
# Oracle logic below is inlined from history_bench_sim.oracle_logic, cooperating with
# planner_denseStep, aggregating multiple internal env.step calls into a unified batch return.
# -----------------------------------------------------------------------------


class OraclePlannerDemonstrationWrapper(gym.Wrapper):
    """
    Wrap Robomme environment with Oracle planning logic into Gym Wrapper for demonstration/evaluation;
    Input to step is command_dict (containing choice and optional pixel point).
    step returns obs as dict-of-lists and reward/terminated/truncated as last-step values.
    """

    def __init__(self, env, env_id, gui_render=True):
        super().__init__(env)
        self.env_id = env_id
        self.gui_render = gui_render

        self.planner = None
        self.language_goal = None

        # State: current available options
        self.available_options = []
        self._oracle_screw_max_attempts = 3
        self._oracle_rrt_max_attempts = 3
        self._front_camera_intrinsic_cv = None
        self._front_camera_extrinsic_cv = None
        self._front_rgb_shape = None

        # Action/Observation space (Empty Dict here, agreed externally)
        self.action_space = gym.spaces.Dict({})
        self.observation_space = gym.spaces.Dict({})

    def _wrap_planner_with_screw_then_rrt_retry(self, planner, screw_failure_exc):
        original_move_to_pose_with_screw = planner.move_to_pose_with_screw
        original_move_to_pose_with_rrt = planner.move_to_pose_with_RRTStar

        def _move_to_pose_with_screw_then_rrt_retry(*args, **kwargs):
            for attempt in range(1, self._oracle_screw_max_attempts + 1):
                try:
                    result = original_move_to_pose_with_screw(*args, **kwargs)
                except screw_failure_exc as exc:
                    logger.debug(
                        f"[OraclePlannerWrapper] screw planning failed "
                        f"(attempt {attempt}/{self._oracle_screw_max_attempts}): {exc}"
                    )
                    continue

                if isinstance(result, int) and result == -1:
                    logger.debug(
                        f"[OraclePlannerWrapper] screw planning returned -1 "
                        f"(attempt {attempt}/{self._oracle_screw_max_attempts})"
                    )
                    continue

                return result

            logger.debug(
                "[OraclePlannerWrapper] screw planning exhausted; "
                f"fallback to RRT* (max {self._oracle_rrt_max_attempts} attempts)"
            )
            for attempt in range(1, self._oracle_rrt_max_attempts + 1):
                try:
                    result = original_move_to_pose_with_rrt(*args, **kwargs)
                except Exception as exc:
                    logger.debug(
                        f"[OraclePlannerWrapper] RRT* planning failed "
                        f"(attempt {attempt}/{self._oracle_rrt_max_attempts}): {exc}"
                    )
                    continue

                if isinstance(result, int) and result == -1:
                    logger.debug(
                        f"[OraclePlannerWrapper] RRT* planning returned -1 "
                        f"(attempt {attempt}/{self._oracle_rrt_max_attempts})"
                    )
                    continue

                return result

            raise RuntimeError(
                "[OraclePlannerWrapper] screw->RRT* planning exhausted; "
                f"screw_attempts={self._oracle_screw_max_attempts}, "
                f"rrt_attempts={self._oracle_rrt_max_attempts}"
            )

        planner.move_to_pose_with_screw = _move_to_pose_with_screw_then_rrt_retry
        return planner

    def reset(self, **kwargs):
        # Prefer fail-aware planners; fallback to base planners if import fails.
        try:
            from ..robomme_env.utils.planner_fail_safe import (
                FailAwarePandaArmMotionPlanningSolver,
                FailAwarePandaStickMotionPlanningSolver,
                ScrewPlanFailure,
            )
        except Exception as exc:
            logger.debug(
                "[OraclePlannerWrapper] Warning: failed to import planner_fail_safe, "
                f"fallback to base planners: {exc}"
            )
            FailAwarePandaArmMotionPlanningSolver = PandaArmMotionPlanningSolver
            FailAwarePandaStickMotionPlanningSolver = PandaStickMotionPlanningSolver

            class ScrewPlanFailure(RuntimeError):
                """Placeholder exception type when fail-aware planner import is unavailable."""

        # Select stick or arm planner based on env_id and initialize.
        if self.env_id in ("PatternLock", "RouteStick"):
            self.planner = FailAwarePandaStickMotionPlanningSolver(
                self.env,
                debug=False,
                vis=self.gui_render,
                base_pose=self.env.unwrapped.agent.robot.pose,
                visualize_target_grasp_pose=False,
                print_env_info=False,
                joint_vel_limits=0.3,
            )
        else:
            self.planner = FailAwarePandaArmMotionPlanningSolver(
                self.env,
                debug=False,
                vis=self.gui_render,
                base_pose=self.env.unwrapped.agent.robot.pose,
                visualize_target_grasp_pose=False,
                print_env_info=False,
            )
        self._wrap_planner_with_screw_then_rrt_retry(
            self.planner,
            screw_failure_exc=ScrewPlanFailure,
        )
        ret = self.env.reset(**kwargs)
        if isinstance(ret, tuple) and len(ret) == 2:
            obs, info = ret
        else:
            obs, info = ret, {}
        self._update_front_camera_cache(obs_like=obs, info_like=info)
        self._build_step_options()
        if isinstance(info, dict):
            info["available_multi_choices"] = self.available_options
        return obs, info

    @staticmethod
    def _flatten_info_batch(info_batch: dict) -> dict:
        return {k: v[-1] if isinstance(v, list) and v else v for k, v in info_batch.items()}

    @staticmethod
    def _take_last_step_value(value):
        if isinstance(value, torch.Tensor):
            if value.numel() == 0 or value.ndim == 0:
                return value
            return value.reshape(-1)[-1]
        if isinstance(value, np.ndarray):
            if value.size == 0 or value.ndim == 0:
                return value
            return value.reshape(-1)[-1]
        if isinstance(value, (list, tuple)):
            return value[-1] if value else value
        return value

    @staticmethod
    def _to_numpy(value):
        if value is None:
            return None
        if isinstance(value, torch.Tensor):
            value = value.detach().cpu().numpy()
        return np.asarray(value)

    @staticmethod
    def _take_last_columnar(value):
        if isinstance(value, list):
            return value[-1] if value else None
        return value

    @classmethod
    def _normalize_intrinsic_cv(cls, intrinsic_like):
        intrinsic = cls._to_numpy(intrinsic_like)
        if intrinsic is None:
            return None
        intrinsic = intrinsic.reshape(-1)
        if intrinsic.size < 9:
            return None
        intrinsic = intrinsic[:9].reshape(3, 3)
        if not np.all(np.isfinite(intrinsic)):
            return None
        return intrinsic.astype(np.float64, copy=False)

    @classmethod
    def _normalize_extrinsic_cv(cls, extrinsic_like):
        extrinsic = cls._to_numpy(extrinsic_like)
        if extrinsic is None:
            return None
        extrinsic = extrinsic.reshape(-1)
        if extrinsic.size < 12:
            return None
        extrinsic = extrinsic[:12].reshape(3, 4)
        if not np.all(np.isfinite(extrinsic)):
            return None
        return extrinsic.astype(np.float64, copy=False)

    def _update_front_camera_cache(self, obs_like=None, info_like=None):
        obs_dict = obs_like if isinstance(obs_like, dict) else {}
        info_dict = info_like if isinstance(info_like, dict) else {}

        front_rgb = self._take_last_columnar(obs_dict.get("front_rgb_list"))
        front_rgb_np = self._to_numpy(front_rgb)
        if front_rgb_np is not None and front_rgb_np.ndim >= 2:
            self._front_rgb_shape = tuple(front_rgb_np.shape[:2])

        front_extrinsic = self._take_last_columnar(
            obs_dict.get("front_camera_extrinsic_list")
        )
        front_extrinsic_np = self._normalize_extrinsic_cv(front_extrinsic)
        if front_extrinsic_np is not None:
            self._front_camera_extrinsic_cv = front_extrinsic_np

        front_intrinsic = self._take_last_columnar(info_dict.get("front_camera_intrinsic"))
        front_intrinsic_np = self._normalize_intrinsic_cv(front_intrinsic)
        if front_intrinsic_np is not None:
            self._front_camera_intrinsic_cv = front_intrinsic_np

    @staticmethod
    def _empty_target():
        return {
            "obj": None,
            "name": None,
            "seg_id": None,
            "position": None,
            "match_distance": None,
            "selection_mode": None,
        }

    def _build_step_options(self):
        selected_target = self._empty_target()
        solve_options = get_vqa_options(self.env, self.planner, selected_target, self.env_id)
        self.available_options = [
            {"label": opt.get("label"), "action": opt.get("action", "Unknown"), "need_parameter": bool(opt.get("available"))}
            for opt in solve_options
        ]
        return selected_target, solve_options

    def _resolve_command(self, command_dict, solve_options):
        if not isinstance(command_dict, dict):
            return None, None
        if "choice" not in command_dict:
            return None, None

        target_choice = command_dict.get("choice")
        if not isinstance(target_choice, str):
            return None, None
        target_choice = target_choice.strip()
        if not target_choice:
            return None, None
        target_label = target_choice.lower()

        found_idx = find_exact_label_option_index(target_label, solve_options)
        if found_idx == -1:
            logger.debug(
                f"Error: Choice '{target_choice}' not found in current options by exact label match."
            )
            return None, None

        point = command_dict.get("point")
        if point is None:
            return found_idx, None
        if not isinstance(point, (list, tuple, np.ndarray)) or len(point) < 2:
            return found_idx, None
        try:
            y = float(point[0])
            x = float(point[1])
        except (TypeError, ValueError):
            return found_idx, None
        if not np.isfinite(x) or not np.isfinite(y):
            return found_idx, None
        # select_target_with_pixel expects [x, y].
        return found_idx, [int(np.rint(x)), int(np.rint(y))]

    def _apply_position_target(self, selected_target, option, target_pixel):
        if target_pixel is None:
            return

        best_cand = select_target_with_pixel(
            available=option.get("available"),
            pixel_like=target_pixel,
            intrinsic_cv=self._front_camera_intrinsic_cv,
            extrinsic_cv=self._front_camera_extrinsic_cv,
            image_shape=self._front_rgb_shape,
        )
        if best_cand is not None:
            selected_target.update(best_cand)

    def _execute_selected_option(self, option_idx, solve_options):
        option = solve_options[option_idx]
        logger.debug(f"Executing option: {option_idx + 1} - {option.get('action')}")

        result = planner_denseStep._run_with_dense_collection(
            self.planner,
            lambda: option.get("solve")(),
        )
        if result == -1:
            action_text = option.get("action", "Unknown")
            raise RuntimeError(
                f"Oracle solve failed after screw->RRT* retries for env '{self.env_id}', "
                f"action '{action_text}' (index {option_idx + 1})."
            )
        return result

    def _post_eval(self):
        self.env.unwrapped.evaluate()
        evaluation = self.env.unwrapped.evaluate(solve_complete_eval=True)
        logger.debug(f"Evaluation result: {evaluation}")

    def _format_step_output(self, batch):
        obs_batch, reward_batch, terminated_batch, truncated_batch, info_batch = batch
        self._update_front_camera_cache(obs_like=obs_batch, info_like=info_batch)
        info_flat = self._flatten_info_batch(info_batch)
        info_flat["available_multi_choices"] = getattr(self, "available_options", [])
        return (
            obs_batch,
            self._take_last_step_value(reward_batch),
            self._take_last_step_value(terminated_batch),
            self._take_last_step_value(truncated_batch),
            info_flat,
        )

    def step(self, action):
        """
        Execute one step: action is command_dict, must contain "choice", optional
        pixel `point=[y, x]` in front_rgb.
        Return last-step signals for reward/terminated/truncated while keeping obs as dict-of-lists.
        """
        # 1) Build solver options once and prepare a mutable selected_target holder for solve() closures.
        selected_target, solve_options = self._build_step_options()
        # 2) Validate/resolve the incoming command into (option index, optional target position).
        found_idx, target_pixel = self._resolve_command(action, solve_options)

        # 3) For invalid command or unmatched choice, keep legacy behavior: return an empty dense batch.
        if found_idx is None:
            return self._format_step_output(planner_denseStep.empty_step_batch())

        # 4) If a point is provided, map it to the nearest candidate target.
        option = solve_options[found_idx]
        self._apply_position_target(
            selected_target=selected_target,
            option=option,
            target_pixel=target_pixel,
        )

        requires_target = "available" in option
        if requires_target:
            if target_pixel is None:
                raise ValueError(
                    f"Multi-choice action '{option.get('action', 'Unknown')}' requires "
                    "a target pixel point=[y, x], but command did not provide it."
                )
            if selected_target.get("obj") is None:
                raise ValueError(
                    f"Multi-choice action '{option.get('action', 'Unknown')}' could not match "
                    f"any available candidate from point={target_pixel}."
                )

        # 5) Execute selected solve() with dense step collection; raise on solve == -1.
        batch = self._execute_selected_option(found_idx, solve_options)
        # 6) Run post-solve environment evaluation to keep existing side effects and logging.
        self._post_eval()
        # 7) Convert batch to wrapper output contract (last reward/terminated/truncated + flattened info).

        print("step!!!!!!")
        return self._format_step_output(batch)