File size: 8,424 Bytes
24e5510
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
from time import time
from time import time
from typing import List, Union, Tuple

import numpy as np
import torch


def generate_bounding_boxes(mask, bbox_size=(192, 192, 192), stride: Union[List[int], Tuple[int, int, int], str] = (16, 16, 16), margin=(10, 10, 10), max_depth=5, current_depth=0):
    """
    Generate overlapping bounding boxes to cover a 3D binary segmentation mask using PyTorch tensors.

    Parameters:
    - mask: 3D PyTorch tensor with values 0 or 1 (binary mask)
    - bbox_size: Tuple or list of three integers specifying the size of bounding boxes per dimension (x, y, z)
    - stride: Tuple or list of three integers specifying the stride for subsampling centers per dimension
    - margin: Tuple or list of three integers specifying the margin to leave uncovered per dimension
    - max_depth: Maximum recursion depth to prevent infinite recursion
    - current_depth: Current recursion depth (used internally)

    Returns:
    - List of tuples [(min_coords, max_coords), ...], where min_coords and max_coords are lists [x, y, z] defining each box
      as a half-open interval [min_coords, max_coords).
    """
    # Prevent infinite recursion
    if current_depth > max_depth:
        # print('random fallback due to max recursion depth')
        return random_sampling_fallback(mask, bbox_size, margin, 25)

    # Ensure bbox_size, stride, and margin are lists
    bbox_size = list(bbox_size)
    margin = list(margin)

    # Compute half sizes for each dimension
    half_size = [bs // 2 for bs in bbox_size]
    # Adjust end offsets to ensure full bbox_size (handles odd sizes)
    end_offset = [bs - hs for bs, hs in zip(bbox_size, half_size)]  # e.g., 193 - 96 = 97

    # Step 1: Find all object voxels
    object_voxels = torch.nonzero(mask, as_tuple=False)
    if object_voxels.numel() == 0:
        return []

    # Step 2: Compute the object's bounding box to limit potential centers
    min_coords = object_voxels.min(dim=0)[0]
    max_coords = object_voxels.max(dim=0)[0]

    if isinstance(stride, str) and stride == 'auto':
        stride = [max(1, round((j.item() - i.item()) / 4)) for i, j in zip(min_coords, max_coords)]

    stride = list(stride)
    # print('stride', stride)
    # print('bbox', [[i, j] for i, j in zip(min_coords, max_coords)])

    # Step 3: Generate potential centers within the object's bounding box
    potential_centers = []
    for x in range(max(0, min_coords[0].item()), min(mask.shape[0], max_coords[0].item() + 1), stride[0]):
        for y in range(max(0, min_coords[1].item()), min(mask.shape[1], max_coords[1].item() + 1), stride[1]):
            for z in range(max(0, min_coords[2].item()), min(mask.shape[2], max_coords[2].item() + 1), stride[2]):
                if mask[x, y, z]:
                    potential_centers.append([x, y, z])
    # print(f'got {len(potential_centers)} center candidates')

    if len(potential_centers) == 0:
        return generate_bounding_boxes(
            mask, bbox_size, [max(1, s // 2) for s in stride], margin, max_depth, current_depth + 1
        )

    potential_centers = torch.tensor(potential_centers, device=mask.device)

    # Step 4: Greedy set cover algorithm
    uncovered = mask.clone().byte()  # Use byte tensor for efficiency
    bboxes = []

    while len(potential_centers) > 0 and uncovered.any():
        best_center = None
        best_covered = 0
        best_bounds = None

        # Find the center that covers the most uncovered voxels
        idx = 0
        while idx < len(potential_centers):
            center = potential_centers[idx]
            c_x, c_y, c_z = center
            x_start = max(0, c_x - half_size[0] + margin[0])
            x_end = min(mask.shape[0], c_x + end_offset[0] - margin[0])  # Use end_offset for odd sizes
            y_start = max(0, c_y - half_size[1] + margin[1])
            y_end = min(mask.shape[1], c_y + end_offset[1] - margin[1])
            z_start = max(0, c_z - half_size[2] + margin[2])
            z_end = min(mask.shape[2], c_z + end_offset[2] - margin[2])

            num_covered = uncovered[
                          x_start:x_end,
                          y_start:y_end,
                          z_start:z_end
            ].sum().item()
            if num_covered > best_covered:
                best_covered = num_covered
                best_center = idx
                best_bounds = (x_start, x_end, y_start, y_end, z_start, z_end)
            idx += 1

        # If no new voxels are covered, stop
        if best_covered == 0:
            break

        # Add the best bounding box
        c_x, c_y, c_z = [i.item() for i in potential_centers[best_center]]
        bboxes.append([
            [c_x - half_size[0], c_x + end_offset[0]],
            [c_y - half_size[1], c_y + end_offset[1]],
            [c_z - half_size[2], c_z + end_offset[2]],
        ])

        # Mark voxels as covered, respecting the margin
        x_s, x_e, y_s, y_e, z_s, z_e = best_bounds
        uncovered[
            x_s: x_e,
            y_s: y_e,
            z_s: z_e,
        ] = 0

        # Remove the used center from potential_centers
        potential_centers = potential_centers[uncovered[tuple(potential_centers.T)] > 0]

    # Step 5: Recursively cover remaining voxels using uncovered as the mask
    if uncovered.any():
        if uncovered.sum() < np.prod([i // 3 for i in bbox_size]):
            # print('random fallback')
            bboxes.extend(random_sampling_fallback(uncovered, bbox_size, margin, 25))
        else:
            remaining_bboxes = generate_bounding_boxes(
                uncovered, bbox_size, [max(1, s // 2) for s in stride], margin, max_depth, current_depth + 1
            )
            bboxes.extend(remaining_bboxes)

    return bboxes


def random_sampling_fallback(mask: torch.Tensor, bbox_size=(192, 192, 192), margin=(10, 10, 10), n_samples: int = 25):
    half_size = [bs // 2 for bs in bbox_size]
    # Adjust end offsets to ensure full bbox_size (handles odd sizes)
    end_offset = [bs - hs for bs, hs in zip(bbox_size, half_size)]  # e.g., 193 - 96 = 97

    bboxes = []

    while mask.any():
        indices = torch.nonzero(mask) # nx3

        best_center = None
        best_covered = 0
        best_bounds = None

        # Find the center that covers the most uncovered voxels
        for i in range(n_samples):
            idx = np.random.choice(len(indices))
            center = indices[idx]
            c_x, c_y, c_z = center
            x_start = max(0, c_x - half_size[0] + margin[0])
            x_end = min(mask.shape[0], c_x + end_offset[0] - margin[0])  # Use end_offset for odd sizes
            y_start = max(0, c_y - half_size[1] + margin[1])
            y_end = min(mask.shape[1], c_y + end_offset[1] - margin[1])
            z_start = max(0, c_z - half_size[2] + margin[2])
            z_end = min(mask.shape[2], c_z + end_offset[2] - margin[2])

            num_covered = mask[
                          x_start:x_end,
                          y_start:y_end,
                          z_start:z_end
            ].sum().item()
            if num_covered > best_covered:
                best_covered = num_covered
                best_center = center
                best_bounds = (x_start, x_end, y_start, y_end, z_start, z_end)

        # Add the best bounding box
        c_x, c_y, c_z = best_center
        bboxes.append([
            [c_x - half_size[0], c_x + end_offset[0]],
            [c_y - half_size[1], c_y + end_offset[1]],
            [c_z - half_size[2], c_z + end_offset[2]],
        ])

        # Mark voxels as covered, respecting the margin
        x_s, x_e, y_s, y_e, z_s, z_e = best_bounds
        mask[
            x_s: x_e,
            y_s: y_e,
            z_s: z_e,
        ] = 0
    return bboxes


if __name__ == '__main__':
    times = []
    torch.set_num_threads(8)
    for _ in range(1):
        st = time()
        mask = torch.zeros((256, 256, 256), dtype=torch.uint8, device=0)
        mask[50:150, 50:150, 50:150] = 1  # A cubic object

        # Generate bounding boxes with an odd size to test
        bboxes = random_sampling_fallback(
            mask,
            bbox_size=(193, 193, 193),  # Odd size
            stride='auto',
            margin=(10, 10, 10)
        )

        # Print results
        print(f"Number of bounding boxes: {len(bboxes)}")
        end = time()
        times.append(end - st)
    print(times)