English
File size: 5,639 Bytes
cbff41a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
#  ------------------------------------------------------------------------------------------
#  Copyright (c) Microsoft Corporation. All rights reserved.
#  Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
#
# Original: https://github.com/microsoft/hi-ml/blob/main/hi-ml/src/health_ml/utils/box_utils.py
#  ------------------------------------------------------------------------------------------

from dataclasses import dataclass
from typing import Optional, Sequence, Tuple

import numpy as np
from scipy import ndimage


@dataclass(frozen=True)
class Box:
    """Utility class representing rectangular regions in 2D images.

    :param x: Horizontal coordinate of the top-left corner.
    :param y: Vertical coordinate of the top-left corner.
    :param w: Box width.
    :param h: Box height.
    :raises ValueError: If either `w` or `h` are <= 0.
    """
    x: int
    y: int
    w: int
    h: int

    def __post_init__(self) -> None:
        if self.w <= 0:
            raise ValueError(f"Width must be strictly positive, received {self.w}")
        if self.h <= 0:
            raise ValueError(f"Height must be strictly positive, received {self.w}")

    def __add__(self, shift: Sequence[int]) -> 'Box':
        """Translates the box's location by a given shift.

        :param shift: A length-2 sequence containing horizontal and vertical shifts.
        :return: A new box with updated `x = x + shift[0]` and `y = y + shift[1]`.
        :raises ValueError: If `shift` does not have two elements.
        """
        if len(shift) != 2:
            raise ValueError("Shift must be two-dimensional")
        return Box(x=self.x + shift[0],
                   y=self.y + shift[1],
                   w=self.w,
                   h=self.h)

    def __mul__(self, factor: float) -> 'Box':
        """Scales the box by a given factor, e.g. when changing resolution.

        :param factor: The factor by which to multiply the box's location and dimensions.
        :return: The updated box, with location and dimensions rounded to `int`.
        """
        return Box(x=int(self.x * factor),
                   y=int(self.y * factor),
                   w=int(self.w * factor),
                   h=int(self.h * factor))

    def __rmul__(self, factor: float) -> 'Box':
        """Scales the box by a given factor, e.g. when changing resolution.

        :param factor: The factor by which to multiply the box's location and dimensions.
        :return: The updated box, with location and dimensions rounded to `int`.
        """
        return self * factor

    def __truediv__(self, factor: float) -> 'Box':
        """Scales the box by a given factor, e.g. when changing resolution.

        :param factor: The factor by which to divide the box's location and dimensions.
        :return: The updated box, with location and dimensions rounded to `int`.
        """
        return self * (1. / factor)

    def add_margin(self, margin: int) -> 'Box':
        """Adds a symmetric margin on all sides of the box.

        :param margin: The amount by which to enlarge the box.
        :return: A new box enlarged by `margin` on all sides.
        """
        return Box(x=self.x - margin,
                   y=self.y - margin,
                   w=self.w + 2 * margin,
                   h=self.h + 2 * margin)

    def clip(self, other: 'Box') -> Optional['Box']:
        """Clips a box to the interior of another.

        This is useful to constrain a region to the interior of an image.

        :param other: Box representing the new constraints.
        :return: A new constrained box, or `None` if the boxes do not overlap.
        """
        x0 = max(self.x, other.x)
        y0 = max(self.y, other.y)
        x1 = min(self.x + self.w, other.x + other.w)
        y1 = min(self.y + self.h, other.y + other.h)
        try:
            return Box(x=x0, y=y0, w=x1 - x0, h=y1 - y0)
        except ValueError:  # Empty result, boxes don't overlap
            return None

    def to_slices(self) -> Tuple[slice, slice]:
        """Converts the box to slices for indexing arrays.

        For example: `my_2d_array[my_box.to_slices()]`.

        :return: A 2-tuple with vertical and horizontal slices.
        """
        return (slice(self.y, self.y + self.h),
                slice(self.x, self.x + self.w))

    @staticmethod
    def from_slices(slices: Sequence[slice]) -> 'Box':
        """Converts a pair of vertical and horizontal slices into a box.

        :param slices: A length-2 sequence containing vertical and horizontal `slice` objects.
        :return: A box with corresponding location and dimensions.
        """
        vert_slice, horz_slice = slices
        return Box(x=horz_slice.start,
                   y=vert_slice.start,
                   w=horz_slice.stop - horz_slice.start,
                   h=vert_slice.stop - vert_slice.start)


def get_bounding_box(mask: np.ndarray) -> Box:
    """Extracts a bounding box from a binary 2D array.

    :param mask: A 2D array with 0 (or `False`) as background and >0 (or `True`) as foreground.
    :return: The smallest box covering all non-zero elements of `mask`.
    :raises TypeError: When the input mask has more than two dimensions.
    :raises RuntimeError: When all elements in the mask are zero.
    """
    if mask.ndim != 2:
        raise TypeError(f"Expected a 2D array but got an array with shape {mask.shape}")

    slices = ndimage.find_objects(mask > 0)
    if not slices:
        raise RuntimeError("The input mask is empty")
    assert len(slices) == 1

    return Box.from_slices(slices[0])