English
weiheng-1009's picture
added code for running
cbff41a
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
#
# Original: https://github.com/microsoft/hi-ml/blob/main/hi-ml/src/health_ml/utils/box_utils.py
# ------------------------------------------------------------------------------------------
from dataclasses import dataclass
from typing import Optional, Sequence, Tuple
import numpy as np
from scipy import ndimage
@dataclass(frozen=True)
class Box:
"""Utility class representing rectangular regions in 2D images.
:param x: Horizontal coordinate of the top-left corner.
:param y: Vertical coordinate of the top-left corner.
:param w: Box width.
:param h: Box height.
:raises ValueError: If either `w` or `h` are <= 0.
"""
x: int
y: int
w: int
h: int
def __post_init__(self) -> None:
if self.w <= 0:
raise ValueError(f"Width must be strictly positive, received {self.w}")
if self.h <= 0:
raise ValueError(f"Height must be strictly positive, received {self.w}")
def __add__(self, shift: Sequence[int]) -> 'Box':
"""Translates the box's location by a given shift.
:param shift: A length-2 sequence containing horizontal and vertical shifts.
:return: A new box with updated `x = x + shift[0]` and `y = y + shift[1]`.
:raises ValueError: If `shift` does not have two elements.
"""
if len(shift) != 2:
raise ValueError("Shift must be two-dimensional")
return Box(x=self.x + shift[0],
y=self.y + shift[1],
w=self.w,
h=self.h)
def __mul__(self, factor: float) -> 'Box':
"""Scales the box by a given factor, e.g. when changing resolution.
:param factor: The factor by which to multiply the box's location and dimensions.
:return: The updated box, with location and dimensions rounded to `int`.
"""
return Box(x=int(self.x * factor),
y=int(self.y * factor),
w=int(self.w * factor),
h=int(self.h * factor))
def __rmul__(self, factor: float) -> 'Box':
"""Scales the box by a given factor, e.g. when changing resolution.
:param factor: The factor by which to multiply the box's location and dimensions.
:return: The updated box, with location and dimensions rounded to `int`.
"""
return self * factor
def __truediv__(self, factor: float) -> 'Box':
"""Scales the box by a given factor, e.g. when changing resolution.
:param factor: The factor by which to divide the box's location and dimensions.
:return: The updated box, with location and dimensions rounded to `int`.
"""
return self * (1. / factor)
def add_margin(self, margin: int) -> 'Box':
"""Adds a symmetric margin on all sides of the box.
:param margin: The amount by which to enlarge the box.
:return: A new box enlarged by `margin` on all sides.
"""
return Box(x=self.x - margin,
y=self.y - margin,
w=self.w + 2 * margin,
h=self.h + 2 * margin)
def clip(self, other: 'Box') -> Optional['Box']:
"""Clips a box to the interior of another.
This is useful to constrain a region to the interior of an image.
:param other: Box representing the new constraints.
:return: A new constrained box, or `None` if the boxes do not overlap.
"""
x0 = max(self.x, other.x)
y0 = max(self.y, other.y)
x1 = min(self.x + self.w, other.x + other.w)
y1 = min(self.y + self.h, other.y + other.h)
try:
return Box(x=x0, y=y0, w=x1 - x0, h=y1 - y0)
except ValueError: # Empty result, boxes don't overlap
return None
def to_slices(self) -> Tuple[slice, slice]:
"""Converts the box to slices for indexing arrays.
For example: `my_2d_array[my_box.to_slices()]`.
:return: A 2-tuple with vertical and horizontal slices.
"""
return (slice(self.y, self.y + self.h),
slice(self.x, self.x + self.w))
@staticmethod
def from_slices(slices: Sequence[slice]) -> 'Box':
"""Converts a pair of vertical and horizontal slices into a box.
:param slices: A length-2 sequence containing vertical and horizontal `slice` objects.
:return: A box with corresponding location and dimensions.
"""
vert_slice, horz_slice = slices
return Box(x=horz_slice.start,
y=vert_slice.start,
w=horz_slice.stop - horz_slice.start,
h=vert_slice.stop - vert_slice.start)
def get_bounding_box(mask: np.ndarray) -> Box:
"""Extracts a bounding box from a binary 2D array.
:param mask: A 2D array with 0 (or `False`) as background and >0 (or `True`) as foreground.
:return: The smallest box covering all non-zero elements of `mask`.
:raises TypeError: When the input mask has more than two dimensions.
:raises RuntimeError: When all elements in the mask are zero.
"""
if mask.ndim != 2:
raise TypeError(f"Expected a 2D array but got an array with shape {mask.shape}")
slices = ndimage.find_objects(mask > 0)
if not slices:
raise RuntimeError("The input mask is empty")
assert len(slices) == 1
return Box.from_slices(slices[0])