| """Utilities for picking out pickup tasks for failure recovery.""" |
| from typing import Any, Iterable, List, Tuple, Union |
|
|
| from .subgoal_planner_func import solve_pickup, solve_pickup_bin |
| import torch |
|
|
| TaskEntry = Union[dict, tuple, list] |
|
|
| FAIL_GRASP_MODES = ("xy", "z") |
|
|
|
|
| def _get_demo_flag(task: TaskEntry) -> bool: |
| """Return the demonstration flag, defaulting to False when missing.""" |
| if isinstance(task, dict): |
| if "demonstration" in task: |
| return bool(task.get("demonstration")) |
| return bool(task.get("demo", False)) |
| if isinstance(task, (list, tuple)) and len(task) >= 3: |
| return bool(task[2]) |
| return False |
|
|
|
|
| def _extract_solve(task: TaskEntry) -> Any: |
| """Fetch the solve callable from a task entry if present.""" |
| if isinstance(task, dict): |
| return task.get("solve") |
| if isinstance(task, (list, tuple)) and len(task) >= 5: |
| return task[4] |
| return None |
|
|
|
|
| def _resolve_pickup_solver(solve_callable: Any): |
| """Return the pickup solver callable referenced by a task, if any.""" |
| if solve_callable is None: |
| return None |
| if isinstance(solve_callable, (list, tuple)): |
| for cb in solve_callable: |
| solver = _resolve_pickup_solver(cb) |
| if solver: |
| return solver |
| return None |
|
|
| if solve_callable in (solve_pickup, solve_pickup_bin): |
| return solve_callable |
|
|
| name = getattr(solve_callable, "__name__", "") |
| if name == "solve_pickup": |
| return solve_pickup |
| if name == "solve_pickup_bin": |
| return solve_pickup_bin |
|
|
| underlying = getattr(solve_callable, "func", None) |
| if underlying and underlying is not solve_callable: |
| solver = _resolve_pickup_solver(underlying) |
| if solver: |
| return solver |
|
|
| code_obj = getattr(solve_callable, "__code__", None) |
| if code_obj: |
| if "solve_pickup_bin" in code_obj.co_names: |
| return solve_pickup_bin |
| if "solve_pickup" in code_obj.co_names: |
| return solve_pickup |
|
|
| wrapped = getattr(solve_callable, "__wrapped__", None) |
| if wrapped and wrapped is not solve_callable: |
| solver = _resolve_pickup_solver(wrapped) |
| if solver: |
| return solver |
|
|
| return None |
|
|
|
|
| def _normalize_single_obj(obj: Any) -> Any: |
| """ |
| Some tasks store a single segment as a list/tuple with one element. |
| For pickup we only need the underlying object, not the container. |
| """ |
| if isinstance(obj, (list, tuple)): |
| return obj[0] if obj else None |
| return obj |
|
|
|
|
| def _solve_refs_pickup(solve_callable: Any) -> bool: |
| """ |
| Check whether a solve callable eventually calls `solve_pickup` or |
| `solve_pickup_bin` without executing it. Handles plain callables, |
| functools.partial, and containers. |
| """ |
| return _resolve_pickup_solver(solve_callable) is not None |
|
|
|
|
| def task4recovery(task_list: Iterable[TaskEntry]) -> Tuple[List[int], List[TaskEntry]]: |
| """ |
| Pass task_list, return indices and task entries where solve uses solve_pickup or solve_pickup_bin |
| and demonstration=False. |
| |
| Args: |
| task_list: Sequential task list (dict or old format tuple/list). |
| |
| Returns: |
| (pickup_indices, pickup_tasks) |
| """ |
| pickup_indices: List[int] = [] |
| pickup_tasks: List[TaskEntry] = [] |
|
|
| for idx, task in enumerate(task_list): |
| if _get_demo_flag(task): |
| continue |
| solve_callable = _extract_solve(task) |
| if _solve_refs_pickup(solve_callable): |
| pickup_indices.append(idx) |
| pickup_tasks.append(task) |
|
|
| return pickup_indices, pickup_tasks |
|
|
|
|
| def _make_fail_grasp_solve(solve_callable: Any, obj: Any, mode: str): |
| """Wrap a solve callable to force fail_grasp=True with a specific failure mode.""" |
| solver = _resolve_pickup_solver(solve_callable) |
| target_obj = _normalize_single_obj(obj) |
|
|
| def _wrapped(env, planner): |
| |
| if solver is not None: |
| try: |
| return solver(env, planner, obj=target_obj, fail_grasp=True, mode=mode) |
| except TypeError: |
| return solver(env, planner, obj=target_obj, fail_grasp=True) |
|
|
| if solve_callable is None: |
| return solve_pickup(env, planner, obj=target_obj, fail_grasp=True, mode=mode) |
|
|
| try: |
| return solve_callable(env, planner, fail_grasp=True, mode=mode) |
| except TypeError: |
| |
| try: |
| return solve_callable(env, planner, fail_grasp=True) |
| except TypeError: |
| try: |
| return solve_callable(env, planner) |
| except TypeError: |
| return solve_pickup(env, planner, obj=target_obj, fail_grasp=True, mode=mode) |
|
|
| return _wrapped |
|
|
|
|
| def inject_fail_grasp(task_list: Iterable[TaskEntry], generator: torch.Generator = None, mode: str = None): |
| """ |
| Randomly select a pickup task, replace its solve with a version where fail_grasp=True. |
| |
| Args: |
| task_list: Task list |
| generator: torch.Generator, for reproducible random selection |
| |
| Returns: |
| Index of modified task; return None if no pickup task exists. |
| """ |
| pickup_indices, _ = task4recovery(task_list) |
| if not pickup_indices: |
| return None |
|
|
| torch_gen = generator if isinstance(generator, torch.Generator) else None |
| if torch_gen is not None: |
| choice = torch.randint(0, len(pickup_indices), (1,), generator=torch_gen).item() |
| else: |
| choice = torch.randint(0, len(pickup_indices), (1,)).item() |
|
|
| target_idx = pickup_indices[choice] |
| task = task_list[target_idx] |
| normalized_mode = mode.lower() if isinstance(mode, str) else mode |
| if normalized_mode is None: |
| if torch_gen is not None: |
| mode_choice = FAIL_GRASP_MODES[torch.randint(0, len(FAIL_GRASP_MODES), (1,), generator=torch_gen).item()] |
| else: |
| mode_choice = FAIL_GRASP_MODES[torch.randint(0, len(FAIL_GRASP_MODES), (1,)).item()] |
| else: |
| if normalized_mode not in FAIL_GRASP_MODES: |
| raise ValueError(f"Unknown fail grasp mode {mode!r}") |
| mode_choice = normalized_mode |
| if isinstance(task, dict): |
| obj = task.get("segment") |
| solve_callable = task.get("solve") |
| task["solve"] = _make_fail_grasp_solve(solve_callable, obj, mode_choice) |
| task["fail_grasp_mode"] = mode_choice |
| task["fail_grasp_injected"] = True |
| return target_idx |
|
|