File size: 4,035 Bytes
1327f34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# Copyright 2025 The Scenic Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Util functions for Segment Anything models."""

import jax.numpy as jnp
import numpy as np
from scenic.projects.baselines.segment_anything.modeling import nms as nms_lib


def build_point_grid(points_per_side):
  """Generates a 2D grid of points evenly spaced in [0, 1] x [0, 1]."""
  offset = 1. / (2 * points_per_side)
  points_one_side = jnp.linspace(offset, 1 - offset, points_per_side)
  points_x = jnp.tile(points_one_side[None, :], (points_per_side, 1))
  points_y = jnp.tile(points_one_side[:, None], (1, points_per_side))
  points = jnp.stack([points_x, points_y], axis=-1).reshape(-1, 2)
  return points  # (points_per_side ** 2, 1)


def batched_mask_to_box(masks):
  """Convert binary masks in (n, h, w) to boxes (n, 4)."""
  if masks.shape[0] == 0:
    return jnp.zeros((0, 4), dtype=jnp.float32)

  h, w = masks.shape[-2:]
  in_height = jnp.max(masks, axis=-1)  # (n, h)
  in_height_coords = in_height * jnp.arange(h)[None]  # (n, h)
  bottom_edges = jnp.max(in_height_coords, axis=-1)  # (n, )
  # Mark "0" as "h" so that we can take min.
  in_height_coords = in_height_coords + h * (1 - in_height)  # (n, h)
  top_edges = jnp.min(in_height_coords, axis=-1)  # (n,)

  in_width = jnp.max(masks, axis=-2)  # (n, w)
  in_width_coords = in_width * jnp.arange(w)[None]  # (n, w)
  right_edges = jnp.max(in_width_coords, axis=-1)  # (n,)
  in_width_coords = in_width_coords + w * (1 - in_width)  # (n, w)
  left_edges = jnp.min(in_width_coords, axis=-1)

  # mark empty mask as [0, 0, 0, 0]
  empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
  out = jnp.stack(
      [left_edges, top_edges, right_edges, bottom_edges], axis=-1)  # (n, 4)
  out = out * (1 - empty_filter)[:, None]
  return out


def batched_mask_to_box_np(masks):
  """Convert binary masks in (n, h, w) to boxes (n, 4)."""
  if masks.shape[0] == 0:
    return np.zeros((0, 4), dtype=np.float32)

  h, w = masks.shape[-2:]
  in_height = np.max(masks, axis=-1)  # (n, h)
  in_height_coords = in_height * np.arange(h)[None]  # (n, h)
  bottom_edges = np.max(in_height_coords, axis=-1)  # (n, )
  # Mark "0" as "h" so that we can take min.
  in_height_coords = in_height_coords + h * (1 - in_height)  # (n, h)
  top_edges = np.min(in_height_coords, axis=-1)  # (n,)

  in_width = np.max(masks, axis=-2)  # (n, w)
  in_width_coords = in_width * np.arange(w)[None]  # (n, w)
  right_edges = np.max(in_width_coords, axis=-1)  # (n,)
  in_width_coords = in_width_coords + w * (1 - in_width)  # (n, w)
  left_edges = np.min(in_width_coords, axis=-1)

  # mark empty mask as [0, 0, 0, 0]
  empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
  out = np.stack(
      [left_edges, top_edges, right_edges, bottom_edges], axis=-1)  # (n, 4)
  out = out * (1 - empty_filter)[:, None]
  return out


def calculate_stability_score(
    mask_logits, mask_threshold, stability_score_offset):
  """The stability score measures if the mask changes with different thresh."""
  low = (mask_logits > (mask_threshold + stability_score_offset)).sum(
      axis=-1).sum(axis=-1)
  high = (mask_logits > (mask_threshold - stability_score_offset)).sum(
      axis=-1).sum(axis=-1)
  return low / high


def nms(boxes, scores, iou_threshold, num_outputs=100):
  _, _, keep = nms_lib.non_max_suppression_padded(
      scores[None], boxes[None], num_outputs, iou_threshold,
      return_idx=True)  # pytype: disable=wrong-arg-types
  return keep[0]  # undo batch