Spaces:
Sleeping
Sleeping
| from time import time | |
| from time import time | |
| from typing import List, Union, Tuple | |
| import numpy as np | |
| import torch | |
| def generate_bounding_boxes(mask, bbox_size=(192, 192, 192), stride: Union[List[int], Tuple[int, int, int], str] = (16, 16, 16), margin=(10, 10, 10), max_depth=5, current_depth=0): | |
| """ | |
| Generate overlapping bounding boxes to cover a 3D binary segmentation mask using PyTorch tensors. | |
| Parameters: | |
| - mask: 3D PyTorch tensor with values 0 or 1 (binary mask) | |
| - bbox_size: Tuple or list of three integers specifying the size of bounding boxes per dimension (x, y, z) | |
| - stride: Tuple or list of three integers specifying the stride for subsampling centers per dimension | |
| - margin: Tuple or list of three integers specifying the margin to leave uncovered per dimension | |
| - max_depth: Maximum recursion depth to prevent infinite recursion | |
| - current_depth: Current recursion depth (used internally) | |
| Returns: | |
| - List of tuples [(min_coords, max_coords), ...], where min_coords and max_coords are lists [x, y, z] defining each box | |
| as a half-open interval [min_coords, max_coords). | |
| """ | |
| # Prevent infinite recursion | |
| if current_depth > max_depth: | |
| # print('random fallback due to max recursion depth') | |
| return random_sampling_fallback(mask, bbox_size, margin, 25) | |
| # Ensure bbox_size, stride, and margin are lists | |
| bbox_size = list(bbox_size) | |
| margin = list(margin) | |
| # Compute half sizes for each dimension | |
| half_size = [bs // 2 for bs in bbox_size] | |
| # Adjust end offsets to ensure full bbox_size (handles odd sizes) | |
| end_offset = [bs - hs for bs, hs in zip(bbox_size, half_size)] # e.g., 193 - 96 = 97 | |
| # Step 1: Find all object voxels | |
| object_voxels = torch.nonzero(mask, as_tuple=False) | |
| if object_voxels.numel() == 0: | |
| return [] | |
| # Step 2: Compute the object's bounding box to limit potential centers | |
| min_coords = object_voxels.min(dim=0)[0] | |
| max_coords = object_voxels.max(dim=0)[0] | |
| if isinstance(stride, str) and stride == 'auto': | |
| stride = [max(1, round((j.item() - i.item()) / 4)) for i, j in zip(min_coords, max_coords)] | |
| stride = list(stride) | |
| # print('stride', stride) | |
| # print('bbox', [[i, j] for i, j in zip(min_coords, max_coords)]) | |
| # Step 3: Generate potential centers within the object's bounding box | |
| potential_centers = [] | |
| for x in range(max(0, min_coords[0].item()), min(mask.shape[0], max_coords[0].item() + 1), stride[0]): | |
| for y in range(max(0, min_coords[1].item()), min(mask.shape[1], max_coords[1].item() + 1), stride[1]): | |
| for z in range(max(0, min_coords[2].item()), min(mask.shape[2], max_coords[2].item() + 1), stride[2]): | |
| if mask[x, y, z]: | |
| potential_centers.append([x, y, z]) | |
| # print(f'got {len(potential_centers)} center candidates') | |
| if len(potential_centers) == 0: | |
| return generate_bounding_boxes( | |
| mask, bbox_size, [max(1, s // 2) for s in stride], margin, max_depth, current_depth + 1 | |
| ) | |
| potential_centers = torch.tensor(potential_centers, device=mask.device) | |
| # Step 4: Greedy set cover algorithm | |
| uncovered = mask.clone().byte() # Use byte tensor for efficiency | |
| bboxes = [] | |
| while len(potential_centers) > 0 and uncovered.any(): | |
| best_center = None | |
| best_covered = 0 | |
| best_bounds = None | |
| # Find the center that covers the most uncovered voxels | |
| idx = 0 | |
| while idx < len(potential_centers): | |
| center = potential_centers[idx] | |
| c_x, c_y, c_z = center | |
| x_start = max(0, c_x - half_size[0] + margin[0]) | |
| x_end = min(mask.shape[0], c_x + end_offset[0] - margin[0]) # Use end_offset for odd sizes | |
| y_start = max(0, c_y - half_size[1] + margin[1]) | |
| y_end = min(mask.shape[1], c_y + end_offset[1] - margin[1]) | |
| z_start = max(0, c_z - half_size[2] + margin[2]) | |
| z_end = min(mask.shape[2], c_z + end_offset[2] - margin[2]) | |
| num_covered = uncovered[ | |
| x_start:x_end, | |
| y_start:y_end, | |
| z_start:z_end | |
| ].sum().item() | |
| if num_covered > best_covered: | |
| best_covered = num_covered | |
| best_center = idx | |
| best_bounds = (x_start, x_end, y_start, y_end, z_start, z_end) | |
| idx += 1 | |
| # If no new voxels are covered, stop | |
| if best_covered == 0: | |
| break | |
| # Add the best bounding box | |
| c_x, c_y, c_z = [i.item() for i in potential_centers[best_center]] | |
| bboxes.append([ | |
| [c_x - half_size[0], c_x + end_offset[0]], | |
| [c_y - half_size[1], c_y + end_offset[1]], | |
| [c_z - half_size[2], c_z + end_offset[2]], | |
| ]) | |
| # Mark voxels as covered, respecting the margin | |
| x_s, x_e, y_s, y_e, z_s, z_e = best_bounds | |
| uncovered[ | |
| x_s: x_e, | |
| y_s: y_e, | |
| z_s: z_e, | |
| ] = 0 | |
| # Remove the used center from potential_centers | |
| potential_centers = potential_centers[uncovered[tuple(potential_centers.T)] > 0] | |
| # Step 5: Recursively cover remaining voxels using uncovered as the mask | |
| if uncovered.any(): | |
| if uncovered.sum() < np.prod([i // 3 for i in bbox_size]): | |
| # print('random fallback') | |
| bboxes.extend(random_sampling_fallback(uncovered, bbox_size, margin, 25)) | |
| else: | |
| remaining_bboxes = generate_bounding_boxes( | |
| uncovered, bbox_size, [max(1, s // 2) for s in stride], margin, max_depth, current_depth + 1 | |
| ) | |
| bboxes.extend(remaining_bboxes) | |
| return bboxes | |
| def random_sampling_fallback(mask: torch.Tensor, bbox_size=(192, 192, 192), margin=(10, 10, 10), n_samples: int = 25): | |
| half_size = [bs // 2 for bs in bbox_size] | |
| # Adjust end offsets to ensure full bbox_size (handles odd sizes) | |
| end_offset = [bs - hs for bs, hs in zip(bbox_size, half_size)] # e.g., 193 - 96 = 97 | |
| bboxes = [] | |
| while mask.any(): | |
| indices = torch.nonzero(mask) # nx3 | |
| best_center = None | |
| best_covered = 0 | |
| best_bounds = None | |
| # Find the center that covers the most uncovered voxels | |
| for i in range(n_samples): | |
| idx = np.random.choice(len(indices)) | |
| center = indices[idx] | |
| c_x, c_y, c_z = center | |
| x_start = max(0, c_x - half_size[0] + margin[0]) | |
| x_end = min(mask.shape[0], c_x + end_offset[0] - margin[0]) # Use end_offset for odd sizes | |
| y_start = max(0, c_y - half_size[1] + margin[1]) | |
| y_end = min(mask.shape[1], c_y + end_offset[1] - margin[1]) | |
| z_start = max(0, c_z - half_size[2] + margin[2]) | |
| z_end = min(mask.shape[2], c_z + end_offset[2] - margin[2]) | |
| num_covered = mask[ | |
| x_start:x_end, | |
| y_start:y_end, | |
| z_start:z_end | |
| ].sum().item() | |
| if num_covered > best_covered: | |
| best_covered = num_covered | |
| best_center = center | |
| best_bounds = (x_start, x_end, y_start, y_end, z_start, z_end) | |
| # Add the best bounding box | |
| c_x, c_y, c_z = best_center | |
| bboxes.append([ | |
| [c_x - half_size[0], c_x + end_offset[0]], | |
| [c_y - half_size[1], c_y + end_offset[1]], | |
| [c_z - half_size[2], c_z + end_offset[2]], | |
| ]) | |
| # Mark voxels as covered, respecting the margin | |
| x_s, x_e, y_s, y_e, z_s, z_e = best_bounds | |
| mask[ | |
| x_s: x_e, | |
| y_s: y_e, | |
| z_s: z_e, | |
| ] = 0 | |
| return bboxes | |
| if __name__ == '__main__': | |
| times = [] | |
| torch.set_num_threads(8) | |
| for _ in range(1): | |
| st = time() | |
| mask = torch.zeros((256, 256, 256), dtype=torch.uint8, device=0) | |
| mask[50:150, 50:150, 50:150] = 1 # A cubic object | |
| # Generate bounding boxes with an odd size to test | |
| bboxes = random_sampling_fallback( | |
| mask, | |
| bbox_size=(193, 193, 193), # Odd size | |
| stride='auto', | |
| margin=(10, 10, 10) | |
| ) | |
| # Print results | |
| print(f"Number of bounding boxes: {len(bboxes)}") | |
| end = time() | |
| times.append(end - st) | |
| print(times) | |