| import re | |
| from typing import Any, Dict, Iterable, List, Optional, Tuple | |
| import numpy as np | |
| import cv2 | |
| import copy | |
| def process_segmentation( | |
| segmentation: np.ndarray, | |
| segmentation_id_map: Optional[Dict[int, Any]], | |
| color_map: Dict[int, List[int]], | |
| current_segment: Any, | |
| current_subgoal_segment: Optional[str], | |
| previous_subgoal_segment: Optional[str], | |
| current_task_name: str, | |
| existing_points: Optional[List[List[int]]] = None, | |
| existing_subgoal_filled: Optional[str] = None, | |
| ) -> Dict[str, Any]: | |
| """ | |
| Shared helper to compute segmentation filtering and grounded subgoal text. | |
| Returns a dict with: | |
| - segmentation_result: segmentation mask filtered to visible ids | |
| - segmentation_result_2d: squeezed version of segmentation_result | |
| - segmentation_points: cached center points for current targets | |
| - current_subgoal_segment_filled: subgoal string with centers filled in | |
| - no_object_flag: whether the target ids are missing in the mask | |
| - updated_previous_subgoal_segment: equals current_subgoal_segment for caller caching | |
| - vis_obj_id_list: ids kept in segmentation_result | |
| """ | |
| segmentation_2d = segmentation.squeeze() if segmentation.ndim > 2 else segmentation | |
| if isinstance(current_segment, (list, tuple)): | |
| active_segments = list(current_segment) | |
| elif current_segment is None: | |
| active_segments = [] | |
| else: | |
| active_segments = [current_segment] | |
| segment_ids_by_index = {idx: [] for idx in range(len(active_segments))} | |
| vis_obj_id_list: List[int] = [] | |
| if isinstance(segmentation_id_map, dict): | |
| for obj_id, obj in sorted(segmentation_id_map.items()): | |
| if active_segments: | |
| for idx, target in enumerate(active_segments): | |
| if obj is target: | |
| vis_obj_id_list.append(obj_id) | |
| segment_ids_by_index[idx].append(obj_id) | |
| break | |
| if getattr(obj, "name", None) == "table-workspace": | |
| color_map[obj_id] = [0, 0, 0] | |
| segmentation_result = np.where( | |
| np.isin(segmentation_2d, vis_obj_id_list), segmentation_2d, 0 | |
| ) | |
| segmentation_result_2d = segmentation_result.squeeze() | |
| segmentation_points = existing_points or [] | |
| current_subgoal_segment_filled = existing_subgoal_filled | |
| no_object_flag = False | |
| if current_subgoal_segment != previous_subgoal_segment: | |
| def compute_center_from_ids(segmentation_mask: np.ndarray, ids: Iterable[int]): | |
| nonlocal no_object_flag | |
| ids = list(ids) | |
| if not ids: | |
| return None | |
| mask = np.isin(segmentation_mask, ids) | |
| if not np.any(mask): | |
| no_object_flag = True | |
| return None | |
| coords = np.argwhere(mask) | |
| if coords.size == 0: | |
| return None | |
| center_y = int(coords[:, 0].mean()) | |
| center_x = int(coords[:, 1].mean()) | |
| return [center_y, center_x] | |
| segment_centers: List[Optional[List[int]]] = [] | |
| if active_segments: | |
| for idx in range(len(active_segments)): | |
| segment_centers.append( | |
| compute_center_from_ids( | |
| segmentation_2d, segment_ids_by_index.get(idx, []) | |
| ) | |
| ) | |
| else: | |
| segment_centers.append( | |
| compute_center_from_ids(segmentation_2d, vis_obj_id_list) | |
| ) | |
| segmentation_points = [center for center in segment_centers if center is not None] | |
| if current_subgoal_segment: | |
| normalized_centers: List[Optional[str]] = [] | |
| for center in segment_centers: | |
| if center is None: | |
| normalized_centers.append(None) | |
| continue | |
| center_y, center_x = center | |
| normalized_centers.append(f"<{center_y}, {center_x}>") | |
| placeholder_pattern = re.compile(r"<[^>]*>") | |
| placeholders = list(placeholder_pattern.finditer(current_subgoal_segment)) | |
| placeholder_count = len(placeholders) | |
| if placeholder_count > 0 and normalized_centers: | |
| replacements = normalized_centers.copy() | |
| if len(replacements) == 1 and placeholder_count > 1: | |
| replacements = replacements * placeholder_count | |
| elif len(replacements) < placeholder_count: | |
| replacements.extend([None] * (placeholder_count - len(replacements))) | |
| missing_placeholder = False | |
| new_text_parts: List[str] = [] | |
| last_idx = 0 | |
| for idx, match in enumerate(placeholders): | |
| new_text_parts.append( | |
| current_subgoal_segment[last_idx : match.start()] | |
| ) | |
| replacement_text = replacements[idx] | |
| if replacement_text is None: | |
| missing_placeholder = True | |
| else: | |
| new_text_parts.append(replacement_text) | |
| last_idx = match.end() | |
| new_text_parts.append(current_subgoal_segment[last_idx:]) | |
| current_subgoal_segment_filled = ( | |
| current_task_name if missing_placeholder else "".join(new_text_parts) | |
| ) | |
| else: | |
| current_subgoal_segment_filled = current_subgoal_segment | |
| else: | |
| current_subgoal_segment_filled = current_subgoal_segment | |
| return { | |
| "segmentation_result": segmentation_result, | |
| "segmentation_result_2d": segmentation_result_2d, | |
| "segmentation_points": segmentation_points, | |
| "current_subgoal_segment_filled": current_subgoal_segment_filled, | |
| "no_object_flag": no_object_flag, | |
| "updated_previous_subgoal_segment": current_subgoal_segment, | |
| "vis_obj_id_list": vis_obj_id_list, | |
| } | |
| def create_segmentation_visuals( | |
| segmentation: np.ndarray, | |
| segmentation_result: np.ndarray, | |
| base_frame: np.ndarray, | |
| color_map: Dict[int, List[int]], | |
| segmentation_points: List[List[int]], | |
| ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: | |
| """ | |
| Build colored segmentation visualizations and target overlay for video export. | |
| Returns (segmentation_vis, segmentation_result_vis, target_for_video). | |
| """ | |
| segmentation_for_video = copy.deepcopy(segmentation) | |
| segmentation_result_for_video = copy.deepcopy(segmentation_result) | |
| segmentation_vis = np.zeros( | |
| (*segmentation_for_video.shape[:2], 3), dtype=np.uint8 | |
| ) | |
| segmentation_result_vis = np.zeros( | |
| (*segmentation_result_for_video.shape[:2], 3), dtype=np.uint8 | |
| ) | |
| seg_2d = ( | |
| segmentation_for_video.squeeze() | |
| if segmentation_for_video.ndim > 2 | |
| else segmentation_for_video | |
| ) | |
| seg_result_2d = ( | |
| segmentation_result_for_video.squeeze() | |
| if segmentation_result_for_video.ndim > 2 | |
| else segmentation_result_for_video | |
| ) | |
| for seg_id in np.unique(seg_2d): | |
| if seg_id > 0: | |
| color = color_map.get(seg_id, [255, 255, 255]) | |
| mask = seg_2d == seg_id | |
| segmentation_vis[mask] = color | |
| for seg_id in np.unique(seg_result_2d): | |
| if seg_id > 0: | |
| color = color_map.get(seg_id, [255, 255, 255]) | |
| mask = seg_result_2d == seg_id | |
| segmentation_result_vis[mask] = color | |
| target_for_video = copy.deepcopy(base_frame) | |
| if segmentation_vis.shape[:2] != base_frame.shape[:2]: | |
| segmentation_vis = cv2.resize( | |
| segmentation_vis, | |
| (base_frame.shape[1], base_frame.shape[0]), | |
| interpolation=cv2.INTER_NEAREST, | |
| ) | |
| if segmentation_result_vis.shape[:2] != base_frame.shape[:2]: | |
| segmentation_result_vis = cv2.resize( | |
| segmentation_result_vis, | |
| (base_frame.shape[1], base_frame.shape[0]), | |
| interpolation=cv2.INTER_NEAREST, | |
| ) | |
| if segmentation_points: | |
| diameter = 5 | |
| for center_y, center_x in segmentation_points: | |
| cv2.circle(target_for_video, (center_x, center_y), diameter, (255, 0, 0), -1) | |
| return segmentation_vis, segmentation_result_vis, target_for_video | |