Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| SAM2 utilities for BMP demo: | |
| - Build and prepare SAM model | |
| - Convert poses to segmentation | |
| - Compute mask-pose consistency | |
| """ | |
| from typing import Any, List, Optional, Tuple | |
| import numpy as np | |
| import torch | |
| from mmengine.structures import InstanceData | |
| from pycocotools import mask as Mask | |
| from sam2.build_sam import build_sam2 | |
| from sam2.sam2_image_predictor import SAM2ImagePredictor | |
| # Threshold for keypoint validity in mask-pose consistency | |
| STRICT_KPT_THRESHOLD: float = 0.5 | |
| def _validate_sam_args(sam_args): | |
| """ | |
| Validate that all required sam_args attributes are present. | |
| """ | |
| required = [ | |
| "crop", | |
| "use_bbox", | |
| "confidence_thr", | |
| "ignore_small_bboxes", | |
| "num_pos_keypoints", | |
| "num_pos_keypoints_if_crowd", | |
| "crowd_by_max_iou", | |
| "batch", | |
| "exclusive_masks", | |
| "extend_bbox", | |
| "pose_mask_consistency", | |
| "visibility_thr", | |
| ] | |
| for param in required: | |
| if not hasattr(sam_args, param): | |
| raise AttributeError(f"Missing required arg {param} in sam_args") | |
| def _get_max_ious(bboxes: List[np.ndarray]) -> np.ndarray: | |
| """ | |
| Compute maximum IoU for each bbox against others. | |
| """ | |
| is_crowd = [0] * len(bboxes) | |
| ious = Mask.iou(bboxes, bboxes, is_crowd) | |
| mat = np.array(ious) | |
| np.fill_diagonal(mat, 0) | |
| return mat.max(axis=1) | |
| def _compute_one_mask_pose_consistency( | |
| mask: np.ndarray, pos_keypoints: Optional[np.ndarray] = None, neg_keypoints: Optional[np.ndarray] = None | |
| ) -> float: | |
| """ | |
| Compute a consistency score between a mask and given keypoints. | |
| Args: | |
| mask (np.ndarray): Binary mask of shape (H, W). | |
| pos_keypoints (Optional[np.ndarray]): Positive keypoints array (N, 3). | |
| neg_keypoints (Optional[np.ndarray]): Negative keypoints array (M, 3). | |
| Returns: | |
| float: Weighted mean of positive and negative keypoint consistency. | |
| """ | |
| if mask is None: | |
| return 0.0 | |
| def _mean_inside(points: np.ndarray) -> float: | |
| if points.size == 0: | |
| return 0.0 | |
| pts_int = np.floor(points[:, :2]).astype(int) | |
| pts_int[:, 0] = np.clip(pts_int[:, 0], 0, mask.shape[1] - 1) | |
| pts_int[:, 1] = np.clip(pts_int[:, 1], 0, mask.shape[0] - 1) | |
| vals = mask[pts_int[:, 1], pts_int[:, 0]] | |
| return vals.mean() if vals.size > 0 else 0.0 | |
| pos_mean = 0.0 | |
| if pos_keypoints is not None: | |
| valid = pos_keypoints[:, 2] > STRICT_KPT_THRESHOLD | |
| pos_mean = _mean_inside(pos_keypoints[valid]) | |
| neg_mean = 0.0 | |
| if neg_keypoints is not None: | |
| valid = neg_keypoints[:, 2] > STRICT_KPT_THRESHOLD | |
| pts = neg_keypoints[valid][:, :2] | |
| inside = mask[np.floor(pts[:, 1]).astype(int), np.floor(pts[:, 0]).astype(int)] | |
| neg_mean = (~inside.astype(bool)).mean() if inside.size > 0 else 0.0 | |
| return 0.5 * pos_mean + 0.5 * neg_mean | |
| def _select_keypoints( | |
| args: Any, | |
| kpts: np.ndarray, | |
| num_visible: int, | |
| bbox: Optional[Tuple[float, float, float, float]] = None, | |
| method: Optional[str] = "distance+confidence", | |
| ) -> Tuple[np.ndarray, np.ndarray]: | |
| """ | |
| Select and order keypoints for SAM prompting based on specified method. | |
| Args: | |
| args: Configuration object with selection_method and visibility_thr attributes. | |
| kpts (np.ndarray): Keypoints array of shape (K, 3). | |
| num_visible (int): Number of keypoints above visibility threshold. | |
| bbox (Optional[Tuple]): Optional bbox for distance methods. | |
| method (Optional[str]): Override selection method. | |
| Returns: | |
| Tuple[np.ndarray, np.ndarray]: Selected keypoint coordinates (N,2) and confidences (N,). | |
| Raises: | |
| ValueError: If an unknown method is specified. | |
| """ | |
| if num_visible == 0: | |
| return kpts[:, :2], kpts[:, 2] | |
| methods = ["confidence", "distance", "distance+confidence", "closest"] | |
| sel_method = method or args.selection_method | |
| if sel_method not in methods: | |
| raise ValueError("Unknown method for keypoint selection: {}".format(sel_method)) | |
| # Select at maximum keypoint from the face | |
| facial_kpts = kpts[:3, :] | |
| facial_conf = kpts[:3, 2] | |
| facial_point = facial_kpts[np.argmax(facial_conf)] | |
| if facial_point[-1] >= args.visibility_thr: | |
| kpts = np.concatenate([facial_point[None, :], kpts[3:]], axis=0) | |
| conf = kpts[:, 2] | |
| vis_mask = conf >= args.visibility_thr | |
| coords = kpts[vis_mask, :2] | |
| confs = conf[vis_mask] | |
| if sel_method == "confidence": | |
| order = np.argsort(confs)[::-1] | |
| coords = coords[order] | |
| confs = confs[order] | |
| elif sel_method == "distance": | |
| if bbox is None: | |
| bbox_center = np.array([coords[:, 0].mean(), coords[:, 1].mean()]) | |
| else: | |
| bbox_center = np.array([(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2]) | |
| dists = np.linalg.norm(coords[:, :2] - bbox_center, axis=1) | |
| dist_matrix = np.linalg.norm(coords[:, None, :2] - coords[None, :, :2], axis=2) | |
| np.fill_diagonal(dist_matrix, np.inf) | |
| min_inter_dist = np.min(dist_matrix, axis=1) | |
| order = np.argsort(dists + 3 * min_inter_dist)[::-1] | |
| coords = coords[order, :2] | |
| confs = confs[order] | |
| elif sel_method == "distance+confidence": | |
| order = np.argsort(confs)[::-1] | |
| confidences = kpts[order, 2] | |
| coords = coords[order, :2] | |
| confs = confs[order] | |
| dist_matrix = np.linalg.norm(coords[:, None, :2] - coords[None, :, :2], axis=2) | |
| selected_idx = [0] | |
| confidences[0] = -1 | |
| for _ in range(coords.shape[0] - 1): | |
| min_dist = np.min(dist_matrix[:, selected_idx], axis=1) | |
| min_dist[confidences < np.percentile(confidences, 80)] = -1 | |
| next_idx = np.argmax(min_dist) | |
| selected_idx.append(next_idx) | |
| confidences[next_idx] = -1 | |
| coords = coords[selected_idx] | |
| confs = confs[selected_idx] | |
| elif sel_method == "closest": | |
| coords = coords[confs > STRICT_KPT_THRESHOLD, :] | |
| confs = confs[confs > STRICT_KPT_THRESHOLD] | |
| if bbox is None: | |
| bbox_center = np.array([coords[:, 0].mean(), coords[:, 1].mean()]) | |
| else: | |
| bbox_center = np.array([(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2]) | |
| dists = np.linalg.norm(coords[:, :2] - bbox_center, axis=1) | |
| order = np.argsort(dists) | |
| coords = coords[order, :2] | |
| confs = confs[order] | |
| return coords, confs | |
| def prepare_model(model_cfg: Any, model_checkpoint: str) -> SAM2ImagePredictor: | |
| """ | |
| Build and return a SAM2ImagePredictor model on the appropriate device. | |
| Args: | |
| model_cfg: Configuration for SAM2 model. | |
| model_checkpoint (str): Path to model checkpoint. | |
| Returns: | |
| SAM2ImagePredictor: Initialized SAM2 image predictor. | |
| """ | |
| if torch.cuda.is_available(): | |
| device = torch.device("cuda") | |
| elif torch.backends.mps.is_available(): | |
| device = torch.device("mps") | |
| else: | |
| device = torch.device("cpu") | |
| sam2 = build_sam2(model_cfg, model_checkpoint, device=device, apply_postprocessing=True) | |
| model = SAM2ImagePredictor( | |
| sam2, | |
| max_hole_area=10.0, | |
| max_sprinkle_area=50.0, | |
| ) | |
| return model | |
| def _compute_mask_pose_consistency(masks: List[np.ndarray], keypoints_list: List[np.ndarray]) -> np.ndarray: | |
| """ | |
| Compute mask-pose consistency score for each mask-keypoints pair. | |
| Args: | |
| masks (List[np.ndarray]): Binary masks list. | |
| keypoints_list (List[np.ndarray]): List of keypoint arrays per instance. | |
| Returns: | |
| np.ndarray: Consistency scores array of shape (N,). | |
| """ | |
| scores: List[float] = [] | |
| for mask, kpts in zip(masks, keypoints_list): | |
| other_kpts = np.concatenate([keypoints_list[:idx], keypoints_list[idx + 1 :]], axis=0).reshape(-1, 3) | |
| score = _compute_one_mask_pose_consistency(mask, kpts, other_kpts) | |
| scores.append(score) | |
| return np.array(scores) | |
| def _pose2seg( | |
| args: Any, | |
| model: SAM2ImagePredictor, | |
| bbox_xyxy: Optional[List[float]] = None, | |
| pos_kpts: Optional[np.ndarray] = None, | |
| neg_kpts: Optional[np.ndarray] = None, | |
| image: Optional[np.ndarray] = None, | |
| gt_mask: Optional[Any] = None, | |
| num_pos_keypoints: Optional[int] = None, | |
| gt_mask_is_binary: bool = False, | |
| ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, float]: | |
| """ | |
| Run SAM segmentation conditioned on pose keypoints and optional ground truth mask. | |
| Args: | |
| args: Configuration object with prompting settings. | |
| model (SAM2ImagePredictor): Prepared SAM2 model. | |
| bbox_xyxy (Optional[List[float]]): Bounding box coordinates in xyxy format. | |
| pos_kpts (Optional[np.ndarray]): Positive keypoints array. | |
| neg_kpts (Optional[np.ndarray]): Negative keypoints array. | |
| image (Optional[np.ndarray]): Input image array. | |
| gt_mask (Optional[Any]): Ground truth mask (optional). | |
| num_pos_keypoints (Optional[int]): Number of positive keypoints to use. | |
| gt_mask_is_binary (bool): Flag indicating if ground truth mask is binary. | |
| Returns: | |
| Tuple of (mask, pos_kpts_backup, neg_kpts_backup, score). | |
| """ | |
| num_pos_keypoints = args.num_pos_keypoints if num_pos_keypoints is None else num_pos_keypoints | |
| # Filter-out un-annotated and invisible keypoints | |
| if pos_kpts is not None: | |
| pos_kpts = pos_kpts.reshape(-1, 3) | |
| valid_kpts = pos_kpts[:, 2] > args.visibility_thr | |
| pose_bbox = np.array([pos_kpts[:, 0].min(), pos_kpts[:, 1].min(), pos_kpts[:, 0].max(), pos_kpts[:, 1].max()]) | |
| pos_kpts, conf = _select_keypoints(args, pos_kpts, num_visible=valid_kpts.sum(), bbox=bbox_xyxy) | |
| pos_kpts_backup = np.concatenate([pos_kpts, conf[:, None]], axis=1) | |
| if pos_kpts.shape[0] > num_pos_keypoints: | |
| pos_kpts = pos_kpts[:num_pos_keypoints, :] | |
| pos_kpts_backup = pos_kpts_backup[:num_pos_keypoints, :] | |
| else: | |
| pose_bbox = None | |
| pos_kpts = np.empty((0, 2), dtype=np.float32) | |
| pos_kpts_backup = np.empty((0, 3), dtype=np.float32) | |
| if neg_kpts is not None: | |
| neg_kpts = neg_kpts.reshape(-1, 3) | |
| valid_kpts = neg_kpts[:, 2] > args.visibility_thr | |
| neg_kpts, conf = _select_keypoints( | |
| args, neg_kpts, num_visible=valid_kpts.sum(), bbox=bbox_xyxy, method="closest" | |
| ) | |
| selected_neg_kpts = neg_kpts | |
| neg_kpts_backup = np.concatenate([neg_kpts, conf[:, None]], axis=1) | |
| if neg_kpts.shape[0] > args.num_neg_keypoints: | |
| selected_neg_kpts = neg_kpts[: args.num_neg_keypoints, :] | |
| else: | |
| selected_neg_kpts = np.empty((0, 2), dtype=np.float32) | |
| neg_kpts_backup = np.empty((0, 3), dtype=np.float32) | |
| # Concatenate positive and negative keypoints | |
| kpts = np.concatenate([pos_kpts, selected_neg_kpts], axis=0) | |
| kpts_labels = np.concatenate([np.ones(pos_kpts.shape[0]), np.zeros(selected_neg_kpts.shape[0])], axis=0) | |
| bbox = bbox_xyxy if args.use_bbox else None | |
| if args.extend_bbox and not bbox is None: | |
| # Expand the bbox such that it contains all positive keypoints | |
| pose_bbox = np.array( | |
| [pos_kpts[:, 0].min() - 2, pos_kpts[:, 1].min() - 2, pos_kpts[:, 0].max() + 2, pos_kpts[:, 1].max() + 2] | |
| ) | |
| expanded_bbox = np.array(bbox) | |
| expanded_bbox[:2] = np.minimum(bbox[:2], pose_bbox[:2]) | |
| expanded_bbox[2:] = np.maximum(bbox[2:], pose_bbox[2:]) | |
| bbox = expanded_bbox | |
| if args.crop and args.use_bbox and image is not None: | |
| # Crop the image to the 1.5 * bbox size | |
| crop_bbox = np.array(bbox) | |
| bbox_center = np.array([(crop_bbox[0] + crop_bbox[2]) / 2, (crop_bbox[1] + crop_bbox[3]) / 2]) | |
| bbox_size = np.array([crop_bbox[2] - crop_bbox[0], crop_bbox[3] - crop_bbox[1]]) | |
| bbox_size = 1.5 * bbox_size | |
| crop_bbox = np.array( | |
| [ | |
| bbox_center[0] - bbox_size[0] / 2, | |
| bbox_center[1] - bbox_size[1] / 2, | |
| bbox_center[0] + bbox_size[0] / 2, | |
| bbox_center[1] + bbox_size[1] / 2, | |
| ] | |
| ) | |
| crop_bbox = np.round(crop_bbox).astype(int) | |
| crop_bbox = np.clip(crop_bbox, 0, [image.shape[1], image.shape[0], image.shape[1], image.shape[0]]) | |
| original_image_size = image.shape[:2] | |
| image = image[crop_bbox[1] : crop_bbox[3], crop_bbox[0] : crop_bbox[2], :] | |
| # Update the keypoints | |
| kpts = kpts - crop_bbox[:2] | |
| bbox[:2] = bbox[:2] - crop_bbox[:2] | |
| bbox[2:] = bbox[2:] - crop_bbox[:2] | |
| model.set_image(image) | |
| masks, scores, logits = model.predict( | |
| point_coords=kpts, | |
| point_labels=kpts_labels, | |
| box=bbox, | |
| multimask_output=False, | |
| ) | |
| mask = masks[0] | |
| scores = scores[0] | |
| if args.crop and args.use_bbox and image is not None: | |
| # Pad the mask to the original image size | |
| mask_padded = np.zeros(original_image_size, dtype=np.uint8) | |
| mask_padded[crop_bbox[1] : crop_bbox[3], crop_bbox[0] : crop_bbox[2]] = mask | |
| mask = mask_padded | |
| bbox[:2] = bbox[:2] + crop_bbox[:2] | |
| bbox[2:] = bbox[2:] + crop_bbox[:2] | |
| if args.pose_mask_consistency: | |
| if gt_mask_is_binary: | |
| gt_mask_binary = gt_mask | |
| else: | |
| gt_mask_binary = Mask.decode(gt_mask).astype(bool) if gt_mask is not None else None | |
| gt_mask_pose_consistency = _compute_one_mask_pose_consistency(gt_mask_binary, pos_kpts_backup, neg_kpts_backup) | |
| dt_mask_pose_consistency = _compute_one_mask_pose_consistency(mask, pos_kpts_backup, neg_kpts_backup) | |
| tol = 0.1 | |
| dt_is_same = np.abs(dt_mask_pose_consistency - gt_mask_pose_consistency) < tol | |
| if dt_is_same: | |
| mask = gt_mask_binary if gt_mask_binary.sum() < mask.sum() else mask | |
| else: | |
| mask = gt_mask_binary if gt_mask_pose_consistency > dt_mask_pose_consistency else mask | |
| return mask, pos_kpts_backup, neg_kpts_backup, scores | |
| def process_image_with_SAM( | |
| sam_args: Any, | |
| image: np.ndarray, | |
| model: SAM2ImagePredictor, | |
| new_dets: InstanceData, | |
| old_dets: Optional[InstanceData] = None, | |
| ) -> InstanceData: | |
| """ | |
| Wrapper that validates args and routes to single or batch processing. | |
| """ | |
| _validate_sam_args(sam_args) | |
| if sam_args.batch: | |
| return _process_image_batch(sam_args, image, model, new_dets, old_dets) | |
| return _process_image_single(sam_args, image, model, new_dets, old_dets) | |
| def _process_image_single( | |
| sam_args: Any, | |
| image: np.ndarray, | |
| model: SAM2ImagePredictor, | |
| new_dets: InstanceData, | |
| old_dets: Optional[InstanceData] = None, | |
| ) -> InstanceData: | |
| """ | |
| Refine instance segmentation masks using SAM2 with pose-conditioned prompts. | |
| Args: | |
| sam_args (Any): DotDict containing required SAM parameters: | |
| crop (bool), use_bbox (bool), confidence_thr (float), | |
| ignore_small_bboxes (bool), num_pos_keypoints (int), | |
| num_pos_keypoints_if_crowd (int), crowd_by_max_iou (Optional[float]), | |
| batch (bool), exclusive_masks (bool), extend_bbox (bool), pose_mask_consistency (bool). | |
| image (np.ndarray): BGR image array of shape (H, W, 3). | |
| model (SAM2ImagePredictor): Initialized SAM2 predictor. | |
| new_dets (InstanceData): New detections with attributes: | |
| bboxes, pred_masks, keypoints, bbox_scores. | |
| old_dets (Optional[InstanceData]): Previous detections for negative prompts. | |
| Returns: | |
| InstanceData: `new_dets` updated in-place with | |
| `.refined_masks`, `.sam_scores`, and `.sam_kpts`. | |
| """ | |
| _validate_sam_args(sam_args) | |
| if not (sam_args.crop and sam_args.use_bbox): | |
| model.set_image(image) | |
| # Ignore all keypoints with confidence below the threshold | |
| new_keypoints = new_dets.keypoints.copy() | |
| for kpts in new_keypoints: | |
| conf_mask = kpts[:, 2] < sam_args.confidence_thr | |
| kpts[conf_mask, :] = 0 | |
| n_new_dets = len(new_dets.bboxes) | |
| n_old_dets = 0 | |
| if old_dets is not None: | |
| n_old_dets = len(old_dets.bboxes) | |
| old_keypoints = old_dets.keypoints.copy() | |
| for kpts in old_keypoints: | |
| conf_mask = kpts[:, 2] < sam_args.confidence_thr | |
| kpts[conf_mask, :] = 0 | |
| all_bboxes = new_dets.bboxes.copy() | |
| if old_dets is not None: | |
| all_bboxes = np.concatenate([all_bboxes, old_dets.bboxes], axis=0) | |
| max_ious = _get_max_ious(all_bboxes) | |
| gt_bboxes = [] | |
| new_dets.refined_masks = np.zeros((n_new_dets, image.shape[0], image.shape[1]), dtype=np.uint8) | |
| new_dets.sam_scores = np.zeros_like(new_dets.bbox_scores) | |
| new_dets.sam_kpts = np.zeros((len(new_dets.bboxes), sam_args.num_pos_keypoints, 3), dtype=np.float32) | |
| for instance_idx in range(len(new_dets.bboxes)): | |
| bbox_xywh = new_dets.bboxes[instance_idx] | |
| bbox_area = bbox_xywh[2] * bbox_xywh[3] | |
| if sam_args.ignore_small_bboxes and bbox_area < 100 * 100: | |
| continue | |
| dt_mask = new_dets.pred_masks[instance_idx] if new_dets.pred_masks is not None else None | |
| bbox_xyxy = [bbox_xywh[0], bbox_xywh[1], bbox_xywh[0] + bbox_xywh[2], bbox_xywh[1] + bbox_xywh[3]] | |
| gt_bboxes.append(bbox_xyxy) | |
| this_kpts = new_keypoints[instance_idx].reshape(1, -1, 3) | |
| other_kpts = None | |
| if old_dets is not None: | |
| other_kpts = old_keypoints.copy().reshape(n_old_dets, -1, 3) | |
| if len(new_keypoints) > 1: | |
| other_new_kpts = np.concatenate([new_keypoints[:instance_idx], new_keypoints[instance_idx + 1 :]], axis=0) | |
| other_kpts = ( | |
| np.concatenate([other_kpts, other_new_kpts], axis=0) if other_kpts is not None else other_new_kpts | |
| ) | |
| num_pos_keypoints = sam_args.num_pos_keypoints | |
| if sam_args.crowd_by_max_iou is not None and max_ious[instance_idx] > sam_args.crowd_by_max_iou: | |
| bbox_xyxy = None | |
| num_pos_keypoints = sam_args.num_pos_keypoints_if_crowd | |
| dt_mask, pos_kpts, neg_kpts, scores = _pose2seg( | |
| sam_args, | |
| model, | |
| bbox_xyxy, | |
| pos_kpts=this_kpts, | |
| neg_kpts=other_kpts, | |
| image=image if (sam_args.crop and sam_args.use_bbox) else None, | |
| gt_mask=dt_mask, | |
| num_pos_keypoints=num_pos_keypoints, | |
| gt_mask_is_binary=True, | |
| ) | |
| new_dets.refined_masks[instance_idx] = dt_mask | |
| new_dets.sam_scores[instance_idx] = scores | |
| # If the number of positive keypoints is less than the required number, fill the rest with zeros | |
| if len(pos_kpts) != sam_args.num_pos_keypoints: | |
| pos_kpts = np.concatenate( | |
| [pos_kpts, np.zeros((sam_args.num_pos_keypoints - len(pos_kpts), 3), dtype=np.float32)], axis=0 | |
| ) | |
| new_dets.sam_kpts[instance_idx] = pos_kpts | |
| n_masks = len(new_dets.refined_masks) + (len(old_dets.refined_masks) if old_dets is not None else 0) | |
| if sam_args.exclusive_masks and n_masks > 1: | |
| all_masks = ( | |
| np.concatenate([new_dets.refined_masks, old_dets.refined_masks], axis=0) | |
| if old_dets is not None | |
| else new_dets.refined_masks | |
| ) | |
| all_scores = ( | |
| np.concatenate([new_dets.sam_scores, old_dets.sam_scores], axis=0) | |
| if old_dets is not None | |
| else new_dets.sam_scores | |
| ) | |
| refined_masks = _apply_exclusive_masks(all_masks, all_scores) | |
| new_dets.refined_masks = refined_masks[: len(new_dets.refined_masks)] | |
| return new_dets | |
| def _process_image_batch( | |
| sam_args: Any, | |
| image: np.ndarray, | |
| model: SAM2ImagePredictor, | |
| new_dets: InstanceData, | |
| old_dets: Optional[InstanceData] = None, | |
| ) -> InstanceData: | |
| """ | |
| Batch process multiple detection instances with SAM2 refinement. | |
| Args: | |
| sam_args (Any): DotDict of SAM parameters (same as `process_image_with_SAM`). | |
| image (np.ndarray): Input BGR image. | |
| model (SAM2ImagePredictor): Prepared SAM2 predictor. | |
| new_dets (InstanceData): New detection instances. | |
| old_dets (Optional[InstanceData]): Previous detections for negative prompts. | |
| Returns: | |
| InstanceData: `new_dets` updated as in `process_image_with_SAM`. | |
| """ | |
| n_new_dets = len(new_dets.bboxes) | |
| model.set_image(image) | |
| image_kpts = [] | |
| image_bboxes = [] | |
| num_valid_kpts = [] | |
| for instance_idx in range(len(new_dets.bboxes)): | |
| bbox_xywh = new_dets.bboxes[instance_idx].copy() | |
| bbox_area = bbox_xywh[2] * bbox_xywh[3] | |
| if sam_args.ignore_small_bboxes and bbox_area < 100 * 100: | |
| continue | |
| this_kpts = new_dets.keypoints[instance_idx].copy().reshape(-1, 3) | |
| kpts_vis = np.array(this_kpts[:, 2]) | |
| visible_kpts = (kpts_vis > sam_args.visibility_thr) & (this_kpts[:, 2] > sam_args.confidence_thr) | |
| num_visible = (visible_kpts).sum() | |
| if num_visible <= 0: | |
| continue | |
| num_valid_kpts.append(num_visible) | |
| image_bboxes.append(np.array(bbox_xywh)) | |
| this_kpts[~visible_kpts, :2] = 0 | |
| this_kpts[:, 2] = visible_kpts | |
| image_kpts.append(this_kpts) | |
| if old_dets is not None: | |
| for instance_idx in range(len(old_dets.bboxes)): | |
| bbox_xywh = old_dets.bboxes[instance_idx].copy() | |
| bbox_area = bbox_xywh[2] * bbox_xywh[3] | |
| if sam_args.ignore_small_bboxes and bbox_area < 100 * 100: | |
| continue | |
| this_kpts = old_dets.keypoints[instance_idx].reshape(-1, 3) | |
| kpts_vis = np.array(this_kpts[:, 2]) | |
| visible_kpts = (kpts_vis > sam_args.visibility_thr) & (this_kpts[:, 2] > sam_args.confidence_thr) | |
| num_visible = (visible_kpts).sum() | |
| if num_visible <= 0: | |
| continue | |
| num_valid_kpts.append(num_visible) | |
| image_bboxes.append(np.array(bbox_xywh)) | |
| this_kpts[~visible_kpts, :2] = 0 | |
| this_kpts[:, 2] = visible_kpts | |
| image_kpts.append(this_kpts) | |
| image_kpts = np.array(image_kpts) | |
| image_bboxes = np.array(image_bboxes) | |
| num_valid_kpts = np.array(num_valid_kpts) | |
| image_kpts_backup = image_kpts.copy() | |
| # Prepare keypoints such that all instances have the same number of keypoints | |
| # First sort keypoints by their distance to the center of the bounding box | |
| # If some are missing, duplicate the last one | |
| prepared_kpts = [] | |
| prepared_kpts_backup = [] | |
| for bbox, kpts, num_visible in zip(image_bboxes, image_kpts, num_valid_kpts): | |
| this_kpts, this_conf = _select_keypoints(sam_args, kpts, num_visible, bbox) | |
| # Duplicate the last keypoint if some are missing | |
| if this_kpts.shape[0] < num_valid_kpts.max(): | |
| this_kpts = np.concatenate( | |
| [this_kpts, np.tile(this_kpts[-1], (num_valid_kpts.max() - this_kpts.shape[0], 1))], axis=0 | |
| ) | |
| this_conf = np.concatenate( | |
| [this_conf, np.tile(this_conf[-1], (num_valid_kpts.max() - this_conf.shape[0],))], axis=0 | |
| ) | |
| prepared_kpts.append(this_kpts) | |
| prepared_kpts_backup.append(np.concatenate([this_kpts, this_conf[:, None]], axis=1)) | |
| image_kpts = np.array(prepared_kpts) | |
| image_kpts_backup = np.array(prepared_kpts_backup) | |
| kpts_labels = np.ones(image_kpts.shape[:2]) | |
| # Compute IoUs between all bounding boxes | |
| max_ious = _get_max_ious(image_bboxes) | |
| num_pos_keypoints = sam_args.num_pos_keypoints | |
| use_bbox = sam_args.use_bbox | |
| if sam_args.crowd_by_max_iou is not None and max_ious[instance_idx] > sam_args.crowd_by_max_iou: | |
| use_bbox = False | |
| num_pos_keypoints = sam_args.num_pos_keypoints_if_crowd | |
| # Threshold the number of positive keypoints | |
| if num_pos_keypoints > 0 and num_pos_keypoints < image_kpts.shape[1]: | |
| image_kpts = image_kpts[:, :num_pos_keypoints, :] | |
| kpts_labels = kpts_labels[:, :num_pos_keypoints] | |
| image_kpts_backup = image_kpts_backup[:, :num_pos_keypoints, :] | |
| elif num_pos_keypoints == 0: | |
| image_kpts = None | |
| kpts_labels = None | |
| image_kpts_backup = np.empty((0, 3), dtype=np.float32) | |
| image_bboxes_xyxy = None | |
| if use_bbox: | |
| image_bboxes_xyxy = np.array(image_bboxes) | |
| image_bboxes_xyxy[:, 2:] += image_bboxes_xyxy[:, :2] | |
| # Expand the bbox to include the positive keypoints | |
| if sam_args.extend_bbox: | |
| pose_bbox = np.stack( | |
| [ | |
| np.min(image_kpts[:, :, 0], axis=1) - 2, | |
| np.min(image_kpts[:, :, 1], axis=1) - 2, | |
| np.max(image_kpts[:, :, 0], axis=1) + 2, | |
| np.max(image_kpts[:, :, 1], axis=1) + 2, | |
| ], | |
| axis=1, | |
| ) | |
| expanded_bbox = np.array(image_bboxes_xyxy) | |
| expanded_bbox[:, :2] = np.minimum(expanded_bbox[:, :2], pose_bbox[:, :2]) | |
| expanded_bbox[:, 2:] = np.maximum(expanded_bbox[:, 2:], pose_bbox[:, 2:]) | |
| # bbox_expanded = (np.abs(expanded_bbox - image_bboxes_xyxy) > 1e-4).any(axis=1) | |
| image_bboxes_xyxy = expanded_bbox | |
| # Process even old detections to get their 'negative' keypoints | |
| masks, scores, logits = model.predict( | |
| point_coords=image_kpts, | |
| point_labels=kpts_labels, | |
| box=image_bboxes_xyxy, | |
| multimask_output=False, | |
| ) | |
| # Reshape the masks to (N, C, H, W). If the model outputs (C, H, W), add a number of masks dimension | |
| if len(masks.shape) == 3: | |
| masks = masks[None, :, :, :] | |
| masks = masks[:, 0, :, :] | |
| N = masks.shape[0] | |
| scores = scores.reshape(N) | |
| if sam_args.exclusive_masks and N > 1: | |
| # Make sure the masks are non-overlapping | |
| # If two masks overlap, set the pixel to the one with the highest score | |
| masks = _apply_exclusive_masks(masks, scores) | |
| gt_masks = new_dets.pred_masks.copy() if new_dets.pred_masks is not None else None | |
| if sam_args.pose_mask_consistency and gt_masks is not None: | |
| # Measure 'mask-pose_conistency' by computing number of keypoints inside the mask | |
| # Compute for both gt (if available) and predicted masks and then choose the one with higher consistency | |
| dt_mask_pose_consistency = _compute_mask_pose_consistency(masks, image_kpts_backup) | |
| gt_mask_pose_consistency = _compute_mask_pose_consistency(gt_masks, image_kpts_backup) | |
| dt_masks_area = np.array([m.sum() for m in masks]) | |
| gt_masks_area = np.array([m.sum() for m in gt_masks]) if gt_masks is not None else np.zeros_like(dt_masks_area) | |
| # If PM-c is approx the same, prefer the smaller mask | |
| tol = 0.1 | |
| pmc_is_equal = np.isclose(dt_mask_pose_consistency, gt_mask_pose_consistency, atol=tol) | |
| dt_is_worse = (dt_mask_pose_consistency < (gt_mask_pose_consistency - tol)) | pmc_is_equal & ( | |
| dt_masks_area > gt_masks_area | |
| ) | |
| new_masks = [] | |
| for dt_mask, gt_mask, dt_worse in zip(masks, gt_masks, dt_is_worse): | |
| if dt_worse: | |
| new_masks.append(gt_mask) | |
| else: | |
| new_masks.append(dt_mask) | |
| masks = np.array(new_masks) | |
| new_dets.refined_masks = masks[:n_new_dets] | |
| new_dets.sam_scores = scores[:n_new_dets] | |
| new_dets.sam_kpts = image_kpts_backup[:n_new_dets] | |
| return new_dets | |
| def _apply_exclusive_masks(masks: np.ndarray, scores: np.ndarray) -> np.ndarray: | |
| """ | |
| Ensure masks are non-overlapping by keeping at each pixel the mask with the highest score. | |
| """ | |
| no_mask = masks.sum(axis=0) == 0 | |
| masked_scores = masks * scores[:, None, None] | |
| argmax_masks = np.argmax(masked_scores, axis=0) | |
| new_masks = argmax_masks[None, :, :] == (np.arange(masks.shape[0])[:, None, None]) | |
| new_masks[:, no_mask] = 0 | |
| return new_masks | |