| | """Utilities for picking out pickup tasks for failure recovery.""" |
| | from typing import Any, Iterable, List, Tuple, Union |
| |
|
| | from .subgoal_planner_func import solve_pickup, solve_pickup_bin |
| | import torch |
| |
|
| | TaskEntry = Union[dict, tuple, list] |
| |
|
| | FAIL_GRASP_MODES = ("xy", "z") |
| |
|
| |
|
| | def _get_demo_flag(task: TaskEntry) -> bool: |
| | """Return the demonstration flag, defaulting to False when missing.""" |
| | if isinstance(task, dict): |
| | if "demonstration" in task: |
| | return bool(task.get("demonstration")) |
| | return bool(task.get("demo", False)) |
| | if isinstance(task, (list, tuple)) and len(task) >= 3: |
| | return bool(task[2]) |
| | return False |
| |
|
| |
|
| | def _extract_solve(task: TaskEntry) -> Any: |
| | """Fetch the solve callable from a task entry if present.""" |
| | if isinstance(task, dict): |
| | return task.get("solve") |
| | if isinstance(task, (list, tuple)) and len(task) >= 5: |
| | return task[4] |
| | return None |
| |
|
| |
|
| | def _resolve_pickup_solver(solve_callable: Any): |
| | """Return the pickup solver callable referenced by a task, if any.""" |
| | if solve_callable is None: |
| | return None |
| | if isinstance(solve_callable, (list, tuple)): |
| | for cb in solve_callable: |
| | solver = _resolve_pickup_solver(cb) |
| | if solver: |
| | return solver |
| | return None |
| |
|
| | if solve_callable in (solve_pickup, solve_pickup_bin): |
| | return solve_callable |
| |
|
| | name = getattr(solve_callable, "__name__", "") |
| | if name == "solve_pickup": |
| | return solve_pickup |
| | if name == "solve_pickup_bin": |
| | return solve_pickup_bin |
| |
|
| | underlying = getattr(solve_callable, "func", None) |
| | if underlying and underlying is not solve_callable: |
| | solver = _resolve_pickup_solver(underlying) |
| | if solver: |
| | return solver |
| |
|
| | code_obj = getattr(solve_callable, "__code__", None) |
| | if code_obj: |
| | if "solve_pickup_bin" in code_obj.co_names: |
| | return solve_pickup_bin |
| | if "solve_pickup" in code_obj.co_names: |
| | return solve_pickup |
| |
|
| | wrapped = getattr(solve_callable, "__wrapped__", None) |
| | if wrapped and wrapped is not solve_callable: |
| | solver = _resolve_pickup_solver(wrapped) |
| | if solver: |
| | return solver |
| |
|
| | return None |
| |
|
| |
|
| | def _normalize_single_obj(obj: Any) -> Any: |
| | """ |
| | Some tasks store a single segment as a list/tuple with one element. |
| | For pickup we only need the underlying object, not the container. |
| | """ |
| | if isinstance(obj, (list, tuple)): |
| | return obj[0] if obj else None |
| | return obj |
| |
|
| |
|
| | def _solve_refs_pickup(solve_callable: Any) -> bool: |
| | """ |
| | Check whether a solve callable eventually calls `solve_pickup` or |
| | `solve_pickup_bin` without executing it. Handles plain callables, |
| | functools.partial, and containers. |
| | """ |
| | return _resolve_pickup_solver(solve_callable) is not None |
| |
|
| |
|
| | def task4recovery(task_list: Iterable[TaskEntry]) -> Tuple[List[int], List[TaskEntry]]: |
| | """ |
| | Pass task_list, return indices and task entries where solve uses solve_pickup or solve_pickup_bin |
| | and demonstration=False. |
| | |
| | Args: |
| | task_list: Sequential task list (dict or old format tuple/list). |
| | |
| | Returns: |
| | (pickup_indices, pickup_tasks) |
| | """ |
| | pickup_indices: List[int] = [] |
| | pickup_tasks: List[TaskEntry] = [] |
| |
|
| | for idx, task in enumerate(task_list): |
| | if _get_demo_flag(task): |
| | continue |
| | solve_callable = _extract_solve(task) |
| | if _solve_refs_pickup(solve_callable): |
| | pickup_indices.append(idx) |
| | pickup_tasks.append(task) |
| |
|
| | return pickup_indices, pickup_tasks |
| |
|
| |
|
| | def _make_fail_grasp_solve(solve_callable: Any, obj: Any, mode: str): |
| | """Wrap a solve callable to force fail_grasp=True with a specific failure mode.""" |
| | solver = _resolve_pickup_solver(solve_callable) |
| | target_obj = _normalize_single_obj(obj) |
| |
|
| | def _wrapped(env, planner): |
| | |
| | if solver is not None: |
| | try: |
| | return solver(env, planner, obj=target_obj, fail_grasp=True, mode=mode) |
| | except TypeError: |
| | return solver(env, planner, obj=target_obj, fail_grasp=True) |
| |
|
| | if solve_callable is None: |
| | return solve_pickup(env, planner, obj=target_obj, fail_grasp=True, mode=mode) |
| |
|
| | try: |
| | return solve_callable(env, planner, fail_grasp=True, mode=mode) |
| | except TypeError: |
| | |
| | try: |
| | return solve_callable(env, planner, fail_grasp=True) |
| | except TypeError: |
| | try: |
| | return solve_callable(env, planner) |
| | except TypeError: |
| | return solve_pickup(env, planner, obj=target_obj, fail_grasp=True, mode=mode) |
| |
|
| | return _wrapped |
| |
|
| |
|
| | def inject_fail_grasp(task_list: Iterable[TaskEntry], generator: torch.Generator = None, mode: str = None): |
| | """ |
| | Randomly select a pickup task, replace its solve with a version where fail_grasp=True. |
| | |
| | Args: |
| | task_list: Task list |
| | generator: torch.Generator, for reproducible random selection |
| | |
| | Returns: |
| | Index of modified task; return None if no pickup task exists. |
| | """ |
| | pickup_indices, _ = task4recovery(task_list) |
| | if not pickup_indices: |
| | return None |
| |
|
| | torch_gen = generator if isinstance(generator, torch.Generator) else None |
| | if torch_gen is not None: |
| | choice = torch.randint(0, len(pickup_indices), (1,), generator=torch_gen).item() |
| | else: |
| | choice = torch.randint(0, len(pickup_indices), (1,)).item() |
| |
|
| | target_idx = pickup_indices[choice] |
| | task = task_list[target_idx] |
| | normalized_mode = mode.lower() if isinstance(mode, str) else mode |
| | if normalized_mode is None: |
| | if torch_gen is not None: |
| | mode_choice = FAIL_GRASP_MODES[torch.randint(0, len(FAIL_GRASP_MODES), (1,), generator=torch_gen).item()] |
| | else: |
| | mode_choice = FAIL_GRASP_MODES[torch.randint(0, len(FAIL_GRASP_MODES), (1,)).item()] |
| | else: |
| | if normalized_mode not in FAIL_GRASP_MODES: |
| | raise ValueError(f"Unknown fail grasp mode {mode!r}") |
| | mode_choice = normalized_mode |
| | if isinstance(task, dict): |
| | obj = task.get("segment") |
| | solve_callable = task.get("solve") |
| | task["solve"] = _make_fail_grasp_solve(solve_callable, obj, mode_choice) |
| | task["fail_grasp_mode"] = mode_choice |
| | task["fail_grasp_injected"] = True |
| | return target_idx |
| |
|