| | |
| | import logging |
| |
|
| | import torch |
| |
|
| | try: |
| | from cc_torch import get_connected_components |
| |
|
| | HAS_CC_TORCH = True |
| | except ImportError: |
| | logging.debug( |
| | "cc_torch not found. Consider installing for better performance. Command line:" |
| | " pip install git+https://github.com/ronghanghu/cc_torch.git" |
| | ) |
| | HAS_CC_TORCH = False |
| |
|
| |
|
| | def connected_components_cpu_single(values: torch.Tensor): |
| | assert values.dim() == 2 |
| | from skimage.measure import label |
| |
|
| | labels, num = label(values.cpu().numpy(), return_num=True) |
| | labels = torch.from_numpy(labels) |
| | counts = torch.zeros_like(labels) |
| | for i in range(1, num + 1): |
| | cur_mask = labels == i |
| | cur_count = cur_mask.sum() |
| | counts[cur_mask] = cur_count |
| | return labels, counts |
| |
|
| |
|
| | def connected_components_cpu(input_tensor: torch.Tensor): |
| | out_shape = input_tensor.shape |
| | if input_tensor.dim() == 4 and input_tensor.shape[1] == 1: |
| | input_tensor = input_tensor.squeeze(1) |
| | else: |
| | assert ( |
| | input_tensor.dim() == 3 |
| | ), "Input tensor must be (B, H, W) or (B, 1, H, W)." |
| |
|
| | batch_size = input_tensor.shape[0] |
| | labels_list = [] |
| | counts_list = [] |
| | for b in range(batch_size): |
| | labels, counts = connected_components_cpu_single(input_tensor[b]) |
| | labels_list.append(labels) |
| | counts_list.append(counts) |
| | labels_tensor = torch.stack(labels_list, dim=0).to(input_tensor.device) |
| | counts_tensor = torch.stack(counts_list, dim=0).to(input_tensor.device) |
| | return labels_tensor.view(out_shape), counts_tensor.view(out_shape) |
| |
|
| |
|
| | def connected_components(input_tensor: torch.Tensor): |
| | """ |
| | Computes connected components labeling on a batch of 2D tensors, using the best available backend. |
| | |
| | Args: |
| | input_tensor (torch.Tensor): A BxHxW integer tensor or Bx1xHxW. Non-zero values are considered foreground. Bool tensor also accepted |
| | |
| | Returns: |
| | Tuple[torch.Tensor, torch.Tensor]: Both tensors have the same shape as input_tensor. |
| | - A tensor with dense labels. Background is 0. |
| | - A tensor with the size of the connected component for each pixel. |
| | """ |
| | if input_tensor.dim() == 3: |
| | input_tensor = input_tensor.unsqueeze(1) |
| |
|
| | assert ( |
| | input_tensor.dim() == 4 and input_tensor.shape[1] == 1 |
| | ), "Input tensor must be (B, H, W) or (B, 1, H, W)." |
| |
|
| | if input_tensor.is_cuda: |
| | if HAS_CC_TORCH: |
| | return get_connected_components(input_tensor.to(torch.uint8)) |
| | else: |
| | |
| | from sam3.perflib.triton.connected_components import ( |
| | connected_components_triton, |
| | ) |
| |
|
| | return connected_components_triton(input_tensor) |
| |
|
| | |
| | return connected_components_cpu(input_tensor) |
| |
|