hanjang's picture
Upload folder using huggingface_hub
24e5510 verified
from time import time
from time import time
from typing import List, Union, Tuple
import numpy as np
import torch
def generate_bounding_boxes(mask, bbox_size=(192, 192, 192), stride: Union[List[int], Tuple[int, int, int], str] = (16, 16, 16), margin=(10, 10, 10), max_depth=5, current_depth=0):
"""
Generate overlapping bounding boxes to cover a 3D binary segmentation mask using PyTorch tensors.
Parameters:
- mask: 3D PyTorch tensor with values 0 or 1 (binary mask)
- bbox_size: Tuple or list of three integers specifying the size of bounding boxes per dimension (x, y, z)
- stride: Tuple or list of three integers specifying the stride for subsampling centers per dimension
- margin: Tuple or list of three integers specifying the margin to leave uncovered per dimension
- max_depth: Maximum recursion depth to prevent infinite recursion
- current_depth: Current recursion depth (used internally)
Returns:
- List of tuples [(min_coords, max_coords), ...], where min_coords and max_coords are lists [x, y, z] defining each box
as a half-open interval [min_coords, max_coords).
"""
# Prevent infinite recursion
if current_depth > max_depth:
# print('random fallback due to max recursion depth')
return random_sampling_fallback(mask, bbox_size, margin, 25)
# Ensure bbox_size, stride, and margin are lists
bbox_size = list(bbox_size)
margin = list(margin)
# Compute half sizes for each dimension
half_size = [bs // 2 for bs in bbox_size]
# Adjust end offsets to ensure full bbox_size (handles odd sizes)
end_offset = [bs - hs for bs, hs in zip(bbox_size, half_size)] # e.g., 193 - 96 = 97
# Step 1: Find all object voxels
object_voxels = torch.nonzero(mask, as_tuple=False)
if object_voxels.numel() == 0:
return []
# Step 2: Compute the object's bounding box to limit potential centers
min_coords = object_voxels.min(dim=0)[0]
max_coords = object_voxels.max(dim=0)[0]
if isinstance(stride, str) and stride == 'auto':
stride = [max(1, round((j.item() - i.item()) / 4)) for i, j in zip(min_coords, max_coords)]
stride = list(stride)
# print('stride', stride)
# print('bbox', [[i, j] for i, j in zip(min_coords, max_coords)])
# Step 3: Generate potential centers within the object's bounding box
potential_centers = []
for x in range(max(0, min_coords[0].item()), min(mask.shape[0], max_coords[0].item() + 1), stride[0]):
for y in range(max(0, min_coords[1].item()), min(mask.shape[1], max_coords[1].item() + 1), stride[1]):
for z in range(max(0, min_coords[2].item()), min(mask.shape[2], max_coords[2].item() + 1), stride[2]):
if mask[x, y, z]:
potential_centers.append([x, y, z])
# print(f'got {len(potential_centers)} center candidates')
if len(potential_centers) == 0:
return generate_bounding_boxes(
mask, bbox_size, [max(1, s // 2) for s in stride], margin, max_depth, current_depth + 1
)
potential_centers = torch.tensor(potential_centers, device=mask.device)
# Step 4: Greedy set cover algorithm
uncovered = mask.clone().byte() # Use byte tensor for efficiency
bboxes = []
while len(potential_centers) > 0 and uncovered.any():
best_center = None
best_covered = 0
best_bounds = None
# Find the center that covers the most uncovered voxels
idx = 0
while idx < len(potential_centers):
center = potential_centers[idx]
c_x, c_y, c_z = center
x_start = max(0, c_x - half_size[0] + margin[0])
x_end = min(mask.shape[0], c_x + end_offset[0] - margin[0]) # Use end_offset for odd sizes
y_start = max(0, c_y - half_size[1] + margin[1])
y_end = min(mask.shape[1], c_y + end_offset[1] - margin[1])
z_start = max(0, c_z - half_size[2] + margin[2])
z_end = min(mask.shape[2], c_z + end_offset[2] - margin[2])
num_covered = uncovered[
x_start:x_end,
y_start:y_end,
z_start:z_end
].sum().item()
if num_covered > best_covered:
best_covered = num_covered
best_center = idx
best_bounds = (x_start, x_end, y_start, y_end, z_start, z_end)
idx += 1
# If no new voxels are covered, stop
if best_covered == 0:
break
# Add the best bounding box
c_x, c_y, c_z = [i.item() for i in potential_centers[best_center]]
bboxes.append([
[c_x - half_size[0], c_x + end_offset[0]],
[c_y - half_size[1], c_y + end_offset[1]],
[c_z - half_size[2], c_z + end_offset[2]],
])
# Mark voxels as covered, respecting the margin
x_s, x_e, y_s, y_e, z_s, z_e = best_bounds
uncovered[
x_s: x_e,
y_s: y_e,
z_s: z_e,
] = 0
# Remove the used center from potential_centers
potential_centers = potential_centers[uncovered[tuple(potential_centers.T)] > 0]
# Step 5: Recursively cover remaining voxels using uncovered as the mask
if uncovered.any():
if uncovered.sum() < np.prod([i // 3 for i in bbox_size]):
# print('random fallback')
bboxes.extend(random_sampling_fallback(uncovered, bbox_size, margin, 25))
else:
remaining_bboxes = generate_bounding_boxes(
uncovered, bbox_size, [max(1, s // 2) for s in stride], margin, max_depth, current_depth + 1
)
bboxes.extend(remaining_bboxes)
return bboxes
def random_sampling_fallback(mask: torch.Tensor, bbox_size=(192, 192, 192), margin=(10, 10, 10), n_samples: int = 25):
half_size = [bs // 2 for bs in bbox_size]
# Adjust end offsets to ensure full bbox_size (handles odd sizes)
end_offset = [bs - hs for bs, hs in zip(bbox_size, half_size)] # e.g., 193 - 96 = 97
bboxes = []
while mask.any():
indices = torch.nonzero(mask) # nx3
best_center = None
best_covered = 0
best_bounds = None
# Find the center that covers the most uncovered voxels
for i in range(n_samples):
idx = np.random.choice(len(indices))
center = indices[idx]
c_x, c_y, c_z = center
x_start = max(0, c_x - half_size[0] + margin[0])
x_end = min(mask.shape[0], c_x + end_offset[0] - margin[0]) # Use end_offset for odd sizes
y_start = max(0, c_y - half_size[1] + margin[1])
y_end = min(mask.shape[1], c_y + end_offset[1] - margin[1])
z_start = max(0, c_z - half_size[2] + margin[2])
z_end = min(mask.shape[2], c_z + end_offset[2] - margin[2])
num_covered = mask[
x_start:x_end,
y_start:y_end,
z_start:z_end
].sum().item()
if num_covered > best_covered:
best_covered = num_covered
best_center = center
best_bounds = (x_start, x_end, y_start, y_end, z_start, z_end)
# Add the best bounding box
c_x, c_y, c_z = best_center
bboxes.append([
[c_x - half_size[0], c_x + end_offset[0]],
[c_y - half_size[1], c_y + end_offset[1]],
[c_z - half_size[2], c_z + end_offset[2]],
])
# Mark voxels as covered, respecting the margin
x_s, x_e, y_s, y_e, z_s, z_e = best_bounds
mask[
x_s: x_e,
y_s: y_e,
z_s: z_e,
] = 0
return bboxes
if __name__ == '__main__':
times = []
torch.set_num_threads(8)
for _ in range(1):
st = time()
mask = torch.zeros((256, 256, 256), dtype=torch.uint8, device=0)
mask[50:150, 50:150, 50:150] = 1 # A cubic object
# Generate bounding boxes with an odd size to test
bboxes = random_sampling_fallback(
mask,
bbox_size=(193, 193, 193), # Odd size
stride='auto',
margin=(10, 10, 10)
)
# Print results
print(f"Number of bounding boxes: {len(bboxes)}")
end = time()
times.append(end - st)
print(times)