fcxfcx's picture
Upload 2446 files
1327f34 verified
# Copyright 2025 The Scenic Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Util functions for Segment Anything models."""
import jax.numpy as jnp
import numpy as np
from scenic.projects.baselines.segment_anything.modeling import nms as nms_lib
def build_point_grid(points_per_side):
"""Generates a 2D grid of points evenly spaced in [0, 1] x [0, 1]."""
offset = 1. / (2 * points_per_side)
points_one_side = jnp.linspace(offset, 1 - offset, points_per_side)
points_x = jnp.tile(points_one_side[None, :], (points_per_side, 1))
points_y = jnp.tile(points_one_side[:, None], (1, points_per_side))
points = jnp.stack([points_x, points_y], axis=-1).reshape(-1, 2)
return points # (points_per_side ** 2, 1)
def batched_mask_to_box(masks):
"""Convert binary masks in (n, h, w) to boxes (n, 4)."""
if masks.shape[0] == 0:
return jnp.zeros((0, 4), dtype=jnp.float32)
h, w = masks.shape[-2:]
in_height = jnp.max(masks, axis=-1) # (n, h)
in_height_coords = in_height * jnp.arange(h)[None] # (n, h)
bottom_edges = jnp.max(in_height_coords, axis=-1) # (n, )
# Mark "0" as "h" so that we can take min.
in_height_coords = in_height_coords + h * (1 - in_height) # (n, h)
top_edges = jnp.min(in_height_coords, axis=-1) # (n,)
in_width = jnp.max(masks, axis=-2) # (n, w)
in_width_coords = in_width * jnp.arange(w)[None] # (n, w)
right_edges = jnp.max(in_width_coords, axis=-1) # (n,)
in_width_coords = in_width_coords + w * (1 - in_width) # (n, w)
left_edges = jnp.min(in_width_coords, axis=-1)
# mark empty mask as [0, 0, 0, 0]
empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
out = jnp.stack(
[left_edges, top_edges, right_edges, bottom_edges], axis=-1) # (n, 4)
out = out * (1 - empty_filter)[:, None]
return out
def batched_mask_to_box_np(masks):
"""Convert binary masks in (n, h, w) to boxes (n, 4)."""
if masks.shape[0] == 0:
return np.zeros((0, 4), dtype=np.float32)
h, w = masks.shape[-2:]
in_height = np.max(masks, axis=-1) # (n, h)
in_height_coords = in_height * np.arange(h)[None] # (n, h)
bottom_edges = np.max(in_height_coords, axis=-1) # (n, )
# Mark "0" as "h" so that we can take min.
in_height_coords = in_height_coords + h * (1 - in_height) # (n, h)
top_edges = np.min(in_height_coords, axis=-1) # (n,)
in_width = np.max(masks, axis=-2) # (n, w)
in_width_coords = in_width * np.arange(w)[None] # (n, w)
right_edges = np.max(in_width_coords, axis=-1) # (n,)
in_width_coords = in_width_coords + w * (1 - in_width) # (n, w)
left_edges = np.min(in_width_coords, axis=-1)
# mark empty mask as [0, 0, 0, 0]
empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
out = np.stack(
[left_edges, top_edges, right_edges, bottom_edges], axis=-1) # (n, 4)
out = out * (1 - empty_filter)[:, None]
return out
def calculate_stability_score(
mask_logits, mask_threshold, stability_score_offset):
"""The stability score measures if the mask changes with different thresh."""
low = (mask_logits > (mask_threshold + stability_score_offset)).sum(
axis=-1).sum(axis=-1)
high = (mask_logits > (mask_threshold - stability_score_offset)).sum(
axis=-1).sum(axis=-1)
return low / high
def nms(boxes, scores, iou_threshold, num_outputs=100):
_, _, keep = nms_lib.non_max_suppression_padded(
scores[None], boxes[None], num_outputs, iou_threshold,
return_idx=True) # pytype: disable=wrong-arg-types
return keep[0] # undo batch