Spaces:
Sleeping
Sleeping
File size: 8,424 Bytes
24e5510 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 |
from time import time
from time import time
from typing import List, Union, Tuple
import numpy as np
import torch
def generate_bounding_boxes(mask, bbox_size=(192, 192, 192), stride: Union[List[int], Tuple[int, int, int], str] = (16, 16, 16), margin=(10, 10, 10), max_depth=5, current_depth=0):
"""
Generate overlapping bounding boxes to cover a 3D binary segmentation mask using PyTorch tensors.
Parameters:
- mask: 3D PyTorch tensor with values 0 or 1 (binary mask)
- bbox_size: Tuple or list of three integers specifying the size of bounding boxes per dimension (x, y, z)
- stride: Tuple or list of three integers specifying the stride for subsampling centers per dimension
- margin: Tuple or list of three integers specifying the margin to leave uncovered per dimension
- max_depth: Maximum recursion depth to prevent infinite recursion
- current_depth: Current recursion depth (used internally)
Returns:
- List of tuples [(min_coords, max_coords), ...], where min_coords and max_coords are lists [x, y, z] defining each box
as a half-open interval [min_coords, max_coords).
"""
# Prevent infinite recursion
if current_depth > max_depth:
# print('random fallback due to max recursion depth')
return random_sampling_fallback(mask, bbox_size, margin, 25)
# Ensure bbox_size, stride, and margin are lists
bbox_size = list(bbox_size)
margin = list(margin)
# Compute half sizes for each dimension
half_size = [bs // 2 for bs in bbox_size]
# Adjust end offsets to ensure full bbox_size (handles odd sizes)
end_offset = [bs - hs for bs, hs in zip(bbox_size, half_size)] # e.g., 193 - 96 = 97
# Step 1: Find all object voxels
object_voxels = torch.nonzero(mask, as_tuple=False)
if object_voxels.numel() == 0:
return []
# Step 2: Compute the object's bounding box to limit potential centers
min_coords = object_voxels.min(dim=0)[0]
max_coords = object_voxels.max(dim=0)[0]
if isinstance(stride, str) and stride == 'auto':
stride = [max(1, round((j.item() - i.item()) / 4)) for i, j in zip(min_coords, max_coords)]
stride = list(stride)
# print('stride', stride)
# print('bbox', [[i, j] for i, j in zip(min_coords, max_coords)])
# Step 3: Generate potential centers within the object's bounding box
potential_centers = []
for x in range(max(0, min_coords[0].item()), min(mask.shape[0], max_coords[0].item() + 1), stride[0]):
for y in range(max(0, min_coords[1].item()), min(mask.shape[1], max_coords[1].item() + 1), stride[1]):
for z in range(max(0, min_coords[2].item()), min(mask.shape[2], max_coords[2].item() + 1), stride[2]):
if mask[x, y, z]:
potential_centers.append([x, y, z])
# print(f'got {len(potential_centers)} center candidates')
if len(potential_centers) == 0:
return generate_bounding_boxes(
mask, bbox_size, [max(1, s // 2) for s in stride], margin, max_depth, current_depth + 1
)
potential_centers = torch.tensor(potential_centers, device=mask.device)
# Step 4: Greedy set cover algorithm
uncovered = mask.clone().byte() # Use byte tensor for efficiency
bboxes = []
while len(potential_centers) > 0 and uncovered.any():
best_center = None
best_covered = 0
best_bounds = None
# Find the center that covers the most uncovered voxels
idx = 0
while idx < len(potential_centers):
center = potential_centers[idx]
c_x, c_y, c_z = center
x_start = max(0, c_x - half_size[0] + margin[0])
x_end = min(mask.shape[0], c_x + end_offset[0] - margin[0]) # Use end_offset for odd sizes
y_start = max(0, c_y - half_size[1] + margin[1])
y_end = min(mask.shape[1], c_y + end_offset[1] - margin[1])
z_start = max(0, c_z - half_size[2] + margin[2])
z_end = min(mask.shape[2], c_z + end_offset[2] - margin[2])
num_covered = uncovered[
x_start:x_end,
y_start:y_end,
z_start:z_end
].sum().item()
if num_covered > best_covered:
best_covered = num_covered
best_center = idx
best_bounds = (x_start, x_end, y_start, y_end, z_start, z_end)
idx += 1
# If no new voxels are covered, stop
if best_covered == 0:
break
# Add the best bounding box
c_x, c_y, c_z = [i.item() for i in potential_centers[best_center]]
bboxes.append([
[c_x - half_size[0], c_x + end_offset[0]],
[c_y - half_size[1], c_y + end_offset[1]],
[c_z - half_size[2], c_z + end_offset[2]],
])
# Mark voxels as covered, respecting the margin
x_s, x_e, y_s, y_e, z_s, z_e = best_bounds
uncovered[
x_s: x_e,
y_s: y_e,
z_s: z_e,
] = 0
# Remove the used center from potential_centers
potential_centers = potential_centers[uncovered[tuple(potential_centers.T)] > 0]
# Step 5: Recursively cover remaining voxels using uncovered as the mask
if uncovered.any():
if uncovered.sum() < np.prod([i // 3 for i in bbox_size]):
# print('random fallback')
bboxes.extend(random_sampling_fallback(uncovered, bbox_size, margin, 25))
else:
remaining_bboxes = generate_bounding_boxes(
uncovered, bbox_size, [max(1, s // 2) for s in stride], margin, max_depth, current_depth + 1
)
bboxes.extend(remaining_bboxes)
return bboxes
def random_sampling_fallback(mask: torch.Tensor, bbox_size=(192, 192, 192), margin=(10, 10, 10), n_samples: int = 25):
half_size = [bs // 2 for bs in bbox_size]
# Adjust end offsets to ensure full bbox_size (handles odd sizes)
end_offset = [bs - hs for bs, hs in zip(bbox_size, half_size)] # e.g., 193 - 96 = 97
bboxes = []
while mask.any():
indices = torch.nonzero(mask) # nx3
best_center = None
best_covered = 0
best_bounds = None
# Find the center that covers the most uncovered voxels
for i in range(n_samples):
idx = np.random.choice(len(indices))
center = indices[idx]
c_x, c_y, c_z = center
x_start = max(0, c_x - half_size[0] + margin[0])
x_end = min(mask.shape[0], c_x + end_offset[0] - margin[0]) # Use end_offset for odd sizes
y_start = max(0, c_y - half_size[1] + margin[1])
y_end = min(mask.shape[1], c_y + end_offset[1] - margin[1])
z_start = max(0, c_z - half_size[2] + margin[2])
z_end = min(mask.shape[2], c_z + end_offset[2] - margin[2])
num_covered = mask[
x_start:x_end,
y_start:y_end,
z_start:z_end
].sum().item()
if num_covered > best_covered:
best_covered = num_covered
best_center = center
best_bounds = (x_start, x_end, y_start, y_end, z_start, z_end)
# Add the best bounding box
c_x, c_y, c_z = best_center
bboxes.append([
[c_x - half_size[0], c_x + end_offset[0]],
[c_y - half_size[1], c_y + end_offset[1]],
[c_z - half_size[2], c_z + end_offset[2]],
])
# Mark voxels as covered, respecting the margin
x_s, x_e, y_s, y_e, z_s, z_e = best_bounds
mask[
x_s: x_e,
y_s: y_e,
z_s: z_e,
] = 0
return bboxes
if __name__ == '__main__':
times = []
torch.set_num_threads(8)
for _ in range(1):
st = time()
mask = torch.zeros((256, 256, 256), dtype=torch.uint8, device=0)
mask[50:150, 50:150, 50:150] = 1 # A cubic object
# Generate bounding boxes with an odd size to test
bboxes = random_sampling_fallback(
mask,
bbox_size=(193, 193, 193), # Odd size
stride='auto',
margin=(10, 10, 10)
)
# Print results
print(f"Number of bounding boxes: {len(bboxes)}")
end = time()
times.append(end - st)
print(times)
|