| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """Util functions for Segment Anything models.""" |
|
|
| import jax.numpy as jnp |
| import numpy as np |
| from scenic.projects.baselines.segment_anything.modeling import nms as nms_lib |
|
|
|
|
| def build_point_grid(points_per_side): |
| """Generates a 2D grid of points evenly spaced in [0, 1] x [0, 1].""" |
| offset = 1. / (2 * points_per_side) |
| points_one_side = jnp.linspace(offset, 1 - offset, points_per_side) |
| points_x = jnp.tile(points_one_side[None, :], (points_per_side, 1)) |
| points_y = jnp.tile(points_one_side[:, None], (1, points_per_side)) |
| points = jnp.stack([points_x, points_y], axis=-1).reshape(-1, 2) |
| return points |
|
|
|
|
| def batched_mask_to_box(masks): |
| """Convert binary masks in (n, h, w) to boxes (n, 4).""" |
| if masks.shape[0] == 0: |
| return jnp.zeros((0, 4), dtype=jnp.float32) |
|
|
| h, w = masks.shape[-2:] |
| in_height = jnp.max(masks, axis=-1) |
| in_height_coords = in_height * jnp.arange(h)[None] |
| bottom_edges = jnp.max(in_height_coords, axis=-1) |
| |
| in_height_coords = in_height_coords + h * (1 - in_height) |
| top_edges = jnp.min(in_height_coords, axis=-1) |
|
|
| in_width = jnp.max(masks, axis=-2) |
| in_width_coords = in_width * jnp.arange(w)[None] |
| right_edges = jnp.max(in_width_coords, axis=-1) |
| in_width_coords = in_width_coords + w * (1 - in_width) |
| left_edges = jnp.min(in_width_coords, axis=-1) |
|
|
| |
| empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) |
| out = jnp.stack( |
| [left_edges, top_edges, right_edges, bottom_edges], axis=-1) |
| out = out * (1 - empty_filter)[:, None] |
| return out |
|
|
|
|
| def batched_mask_to_box_np(masks): |
| """Convert binary masks in (n, h, w) to boxes (n, 4).""" |
| if masks.shape[0] == 0: |
| return np.zeros((0, 4), dtype=np.float32) |
|
|
| h, w = masks.shape[-2:] |
| in_height = np.max(masks, axis=-1) |
| in_height_coords = in_height * np.arange(h)[None] |
| bottom_edges = np.max(in_height_coords, axis=-1) |
| |
| in_height_coords = in_height_coords + h * (1 - in_height) |
| top_edges = np.min(in_height_coords, axis=-1) |
|
|
| in_width = np.max(masks, axis=-2) |
| in_width_coords = in_width * np.arange(w)[None] |
| right_edges = np.max(in_width_coords, axis=-1) |
| in_width_coords = in_width_coords + w * (1 - in_width) |
| left_edges = np.min(in_width_coords, axis=-1) |
|
|
| |
| empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) |
| out = np.stack( |
| [left_edges, top_edges, right_edges, bottom_edges], axis=-1) |
| out = out * (1 - empty_filter)[:, None] |
| return out |
|
|
|
|
| def calculate_stability_score( |
| mask_logits, mask_threshold, stability_score_offset): |
| """The stability score measures if the mask changes with different thresh.""" |
| low = (mask_logits > (mask_threshold + stability_score_offset)).sum( |
| axis=-1).sum(axis=-1) |
| high = (mask_logits > (mask_threshold - stability_score_offset)).sum( |
| axis=-1).sum(axis=-1) |
| return low / high |
|
|
|
|
| def nms(boxes, scores, iou_threshold, num_outputs=100): |
| _, _, keep = nms_lib.non_max_suppression_padded( |
| scores[None], boxes[None], num_outputs, iou_threshold, |
| return_idx=True) |
| return keep[0] |
|
|